We run the MNIST example from Liu et al; adapted for output scaling.

In [39]:
from collections import defaultdict
from itertools import islice
import random
import time
import os
from pathlib import Path
import math

import numpy as np
from tqdm.auto import tqdm

import torch
import torch.nn as nn
import torchvision

import matplotlib.pyplot as plt
import seaborn as sns

In [40]:
optimizer_dict = {
    'AdamW': torch.optim.AdamW,
    'Adam': torch.optim.Adam,
    'SGD': torch.optim.SGD
}

activation_dict = {
    'ReLU': nn.ReLU,
    'Tanh': nn.Tanh,
    'Sigmoid': nn.Sigmoid,
    'GELU': nn.GELU
}

loss_function_dict = {
    'MSE': nn.MSELoss,
    'CrossEntropy': nn.CrossEntropyLoss
}

In [41]:
train_points = 960
optimization_steps = 100001
batch_size = 16
loss_function = 'MSE'   # 'MSE' or 'CrossEntropy'
optimizer = 'AdamW'     # 'AdamW' or 'Adam' or 'SGD'
lr = 1e-3
initialization_scale = 8.0
download_directory = "../data"
weight_decay = 0
depth = 3              # the number of nn.Linear modules in the model
width = 200
activation = 'ReLU'     # 'ReLU' or 'Tanh' or 'Sigmoid' or 'GELU'

log_freq = math.ceil(optimization_steps / 150)


device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
dtype = torch.float64
seed = 0

In [42]:
activation_fn = activation_dict[activation]

def create_mlp(depth, width, activation, alpha=1.0):
    """Creates an MLP model with specified depth, width, activation, and output scaling."""
    layers = [nn.Flatten()]
    for i in range(depth):
        if i == 0:
            layers.append(nn.Linear(28 * 28, width))
            layers.append(activation_fn())
        elif i == depth - 1:
            layers.append(nn.Linear(width, 10))
        else:
            layers.append(nn.Linear(width, width))
            layers.append(activation_fn())

    class OutputScaledMLP(nn.Module):
        def __init__(self, mlp, alpha):
            super().__init__()
            self.mlp = mlp
            self.alpha = alpha

        def forward(self, x):
            return self.alpha * self.mlp(x)

    return OutputScaledMLP(nn.Sequential(*layers), alpha).to(device)

In [43]:
def cycle(iterable):
    while True:
        for x in iterable:
            yield x

