# CIFAR10

In [None]:
from torch.utils.data import DataLoader, random_split
from torchinfo import summary
from torchvision import datasets, transforms
import gc
import lightning as L
import lightning.pytorch.callbacks as callbacks
import lightning.pytorch.loggers as loggers
import matplotlib.colors as mcolors
import matplotlib.pyplot as plt
import os
import pandas as pd
import torch
import torch.nn as nn
import yaml
import math

gc.collect()
torch.cuda.empty_cache()

BASE_DIR = "./notebooks/cifar10"

In [None]:
from model import SwinTransformer2D, SwinTransformerConfig2D, PatchMode, RelativePositionalEmeddingMode
from notebooks.cifar10.reference import SwinTransformerReference2D

## Hyperparameters

In [None]:
batch_size = 128
num_epochs = 1
learning_rate = 3e-4
weight_decay = 0.05

## Data setup

In [None]:
transform_train = transforms.Compose(
    [
        transforms.RandomPerspective(distortion_scale=0.1, p=0.5),
        transforms.RandomRotation(15),
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ]
)
transform_test = transforms.Compose(
    [
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ]
)

In [None]:
train_dataset = datasets.CIFAR10(root=f"{BASE_DIR}/.data", train=True, download=True, transform=transform_train)
test_dataset = datasets.CIFAR10(root=f"{BASE_DIR}/.data", train=False, download=True, transform=transform_test)
train_dataset, val_dataset = random_split(train_dataset, [45000, 5000])

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=8)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=8)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=8)

## Model setup

In [None]:
class SwinTransformerClf(L.LightningModule):
    def __init__(self, model, name, max_epochs=100, steps_per_epoch=300, learning_rate=3e-4, weight_decay=0.05):
        super().__init__()
        self.learning_rate = learning_rate
        self.weight_decay = weight_decay
        self.max_epochs = max_epochs
        self.steps_per_epoch = steps_per_epoch
        self.criterion = nn.CrossEntropyLoss(label_smoothing=0.1)
        self.name = name
        if isinstance(model, SwinTransformerConfig2D):
            self.config = model 
            self.model = SwinTransformer2D(model)
        else:
            self.model = model
        self.save_hyperparameters(ignore="model")

    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, _):
        x, y = batch
        y_hat = self(x)
        loss = self.criterion(y_hat, y)
        acc = (y == y_hat.argmax(dim=1)).float().mean()
        self.log_dict(
            {"train_loss": loss, "train_acc": acc},
            on_step=True,
            on_epoch=True,
            prog_bar=True,
            sync_dist=True,
        )
        return loss

    def validation_step(self, batch, _):
        x, y = batch
        y_hat = self(x)
        loss = self.criterion(y_hat, y)
        acc = (y == y_hat.argmax(dim=1)).float().mean()
        self.log_dict(
            {"val_loss": loss, "val_acc": acc},
            on_step=False,
            on_epoch=True,
            prog_bar=True,
            sync_dist=True,
        )
        return loss

    def test_step(self, batch, _):
        x, y = batch
        y_hat = self(x)
        loss = self.criterion(y_hat, y)
        acc = (y == y_hat.argmax(dim=1)).float().mean()
        self.log_dict(
            {"test_loss": loss, "test_acc": acc},
            on_step=False,
            on_epoch=True,
            prog_bar=True,
            sync_dist=True,
        )
        return loss

    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(self.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay)
        scheduler = torch.optim.lr_scheduler.OneCycleLR(
            optimizer, max_lr=self.learning_rate, epochs=self.max_epochs, steps_per_epoch=self.steps_per_epoch
        )
        return {"optimizer": optimizer, "lr_scheduler": {"scheduler": scheduler, "interval": "step"}}

