<a href="https://colab.research.google.com/github/evanjiang943/2021-Fall/blob/master/6_s898_Fall_2023_hw2_transfer.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Hyperparameter transfer

## Basic imports

In [None]:
import math
import time
import numpy as np

from tqdm.notebook import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F

from torchvision import datasets, transforms

## Question 1: Spectral norm of Gaussian matrix

In [None]:
d = 100

M = torch.randn(size=(d,d), device="cuda")
spec_norm = torch.linalg.matrix_norm(M, ord=2)

print(spec_norm.item())
print(math.sqrt(4*d))

## Question 2: Spectral norm of orthogonal matrix

In [None]:
from torch.nn.init import orthogonal_

d = 100

M = torch.zeros(size=(d,d), device="cuda")
orthogonal_(M) # this line resamples M to be a random semi-orthogonal matrix
spec_norm = torch.linalg.matrix_norm(M, ord=2)

print(spec_norm.item())

## Question 3: Power iteration

In [None]:
def spectral_norm(A, n_steps=10):
    v = torch.randn(A.shape[1], device=A.device)
    for _ in range(n_steps):
        v /= v.norm()
        v = A @ v @ A
    return v.norm().sqrt()

d = 2000
M = torch.randn(size=(d,d), device="cuda")

t0 = time.time()
spec_norm = spectral_norm(M)
print(time.time()-t0, spec_norm.item())

t0 = time.time()
spec_norm = torch.linalg.matrix_norm(M, ord=2)
print(time.time()-t0, spec_norm.item())

## Question 4: Learning rate transfer across width and depth
You only need to modify the two lines of code marked TODO.

In [None]:
batch_size = 128

mean = (0.4914, 0.4822, 0.4465)
std = (0.2023, 0.1994, 0.2010)

transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize(mean, std)])

trainset = datasets.CIFAR10('./data', train=True,  download=True, transform=transform)
testset  = datasets.CIFAR10('./data', train=False, download=True, transform=transform)

train_loader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True,  pin_memory=True)
test_loader  = torch.utils.data.DataLoader(testset,  batch_size=batch_size, shuffle=False, pin_memory=True)

## Define the MLP architecture

In [None]:
class MLP(nn.Module):
    def __init__(self, depth, width):
        super(MLP, self).__init__()

        self.initial = nn.Linear(3072, width, bias=False)
        self.layers = nn.ModuleList([nn.Linear(width, width, bias=False) for _ in range(depth-2)])
        self.final = nn.Linear(width, 10, bias=False)

        self.nonlinearity = lambda x: F.relu(x) * math.sqrt(2)

    def forward(self, x):
        x = x.view(x.shape[0],-1)

        x = self.initial(x)
        x = self.nonlinearity(x)

        for layer in self.layers:
            x = layer(x)
            x = self.nonlinearity(x)

        return self.final(x)

## Define the train and test loop

In [None]:
def loop(net, train, eta):
    dataloader  = train_loader   if train else test_loader
    description = "Training... " if train else "Testing... "

    acc_list = []

    for data, target in tqdm(dataloader, desc=description):
        data, target = data.cuda(), target.cuda()
        output = net(data)

        loss = output.logsumexp(dim=1).mean() - output[range(target.shape[0]),target].mean() # cross-entropy loss
        acc = (output.max(dim=1)[1] == target).sum() / target.shape[0] # accuracy
        acc_list.append(acc.item())

        if train:
            loss.backward()

            depth = sum(1 for p in net.parameters())
            for p in net.parameters():
                update = p.grad
                update *= 1 # TODO modify this line of code
                p.data -= eta * update
            net.zero_grad()

    return np.mean(acc_list)

## Train networks at different widths and depths

In [None]:
for width, depth in [(100,3), (2000,5), (4000,7)]:
    print(f"Training at {width=}, {depth=}")

    net = MLP(depth, width).cuda()

    print("\nNetwork tensor shapes are:\n")
    for name, p in net.named_parameters():
        print(p.shape, '\t', name)
        orthogonal_(p)
        p.data *= 1 # TODO modify this line of code

    for _ in range(3):
        train_acc = loop(net, train=True,  eta=0.5)
    test_acc  = loop(net, train=False, eta=None)

    print(f"\nWe achieved train acc={train_acc:.3f} and test acc={test_acc:.3f}\n")
    print("===================================================================\n")