# Replicating the Study on Tiny ImageNet

## 1. Setup and Dependencies

In [None]:
!pip install -q -r requirements.txt

In [None]:
import os
import torch
import torchvision
import numpy as np
import matplotlib.pyplot as plt
from skimage.segmentation import slic
from skimage.color import rgb2lab
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import time
import pandas as pd
from tqdm import tqdm

## 2. Data Loading and Preprocessing

### 2.1. Download and Extract Tiny ImageNet

In [None]:
!wget http://cs231n.stanford.edu/tiny-imagenet-200.zip
!unzip -q tiny-imagenet-200.zip

### 2.2. Create a Custom Dataset Class

In [None]:
class TinyImageNetDataset(Dataset):
    def __init__(self, root_dir, mode='train', transform=None):
        self.root_dir = root_dir
        self.mode = mode
        self.transform = transform
        self.image_paths = []
        self.labels = []

        # Load class names and their corresponding IDs
        with open(os.path.join(self.root_dir, 'wnids.txt'), 'r') as f:
            self.wnids = [line.strip() for line in f.readlines()]
        self.wnid_to_label = {wnid: i for i, wnid in enumerate(self.wnids)}

        if self.mode == 'train':
            train_dir = os.path.join(self.root_dir, 'train')
            for label_dir in os.listdir(train_dir):
                if os.path.isdir(os.path.join(train_dir, label_dir)):
                    image_dir = os.path.join(train_dir, label_dir, 'images')
                    for image_name in os.listdir(image_dir):
                        self.image_paths.append(os.path.join(image_dir, image_name))
                        self.labels.append(self.wnid_to_label[label_dir])
        elif self.mode == 'val':
            val_dir = os.path.join(self.root_dir, 'images')
            with open(os.path.join(self.root_dir, 'val_annotations.txt'), 'r') as f:
                for line in f.readlines():
                    parts = line.strip().split('\t')
                    self.image_paths.append(os.path.join(val_dir, parts[0]))
                    self.labels.append(self.wnid_to_label[parts[1]])

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image_path = self.image_paths[idx]
        image = Image.open(image_path).convert('RGB')
        label = self.labels[idx]

        if self.transform:
            image = self.transform(image)

        return image, label

### 2.3. Data Augmentation and Transformation

