In [1]:
import torch
import numpy as np
from scipy.sparse.linalg import lsqr, LinearOperator, cg#, minres
import matplotlib.pyplot as plt
from torch import nn
from torch.utils.data import DataLoader, Subset
from torchvision import datasets
from torchvision.transforms import Compose, ToTensor, v2
from torch.optim import lr_scheduler
from functorch import make_functional, vmap, vjp, jvp, jacrev
import torch.nn.functional as F
import time



In [2]:
training_data = datasets.MNIST(
    root="data/MNIST",
    train=True,
    download=True,
    transform=ToTensor()
)

test_data = datasets.MNIST(
    root="data/MNIST",
    train=False,
    download=True,
    transform=ToTensor()
)

N_TRAIN = 1000
N_TEST = 100
N_OUTPUT = 10

training_data = torch.utils.data.Subset(training_data,range(N_TRAIN))
test_data = torch.utils.data.Subset(test_data,range(N_TEST))
    

def train_loop(dataloader, model, loss_fn, optimizer):
    size = len(dataloader.dataset)
    # Set the model to training mode - important for batch normalization and dropout layers
    # Unnecessary in this situation but added for best practices
    model.train()
    for batch, (X, y) in enumerate(dataloader):
        # Compute prediction and loss
        pred = model(X)
        loss = loss_fn(pred, y)

        # Backpropagation
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        if batch % 10 == 0:
            loss, current = loss.item(), (batch + 1) * len(X)
            print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")


def test_loop(dataloader, model, loss_fn):
    # Set the model to evaluation mode - important for batch normalization and dropout layers
    # Unnecessary in this situation but added for best practices
    model.eval()
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    test_loss, correct = 0, 0

    # Evaluating the model with torch.no_grad() ensures that no gradients are computed during test mode
    # also serves to reduce unnecessary gradient computations and memory usage for tensors with requires_grad=True
    with torch.no_grad():
        for X, y in dataloader:
            pred = model(X)
            test_loss += loss_fn(pred, y).item()
            correct += (pred.argmax(1) == y).type(torch.float).sum().item()

    test_loss /= num_batches
    correct /= size
    print(f"Test Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")

In [4]:
class CNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 8, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(8, 3, 5)
        self.fc1 = nn.Linear(3 * 4 * 4, 10)
        # self.fc2 = nn.Linear(120,84)
        # self.fc3 = nn.Linear(84,10)
        self.relu = nn.ReLU()

    def forward(self, x):
        x = self.pool(self.relu(self.conv1(x)))
        x = self.pool(self.relu(self.conv2(x)))
        x = torch.flatten(x,len(x.shape)-3)
        x = self.fc1(x)
        # x = self.relu(self.fc1(x))
        # x = self.relu(self.fc2(x))
        # x = self.fc3(x)
        return x
        
    
cnn = CNN()
print("Number of parameters p = {}".format(sum(p.numel() for p in cnn.parameters() if p.requires_grad)))

Number of parameters p = 1301


In [5]:
learning_rate = 1e-1
batch_size = 50
epochs = 50

train_dataloader = DataLoader(training_data, batch_size)
test_dataloader = DataLoader(test_data, batch_size)

loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(cnn.parameters(), lr=learning_rate)


for t in range(epochs):
    print(f"Epoch {t+1}\n-------------------------------")
    train_loop(train_dataloader, cnn, loss_fn, optimizer)
    test_loop(test_dataloader, cnn, loss_fn)
print("Done!")

Epoch 1
-------------------------------
loss: 2.303186  [   50/ 1000]
loss: 2.288059  [  550/ 1000]
Test Accuracy: 8.0%, Avg loss: 2.243985 

Epoch 2
-------------------------------
loss: 2.247828  [   50/ 1000]
loss: 2.178858  [  550/ 1000]
Test Accuracy: 52.0%, Avg loss: 1.636130 

Epoch 3
-------------------------------
loss: 1.730775  [   50/ 1000]
loss: 1.503000  [  550/ 1000]
Test Accuracy: 73.0%, Avg loss: 0.807350 

Epoch 4
-------------------------------
loss: 0.952932  [   50/ 1000]
loss: 1.180496  [  550/ 1000]
Test Accuracy: 80.0%, Avg loss: 0.616059 

Epoch 5
-------------------------------
loss: 0.744069  [   50/ 1000]
loss: 0.914102  [  550/ 1000]
Test Accuracy: 84.0%, Avg loss: 0.531699 

Epoch 6
-------------------------------
loss: 0.600971  [   50/ 1000]
loss: 0.749042  [  550/ 1000]
Test Accuracy: 85.0%, Avg loss: 0.476241 

Epoch 7
-------------------------------
loss: 0.503416  [   50/ 1000]
loss: 0.677698  [  550/ 1000]
Test Accuracy: 87.0%, Avg loss: 0.434928 



In [6]:
## Find NTK
def flatten_extend_gradient(parameters):
    flat_list = []
    for parameter in parameters:
        flat_list.extend(parameter.grad.detach().numpy().flatten())
    return flat_list

