In [3]:
import torch
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Subset
import numpy as np
from introduction_code import * 
import matplotlib.pyplot as plt
import random

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

device(type='cpu')

In [4]:
import lightning as L
L.seed_everything(628, workers=True)

Seed set to 628


628

In [6]:
def labels_to_one_hot(labels, num_classes=10):
    labels_one_hot = torch.zeros(labels.size(0), num_classes, device=labels.device)
    labels_one_hot.scatter_(1, labels.view(-1, 1), 1)
    return labels_one_hot
transform = transforms.Compose([transforms.ToTensor()])
train_dataset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform)

class OneHotMNISTDataset(torch.utils.data.Dataset):
    def __init__(self, mnist_dataset):
        self.mnist_dataset = mnist_dataset
    def __len__(self):
        return len(self.mnist_dataset)
    def __getitem__(self, idx):
        image, label = self.mnist_dataset[idx]
        label_one_hot = torch.zeros(10)
        label_one_hot[label] = 1
        return image, label_one_hot

batch_size = 64
train_dataset_onehot = OneHotMNISTDataset(train_dataset)
train_loader = DataLoader(train_dataset_onehot, batch_size=batch_size, shuffle=True)
train_loader, len(train_loader)

(<torch.utils.data.dataloader.DataLoader at 0x262a07db4d0>, 938)

In [None]:
# Limit the size of gradient_loader to prevent memory problems
subset_size = 1000 
subset_indices = torch.randperm(len(train_dataset_onehot))[:subset_size]
small_train_dataset = Subset(train_dataset_onehot, subset_indices)
gradient_loader = DataLoader(small_train_dataset, batch_size=batch_size, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

# Model time
input_size = 28*28
output_size = 10
hidden_size = 128
num_layers = 1 

model = SingleLayerMLP(input_size, output_size, hidden_size, num_layers)
model.to(device)

# optimizer and loss function
optimizer = torch.optim.SGD(model.parameters(), lr=0.02, momentum=0.9)
criterion = lambda x, y: ((x - y)**2).mean()

In [8]:
num_epochs = 1  
print("Training the model")
for epoch in range(num_epochs):
    model.train()
    for images, labels_one_hot in train_loader:
        images = images.to(device)
        labels_one_hot = labels_one_hot.to(device)
        images = images.view(-1, 28*28)

        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels_one_hot)
        loss.backward()
        optimizer.step()

# Use GaussianFit to compute the NTK approximation
kernel_model = GaussianFit(model, device, noise_var=0.0)
print("Fitting Gaussian model (computing NTK)")
kernel_model.fit(gradient_loader, optimizer, MSELoss_batch)

kernel_model.eval()
correct = 0
total = 0
for images, labels in test_loader:
    images = images.to(device)
    labels = labels.to(device)
    images = images.view(-1, 28*28)

    # Get preds from the kernel model
    outputs = kernel_model(images)
    _, predicted = torch.max(outputs.data, 1)
    total += labels.size(0)
    correct += (predicted == labels).sum().item()

accuracy = correct / total * 100
print(f'NTK acc on MNIST: {accuracy:.2f}%')


Training the model
Fitting Gaussian model (computing NTK)
NTK acc on MNIST: 79.94%
