# **NTK matrix of a two-layer network**

In [4]:
import numpy as np
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
%matplotlib inline


class Net(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
      "Define the model: a simple two-layer net."
      super(Net, self).__init__()
      self.lin1 = nn.Linear(input_dim, hidden_dim)
      self.lin2 = nn.Linear(hidden_dim, output_dim)
    
    def forward(self, x):
      h = torch.relu(self.lin1(x))
      return self.lin2(h)

def get_grads(net):
    d = []
    for name, p in net.named_parameters():
        if p.grad is not None:
            d.append(p.grad.clone().detach().numpy().flatten())
    return np.concatenate(d)

def _grad_model(model, i):
  model.zero_grad()
  out = model(x[i])
  out.backward()
  return get_grads(model)

def compute_ntk(model, x):
  # # Define the NTK matrix and compute it. 
  ntk = np.zeros((num_data, num_data))
  for i in range(num_data):
    grads1 = _grad_model(model, i)
    for j in range(num_data):
      grads2 = _grad_model(model, j)
      ntk[i, j] =  grads1 @ grads2
  return ntk


input_dim, num_data = 20, 100
model = Net(input_dim, 10, 1)
# # Define the input data; random samples for the demonstration.
x = torch.randn(num_data, input_dim)
ntk = compute_ntk(model, x)

In [3]:
plt.imshow(ntk, cmap='hot')
plt.colorbar()
plt.show()

# **NTK matrix of a two-layer network with a Hadamard product**

In [2]:
class Net_HP(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
      "Define the model: a simple two-layer net."
      super(Net_HP, self).__init__()
      self.lin1 = nn.Linear(input_dim, hidden_dim)
      self.lin2 = nn.Linear(input_dim, hidden_dim)
      self.lin3 = nn.Linear(hidden_dim, output_dim)
    
    def forward(self, x):
      h = torch.relu(self.lin1(x)) * torch.relu(self.lin2(x))
      return self.lin3(h)


model_HP = Net_HP(input_dim, 10, 1)
ntk_HP = compute_ntk(model_HP, x)
plt.imshow(ntk, cmap='hot')
plt.colorbar()

In [1]:
# # Are the two NTKs the same given that the data are the same?
plt.imshow(ntk - ntk_HP, cmap='hot')
plt.colorbar()