In [1]:
from models.image_classification.vanilla_vit import ViT
from models.image_classification.swin_transformer import SwinTransformer

# Loading Data
from utils.load_data import get_train_test_loaders
from utils.args import get_args

# PyTorch
import torch
import torch.nn as nn

# Visualization
from utils.visualization import plot_patches
from utils.visualization import plot_attention_maps
import matplotlib.pyplot as plt
import seaborn as sns

# Swin Transformer

In [2]:
train_loader, val_loader, test_loader = get_train_test_loaders(dataset_name="cifar100", batch_size=256, val_split=0.2, num_workers=4)

Files already downloaded and verified
Files already downloaded and verified


In [3]:
args = get_args("swin_tiny_cifar100")
swin_tiny = SwinTransformer(patch_size=args["patch_size"], embed_dim=args["embed_dim"], depths=args["depths"], 
                            num_heads=args["num_heads"], window_size=args["window_size"], mlp_ratio=args["mlp_ratio"],
                            dropout=args["dropout"], attention_dropout=args["attention_dropout"], stochastic_depth_prob=args["stochastic_depth_prob"], num_classes=args["num_classes"])


swin_tiny.to("cpu")
for images, labels in train_loader:
    images = images.to("cpu")
    labels = labels.to("cpu")
    outputs = swin_tiny(images)
    break

# Vanilla ViT

In [6]:
train_loader, val_loader, test_loader = get_train_test_loaders(dataset_name="cifar100", batch_size=256, val_split=0.2, num_workers=4)

Files already downloaded and verified
Files already downloaded and verified


In [8]:
args = get_args("vit_tiny_cifar100")
vanilla_vit = ViT(image_size=args["image_size"], patch_size=args["patch_size"], num_layers=args["num_layers"], 
                  num_heads=args["num_heads"], hidden_dim=args["hidden_dim"], mlp_dim=args["mlp_dim"], 
                  dropout=args["dropout"], attention_dropout=args["attention_dropout"], num_classes=args["num_classes"])

vanilla_vit.to("cpu")
for images, labels in train_loader:
    images = images.to("cpu")
    labels = labels.to("cpu")
    patches, embeddings, encoder_output, class_token_output, output = vanilla_vit(images)
    break