In [32]:
import numpy as np
from torchvision import datasets, transforms
from torch.utils.data import Subset, DataLoader

import torch
import torch.nn as nn
from functorch import make_functional, vmap, vjp, jvp, jacrev
device = 'cuda'

class DenseNN(nn.Module):
    """
    Fully connected neural network
    """
    def __init__(self, num_hidden_units):
        super(DenseNN, self).__init__()
        self.num_hidden_units = num_hidden_units
        self.l1 = nn.Linear(784, num_hidden_units)
        self.activation_fun = nn.ReLU()
        self.l2 = nn.Linear(num_hidden_units, 10)

    def forward(self, x):
        return self.l2(self.activation_fun(self.l1(x)))
    
#x_train = torch.randn(2, 784, device=device)


transform=transforms.Compose([
            transforms.ToTensor(),
            torch.flatten
            ])
data_rng = np.random.RandomState(12345)

dataset1 = datasets.MNIST('./mnist_data', train=True, download=True,
                    transform=transform)
dataset2 = datasets.MNIST('./mnist_data', train=False, download=True,
                    transform=transform)
dataset1 = Subset(dataset1, data_rng.choice(len(dataset1), 500, replace=False))
train_loader = DataLoader(dataset1, batch_size=len(dataset1), shuffle=True)
i, (train_data, train_labels) = next(enumerate(train_loader))
train_data = train_data.to(torch.device("cuda"))

test_loader = DataLoader(dataset2, batch_size=1000, shuffle=False)
i, (test_data, test_labels) = next(enumerate(test_loader))
test_data = test_data.to(torch.device("cuda"))

net = DenseNN(5).to(device)
fnet, params = make_functional(net)

  warn_deprecated('make_functional', 'torch.func.functional_call')


In [33]:
def fnet_single(params, x):
    return fnet(params, x.unsqueeze(0)).squeeze(0)

def empirical_ntk_jacobian_contraction(fnet_single, params, x1, x2, compute="full"):
    # Compute J(x1)
    jac1 = vmap(jacrev(fnet_single), (None, 0))(params, x1)
    jac1 = [j.flatten(2) for j in jac1]
    
    # Compute J(x2)
    jac2 = vmap(jacrev(fnet_single), (None, 0))(params, x2)
    jac2 = [j.flatten(2) for j in jac2]
    
    # Compute J(x1) @ J(x2).T
    einsum_expr = None
    if compute == 'full':
        einsum_expr = 'Naf,Mbf->NMab'
    elif compute == 'trace':
        einsum_expr = 'Naf,Maf->NM'
    elif compute == 'diagonal':
        einsum_expr = 'Naf,Maf->NMa'
    else:
        assert False

    result = torch.stack([torch.einsum(einsum_expr, j1, j2) for j1, j2 in zip(jac1, jac2)])
    result = result.sum(0)
    return result

In [34]:
result = empirical_ntk_jacobian_contraction(fnet_single, params, train_data, train_data, compute="trace")
print(result.shape)

torch.Size([500, 500])


  warn_deprecated('jacrev')
  warn_deprecated('vmap', 'torch.vmap')


In [35]:
result

tensor([[ 10.0000,  10.0000,  10.0000,  ...,  10.0000,  10.0000,  10.0000],
        [ 10.0000, 123.2491,  36.1689,  ...,  54.8940,  52.2415,  21.3331],
        [ 10.0000,  36.1689, 149.8487,  ...,  76.0656,  23.9620,  10.0000],
        ...,
        [ 10.0000,  54.8940,  76.0656,  ..., 143.0298,  41.0103,  10.0000],
        [ 10.0000,  52.2415,  23.9620,  ...,  41.0103, 107.5503,  30.5195],
        [ 10.0000,  21.3331,  10.0000,  ...,  10.0000,  30.5195,  62.5905]],
       device='cuda:0', grad_fn=<SumBackward1>)

In [36]:
from sklearn import svm

clf = svm.SVC(C=0.1, kernel="precomputed")
model = clf.fit(result.numpy(force=True), train_labels.numpy(force=True))

In [38]:
test_labels.numpy(force=True)

array([7, 2, 1, 0, 4, 1, 4, 9, 5, 9, 0, 6, 9, 0, 1, 5, 9, 7, 3, 4, 9, 6,
       6, 5, 4, 0, 7, 4, 0, 1, 3, 1, 3, 4, 7, 2, 7, 1, 2, 1, 1, 7, 4, 2,
       3, 5, 1, 2, 4, 4, 6, 3, 5, 5, 6, 0, 4, 1, 9, 5, 7, 8, 9, 3, 7, 4,
       6, 4, 3, 0, 7, 0, 2, 9, 1, 7, 3, 2, 9, 7, 7, 6, 2, 7, 8, 4, 7, 3,
       6, 1, 3, 6, 9, 3, 1, 4, 1, 7, 6, 9, 6, 0, 5, 4, 9, 9, 2, 1, 9, 4,
       8, 7, 3, 9, 7, 4, 4, 4, 9, 2, 5, 4, 7, 6, 7, 9, 0, 5, 8, 5, 6, 6,
       5, 7, 8, 1, 0, 1, 6, 4, 6, 7, 3, 1, 7, 1, 8, 2, 0, 2, 9, 9, 5, 5,
       1, 5, 6, 0, 3, 4, 4, 6, 5, 4, 6, 5, 4, 5, 1, 4, 4, 7, 2, 3, 2, 7,
       1, 8, 1, 8, 1, 8, 5, 0, 8, 9, 2, 5, 0, 1, 1, 1, 0, 9, 0, 3, 1, 6,
       4, 2, 3, 6, 1, 1, 1, 3, 9, 5, 2, 9, 4, 5, 9, 3, 9, 0, 3, 6, 5, 5,
       7, 2, 2, 7, 1, 2, 8, 4, 1, 7, 3, 3, 8, 8, 7, 9, 2, 2, 4, 1, 5, 9,
       8, 7, 2, 3, 0, 4, 4, 2, 4, 1, 9, 5, 7, 7, 2, 8, 2, 6, 8, 5, 7, 7,
       9, 1, 8, 1, 8, 0, 3, 0, 1, 9, 9, 4, 1, 8, 2, 1, 2, 9, 7, 5, 9, 2,
       6, 4, 1, 5, 8, 2, 9, 2, 0, 4, 0, 0, 2, 8, 4,

In [39]:
p = model.predict(empirical_ntk_jacobian_contraction(fnet_single, params, test_data, train_data, compute="trace").numpy(force=True))

  warn_deprecated('jacrev')
  warn_deprecated('vmap', 'torch.vmap')


In [44]:
sum(p == test_labels.numpy(force=True)) / len(p)

0.753