In [None]:
import os
import math
import wandb
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision as tv
import numpy as np
import matplotlib.pyplot as plt
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from sklearn.decomposition import PCA
from sklearn.metrics import precision_score, recall_score, f1_score

wandb.login()

In [2]:
class GLUActivation(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return 0.5 * x * (1.0 + torch.erf(x / math.sqrt(2.0)))

In [3]:
class multiHeadAttention(nn.Module):
    def __init__(self, hidden_size, num_attention_heads, dropout_rate, qkv_bias=True):
        super().__init__()
        self.num_attention_heads = num_attention_heads
        self.attention_head_size = hidden_size // num_attention_heads
        self.all_head_size = hidden_size

        self.qkv = nn.Linear(hidden_size, self.all_head_size * 3, bias=qkv_bias)
        self.attn_dropout = nn.Dropout(dropout_rate)

        self.out_proj = nn.Linear(hidden_size, hidden_size)
        self.proj_dropout = nn.Dropout(dropout_rate)

    def forward(self, query, key, value):
        B, N, _ = query.shape
        qkv = (self.qkv(query)
               .reshape(B, N, 3, self.num_attention_heads, self.attention_head_size)
               .permute(2, 0, 3, 1, 4))
        q, k, v = qkv[0], qkv[1], qkv[2]

        attn_scores = (q @ k.transpose(-2, -1)) / math.sqrt(self.attention_head_size)
        attn_probs = F.softmax(attn_scores, dim=-1)
        attn_probs = self.attn_dropout(attn_probs)

        context = (attn_probs @ v).transpose(1, 2).reshape(B, N, self.all_head_size)
        out = self.out_proj(context)
        out = self.proj_dropout(out)
        return out, attn_probs

In [4]:
class MLP(nn.Module):
    def __init__(self, hidden_size, intermediate_size, dropout_rate):
        super().__init__()
        self.fc1 = nn.Linear(hidden_size, intermediate_size)
        self.activation = GLUActivation()
        self.dropout1 = nn.Dropout(dropout_rate)
        self.fc2 = nn.Linear(intermediate_size, hidden_size)
        self.dropout2 = nn.Dropout(dropout_rate)

    def forward(self, x):
        x = self.fc1(x)
        x = self.activation(x)
        x = self.dropout1(x)
        x = self.fc2(x)
        x = self.dropout2(x)
        return x

In [5]:
class TransformerEncoderBlock(nn.Module):
    def __init__(self, hidden_size, num_attention_heads, intermediate_size, dropout_rate, qkv_bias=True):
        super().__init__()
        self.norm1 = nn.LayerNorm(hidden_size, eps=1e-6)
        self.attn = multiHeadAttention(hidden_size, num_attention_heads, dropout_rate, qkv_bias=qkv_bias)
        self.norm2 = nn.LayerNorm(hidden_size, eps=1e-6)
        self.mlp = MLP(hidden_size, intermediate_size, dropout_rate)

    def forward(self, x):
        h = x
        x_norm = self.norm1(x)
        attn_out, _ = self.attn(x_norm, x_norm, x_norm)
        x = h + attn_out

        h2 = x
        x_norm2 = self.norm2(x)
        mlp_out = self.mlp(x_norm2)
        x = h2 + mlp_out
        return x

In [6]:
class VisionTransformer(nn.Module):
    def __init__(self, image_size=32, patch_size=4, num_channels=3, num_classes=10,
                 hidden_size=48, num_hidden_layers=4, num_attention_heads=4,
                 intermediate_size=192, dropout_rate=0.1, qkv_bias=True):
        super().__init__()
        assert image_size % patch_size == 0, "Image must be divisible by patch size!"
        num_patches = (image_size // patch_size) ** 2

        self.projection = nn.Conv2d(in_channels=num_channels, out_channels=hidden_size,
                                    kernel_size=patch_size, stride=patch_size)

        self.cls_token = nn.Parameter(torch.zeros(1, 1, hidden_size))
        self.position_embeddings = nn.Parameter(torch.zeros(1, num_patches + 1, hidden_size))
        self.dropout = nn.Dropout(dropout_rate)

        self.encoder_blocks = nn.ModuleList([
            TransformerEncoderBlock(
                hidden_size=hidden_size,
                num_attention_heads=num_attention_heads,
                intermediate_size=intermediate_size,
                dropout_rate=dropout_rate,
                qkv_bias=qkv_bias
            ) for _ in range(num_hidden_layers)
        ])

        self.norm = nn.LayerNorm(hidden_size, eps=1e-6)
        self.classifier = nn.Linear(hidden_size, num_classes)
        self._init_weights()

    def _init_weights(self):
        for module in self.modules():
            if isinstance(module, nn.Linear):
                nn.init.trunc_normal_(module.weight, std=0.02)
                if module.bias is not None:
                    nn.init.zeros_(module.bias)
            elif isinstance(module, nn.LayerNorm):
                nn.init.zeros_(module.bias)
                nn.init.ones_(module.weight)
        nn.init.trunc_normal_(self.position_embeddings, std=0.02)
        nn.init.trunc_normal_(self.cls_token, std=0.02)

    def forward(self, x):
        B = x.shape[0]
        x = self.projection(x)
        x = x.flatten(2).transpose(1, 2)

        cls_tokens = self.cls_token.expand(B, -1, -1)
        x = torch.cat((cls_tokens, x), dim=1)

        x = x + self.position_embeddings
        x = self.dropout(x)

        # pass through encoder blocks
        for block in self.encoder_blocks:
            x = block(x)

        x = self.norm(x)
        logits = self.classifier(x[:, 0])
        return logits

In [7]:
class CIFAR100Data:
    def __init__(self, batch_size=64, resize=(32, 32), root="./data"):
        self.batch_size = batch_size
        self.resize = resize
        self.root = root

        # CIFAR-100 normalization constants
        mean = (0.5071, 0.4867, 0.4408)
        std = (0.2675, 0.2565, 0.2761)

        # train transforms: augment
        self.train_transform = transforms.Compose([
            transforms.RandomCrop(resize[0], padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ColorJitter(0.1, 0.1, 0.1, 0.1),
            transforms.ToTensor(),
            transforms.Normalize(mean, std),
        ])
        # val transforms: only resize + normalize
        self.val_transform = transforms.Compose([
            transforms.Resize(resize),
            transforms.ToTensor(),
            transforms.Normalize(mean, std),
        ])

        self.train = tv.datasets.CIFAR100(
            root=self.root,
            train=True,
            transform=self.train_transform,
            download=True,
        )
        self.val = tv.datasets.CIFAR100(
            root=self.root,
            train=False,
            transform=self.val_transform,
            download=True,
        )
        self.classes = self.train.classes

    def get_dataloader(self, train=True):
        dataset = self.train if train else self.val
        return DataLoader(dataset, self.batch_size, shuffle=train, num_workers=2, pin_memory=True)

In [8]:
class FashionMNIST(nn.Module):

  def __init__(self, batch_size = 64, resize = (28, 28), root = './data'):
    super().__init__()
    self.batch_size = batch_size
    self.resize = resize
    self.root = root

    # data augmentation via color jitter and flip
    color_aug = tv.transforms.ColorJitter(brightness = 0.25, contrast = 0.25, saturation = 0.25, hue = 0.25)
    train_transform = transforms.Compose([
            transforms.Resize(resize),
            transforms.RandomHorizontalFlip(),
            color_aug,
            transforms.ToTensor()
    ])

    # no data augmentation for validation
    val_transform = transforms.Compose([
            transforms.Resize(resize),
            transforms.ToTensor()
    ])

    # access datasets within torchvision
    self.train = tv.datasets.FashionMNIST(root=self.root, train=True , transform=train_transform, download=True)
    self.val   = tv.datasets.FashionMNIST(root=self.root, train=False, transform=val_transform  , download=True)

  def text_labels(self, indices):
    labels = ['T-shirt', 'Trouser', 'Pullover', 'Dress', 'Coat', 'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot']

    return [labels[i] for i in indices]

  def get_dataloader(self, train):
    data = self.train if train else self.val

    # data-iterator reads mini-batch of data
    # key component for efficient performance
    # exploit high-performance cmputing to avoid slowing down training loop
    return torch.utils.data.DataLoader(data, self.batch_size, shuffle = train)

  def train_dataloader(self):
    return self.get_dataloader(train = True)

In [9]:
def visualize_attention_maps(
    model: VisionTransformer,
    images: torch.Tensor,
    attn_maps: list,
    patch_size: int,
    device: torch.device,
    n_heads_to_show: int = 4,
):
    """
    images: [B, C, H, W]
    attn_maps: list of tensors [B, heads, N, N]
    """
    model.eval()
    B, C, H, W = images.shape

    # pick first image
    img = images[0].cpu().permute(1,2,0).numpy()
    fig, axes = plt.subplots(1 + n_heads_to_show, 1 + n_heads_to_show, figsize=(8, 8))

    # show original image
    axes[0,0].imshow((img - img.min())/(img.max()-img.min()), interpolation='nearest')
    axes[0,0].set_title('Input')
    axes[0,0].axis('off')

    # for each head in first layer
    first_attn = attn_maps[0][0]  # [heads, N, N]
    num_patches = int((H//patch_size) * (W//patch_size))
    for h in range(min(n_heads_to_show, first_attn.shape[0])):
        attn_head = first_attn[h, 0, 1:]  # cls token -> patches
        attn_map = attn_head.reshape(H//patch_size, W//patch_size).detach().cpu().numpy()
        attn_map = attn_map / attn_map.max()

        # upsample
        attn_map = torch.tensor(attn_map).unsqueeze(0).unsqueeze(0)
        attn_map = F.interpolate(attn_map, size=(H, W), mode='bilinear', align_corners=False).squeeze().numpy()

        # overlay
        axes[0, h+1].imshow((img - img.min())/(img.max()-img.min()), interpolation='nearest')
        axes[0, h+1].imshow(attn_map, alpha=0.5, interpolation='nearest')
        axes[0, h+1].set_title(f'Head {h}')
        axes[0, h+1].axis('off')

    plt.tight_layout()
    plt.show()


In [10]:
def plot_embedding_filters(model, n_components=28):
    """
    1) RGB embedding filters (first 28 principal components)
    """
    # extract the patch-projection conv2d weights:
    # shape = [hidden_size, channels, patch_size, patch_size]
    w = model.projection.weight.data.cpu().numpy()
    n_filters, C, p, _ = w.shape

    # flatten each filter to a vector
    w_flat = w.reshape(n_filters, -1)  # [hidden, C*p*p]

    # run PCA on these filters and grab the top components
    max_comp = min(n_filters, w_flat.shape[1])
    n_comp   = min(n_components, max_comp)
    pca = PCA(n_components=n_comp)
    pca.fit(w_flat)
    pcs = pca.components_         # [n_components, C*p*p]
    pcs = pcs.reshape(n_comp, C, p, p)

    # plot in a grid (here 4×7)
    cols = 7
    rows = int(np.ceil(n_comp/cols))
    fig, axes = plt.subplots(rows, cols, figsize=(cols, rows))
    for i in range(n_comp):
        ax = axes[i//cols, i%cols]
        # normalize and transpose to H×W×C
        img = pcs[i].transpose(1,2,0)
        img = (img - img.min())/(img.max()-img.min())
        ax.imshow(img, interpolation='nearest')
        ax.axis('off')
    plt.suptitle('RGB embedding filters (first 28 principal components)', y=1.02)
    plt.tight_layout()
    plt.show()

def plot_position_embedding_similarity(model):
    """
    2) Position embedding similarity: one heatmap per patch, arranged in the patch grid.
    """
    # grab learnable position embeddings (drop the CLS token)
    pos = model.position_embeddings.data.cpu().numpy().squeeze(0)[1:,:]  # [N, hidden]
    N, _ = pos.shape
    grid = int(np.sqrt(N))

    # cosine similarity matrix
    norms = np.linalg.norm(pos, axis=1, keepdims=True)
    cos_sim = (pos @ pos.T) / (norms @ norms.T)   # [N, N]

    fig, axes = plt.subplots(grid, grid, figsize=(grid, grid))
    for i in range(N):
        ax = axes[i//grid, i%grid]
        tile = cos_sim[i].reshape(grid, grid)
        im = ax.imshow(tile, interpolation='nearest', vmin=-1, vmax=1)
        ax.axis('off')

    # shared colorbar
    fig.colorbar(im, ax=axes.ravel().tolist(), shrink=0.6)
    plt.suptitle('Position embedding similarity', y=1.02)
    plt.tight_layout()
    plt.show()

def plot_mean_attention_distance(attn_maps, patch_size, image_height, image_width):
    """
    3) Scatter of mean CLS→patch attention distance vs. layer depth.
       - attn_maps: list of length=L of tensors [B, heads, N, N]
       - patch_size: pixel size of one patch
       - image_height/width: original image H, W
    """
    # build patch‐center coords
    # N includes the CLS token, so index 0 is CLS; the remaining N-1 are patches
    _, heads, N, _ = attn_maps[0].shape
    P = int(np.sqrt(N-1))
    coords = []
    for y in range(P):
        for x in range(P):
            coords.append(np.array([(y+0.5)*patch_size, (x+0.5)*patch_size]))
    coords = np.stack(coords)   # [N-1, 2]

    # assume CLS token at the top-left (0,0) patch center
    cls_coord = np.array([patch_size/2, patch_size/2])

    xs, ys, cs = [], [], []
    for layer_idx, layer_attn in enumerate(attn_maps):
        # pick batch=0, then for each head take CLS→patch (row 0, cols 1:])
        arr = layer_attn[0].cpu().numpy()  # [heads, N, N]
        for h in range(heads):
            w = arr[h, 0, 1:]              # [N-1]
            dists = np.linalg.norm(coords - cls_coord, axis=1)
            mean_dist = (w * dists).sum() / w.sum()
            xs.append(layer_idx)
            ys.append(mean_dist)
            cs.append(h)

    fig, ax = plt.subplots()
    scatter = ax.scatter(xs, ys, c=cs)
    handles, labels = scatter.legend_elements()
    ax.legend(handles, [f"Head {i}" for i in range(heads)], title="Attention Heads")
    ax.set_xlabel("Network depth (layer)")
    ax.set_ylabel("Mean attention distance (pixels)")
    ax.set_title("ViT Mean CLS→patch attention distance")
    plt.tight_layout()
    plt.show()

def get_attention_maps(model: VisionTransformer, images: torch.Tensor):
  attn_maps = []
  hooks = []
  for block in model.encoder_blocks:
      # each block.attn returns (out, attn_probs)
      hooks.append(
          block.attn.register_forward_hook(
              lambda module, inp, out: attn_maps.append(out[1].detach().cpu())
          )
      )
  # run a single forward pass to fill attn_maps
  _ = model(images)
  # remove hooks
  for h in hooks:
      h.remove()
  return attn_maps

In [None]:
def run_experiment(
    dataset_name: str,
    patch_size: int,
    hidden_size: int,
    num_layers: int,
    num_heads: int,
    dropout_rate: float,
    lr: float,
    batch_size: int = 128,
    max_epochs: int = 25,
):
    # 2) initalizing wandb run
    wandb.init(
        project="ViT-Experiments-Graphs",
        name=f"{dataset_name}_ps{patch_size}_hs{hidden_size}_nl{num_layers}_lr{lr}",
        config={
            "dataset": dataset_name,
            "patch_size": patch_size,
            "hidden_size": hidden_size,
            "num_layers": num_layers,
            "num_heads": num_heads,
            "dropout_rate": dropout_rate,
            "lr": lr,
            "batch_size": batch_size,
            "max_epochs": max_epochs,
        }
    )
    config = wandb.config

    # prepping data, though resize is not really needed
    resize = (config.patch_size * int(32/config.patch_size),
              config.patch_size * int(32/config.patch_size))
    if config.dataset == "FashionMNIST":
        data_module = FashionMNIST(batch_size=config.batch_size,
                                   resize=resize)
        num_channels, num_classes = 1, 10
        train_loader = data_module.train_dataloader()
        val_loader   = data_module.get_dataloader(train=False)
    else:
        data_module = CIFAR100Data(batch_size=config.batch_size,
                                         resize=resize)
        num_channels, num_classes = 3, 100
        train_loader = data_module.get_dataloader(train=True)
        val_loader   = data_module.get_dataloader(train=False)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    model = VisionTransformer(
        image_size=resize[0],
        patch_size=config.patch_size,
        num_channels=num_channels,
        num_classes=num_classes,
        hidden_size=config.hidden_size,
        num_hidden_layers=config.num_layers,
        num_attention_heads=config.num_heads,
        intermediate_size=config.hidden_size * 4,
        dropout_rate=config.dropout_rate,
        qkv_bias=True
    ).to(device)

    # Adam optimizer, loss, cosine scheduler
    criterion = nn.CrossEntropyLoss(label_smoothing=0.1)
    optimizer = optim.AdamW(model.parameters(),
                            lr=config.lr,
                            weight_decay=0.05)
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer,
                                                     T_max=config.max_epochs)

    # training
    for epoch in range(1, config.max_epochs + 1):
        model.train()
        total_loss = 0.0
        correct_train = 0
        total_train = 0
        for images, labels in train_loader:
            images, labels = images.to(device), labels.to(device)
            optimizer.zero_grad()
            logits = model(images)
            loss = criterion(logits, labels)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()

            preds = logits.argmax(dim=1)
            correct_train += (preds == labels).sum().item()
            total_train += labels.size(0)
            total_loss += loss.item()
        train_loss = total_loss / len(train_loader)
        train_acc = correct_train / total_train

        # validate
        model.eval()
        total_val_loss = 0.0
        correct_val = 0
        total_val = 0
        all_preds, all_labels = [], []
        with torch.no_grad():
            for images, labels in val_loader:
                images, labels = images.to(device), labels.to(device)
                logits = model(images)
                loss = criterion(logits, labels)
                total_val_loss += loss.item()

                preds = logits.argmax(dim=1)
                correct_val += (preds == labels).sum().item()
                total_val += labels.size(0)

                all_preds.append(preds.cpu())
                all_labels.append(labels.cpu())
        val_loss = total_val_loss / len(val_loader)
        val_acc = correct_val / total_val
        all_preds = torch.cat(all_preds).numpy()
        all_labels = torch.cat(all_labels).numpy()

        precision = precision_score(all_labels, all_preds, average='macro')
        recall = recall_score(all_labels, all_preds, average='macro')
        f1 = f1_score(all_labels, all_preds, average='macro')

        wandb.log({
            "epoch": epoch,
            "train_loss": train_loss,
            "train_acc": train_acc,
            "val_loss": val_loss,
            "val_acc": val_acc,
            "precision": precision,
            "recall": recall,
            "f1": f1,
        })
        scheduler.step()

    # save model weights
    os.makedirs('weights', exist_ok=True)
    weight_path = f"weights/{dataset_name}_ps{patch_size}_hs{hidden_size}_nl{num_layers}_lr{lr}.pth"
    torch.save(model.state_dict(), weight_path)
    wandb.save(weight_path)

    images, _ = next(iter(val_loader))
    images = images.to(device)

    # 1, PCA of patch-embedding filters
    plot_embedding_filters(model, n_components=28)

    # 2, position embedding similarity
    plot_position_embedding_similarity(model)

    # 3, mean CLS→patch attention distance
    attn_maps = get_attention_maps(model, images)
    plot_mean_attention_distance(
        attn_maps,
        patch_size=config.patch_size,
        image_height=resize[0],
        image_width=resize[1]
    )

    wandb.finish()

# final hyperparameter sweep
if __name__ == "__main__":
    patch_sizes = [7]
    hidden_sizes = [48]
    learning_rates = [1e-3]
    fixed_num_layers = 4
    dropout_rate = 0.1
    num_heads = 4

    for ps in patch_sizes:
        for hs in hidden_sizes:
          for lr in learning_rates:
              run_experiment(
                  dataset_name="FashionMNIST",
                  patch_size=ps,
                  hidden_size=hs,
                  num_layers=fixed_num_layers,
                  num_heads=num_heads,
                  dropout_rate=dropout_rate,
                  lr=lr,
                  batch_size=128,
                  max_epochs=40
              )


    patch_sizes = [4]
    hidden_sizes = [128]
    learning_rates = [5e-4]
    fixed_num_layers = 12
    dropout_rate = 0.1
    num_heads = 8

    for ps in patch_sizes:
            for hs in hidden_sizes:
              for lr in learning_rates:
                  run_experiment(
                      dataset_name="CIFAR100Data",
                      patch_size=ps,
                      hidden_size=hs,
                      num_layers=fixed_num_layers,
                      num_heads=num_heads,
                      dropout_rate=dropout_rate,
                      lr=lr,
                      batch_size=128,
                      max_epochs=200
                  )