# CIFAR10

In [None]:
from torch.utils.data import DataLoader, random_split
from torchinfo import summary
from torchvision import datasets, transforms
import gc
import lightning as L
import lightning.pytorch.callbacks as callbacks
import lightning.pytorch.loggers as loggers
import math
import matplotlib.colors as mcolors
import matplotlib.pyplot as plt
import os
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as nnf

from model import SwinTransformer2D, SwinTransformerConfig2D
from notebooks.utils.configs import MODEL_CONFIGS
from notebooks.utils.reference import SwinTransformerReference2D

gc.collect()
torch.cuda.empty_cache()

## Directories

In [None]:
BASE_DIR = "./notebooks/cifar10"
DATA_DIR = os.path.join(BASE_DIR, "data")
LOGS_DIR = os.path.join(BASE_DIR, "logs")
CKPT_DIR = os.path.join(BASE_DIR, "ckpts")

os.makedirs(BASE_DIR, exist_ok=True)
os.makedirs(DATA_DIR, exist_ok=True)
os.makedirs(LOGS_DIR, exist_ok=True)
os.makedirs(CKPT_DIR, exist_ok=True)

## Hyperparameters

In [None]:
batch_size = 128
num_epochs = 100
learning_rate = 3e-4
weight_decay = 0.05

## Dataset setup

In [None]:
%%script echo "SKIP"
torch.manual_seed(0)

train_dataset = datasets.CIFAR10(root=DATA_DIR, train=True, download=True, transform=transforms.ToTensor())
test_dataset = datasets.CIFAR10(root=DATA_DIR, train=False, download=True, transform=transforms.ToTensor())
train_dataset, val_dataset = random_split(train_dataset, [45000, 5000])

train_data = torch.stack([img_t for img_t, _ in train_dataset], dim=3)
val_data = torch.stack([img_t for img_t, _ in val_dataset], dim=3)
test_data = torch.stack([img_t for img_t, _ in test_dataset], dim=3)

train_mean, train_std = train_data.view(3, -1).mean(dim=1), train_data.view(3, -1).std(dim=1)
val_mean, val_std = val_data.view(3, -1).mean(dim=1), val_data.view(3, -1).std(dim=1)
test_mean, test_std = test_data.view(3, -1).mean(dim=1), test_data.view(3, -1).std(dim=1)

print(f"Train mean: {train_mean}, std: {train_std}")
print(f"Val mean: {val_mean}, std: {val_std}")
print(f"Test mean: {test_mean}, std: {test_std}")

In [None]:
LABELS = [
    "airplane",
    "automobile",
    "bird",
    "cat",
    "deer",
    "dog",
    "frog",
    "horse",
    "ship",
    "truck",
]

In [None]:
torch.manual_seed(0)

train_mean = (0.4912, 0.4820, 0.4464)
train_std = (0.2472, 0.2436, 0.2617)

transform_train = transforms.Compose(
    [
        transforms.ToTensor(),
        transforms.Normalize(train_mean, train_std),
        transforms.RandomPerspective(distortion_scale=0.1, p=0.5),
        transforms.RandomRotation(15),
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
    ]
)

train_dataset = datasets.CIFAR10(root=DATA_DIR, train=True, download=True, transform=transform_train)
train_dataset, _ = random_split(train_dataset, [45000, 5000])
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=8)


val_mean = (0.4932, 0.4836, 0.4474)
val_std = (0.2459, 0.2422, 0.2608)

transform_val = transforms.Compose(
    [
        transforms.ToTensor(),
        transforms.Normalize(val_mean, val_std),
    ]
)
val_dataset = datasets.CIFAR10(root=DATA_DIR, train=True, download=True, transform=transform_val)
_, val_dataset = random_split(val_dataset, [45000, 5000])
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=8)


test_mean = (0.4942, 0.4851, 0.4504)
test_std = (0.2467, 0.2429, 0.2616)    

transform_test = transforms.Compose(
    [
        transforms.ToTensor(),
        transforms.Normalize(test_mean, test_std),
    ]
)

