### Three-Layer NTK

#### [NTK paper](https://papers.nips.cc/paper/2019/file/c4ef9c39b300931b69a36fb3dbb8d60e-Paper.pdf)

#### [REU overleaf doc](https://www.overleaf.com/project/606a0e1c8d7e5e3b95b62e9a)

#### [GitHub Repo for this project](https://github.com/genglinliu/NTK-study)

In [1]:
import numpy as np
import matplotlib.pyplot as plt

from tqdm import tqdm

import warnings
warnings.filterwarnings('ignore')

In [2]:
"""
activation functions and their derivatives
"""

def relu_k(x, k):
    return np.maximum(x**k, 0)

def d_relu_k(x, k):
    return k * x**(k-1) * (x > 0)

def relu(x):
    return np.maximum(x, 0)

def d_relu(x):
    return 1.0 * (x > 0)

def sin(x, k=None):
    return np.sin(x)

def cos(x, k=None):
    return np.cos(x)

def d_cos(x, k=None):
    return -np.sin(x)

In [3]:
def init_inputs(num_inputs=200):
    # sin(theta) and cos(theta), theta ~ (0, 2*pi) uniformly sampled
    # take 100 points on the unit circle
    theta = np.linspace(0.0, 2 * np.pi, num=num_inputs)
    x = np.asarray((np.cos(theta), np.sin(theta))) # (2, 100)
    return x

In [4]:
# check kernel symmetry and positive definitiveness
def check(matrix, tol=1e-10):
    if not np.all(np.abs(matrix-matrix.T) < tol):
        print("warning: kernel is not symmetric")
    if not np.all(np.linalg.eigvals(matrix) >= -tol):
        print("warning: kernel is not positive semi-definite")

#### Enrty-wise Computation

In [83]:
"""
see page 3-4 of the paper

STEP 0
    x and x' are 2-dimensional points of shape (2,)
    K_0(x, x') = sigma_0(x, x') = <x, x'> = (x^T).dot(x')

STEP 1a
    B_1 = [sigma_0(x, x),  sigma_0(x, x'),
           sigma_0(x, x'), sigma_0(x', x')]
           
STEP 1b
    sigma_1(x, x') = 2 E(u, v)~N(0, B_1) [activation(u) activation(v)]
    sigma_1'(x, x') = 2 E(u, v)~N(0, B_1) [activation'(u) activation'(v)]

STEP 1c
    K_1(x, x') = sigma_1(x, x') + K_0(x, x') sigma_1'(x, x')

"""
#------------
# step 0
#------------
x = np.array([1, 2])
y = np.array([1, 2])

def K_0(x, y):
    return np.dot(x.T, y)

def sigma_0(x, y):
    return np.dot(x.T, y)

#------------
# step 1a
#------------
def B_1(x, y):
    v1 = np.vstack((sigma_0(x, x), sigma_0(x, y)))
    v2 = np.vstack((sigma_0(x, y), sigma_0(y, y)))

    B_1 = np.hstack((v1, v2)) # (400, 400)
    return B_1

#------------
# step 1b
#------------
def sigma_1_and_prime(x, y, activation, k):
    
    activation_map = {
        'relu_k': [relu_k, d_relu_k],
        'sin': [sin, cos],
        'cos': [cos, d_cos]
    }
        
    mean = np.zeros(x.shape[0])
    cov = B_1(x, y)
    expectation = 0
    num_samples = 1000
    # generate 1000 pairs of (u, v), apply activation array-wise then take average
    sample_1 = np.random.multivariate_normal(mean, cov, num_samples) # (num_samples, 2)
    u1, v1 = sample_1[:, 0], sample_1[:, 1] # (num_samples,)
    
    sample_2 = np.random.multivariate_normal(mean, cov, num_samples) # (num_samples, 2)
    u2, v2 = sample_2[:, 0], sample_2[:, 1] # (num_samples,)
    
    activation_func, d_activation = activation_map[activation]
    expectation_1 = np.mean(activation_func(u1, k) * activation_func(v1, k))
    expectation_2 = np.mean(d_activation(u2, k) * d_activation(v2, k))

    sigma_1 = 2 * expectation_1
    sigma_1_prime = 2 * expectation_2
    
    return sigma_1, sigma_1_prime

#------------
# step 1c
#------------
def K_1(x, y, activation, k=1):
    sigma_1, sigma_1_prime = sigma_1_and_prime(x, y, activation, k)
    return sigma_1 + K_0(x, y) * sigma_1_prime

In [84]:
K_1(x, y, 'relu_k')

10.355288932403742

In [None]:
"""
------------------------------------------------------------------
STEP 2a
    B_2 = [sigma_1(x, x),  sigma_1(x, x'),
           sigma_1(x, x'), sigma_1(x', x')]
           
STEP 2b
    sigma_2(x, x') = 2 E(u, v)~N(0, B_2) [activation(u) activation(v)]
    sigma_2'(x, x') = 2 E(u, v)~N(0, B_2) [activation'(u) activation'(v)]

STEP 2c
    K_2(x, x') = sigma_2(x, x') + K_1(x, x') sigma_2'(x, x')
    
    
experiment:
1. first hidden layer: smooth (sin/cos), second: non-smooth (ReLU^k)
2. reverse the order of activation

"""