In [180]:
import torch
import torch.nn as nn
import torch.optim as optim

# Define the model
class TwoLayerModel(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(TwoLayerModel, self).__init__()
        self.layer1 = nn.Linear(input_size, hidden_size, bias=False)
        self.layer2 = nn.Linear(hidden_size, output_size, bias=False)
    
    def forward(self, x):
        x = self.layer1(x)
        x = self.layer2(x)
        return x

# Example usage
input_size = 10
hidden_size = 5
output_size = 1

model = TwoLayerModel(input_size, hidden_size, output_size)
print(model)

model = TwoLayerModel(input_size, hidden_size, output_size)

# Generate some dummy data
inputs1 = torch.randn(1, input_size)
inputs2 = torch.randn(1, input_size)
inputs = torch.cat([inputs1, inputs2], dim=0)

TwoLayerModel(
  (layer1): Linear(in_features=10, out_features=5, bias=False)
  (layer2): Linear(in_features=5, out_features=1, bias=False)
)


In [186]:
# compute ntk with torch for sanity
from functorch import make_functional, vmap, vjp, jvp, jacrev
fnet, params = make_functional(model)
def fnet_single(params, x):
    return fnet(params, x.unsqueeze(0)).squeeze(0)
def empirical_ntk_jacobian_contraction(fnet_single, params, x1, x2):
    # Compute J(x1)
    jac1 = vmap(jacrev(fnet_single), (None, 0))(params, x1)
    jac1 = [j.flatten(2) for j in jac1]
    
    # Compute J(x2)
    jac2 = vmap(jacrev(fnet_single), (None, 0))(params, x2)
    jac2 = [j.flatten(2) for j in jac2]
    
    # Compute J(x1) @ J(x2).T
    result = torch.stack([torch.einsum('Naf,Mbf->NMab', j1, j2) for j1, j2 in zip(jac1, jac2)])
    result = result.sum(0)
    return result
m12 = empirical_ntk_jacobian_contraction(fnet_single, params, inputs1, inputs2).squeeze().flatten()
m11 = empirical_ntk_jacobian_contraction(fnet_single, params, inputs1, inputs1).squeeze().flatten()
m22 = empirical_ntk_jacobian_contraction(fnet_single, params, inputs2, inputs2).squeeze().flatten()
print(m12, m11, m22)

# print(torch.linalg.eigvalsh(result), torch.linalg.eigvalsh(result).sum())

tensor(0.6454, grad_fn=<SqueezeBackward0>) tensor(3.2474, grad_fn=<SqueezeBackward0>) tensor(3.6465, grad_fn=<SqueezeBackward0>)


In [174]:

# # Forward pass
# model.zero_grad()
# outputs = model(inputs1)
# gradients = torch.autograd.grad(outputs, model.parameters(), grad_outputs=torch.ones_like(outputs), create_graph=True)
# print(gradients[0].shape, gradients[1].shape)
# flattened_gradients1 = torch.cat([grad.view(-1) for grad in gradients])

# # Forward pass
# model.zero_grad()
# outputs = model(inputs2)
# gradients = torch.autograd.grad(outputs, model.parameters(), grad_outputs=torch.ones_like(outputs), create_graph=True)
# flattened_gradients2 = torch.cat([grad.view(-1) for grad in gradients])

# m11 = flattened_gradients1@flattened_gradients1
# m22 = flattened_gradients2@flattened_gradients2
# m12 = flattened_gradients1@flattened_gradients2
# ntk_matrix = torch.tensor([[m11, m12], [m12, m22]])
# print(ntk_matrix)

# print(torch.linalg.eigvalsh(ntk_matrix).sum())

In [179]:
ntk_eigenvalues = kernel_eigenvalues(model, inputs, cross_entropy=False, print_progress=False, top_n=1, tol=1e-6)
print(ntk_eigenvalues)

[tensor(17.7620, grad_fn=<SumBackward0>)]
