# Neural-Tangent-Kernel
A python implementation of the neural tangent kernel (NTK)

Original NTK paper: https://arxiv.org/abs/1806.07572
Blog post: https://jackhmiller.github.io/My-DS-Blog/2021/10/02/NTK.html

In [None]:
import torch
from torch import optim, nn
import copy
import warnings
from pylab import *
import imageio
warnings.filterwarnings('ignore')

In [None]:
rcParams['figure.figsize'] = 12,9
rcParams['axes.grid'] = True
rcParams['font.size'] = 20
rcParams['lines.linewidth'] = 3
DEFAULT_COLORS = rcParams['axes.prop_cycle'].by_key()['color']

In [None]:
#!pip install "jax[cpu]===0.3.14" -f https://whls.blob.core.windows.net/unstable/index.html --use-deprecated legacy-resolver

In [None]:
act_dict = {
    'relu': nn.ReLU,
    'tanh': nn.Tanh
}

In [None]:
class ZeroOutput(nn.Module):
    """Zero the output of a model by subtracting out a copy of it."""
    def __init__(self, model):
        super().__init__()
        self.init_model = [copy.deepcopy(model).eval()]
        
        self.model = model
    
    def forward(self, x):
        return self.model(x) - self.init_model[0](x)

In [None]:
class Scale(nn.Module):
    """Scale the output of the model by alpha."""
    def __init__(self, model, alpha):
        super().__init__()
        self.model = model
        self.alpha = alpha
        
    def forward(self, x):
        return self.alpha*self.model(x)

In [None]:
def simple_net(width,
              bias=True,
              zero_output=True,
              alpha=1,
              hidden_layers=1,
              act='relu',
              **kwargs):
    """A simple 1d input to 1d output deep ReLU network."""
    
    activation = act_dict[act]
    model = nn.Sequential(nn.Linear(1, width, bias=bias),
                         activation(),
                         *[layer for _ in range(hidden_layers-1)
                          for layer in [nn.Linear(width, width, bias=bias), activation()]],
                         nn.Linear(width, 1, bias=bias))
    if zero_output:
        model = ZeroOutput(model)
    model = Scale(model, alpha)
    
    return model

In [None]:
def ntk(model, x):
    """Calculate the neural tangent kernel of the model on the inputs."""
    
    out = model(x)
    p_vec = nn.utils.parameters_to_vector(model.parameters())
    p, = p_vec.shape
    n, outdim = out.shape
    
    features = torch.zeros(n, p, requires_grad=False)
    
    for i in range(n):
        model.zero_grad()
        out[i].backward(retain_graph=True)
        p_grad = torch.tensor([], requires_grad=False)
        for p in model.parameters():
            p_grad = torch.cat((p_grad, p.grad.reshape(-1)))
        features[i, :] = p_grad
        
    tangent_kernel = features@features.t()
    return features, tangent_kernel

In [None]:
def gd(model, xdata, ydata,
       iters=100,
       lr=1e-3,
       alpha=1,
       eps=1e-10):
    """Gradient Descent using normalized (depending on alpha) L2 loss of model"""
    opt = optim.SGD(model.parameters(), lr=lr)
    losses = []
    
    litem = -1
    t = range(iters)
    for i in t:
        out = model(xdata)
        loss = 1/(alpha**2) * nn.MSELoss()(out, ydata)
        litem = loss.item()*(alpha**2)
        losses.append(litem)
        if litem < eps:
            return losses
        opt.zero_grad()
        loss.backward()
        opt.step()
    return losses

In [None]:
def linear_gd(A, b, x0,
             iters=100,
              lr=1e-3,
              alpha=1,
             eps=1e-10):
    m, p = A.shape
    x = nn.Parameter(x0.clone())
    opt = optim.SGD([x], lr=lr)
    losses = []
    
    litem = -1
    for i in range(iters):
        out = A@(x-x0)
        loss = 1/(alpha**2)*nn.MSELoss()(out.speeze(), b)
        litem = loss.item()*(alpha**2)
        losses.append(litem)
        
        if litem < eps:
            return losses
        opt.zero_grad()
        loss.backward()
        opt.step()
        
    return losses

In [None]:
xs = {}
budges = {}
losses = {}

eps = 1e-10
iters = 1000
steps_per_iter = 1
lr = 1e-3

In [None]:
xin = torch.tensor([-3, 0.5]).unsqueeze(1)
yin = torch.tensor([2, -1.0]).unsqueeze(1)

In [None]:
for m in [10, 100, 1000]:
    f = simple_net(width=m, bias=True, alpha=1, zero_output=False, hidden_layers=2)
    A0, tk0 = ntk(f, xin)
    weights0 = list(f.modules())[4].weight.detach().numpy().copy()
    allw0 = nn.utils.parameters_to_vector(f.parameters()).detach().numpy().copy()

    imgs = []
    imgs2 = []
    xvals = [0]
    budgevals = [0]
    lossvals = []
    for i in range(iters):
        ls = gd(f, xin, yin, alpha=1, iters=steps_per_iter, lr=lr, progress_bar=False)
        lossvals.extend(ls)
        weights = list(f.modules())[4].weight.detach().numpy().copy()
        allw = nn.utils.parameters_to_vector(f.parameters()).detach().numpy().copy()
        budge = norm(allw-allw0)/norm(allw0)
        xvals.append((i+1)*steps_per_iter)
        budgevals.append(budge)

        if ls[-1]<eps:
            break
    
    xs[m] = xvals.copy()
    budges[m] = budgevals.copy()
    losses[m] = lossvals.copy()

In [None]:
title(f"Relative change in norm of weights from initialization")
for m in sorted(xs.keys()):
    plot(xs[m], budges[m], label=f"Width {m}")
xlabel("Step (n)")
ylabel(r"$\frac{\Vert w(n) -  w(0) \Vert}{\Vert w(0) \Vert}$")
legend()
show()

In [None]:
figure()
title(f"Training loss")
for m in sorted(xs.keys()):
    plot(arange(len(losses[m])), losses[m], label=f"Width {m}")
xlabel("Step")
ylabel("Loss")
legend()
show()