def compute_accuracy(network, dataset, device, N=2000, batch_size=50):
    """Computes accuracy of `network` on `dataset`.
    """
    with torch.no_grad():
        N = min(len(dataset), N)
        batch_size = min(batch_size, N)
        dataset_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)
        correct = 0
        total = 0
        for x, labels in islice(dataset_loader, N // batch_size):
            logits = network(x.to(device))
            predicted_labels = torch.argmax(logits, dim=1)
            correct += torch.sum(predicted_labels == labels.to(device))
            total += x.size(0)
        return (correct / total).item()

def compute_loss(network, dataset, loss_function, device, N=2000, batch_size=50):
    """Computes mean loss of `network` on `dataset`.
    """
    with torch.no_grad():
        N = min(len(dataset), N)
        batch_size = min(batch_size, N)
        dataset_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)
        loss_fn = loss_function_dict[loss_function](reduction='sum')
        one_hots = torch.eye(10, 10).to(device)
        total = 0
        points = 0
        for x, labels in islice(dataset_loader, N // batch_size):
            y = network(x.to(device))
            if loss_function == 'CrossEntropy':
                total += loss_fn(y, labels.to(device)).item()
            elif loss_function == 'MSE':
                total += loss_fn(y, one_hots[labels]).item()
            points += len(labels)
        return total / points

In [44]:
# load dataset
train = torchvision.datasets.MNIST(root=download_directory, train=True,
    transform=torchvision.transforms.ToTensor(), download=False)
test = torchvision.datasets.MNIST(root=download_directory, train=False,
    transform=torchvision.transforms.ToTensor(), download=False)
train = torch.utils.data.Subset(train, range(train_points))
train_loader = torch.utils.data.DataLoader(train, batch_size=batch_size, shuffle=False)

subset_indices = np.random.choice(len(test), 960, replace=False) 
test_subset = torch.utils.data.Subset(test, subset_indices)
test_loader = torch.utils.data.DataLoader(test_subset, batch_size=batch_size, shuffle=False)

assert activation in activation_dict, f"Unsupported activation function: {activation}"
activation_fn = activation_dict[activation]

In [45]:
def compute_ntk_at_epoch(model, dataloader):
    # Set requires_grad to True for all model parameters
    for param in model.parameters():
        param.requires_grad = True
    
    ntk_sum = None
    num_batches = 0
    
    for inputs, _ in dataloader:
        inputs = inputs.to(device)
        
        # Compute the Jacobian matrix
        model.zero_grad()
        outputs = model(inputs)
        jacobian = []
        
        for output in outputs:
            grad_output = torch.zeros_like(output)
            grad_output[:] = 1.0
            gradients = torch.autograd.grad(output, model.parameters(), grad_outputs=grad_output, create_graph=True)
            jacobian.append(torch.cat([g.view(-1) for g in gradients]))
        
        jacobian = torch.stack(jacobian)
        
        # Compute the NTK for the current batch
        ntk_batch = torch.matmul(jacobian, jacobian.t())
        
        # Accumulate the NTK sum
        if ntk_sum is None:
            ntk_sum = ntk_batch.cpu().detach()
        else:
            ntk_sum += ntk_batch.cpu().detach()
        
        num_batches += 1
    
    # Compute the average NTK over all batches
    ntk_avg = ntk_sum / num_batches
    
    return ntk_avg.numpy()

def kernel_distance(Kt1, Kt2):
    """
    Compute the kernel distance between two Neural Tangent Kernels.
    
    Args:
        Kt1: NTK matrix at time t1
        Kt2: NTK matrix at time t2
    
    Returns:
        The kernel distance between Kt1 and Kt2
    """
    # Compute the Frobenius inner product
    frobenius_inner_product = np.sum(Kt1 * Kt2)
    
    # Compute the Frobenius norms
    frobenius_norm_Kt1 = np.sqrt(np.sum(Kt1**2))
    frobenius_norm_Kt2 = np.sqrt(np.sum(Kt2**2))
    
    # Compute the kernel distance
    kernel_dist = 1 - frobenius_inner_product / (frobenius_norm_Kt1 * frobenius_norm_Kt2)
    
    return kernel_dist


def compute_ntk_full_at_epoch(model, dataloader):
    '''
    Compute NTK for full dataset loaded.
    '''
    for param in model.parameters():
        param.requires_grad = True
    
    all_jacobians = []
    all_labels = []
    
    for inputs, labels in dataloader:
        inputs = inputs.to(device)
        
        model.zero_grad()
        outputs = model(inputs)
        jacobian = []
        
        for output in outputs:
            grad_output = torch.zeros_like(output)
            grad_output[:] = 1.0
            gradients = torch.autograd.grad(output, model.parameters(), grad_outputs=grad_output, create_graph=True)
            jacobian.append(torch.cat([g.view(-1) for g in gradients]))
        
        # stack Jacobians for this batch
        jacobian = torch.stack(jacobian)
        
        # stack across batch
        all_jacobians.append(jacobian)
        all_labels.append(labels)
    
    # concatenate Jacobians and labels from all batches
    full_jacobian = torch.cat(all_jacobians, dim=0)
    full_labels = torch.cat(all_labels, dim=0)
    
    # compute the NTK:
    ntk = torch.mm(full_jacobian, full_jacobian.t())
    
    # return NTK and labels as NumPy arrays 
    return ntk.cpu().detach().numpy(), full_labels.cpu().numpy()


def compute_pairwise_distances(models, dataloader):
    n = len(models)
    distance_matrix = np.zeros((n, n))
    
    for i in range(n):
        for j in range(i, n):
            ntk_i, _ = compute_ntk_full_at_epoch(models[i], dataloader)
            ntk_j, _ = compute_ntk_full_at_epoch(models[j], dataloader)
            distance = kernel_distance(ntk_i, ntk_j)
            distance_matrix[i, j] = distance
            distance_matrix[j, i] = distance
    
    return distance_matrix

def plot_heatmap(distance_matrix, param_list, save_path):
    plt.figure(figsize=(10, 8))
    sns.heatmap(distance_matrix, annot=False, cmap='viridis')
    plt.title(f'Kernel Distance Heatmap for {param_list}')
    plt.xlabel('Model Checkpoint')
    plt.ylabel('Model Checkpoint')
    plt.savefig(save_path)
    plt.close()

In [46]:
alphas = [0.5, 0.001]
alpha = alphas[0]

steps = [1, 10, 100, 500, 1000, 5000, 10000, 20000, 30000, 40000, 50000, 60000, 70000, 80000, 90000, 100000]

for alpha in alphas:
    print(alpha)
    MLPmodels = []
    mlp0 = create_mlp(depth, width, activation_fn, .5)
    with torch.no_grad():
        for p in mlp0.parameters():
            p.data = initialization_scale * p.data

    MLPmodels.append(mlp0)
    for step in steps:
        checkpoint = torch.load(os.path.join('/data/cici/Geometry/new/MNIST/checkpoints', f'alpha_{alpha}_{step}.pth'))
        mlp = create_mlp(depth, width, activation_fn, .5)
        mlp.load_state_dict(checkpoint)
        MLPmodels.append(mlp)

    distance_matrix = compute_pairwise_distances(MLPmodels, test_loader)
    plot_heatmap(distance_matrix, f'MNIST, alpha = {alpha}',f'./K-dist_{alpha}_test.pdf')
    np.save(f'K-dist_{alpha}_test.npy', distance_matrix)
    distance_matrix = compute_pairwise_distances(MLPmodels, train_loader)
    plot_heatmap(distance_matrix, f'MNIST, alpha = {alpha}',f'./K-dist_{alpha}_train.pdf')
    np.save(f'K-dist_{alpha}_train.npy', distance_matrix)


0.5
0.001
