In [271]:
import torch
import torch.nn as nn
import torch.optim as optim

# Define the model
class TwoLayerModel(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(TwoLayerModel, self).__init__()
        self.layer1 = nn.Linear(input_size, hidden_size, bias=False)
        self.layer2 = nn.Linear(hidden_size, output_size, bias=False)
    
    def forward(self, x):
        x = self.layer1(x)
        x = self.layer2(x)
        return x

# Example usage
input_size = 10
hidden_size = 5
output_size = 3

model = TwoLayerModel(input_size, hidden_size, output_size)
print(model)

model = TwoLayerModel(input_size, hidden_size, output_size)

# Generate some dummy data
inputs1 = torch.randn(1, input_size)
inputs2 = torch.randn(1, input_size)
inputs = torch.cat([inputs1, inputs2], dim=0)

TwoLayerModel(
  (layer1): Linear(in_features=10, out_features=5, bias=False)
  (layer2): Linear(in_features=5, out_features=3, bias=False)
)


In [272]:
# compute ntk with torch for sanity
from functorch import make_functional, vmap, vjp, jvp, jacrev
fnet, params = make_functional(model)
def fnet_single(params, x):
    return fnet(params, x.unsqueeze(0)).squeeze(0)
def empirical_ntk_jacobian_contraction(fnet_single, params, x1, x2):
    # Compute J(x1)
    jac1 = vmap(jacrev(fnet_single), (None, 0))(params, x1)
    jac1 = [j.flatten(2) for j in jac1]
    
    # Compute J(x2)
    jac2 = vmap(jacrev(fnet_single), (None, 0))(params, x2)
    jac2 = [j.flatten(2) for j in jac2]
    
    # Compute J(x1) @ J(x2).T
    result = torch.stack([torch.einsum('Naf,Mbf->NMab', j1, j2) for j1, j2 in zip(jac1, jac2)])
    result = result.sum(0)
    return result

ntk_matrix = empirical_ntk_jacobian_contraction(fnet_single, params, inputs, inputs)
ntk_matrix = torch.einsum('abij->aibj', ntk_matrix).reshape(output_size*2, output_size*2).detach()
ntk_eigenvalues = torch.linalg.eigvalsh(ntk_matrix).sort(descending=True).values[:3]
print(ntk_eigenvalues) # from torch computing the ntk matrix

tensor([11.3608,  9.1758,  7.1037])


In [273]:
ntk_eigenvalues = kernel_eigenvalues(model, inputs, cross_entropy=False, print_progress=False, top_n=3, tol=1e-6)
ntk_eigenvalues = torch.stack([torch.tensor(x.item()) for x in ntk_eigenvalues])
print(ntk_eigenvalues) # from the library

tensor([11.3608,  9.1758,  7.1037])