def gradient_model(model,xi,c):
    ## model needs to have parameters with requires_grad=true
    optimizer.zero_grad()
    model(xi)[c].backward()
    grad_vec = np.array(flatten_extend_gradient(list(model.parameters())))
    return grad_vec

def ntk_single(x1,x2,model,c):
    j1 = gradient_model(model,x1,c)
    j2 = gradient_model(model,x2,c)
    j = j1 @ j2.transpose()
    return j

def ntk_matrix(X1,X2,model,c):
# x1, x2 must become torch variables
    n1 = len(X1)
    n2 = len(X2)
    Kappa = np.empty((n1,n2))
    for i1,x1 in enumerate(X1):
        if type(x1) is tuple:
            x1,_ = x1
        for i2,x2 in enumerate(X2):
            x2 = x2.reshape((1,28,-1))
            if type(x2) is tuple:
                x2,_ = x2
            Kappa[i1,i2] = ntk_single(x1,x2,model,c)
            break
    return Kappa

def MVP_JTX(v,model,X_training,c):
    p = sum(p.numel() for p in model.parameters() if p.requires_grad)
    mvp = np.zeros((p,1))
    for i,(xi,_) in enumerate(X_training):
        g = gradient_model(model,xi,c).reshape((p,1))
        mvp += v[i]*g
    return mvp

def MVP_JX(v,model,X_training,c):
    p = sum(p.numel() for p in model.parameters() if p.requires_grad)
    n = N_TRAIN
    mvp = np.zeros((n,1))
    v = v.reshape((p,1))
    for i,(xi,_) in enumerate(X_training):
        g = gradient_model(model,xi,c).reshape((p,1))
        mvp[i,0] = g.transpose() @ v
    return mvp

def MVP_JJT(v,model,X_training,c):
    x1 = MVP_JTX(v,model,X_training,c)
    x2 = MVP_JX(x1,model,X_training,c)
    return x2

In [7]:
import solvers as solv

In [None]:
for i,(xi,yi) in enumerate(test_data):
    for c in range(N_OUTPUT):
        kappa_xX = ntk_matrix(training_data,xi,cnn,c)
        mvp = lambda v : MVP_JJT(v=v,model=cnn,X_training=training_data, c = c)
        A = LinearOperator((N_TRAIN,N_TRAIN), matvec=mvp)
        b = kappa_xX
        x = solv.CR(A,b,rtol=1e-7,maxit=50)
        break
    break
kappa_xx = ntk_single(xi,xi,cnn,c)
uq = kappa_xx - b.transpose() @ x

In [8]:
for i,(xi,yi) in enumerate(test_data):
    for c in range(N_OUTPUT):
        kappa_xX = ntk_matrix(training_data,xi,cnn,c)
        mvp = lambda v : MVP_JJT(v=v,model=cnn,X_training=training_data, c = c)
        A = LinearOperator((N_TRAIN,N_TRAIN), matvec=mvp)
        b = kappa_xX
        x = solv.CR(A,b,rtol=1e-7,maxit=50)
        break
    break
kappa_xx = ntk_single(xi,xi,cnn,c)
uq = kappa_xx - b.transpose() @ x

-------------------------------------
 ite  |  |rk|/|b|  | |Ark|/|Ab|
-------------------------------------
  1   |  1.00e+00  |  1.00e+00 
  2   |  3.48e-01  |  9.04e-02 
  3   |  1.21e-01  |  1.72e-02 
  4   |  5.46e-02  |  6.07e-03 
  5   |  3.19e-02  |  1.56e-03 
  6   |  2.20e-02  |  1.40e-03 
  7   |  1.47e-02  |  5.11e-04 
  8   |  1.03e-02  |  4.18e-04 
  9   |  7.06e-03  |  9.59e-05 
 10   |  5.71e-03  |  1.40e-04 
-------------------------------------
 ite  |  |rk|/|b|  | |Ark|/|Ab|
-------------------------------------
 11   |  4.85e-03  |  9.21e-05 
 12   |  3.96e-03  |  4.13e-04 
 13   |  3.92e-03  |  2.40e-04 
 14   |  3.45e-03  |  3.62e-05 
 15   |  2.62e-03  |  2.81e-05 
 16   |  2.15e-03  |  1.71e-05 
 17   |  1.85e-03  |  1.26e-05 
 18   |  1.58e-03  |  1.09e-05 
 19   |  1.47e-03  |  1.19e-04 
 20   |  1.32e-03  |  9.42e-06 
-------------------------------------
 ite  |  |rk|/|b|  | |Ark|/|Ab|
-------------------------------------
 21   |  1.16e-03  |  3.20e-05 
 22 

In [9]:
print(uq)
print(cnn(xi))

[[3.50427673]]
tensor([  0.2595,   1.5719,  15.7995,  18.2848, -17.5986,  -6.8794, -18.9326,
         27.6804,  -3.4832,  -4.5646], grad_fn=<AddBackward0>)