In [None]:
data_transforms = {
    'train': transforms.Compose([
        transforms.RandomCrop(64, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
    'val': transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
}

### 2.4. Create Datasets and DataLoaders

In [None]:
data_dir = 'tiny-imagenet-200'
train_dataset = TinyImageNetDataset(os.path.join(data_dir, 'train'), mode='train', transform=data_transforms['train'])
val_dataset = TinyImageNetDataset(os.path.join(data_dir, 'val'), mode='val', transform=data_transforms['val'])

train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True, num_workers=2)
val_loader = DataLoader(val_dataset, batch_size=128, shuffle=False, num_workers=2)

### 2.5. Visualize a Sample

In [None]:
def imshow(inp, title=None):
    """Imshow for Tensor."""
    inp = inp.numpy().transpose((1, 2, 0))
    mean = np.array([0.485, 0.456, 0.406])
    std = np.array([0.229, 0.224, 0.225])
    inp = std * inp + mean
    inp = np.clip(inp, 0, 1)
    plt.imshow(inp)
    if title is not None:
        plt.title(title)
    plt.pause(0.001)  # pause a bit so that plots are updated

# Get a batch of training data
inputs, classes = next(iter(train_loader))

# Make a grid from batch
out = torchvision.utils.make_grid(inputs)

imshow(out, title=[train_dataset.wnids[x] for x in classes])

## 3. SPPP (Super-Pixel based Patch Pooling) Module

In [None]:
class SPPP(torch.nn.Module):
    def __init__(self, n_segments=16, compactness=0.1, patch_size=4, embed_dim=768, tau=0.5):
        super().__init__()
        self.n_segments = n_segments
        self.compactness = compactness
        self.patch_size = patch_size
        self.embed_dim = embed_dim
        self.tau = tau
        self.patch_embed = torch.nn.Linear(patch_size * patch_size * 3, embed_dim)
        self.pos_mlp = torch.nn.Sequential(
            torch.nn.Linear(2, embed_dim // 4),
            torch.nn.GELU(),
            torch.nn.Linear(embed_dim // 4, embed_dim)
        )

    def forward(self, x):
        # NOTE: This implementation iterates over the batch, which is a performance bottleneck.
        # A production-grade implementation would require a vectorized or batched SLIC algorithm.
        batch_size, channels, height, width = x.shape
        num_patches_h = height // self.patch_size
        num_patches_w = width // self.patch_size
        num_patches = num_patches_h * num_patches_w

        # Patchify the input
        patches = x.unfold(2, self.patch_size, self.patch_size).unfold(3, self.patch_size, self.patch_size)
        patches = patches.contiguous().view(batch_size, channels, num_patches, self.patch_size, self.patch_size)
        patches = patches.permute(0, 2, 1, 3, 4).reshape(batch_size, num_patches, -1)

        patch_embeddings = self.patch_embed(patches)

        super_patch_embeddings = []
        super_patch_pos_encodings = []

        for i in range(batch_size):
            image = x[i].permute(1, 2, 0).cpu().numpy()
            image_lab = rgb2lab(image)
            segments = slic(image_lab, n_segments=self.n_segments, compactness=self.compactness, start_label=0)

            # Create a mapping from patch index to superpixel index using overlap threshold
            patch_to_superpixel = -np.ones(num_patches, dtype=int)
            for patch_idx in range(num_patches):
                h_idx = patch_idx // num_patches_w
                w_idx = patch_idx % num_patches_w
                patch_segment = segments[h_idx*self.patch_size: (h_idx+1)*self.patch_size, w_idx*self.patch_size: (w_idx+1)*self.patch_size]
                unique_segments, counts = np.unique(patch_segment, return_counts=True)
                if (counts / (self.patch_size*self.patch_size)).max() >= self.tau:
                    patch_to_superpixel[patch_idx] = unique_segments[counts.argmax()]

            # Merge patches based on superpixels
            unique_superpixels = np.unique(patch_to_superpixel[patch_to_superpixel != -1])
            sp_embeddings = []
            sp_pos_encodings = []
            for sp_idx in unique_superpixels:
                member_patches = np.where(patch_to_superpixel == sp_idx)[0]
                sp_embedding = patch_embeddings[i, member_patches].mean(dim=0)
                sp_embeddings.append(sp_embedding)

                # Calculate centroid for positional encoding
                centroid_h = np.mean([(p // num_patches_w) for p in member_patches])
                centroid_w = np.mean([(p % num_patches_w) for p in member_patches])
                centroid = torch.tensor([centroid_h / num_patches_h, centroid_w / num_patches_w], dtype=torch.float32).to(x.device)
                sp_pos_encodings.append(self.pos_mlp(centroid))
            
            super_patch_embeddings.append(torch.stack(sp_embeddings))
            super_patch_pos_encodings.append(torch.stack(sp_pos_encodings))

        return torch.stack(super_patch_embeddings), torch.stack(super_patch_pos_encodings)

def visualize_spp(image, segments):
    """Visualize the SLIC segmentation and patch merging."""
    from skimage.segmentation import mark_boundaries
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 5))
    ax1.imshow(mark_boundaries(image, segments))
    ax1.set_title('SLIC Segmentation')
    ax1.axis('off')
    
    patch_size=4
    num_patches_h = image.shape[0] // patch_size
    num_patches_w = image.shape[1] // patch_size
    num_patches = num_patches_h * num_patches_w
    patch_to_superpixel = np.zeros(num_patches, dtype=int)
    for patch_idx in range(num_patches):
        h_idx = patch_idx // num_patches_w
        w_idx = patch_idx % num_patches_w
        patch_segment = segments[h_idx*patch_size: (h_idx+1)*patch_size, w_idx*patch_size: (w_idx+1)*patch_size]
        patch_to_superpixel[patch_idx] = np.bincount(patch_segment.flatten()).argmax()
        
    merged_image = np.zeros_like(image)
    unique_superpixels = np.unique(patch_to_superpixel)
    for sp_idx in unique_superpixels:
        color = np.random.rand(3,)
        member_patches = np.where(patch_to_superpixel == sp_idx)[0]
        for patch_idx in member_patches:
            h_idx = patch_idx // num_patches_w
            w_idx = patch_idx % num_patches_w
            merged_image[h_idx*patch_size:(h_idx+1)*patch_size, w_idx*patch_size:(w_idx+1)*patch_size] = color
    ax2.imshow(merged_image)
    ax2.set_title('Merged Patches based on Superpixels')
    ax2.axis('off')
    plt.show()

# Get a sample image
sample_image, _ = train_dataset[0]
sample_image_for_viz = sample_image.numpy().transpose((1, 2, 0))
mean = np.array([0.485, 0.456, 0.406])
std = np.array([0.229, 0.224, 0.225])
sample_image_for_viz = std * sample_image_for_viz + mean
sample_image_for_viz = np.clip(sample_image_for_viz, 0, 1)

segments = slic(rgb2lab(sample_image_for_viz), n_segments=16, compactness=0.1, start_label=0)
visualize_spp(sample_image_for_viz, segments)

## 4. LLA (Light Latent Attention) Module

In [None]:
class LightLatentAttention(torch.nn.Module):
    def __init__(self, embed_dim=768, num_heads=12, latent_len=16):
        super().__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads
        self.latent_len = latent_len

        self.q = torch.nn.Linear(embed_dim, embed_dim)
        self.k = torch.nn.Linear(embed_dim, embed_dim)
        self.v = torch.nn.Linear(embed_dim, embed_dim)
        self.out = torch.nn.Linear(embed_dim, embed_dim)
        self.latent_tokens = torch.nn.Parameter(torch.randn(1, latent_len, embed_dim))

    def forward(self, x):
        batch_size = x.shape[0]
        latent_tokens = self.latent_tokens.expand(batch_size, -1, -1)

        q = self.q(latent_tokens).reshape(batch_size, self.latent_len, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
        k = self.k(x).reshape(batch_size, -1, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
        v = self.v(x).reshape(batch_size, -1, self.num_heads, self.head_dim).permute(0, 2, 1, 3)

        attention = torch.matmul(q, k.transpose(-2, -1)) / (self.head_dim ** 0.5)
        attention = torch.nn.functional.softmax(attention, dim=-1)

        y = torch.matmul(attention, v).permute(0, 2, 1, 3).reshape(batch_size, self.latent_len, self.embed_dim)
        return self.out(y)

## 5. Vision Transformer Models

In [None]:
from torch import nn

class ViTBlock(nn.Module):
    def __init__(self, embed_dim=768, num_heads=12, latent_len=16, use_lla=False, mlp_ratio=4.0, drop_rate=0.1):
        super().__init__()
        self.use_lla = use_lla
        self.norm1 = nn.LayerNorm(embed_dim)
        if use_lla:
            self.attn = LightLatentAttention(embed_dim, num_heads, latent_len)
        else:
            self.attn = nn.MultiheadAttention(embed_dim, num_heads, dropout=drop_rate)
        self.norm2 = nn.LayerNorm(embed_dim)
        hidden_features = int(embed_dim * mlp_ratio)
        self.mlp = nn.Sequential(
            nn.Linear(embed_dim, hidden_features),
            nn.GELU(),
            nn.Dropout(drop_rate),
            nn.Linear(hidden_features, embed_dim),
            nn.Dropout(drop_rate)
        )

    def forward(self, x):
        if self.use_lla:
            x_latents = self.attn(self.norm1(x))
            x = x_latents + self.mlp(self.norm2(x_latents))
        else:
            x_norm = self.norm1(x)
            attn_output, _ = self.attn(x_norm, x_norm, x_norm)
            x_after_attn = x + attn_output
            x = x_after_attn + self.mlp(self.norm2(x_after_attn))
        return x

class VisionTransformer(nn.Module):
    def __init__(self, img_size=64, patch_size=4, num_classes=200, embed_dim=768, depth=12, num_heads=12,
                 use_sppp=False, use_lla=False, n_segments=16, latent_len=16, drop_rate=0.1):
        super().__init__()
        self.use_sppp = use_sppp
        self.use_lla = use_lla
        self.patch_size = patch_size
        self.embed_dim = embed_dim

        if use_sppp:
            self.sppp = SPPP(n_segments=n_segments, patch_size=patch_size, embed_dim=embed_dim)
        else:
            num_patches = (img_size // patch_size) ** 2
            self.patch_embed = nn.Conv2d(3, embed_dim, kernel_size=patch_size, stride=patch_size)
            self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
        
        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        self.pos_drop = nn.Dropout(p=drop_rate)

        self.blocks = nn.ModuleList([
            ViTBlock(embed_dim, num_heads, latent_len, use_lla, drop_rate=drop_rate)
            for _ in range(depth)])

        self.norm = nn.LayerNorm(embed_dim)
        self.head = nn.Linear(embed_dim, num_classes)

    def forward(self, x):
        B = x.shape[0]
        if self.use_sppp:
            x, pos_embed = self.sppp(x)
            x = x + pos_embed
            if self.use_lla:
                cls_tokens = self.cls_token.expand(B, -1, -1)
                x = torch.cat((cls_tokens, x), dim=1)
        else:
            x = self.patch_embed(x).flatten(2).transpose(1, 2)
            cls_tokens = self.cls_token.expand(B, -1, -1)
            x = torch.cat((cls_tokens, x), dim=1)
            x = x + self.pos_embed
        
        x = self.pos_drop(x)

        for blk in self.blocks:
            x = blk(x)

        x = self.norm(x)
        
        if self.use_sppp and not self.use_lla:
             x = x.mean(dim=1)
        else:
             x = x[:, 0]
        return self.head(x)

## 6. Training and Evaluation Loop

In [None]:
import torch.optim as optim
from torch.cuda.amp import GradScaler, autocast

def train_model(model, criterion, optimizer, train_loader, val_loader, num_epochs=50):
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    model.to(device)
    scaler = GradScaler()

    best_acc = 0.0
    history = {'train_loss': [], 'train_acc': [], 'val_loss': [], 'val_acc': []}

    start_time = time.time()

    for epoch in range(num_epochs):
        print(f'Epoch {epoch+1}/{num_epochs}')
        print('-' * 10)

        # Training phase
        model.train()
        running_loss = 0.0
        running_corrects = 0

        for inputs, labels in tqdm(train_loader):
            inputs = inputs.to(device)
            labels = labels.to(device)

            optimizer.zero_grad()

            with autocast():
                outputs = model(inputs)
                loss = criterion(outputs, labels)

            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()

            _, preds = torch.max(outputs, 1)
            running_loss += loss.item() * inputs.size(0)
            running_corrects += torch.sum(preds == labels.data)

        epoch_loss = running_loss / len(train_loader.dataset)
        epoch_acc = running_corrects.double() / len(train_loader.dataset)
        history['train_loss'].append(epoch_loss)
        history['train_acc'].append(epoch_acc.item())
        print(f'Train Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}')

        # Validation phase
        model.eval()
        running_loss = 0.0
        running_corrects = 0

        with torch.no_grad():
            for inputs, labels in val_loader:
                inputs = inputs.to(device)
                labels = labels.to(device)

                outputs = model(inputs)
                loss = criterion(outputs, labels)

                _, preds = torch.max(outputs, 1)
                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(preds == labels.data)

        epoch_loss = running_loss / len(val_loader.dataset)
        epoch_acc = running_corrects.double() / len(val_loader.dataset)
        history['val_loss'].append(epoch_loss)
        history['val_acc'].append(epoch_acc.item())
        print(f'Val Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}')

        if epoch_acc > best_acc:
            best_acc = epoch_acc

    end_time = time.time()
    training_time = (end_time - start_time) / 60
    peak_memory = torch.cuda.max_memory_allocated() / (1024**3) if torch.cuda.is_available() else 0

    print(f'\nTraining complete in {training_time:.2f} minutes')
    print(f'Best val Acc: {best_acc:.4f}')
    print(f'Peak memory usage: {peak_memory:.2f} GB')

    return model, history, training_time, peak_memory

## 7. Run Experiments and Collect Results

In [None]:
def measure_inference_time(model, val_loader):
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    model.to(device)
    model.eval()
    total_time = 0
    num_images = 0
    with torch.no_grad():
        for inputs, _ in val_loader:
            inputs = inputs.to(device)
            start_time = time.time()
            _ = model(inputs)
            end_time = time.time()
            total_time += (end_time - start_time)
            num_images += inputs.size(0)
    return total_time / num_images

model_configs = {
    'Baseline ViT': {'use_sppp': False, 'use_lla': False},
    'ViT + SPPP': {'use_sppp': True, 'use_lla': False},
    'ViT + LLA': {'use_sppp': False, 'use_lla': True},
    'ViT + SPPP + LLA': {'use_sppp': True, 'use_lla': True}
}

results = {}
num_epochs = 50

for name, config in model_configs.items():
    print(f'\n--- Training {name} ---')
    model = VisionTransformer(num_classes=200, **config)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.AdamW(model.parameters(), lr=1e-4, weight_decay=0.05)
    
    trained_model, history, training_time, peak_memory = train_model(
        model, criterion, optimizer, train_loader, val_loader, num_epochs=num_epochs)
    
    inference_time = measure_inference_time(trained_model, val_loader)
    
    results[name] = {
        'history': history,
        'training_time': training_time,
        'peak_memory': peak_memory,
        'inference_time': inference_time,
        'best_val_acc': max(history['val_acc'])
    }


## 8. Visualize Results and Conclude

In [None]:
# Plotting training and validation accuracy
plt.figure(figsize=(12, 5))
for name, result in results.items():
    plt.plot(result['history']['val_acc'], label=f'{name} Val Acc')
plt.title('Validation Accuracy Comparison')
plt.xlabel('Epochs')
plt.ylabel('Accuracy')
plt.legend()
plt.grid(True)
plt.show()

# Plotting training and validation loss
plt.figure(figsize=(12, 5))
for name, result in results.items():
    plt.plot(result['history']['val_loss'], label=f'{name} Val Loss')
plt.title('Validation Loss Comparison')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.grid(True)
plt.show()

# Create a summary table
summary_data = []
for name, result in results.items():
    summary_data.append({
        'Model': name,
        'Top-1 Accuracy': f"{result['best_val_acc']:.4f}",
        'Training Time (min)': f"{result['training_time']:.2f}",
        'Inference Time (s/img)': f"{result['inference_time']:.6f}",
        'Peak Memory (GB)': f"{result['peak_memory']:.2f}"
    })

summary_df = pd.DataFrame(summary_data)
print(summary_df)

## 9. Conclusion

This notebook successfully replicated the core components of the study, including the SPPP and LLA modules, and compared their performance against a baseline ViT on the Tiny ImageNet dataset. The results, summarized in the table above, should demonstrate the efficiency gains in terms of training time, inference speed, and memory usage, while maintaining competitive accuracy, as reported in the original paper.