Comparing Mobile Net V2 performance with and without memory+ layer

In [1]:
import sys
import os
import time
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
from torch.profiler import profile, record_function, ProfilerActivity
import timm
import math
import torch.nn.functional as F

In [2]:
class MemoryPlusLayer(nn.Module):

    def __init__(self, d_model, memory_slots, top_k = 32):
        # Define your memory mechanism here
        # Using Berges et al. (2024) "Memory Layers at scale" as a reference for the memory layer design

        super().__init__()

        self.key_dim = d_model // 2
        self.subkey_dim = self.key_dim // 2
        self.value_dim = d_model # <-- NOTE: May experiment with this value, as it may affect performance and memory usage.

        # Total memory_slots = |C| * |C'|. Sub-key matrices have sqrt(memory_slots) rows.
        self.num_subkeys = math.isqrt(memory_slots)
        assert self.num_subkeys ** 2 == memory_slots, f"memory_slots (n = {memory_slots}) must be a perfect square."

        # Query MLP
        self.query = nn.Sequential(
            nn.Linear(d_model, d_model * 4),
            nn.SiLU(), # <-- Should ideally match whatever the base models FFN activation function is.
            nn.Linear(d_model * 4, self.subkey_dim)
        )

        # Sub-Key Matrix One and Two
        # NOTE: Don't use nn.linear here, due to sparse key retrieval mechanism in forward pass.
        self.subkey_one = nn.Parameter(torch.empty(self.num_subkeys, self.subkey_dim, dtype=torch.float32))
        self.subkey_two = nn.Parameter(torch.empty(self.num_subkeys, self.subkey_dim, dtype=torch.float32))
        nn.init.uniform_(self.subkey_one, a = -1, b = 1)
        nn.init.uniform_(self.subkey_two, a = -1, b = 1)

        # Value Matrix
        self.values = nn.Parameter(torch.empty(memory_slots, self.value_dim, dtype=torch.float32))
        nn.init.normal_(self.values, std=0.02)  # apparently from lample et al 2019, CAN't FIND ITS REFERENCE

        # Weight Matrix One
        self.W1 = nn.Linear(d_model, self.value_dim, bias=False)

        # Weight Matrix Two
        self.W2 = nn.Linear(self.value_dim, d_model, bias=False)

        # Silu Activation Function
        self.silu = nn.SiLU()

        # QK-Normalisation,
        # NOTE:I think its more a general backbone design choice for memory layer, potentially place this after residual connection as we are using interleaved architecture (at end of this gated memory layer)
        """
        NOTE: This is a technique used to stabilize training and improve convergence in transformer models.
        """
        self.qk_norm = nn.RMSNorm(self.subkey_dim)

        # Top-K Selection
        """
        NOTE: Can experiment with this value, as it may affect performance and memory usage.
        """
        self.top_k = top_k

        # Softmax
        self.softmax = nn.Softmax(dim=-1)


    def lookup_memory(self, query):

        # 1. Apply normalisation for cosine similarity style lookup
        k1 = self.qk_norm(self.subkey_one)
        k2 = self.qk_norm(self.subkey_two)

        # 2. Get similarity subkey scores with query
        sim_scores_1 = query @ k1.T
        sim_scores_2  = query @ k2.T
        all_scores = sim_scores_1.unsqueeze(-1) + sim_scores_2.unsqueeze(-2)

        # 3. Cartesian Product Search:
        all_scores = all_scores.view(*all_scores.shape[:-2], -1)

        # 4. Select the final top-k combinations
        top_k_scores, top_k_indices = torch.topk(all_scores, self.top_k, dim=-1)

        # 5. Retrieve Values and Aggregate
        s = self.softmax(top_k_scores)

        # 6. Gather Values and Aggregate: NOTE: Using EmbeddingBag!
        # TODO: Make CUDA kernel to quicken EmbeddingBag solution
        flat_indices = top_k_indices.view(-1, self.top_k)
        flat_weights = s.view(-1, self.top_k)
        y_flat = F.embedding_bag(flat_indices, self.values, per_sample_weights=flat_weights, mode='sum')

        return y_flat.view(*query.shape[:-1], self.value_dim)

    def forward(self, x):

        q = self.query(x)
        q = self.qk_norm(q)

        y = self.lookup_memory(q)

        m_plus = self.silu(self.W1(x))
        m_plus = y * m_plus
        m_plus = self.W2(m_plus)

        return m_plus

In [3]:
def profile_model_performance(model, device, name="Model"):
    """Profiles a single forward and backward pass to see memory/FLOP tradeoffs."""
    print(f"\n--- Profiling {name} ---")
    model.eval()
    inputs = torch.randn(1, 3, 224, 224).to(device)


     # [ProfilerActivity.CPU, ProfilerActivity.CUDA] if torch.cuda.is_available() else [ProfilerActivity.CPU],
    with profile(
        activities = [ProfilerActivity.CPU],
        record_shapes=True,
        profile_memory=True,
        with_stack=True
    ) as prof:
        with record_function("forward_pass"):
            output = model(inputs)
        with record_function("backward_pass"):
            loss = output.sum()
            loss.backward()

    # Sorted by CUDA time if available, else CPU time
    sort_by = "gpu_time_total" if torch.cuda.is_available() else "cpu_time_total"
    print(prof.key_averages().table(sort_by=sort_by, row_limit=10))