test_dataset = datasets.CIFAR10(root=DATA_DIR, train=False, download=True, transform=transform_test)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=8)

## Model setup

In [None]:
class SwinModel(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.swin = SwinTransformer2D(config)
        self.norm = nn.LayerNorm(self.swin.out_channels[-1])
        self.avgpool = nn.AdaptiveAvgPool1d(1)
        self.flatten = nn.Flatten(1)
        self.fc = nn.Linear(self.swin.out_channels[-1], 10)
        self.config = config

    def forward(self, x):
        out = self.swin(x)
        x = out[-1]
        x = self.norm(x)
        x = self.avgpool(x.transpose(1, 2))
        x = self.flatten(x)
        x = self.fc(x)
        return x

In [None]:
class SwinTransformerClf(L.LightningModule):
    def __init__(
        self, name="default", config=None, max_epochs=100, steps_per_epoch=300, learning_rate=3e-4, weight_decay=0.05
    ):
        super().__init__()
        self.learning_rate = learning_rate
        self.weight_decay = weight_decay
        self.max_epochs = max_epochs
        self.steps_per_epoch = steps_per_epoch
        self.criterion = nn.CrossEntropyLoss(label_smoothing=0.1)
        self.name = name
        self.config = config

        if config is None:
            self.model = nn.Identity()
        else:
            try:
                config = SwinTransformerConfig2D(**config) if isinstance(config, dict) else config
                self.model = SwinModel(config)
            except Exception:
                self.model = SwinTransformerReference2D(**config)

        self.save_hyperparameters(ignore="model")

    def forward(self, x):
        return nnf.softmax(self.model(x), dim=1)

    def training_step(self, batch, _):
        x, y = batch
        y_hat = self(x)
        loss = self.criterion(y_hat, y)
        acc = (y == y_hat.argmax(dim=1)).float().mean()
        self.log_dict(
            {"train_loss": loss, "train_acc": acc},
            on_step=True,
            on_epoch=True,
            prog_bar=True,
            sync_dist=True,
        )
        return loss

    def validation_step(self, batch, _):
        x, y = batch
        y_hat = self(x)
        loss = self.criterion(y_hat, y)
        acc = (y == y_hat.argmax(dim=1)).float().mean()
        self.log_dict(
            {"val_loss": loss, "val_acc": acc},
            on_step=False,
            on_epoch=True,
            prog_bar=True,
            sync_dist=True,
        )
        return loss

    def test_step(self, batch, _):
        x, y = batch
        y_hat = self(x)
        loss = self.criterion(y_hat, y)
        acc = (y == y_hat.argmax(dim=1)).float().mean()
        self.log_dict(
            {"test_loss": loss, "test_acc": acc},
            on_step=False,
            on_epoch=True,
            prog_bar=True,
            sync_dist=True,
        )
        return loss

    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(self.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay)
        scheduler = torch.optim.lr_scheduler.OneCycleLR(
            optimizer, max_lr=self.learning_rate, epochs=self.max_epochs, steps_per_epoch=self.steps_per_epoch
        )
        return {"optimizer": optimizer, "lr_scheduler": {"scheduler": scheduler, "interval": "step"}}

In [None]:
models = {
    name: SwinTransformerClf(name, config, num_epochs, len(train_loader), learning_rate, weight_decay)
    for name, config in MODEL_CONFIGS.items()
}

for clf in models:
    print(summary(models[clf], (batch_size, 3, 32, 32)))

## Model training

In [None]:
for name, clf in models.items():
    print("=" * 80, flush=True)
    print(f"Training model {name}", flush=True)
    print()

    gc.collect()
    torch.cuda.empty_cache()

    L.seed_everything(42)

    csv_logger = loggers.CSVLogger(LOGS_DIR, name=clf.name)
    learning_rate_monitor = callbacks.LearningRateMonitor(logging_interval="epoch")
    model_checkpoint = callbacks.ModelCheckpoint(
        monitor="val_acc",
        dirpath=CKPT_DIR,
        filename=clf.name,
        save_top_k=1,
        mode="max",
    )

    trainer = L.Trainer(
        max_epochs=num_epochs,
        logger=csv_logger,
        callbacks=[learning_rate_monitor, model_checkpoint],
        gradient_clip_val=1.0,
        precision="16-mixed",
    )
    
    print("=" * 80, flush=True)
    
    trainer.fit(model=clf, train_dataloaders=train_loader, val_dataloaders=val_loader)
    trainer.test(model=clf, dataloaders=test_loader)

## Analysis

In [None]:
ref_dataset = datasets.CIFAR10(root=DATA_DIR, train=False, download=True, transform=transforms.ToTensor())
tra_dataset = datasets.CIFAR10(root=DATA_DIR, train=False, download=True, transform=transform_test)

In [None]:
model_logs = [os.path.join(LOGS_DIR, f) for f in os.listdir(LOGS_DIR)]
model_logs = [[os.path.join(f, v) for v in os.listdir(f)] for f in model_logs]
model_logs = [max(v, key=os.path.getctime) for v in model_logs]
model_logs = [os.path.join(v, "metrics.csv") for v in model_logs]
model_logs.sort()

ckpt_files = os.listdir(CKPT_DIR)
ckpt_files = [f for f in ckpt_files if f.endswith(".ckpt")]

models = {}
for ckpt in ckpt_files:
    try:
        model_name = ckpt.split(".")[0]
        clf = SwinTransformerClf.load_from_checkpoint(os.path.join(CKPT_DIR, ckpt))
        models[model_name] = clf
    except Exception as e:
        print(e)
models = {k: v for k, v in sorted(models.items(), key=lambda item: item[1].name)}

model_logs, [m.name for m in models.values()] 

In [None]:
fig = plt.figure(constrained_layout=True, figsize=(12, 12), dpi=100)
fig.suptitle("Training, Validation, and Test Accuracy over Epochs")

gs = fig.add_gridspec(2, 2)
ax_acc = fig.add_subplot(gs[0, 0])
ax_loss = fig.add_subplot(gs[0, 1])
ax_test = fig.add_subplot(gs[1, :])

max_acc = 0
for i, (logs, clf) in enumerate(zip(model_logs, models.values())):
    df = pd.read_csv(logs)
    col = list(mcolors.BASE_COLORS.keys())[i]

    if "train_acc_epoch" in df.columns:
        train_data = df[df["train_acc_epoch"].notna()]
        val_data = df[df["val_acc"].notna()]
        ax_acc.plot(
            train_data["epoch"],
            train_data["train_acc_epoch"],
            label=f"Training Accuracy {clf.name}",
            color=col,
            linestyle="--",
        )
        ax_acc.plot(
            val_data["epoch"],
            val_data["val_acc"],
            label=f"Validation Accuracy {clf.name}",
            color=col,
            linestyle="-",
        )
    ax_acc.set_xlabel("Epoch")
    ax_acc.set_ylabel("Accuracy")
    ax_acc.legend()
    ax_acc.grid(True, linestyle="--", alpha=0.7)

    if "train_loss_epoch" in df.columns:
        train_data = df[df["train_loss_epoch"].notna()]
        val_data = df[df["val_loss"].notna()]
        ax_loss.plot(
            train_data["epoch"],
            train_data["train_loss_epoch"],
            label=f"Training Loss {clf.name}",
            color=col,
            linestyle="--",
        )
        ax_loss.plot(
            val_data["epoch"],
            val_data["val_loss"],
            label=f"Validation Loss {clf.name}",
            color=col,
            linestyle="-",
        )
    ax_loss.set_xlabel("Epoch")
    ax_loss.set_ylabel("Loss")
    ax_loss.legend()
    ax_loss.grid(True, linestyle="--", alpha=0.7)

    if "test_acc" in df.columns:
        test_data = df[df["test_acc"].notna()]
        ax_test.barh(i, test_data["test_acc"], label=f"Test Accuracy {clf.name}")
        ax_test.set_xlabel("Epoch")
        ax_test.set_ylabel("Accuracy")
        ax_test.legend()
        ax_test.grid(True, linestyle="--", alpha=0.7)
        if test_data["test_acc"].max() > max_acc:
            max_acc = test_data["test_acc"].max()

ax_test.axvline(max_acc, color="black", linestyle="--")
ax_test.text(max_acc + 0.005, i, f"{max_acc:.2f}", va="center", ha="left")

plt.show()
plt.close()

In [None]:
idx = 0

image = ref_dataset[idx][0].permute(1, 2, 0)
label = ref_dataset[idx][1]

plt.imshow(image)
plt.title(LABELS[label])
plt.axis("off")
plt.show()
plt.close()

image_model = tra_dataset[idx][0]
for clf in models.values():
    clf.eval()
    with torch.no_grad():
        y_hat = clf(image_model.unsqueeze(0).cuda())

In [None]:
for name, clf in models.items():
    if isinstance(clf.model, SwinModel):
        print(name)
        rel_embed_table = clf.model.swin.stages[0].blocks[0].attn.attn_weights.sum(dim=1)
        count = rel_embed_table.shape[1]
        split = int(math.sqrt(count))

        fig = plt.figure(constrained_layout=True, figsize=(12, 12), dpi=100)
        fig.suptitle("Attention Weights")

        gs = fig.add_gridspec(split, split)
        for i in range(count):
            ax = fig.add_subplot(gs[i // split, i % split])
            ax.set_title(f"Window {i}")
            ax.imshow(rel_embed_table[i].detach().cpu().numpy(), cmap="hot", interpolation="nearest")
            ax.axis("off")

        plt.show()

In [None]:
for name, clf in models.items():
    if isinstance(clf.model, SwinModel):
        print(name, flush=True)

        cnt = 0
        if clf.model.swin.stages[0].blocks[0].attn.bias_mode:
            rel_embed_table = clf.model.swin.stages[0].blocks[0].attn.embedding_table
            factor = 2
            cnt = 1
        elif clf.model.swin.stages[0].blocks[0].attn.context_mode:
            rel_embed_table_q = clf.model.swin.stages[0].blocks[0].attn.embedding_table_q
            rel_embed_table_k = clf.model.swin.stages[0].blocks[0].attn.embedding_table_k
            factor = 10
            cnt = 2

        if cnt == 0:
            continue

        ratio = rel_embed_table.weight.shape[0] / rel_embed_table.weight.shape[1]

        fig = plt.figure(constrained_layout=True, figsize=(ratio * factor, factor), dpi=100)
        fig.suptitle("Relative Positional Embeddings")
        gs = fig.add_gridspec(cnt, 1)

        if cnt == 1:
            ax = fig.add_subplot(gs[0, 0])
            ax.imshow(rel_embed_table.weight.transpose(0,1).detach().cpu().numpy(), cmap="hot", interpolation="nearest")
            ax.set_xticks([])
            ax.set_yticks([])
            ax.set_ylabel("Embedding Dimension")
            ax.set_xlabel("Relative Position")

        elif cnt == 2:
            ax = fig.add_subplot(gs[0, 0])
            ax.imshow(rel_embed_table_q.weight.detach().cpu().numpy(), cmap="hot", interpolation="nearest")
            ax.set_xticks([])
            ax.set_yticks([])
            ax.set_xlabel("Embedding Dimension")
            ax.set_ylabel("Relative Position")
            
            ax = fig.add_subplot(gs[1, 0])
            ax.imshow(rel_embed_table_k.weight.detach().cpu().numpy(), cmap="hot", interpolation="nearest")
            ax.set_xticks([])
            ax.set_yticks([])
            ax.set_xlabel("Embedding Dimension")
            ax.set_ylabel("Relative Position")

        plt.show()
        plt.close()