In [None]:
class SwinModel(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.swin = SwinTransformer2D(config)
        self.norm = nn.LayerNorm(self.swin.out_channels[-1])
        self.avgpool = nn.AdaptiveAvgPool1d(1)
        self.flatten = nn.Flatten(1)
        self.fc = nn.Linear(self.swin.out_channels[-1], 10)
        self.config = config

    def forward(self, x):
        out = self.swin(x)
        x = out[-1]
        x = self.norm(x)
        x = self.avgpool(x.transpose(1, 2))
        x = self.flatten(x)
        x = self.fc(x)
        return x

In [None]:
model_embedding_none = SwinModel(
    SwinTransformerConfig2D(
        input_size=(32, 32),
        in_channels=3,
        embed_dim=32,
        num_blocks=[2, 4, 4, 2],
        patch_window_size=[(2, 2)] * 4,
        block_window_size=([(4, 4)] * 3) + [(2, 2)],
        num_heads=[2, 4, 8, 16],
        drop_path=0.1,
        rpe_mode=RelativePositionalEmeddingMode.NONE,
    )
)

model_embedding_bias = SwinModel(
    SwinTransformerConfig2D(
        input_size=(32, 32),
        in_channels=3,
        embed_dim=32,
        num_blocks=[2, 4, 4, 2],
        patch_window_size=[(2, 2)] * 4,
        block_window_size=([(4, 4)] * 3) + [(2, 2)],
        num_heads=[2, 4, 8, 16],
        drop_path=0.1,
        rpe_mode=RelativePositionalEmeddingMode.BIAS,
    )
)

model_embedding_context = SwinModel(
    SwinTransformerConfig2D(
        input_size=(32, 32),
        in_channels=3,
        embed_dim=32,
        num_blocks=[2, 4, 4, 2],
        patch_window_size=[(2, 2)] * 4,
        block_window_size=([(4, 4)] * 3) + [(2, 2)],
        num_heads=[2, 4, 8, 16],
        drop_path=0.1,
        rpe_mode=RelativePositionalEmeddingMode.CONTEXT,
    )
)

model_merge_convolution = SwinModel(
    SwinTransformerConfig2D(
        input_size=(32, 32),
        in_channels=3,
        embed_dim=32,
        num_blocks=[2, 4, 4, 2],
        patch_window_size=[(2, 2)] * 4,
        block_window_size=([(4, 4)] * 3) + [(2, 2)],
        num_heads=[2, 4, 8, 16],
        drop_path=0.1,
        patch_mode=[PatchMode.CONCATENATE] * 4,
    )
)

model_merge_convolution = SwinModel(
    SwinTransformerConfig2D(
        input_size=(32, 32),
        in_channels=3,
        embed_dim=32,
        num_blocks=[2, 4, 4, 2],
        patch_window_size=[(2, 2)] * 4,
        block_window_size=([(4, 4)] * 3) + [(2, 2)],
        num_heads=[2, 4, 8, 16],
        drop_path=0.1,
        patch_mode=[PatchMode.CONVOLUTION] * 4,
    )
)

model_odd_windows = SwinModel(
    SwinTransformerConfig2D(
        input_size=(32, 32),
        in_channels=3,
        embed_dim=32,
        num_blocks=[2, 4, 4, 2],
        patch_window_size=[(2, 2)] * 4,
        block_window_size=[(3, 3)] * 4,
        num_heads=[2, 4, 8, 16],
        drop_path=0.1,
    )
)

model_reference = SwinTransformerReference2D(
    num_classes=10,
    img_size=32,
    in_chans=3,
    embed_dim=32,
    depths=[2, 4, 4, 2],
    patch_size=2,
    window_size=4,
    num_heads=[2, 4, 8, 16],
)

models = {
    "odd_windows": model_odd_windows,
    "embedding_none": model_embedding_none,
    "embedding_bias": model_embedding_bias,
    "embedding_context": model_embedding_context,
    "merge_concatenate": model_merge_convolution,
    "merge_convolution": model_merge_convolution,
    "reference": model_reference,
}

In [None]:
model_list = [
    SwinTransformerClf(model, name, num_epochs, len(train_loader), learning_rate, weight_decay)
    for name, model in models.items()
]
for m in model_list:
    print(summary(m, (batch_size, 3, 32, 32))) 

## Model training

In [None]:
for model in model_list:
    gc.collect()
    torch.cuda.empty_cache()

    L.seed_everything(42)

    csv_logger = loggers.CSVLogger(f"{BASE_DIR}/logs", name=model.name)
    learning_rate_monitor = callbacks.LearningRateMonitor(logging_interval="epoch")
    model_checkpoint = callbacks.ModelCheckpoint(
        monitor="val_acc",
        dirpath=f"{BASE_DIR}/checkpoints",
        filename=model.name,
        save_top_k=1,
        mode="max",
    )

    trainer = L.Trainer(
        max_epochs=num_epochs,
        logger=csv_logger,
        callbacks=[learning_rate_monitor, model_checkpoint],
        gradient_clip_val=1.0,
        precision="16-mixed",
    )
    trainer.fit(model=model, train_dataloaders=train_loader, val_dataloaders=val_loader)
    trainer.test(model=model, dataloaders=test_loader)

## Analysis

In [None]:
ckpt_path = f"{BASE_DIR}/checkpoints"
log_path = f"{BASE_DIR}/logs"

In [None]:
model_logs = [os.path.join(log_path, f) for f in os.listdir(log_path)]
model_versions = [[os.path.join(f, v) for v in os.listdir(f)] for f in model_logs]
model_versions_latest = [max(v, key=os.path.getctime) for v in model_versions]
model_csvs = [os.path.join(v, "metrics.csv") for v in model_versions_latest]
model_hparams = [os.path.join(v, "hparams.yaml") for v in model_versions_latest]

In [None]:
ckpt_files = os.listdir(ckpt_path)
ckpt_files = [f for f in ckpt_files if f.endswith(".ckpt")]

for ckpt in ckpt_files:
    model_name = ckpt.split(".")[0]
    model = SwinModel(None)
    model = SwinTransformerClf.load_from_checkpoint(f"{ckpt_path}/{ckpt}", model)

In [None]:


fig = plt.figure(constrained_layout=True, figsize=(12, 12), dpi=100)
fig.suptitle("Training, Validation, and Test Accuracy over Epochs")

gs = fig.add_gridspec(2, 2)
ax_acc = fig.add_subplot(gs[0, 0])
ax_loss = fig.add_subplot(gs[0, 1])
ax_test = fig.add_subplot(gs[1, :])

for i, (csv, hparams) in enumerate(zip(model_csvs, model_hparams)):
    with open(hparams, "r") as file:
        hparams = yaml.safe_load(file)

    print("Hyperparameters:")
    for key, value in hparams.items():
        print(f"   {key}: {value}")

    df = pd.read_csv(csv)
    df["epoch"] = df["step"] // hparams["steps_per_epoch"]
    col = list(mcolors.BASE_COLORS.keys())[i]

    train_data = df[df["train_acc_epoch"].notna()]
    val_data = df[df["val_acc"].notna()]
    ax_acc.plot(
        train_data["epoch"],
        train_data["train_acc_epoch"],
        label=f"Training Accuracy {hparams['name']}",
        color=col,
        linestyle="--",
    )
    ax_acc.plot(
        val_data["epoch"],
        val_data["val_acc"],
        label=f"Validation Accuracy {hparams['name']}",
        color=col,
        linestyle="-",
    )
    ax_acc.set_xlabel("Epoch")
    ax_acc.set_ylabel("Accuracy")
    ax_acc.legend()
    ax_acc.grid(True, linestyle="--", alpha=0.7)

    train_data = df[df["train_loss_epoch"].notna()]
    val_data = df[df["val_loss"].notna()]
    ax_loss.plot(
        train_data["epoch"],
        train_data["train_loss_epoch"],
        label=f"Training Loss {hparams['name']}",
        color=col,
        linestyle="--",
    )
    ax_loss.plot(
        val_data["epoch"],
        val_data["val_loss"],
        label=f"Validation Loss {hparams['name']}",
        color=col,
        linestyle="-",
    )
    ax_loss.set_xlabel("Epoch")
    ax_loss.set_ylabel("Loss")
    ax_loss.legend()
    ax_loss.grid(True, linestyle="--", alpha=0.7)

    test_data = df[df["test_acc"].notna()]

    ax_test.barh(i, test_data["test_acc"], label=f"Test Accuracy {hparams['name']}")
    ax_test.set_xlabel("Epoch")
    ax_test.set_ylabel("Accuracy")
    ax_test.legend()
    ax_test.grid(True, linestyle="--", alpha=0.7)
    ax_test.axvline(test_data["test_acc"].max(), color="black", linestyle="--")
    ax_test.text(
        test_data["test_acc"].max() + 0.01,
        i,
        f"{test_data['test_acc'].max():.2f}",
        va="center",
        ha="left",
    )

plt.show()
plt.close()

In [None]:

for m in model_list:
    print(m.name)
    attn_weights = m.model.swin.stages[-1].blocks[-1].attn.attn_weights.mean(dim=1)
    count = attn_weights.shape[0]
    split = int(math.sqrt(count))

    fig = plt.figure(constrained_layout=True, figsize=(12, 12), dpi=100)
    fig.suptitle("Attention Weights")

    gs = fig.add_gridspec(split, split)

    for i in range(count):
        ax = fig.add_subplot(gs[i // split, i % split])
        ax.imshow(attn_weights[i].detach().cpu().numpy(), cmap="hot", interpolation="nearest")
        ax.axis("off")

    plt.show()