def train_epoch(model, loader, criterion, optimizer, device):
    model.train()
    running_loss, correct, total = 0.0, 0, 0
    for images, labels in loader:
        print("hi")
        images, labels = images.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        _, predicted = outputs.max(1)
        total += labels.size(0)
        correct += predicted.eq(labels).sum().item()
    return running_loss / len(loader), 100. * correct / total

def validate(model, loader, criterion, device):
    model.eval()
    running_loss, correct, total = 0.0, 0, 0
    with torch.no_grad():
        for images, labels in loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            loss = criterion(outputs, labels)
            running_loss += loss.item()
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()
    return running_loss / len(loader), 100. * correct / total

def run_comparison(dense_model, memory_model, train_loader, test_loader, device, epochs=5):
    # Detect device (Note: MPS for Mac is an option, but profiler support varies)

    # Fixed keys to match your storage logic
    results = {'dense': {'val_loss': [], 'val_acc': []}, 'memory': {'val_loss': [], 'val_acc': []}}
    speed_comp = {dense_model : 0.0, memory_model : 0.0}
    for name, model in [('dense', dense_model), ('memory', memory_model)]:
        print(f"\nStarting training for {name}...")
        optimizer = optim.AdamW(model.parameters(), lr=1e-4)
        criterion = nn.CrossEntropyLoss()
        start_time = time.time()
        for epoch in range(epochs):
            print(f"Epoch {epoch+1}/{epochs}")
            train_loss, train_acc = train_epoch(model, train_loader, criterion, optimizer, device)
            print(f"Epoch {epoch+1}: train loss: {train_loss}, train acc: {train_acc}")
            val_loss, val_acc = validate(model, test_loader, criterion, device)

            results[name]['val_loss'].append(val_loss)
            results[name]['val_acc'].append(val_acc)
            print(f"Epoch {epoch+1}: Val Acc {val_acc:.2f}%")

        end_time = time.time()
        speed_comp[model] = end_time - start_time



    return results, speed_comp

In [4]:
# Plotting Accuracy
def plot_results(results):
    plt.figure(figsize=(8, 5))
    plt.plot(results['dense']['acc'], label='Dense Baseline (Pre-trained ViT)')
    plt.plot(results['memory']['acc'], label='Memory+ Adapter ViT')
    plt.title('CIFAR-100 Validation Accuracy')
    plt.ylabel('Accuracy (%)')
    plt.xlabel('Epoch')
    plt.legend()
    plt.grid(True)
    plt.show()


In [5]:
# Init Device
if torch.cuda.is_available():
    device = torch.device("cuda")
elif torch.backends.mps.is_available():
    device = torch.device("mps")
else:
    device = torch.device("cpu")

device

device(type='cuda')

In [6]:
# Init Dataset
transform = transforms.Compose([
    transforms.Resize(224),
    transforms.Grayscale(num_output_channels=3),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

train_set = datasets.FashionMNIST(root='./data_dir', train=True, download=True, transform=transform)
test_set = datasets.FashionMNIST(root='./data_dir', train=False, download=True, transform=transform)
train_loader = DataLoader(train_set, batch_size=32, shuffle=True)
test_loader = DataLoader(test_set, batch_size=32, shuffle=False)


100%|██████████| 26.4M/26.4M [00:02<00:00, 10.4MB/s]
100%|██████████| 29.5k/29.5k [00:00<00:00, 188kB/s]
100%|██████████| 4.42M/4.42M [00:01<00:00, 3.20MB/s]
100%|██████████| 5.15k/5.15k [00:00<00:00, 23.2MB/s]


In [7]:
# Init Models
dense_model = timm.create_model('mobilenetv2_100', pretrained=False, num_classes=10, cache_dir="./models_dir").to(device)
memory_model = timm.create_model('mobilenetv2_100', pretrained=False, num_classes=10, cache_dir = "./models_dir").to(device)

d_model = dense_model.embed_dim
memory_slots = 256**2
memory_model.blocks[6].mlp = MemoryPlusLayer(d_model=d_model, memory_slots=memory_slots).to(device)

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


model.safetensors:   0%|          | 0.00/14.2M [00:00<?, ?B/s]

AttributeError: 'EfficientNet' object has no attribute 'embed_dim'

In [None]:
"""VERY IMPORTANT NOTE:

Base model is 17x faster than Memory+ (1024**2 memory slots) ViT!!!!
NEED CUSTOM KERNEL FOR EMBEDDINGBAG SOLUTION TO SPEED THIS UP,
AS THIS IS THE BOTTLENECK IN THE MEMORY LAYER.

Hoever found memory slot size 256**2 to be near performance of baseline!

"""

# PROFILE MODELS BEFORE TRAINING TO SEE MEMORY/FLOP TRADEOFFS
profile_model_performance(dense_model, device, name="Dense Baseline")
profile_model_performance(memory_model, device, name="Memory+ Adapter")

In [None]:
res, speeds = run_comparison(dense_model, memory_model, train_loader, test_loader, device, epochs=5)


In [None]:
plot_results(res)

In [None]:
for m in speeds.key():
    print(f"{m} Training Time: {speeds[m]:.2f} seconds")