<a href="https://colab.research.google.com/github/davidgonmar/model-compression-exps/blob/main/svd_update_alignment.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# by chatgpt
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import time


# -------------------------------
# Alignment Function
# -------------------------------
def svd_alignment(W, dW, k=5):
    """Compute how much of dW lies in the top-k singular subspace of W."""
    with torch.no_grad():
        U, S, Vh = torch.linalg.svd(W, full_matrices=False)
        U_k = U[:, :k]
        V_k = Vh[:k, :].T
        dW_proj = U_k @ (U_k.T @ dW @ V_k) @ V_k.T
        alignment = (dW_proj.norm() ** 2) / (dW.norm() ** 2 + 1e-10)
        return alignment.item()


# -------------------------------
# Simple Model
# -------------------------------
class SimpleMLP(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(28 * 28, 256)
        self.fc2 = nn.Linear(256, 10)

    def forward(self, x):
        x = x.view(-1, 28 * 28)
        x = F.relu(self.fc1(x))
        return self.fc2(x)


# -------------------------------
# Training Setup
# -------------------------------
def train(model, device, train_loader, optimizer, criterion, epochs=3, k=5):
    model.train()
    for epoch in range(1, epochs + 1):
        for batch_idx, (data, target) in enumerate(train_loader):
            data, target = data.to(device), target.to(device)

            optimizer.zero_grad()
            output = model(data)
            loss = criterion(output, target)
            loss.backward()

            # SVD alignment (before optimizer step)
            with torch.no_grad():
                for name, param in model.named_parameters():
                    if 'weight' in name and 'fc1' in name and param.grad is not None:
                        align = svd_alignment(param.data, param.grad, k=k)
                        print(f"Epoch {epoch}, Batch {batch_idx}: {name} update alignment (top-{k}) = {align:.4f}")
                        break  # Only compute for one layer

            optimizer.step()


# -------------------------------
# Main
# -------------------------------
def main():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # MNIST
    transform = transforms.ToTensor()
    train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
    train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)

    # Model
    model = SimpleMLP().to(device)
    optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
    criterion = nn.CrossEntropyLoss()

    # Train
    train(model, device, train_loader, optimizer, criterion, epochs=1, k=5)


if __name__ == "__main__":
    main()