### Three-Layer NTK

#### [NTK paper](https://papers.nips.cc/paper/2019/file/c4ef9c39b300931b69a36fb3dbb8d60e-Paper.pdf)

#### [REU overleaf doc](https://www.overleaf.com/project/606a0e1c8d7e5e3b95b62e9a)

#### [GitHub Repo for this project](https://github.com/genglinliu/NTK-study)

In [2]:
import numpy as np
import matplotlib.pyplot as plt

from tqdm import tqdm

import warnings
warnings.filterwarnings('ignore')

In [3]:
"""
activation functions and their derivatives
"""

def relu_k(x, k):
    return np.maximum(x**k, 0)

def d_relu_k(x, k):
    return k * x**(k-1) * (x > 0)

def relu(x):
    return np.maximum(x, 0)

def d_relu(x):
    return 1.0 * (x > 0)

def sin(x, k=None):
    return np.sin(x)

def cos(x, k=None):
    return np.cos(x)

def d_cos(x, k=None):
    return -np.sin(x)

In [4]:
def init_inputs(num_inputs=200):
    # sin(theta) and cos(theta), theta ~ (0, 2*pi) uniformly sampled
    # take 100 points on the unit circle
    theta = np.linspace(0.0, 2 * np.pi, num=num_inputs)
    x = np.asarray((np.cos(theta), np.sin(theta))) # (2, 100)
    return x

In [5]:
# check kernel symmetry and positive definitiveness
def check(matrix, tol=1e-10):
    if not np.all(np.abs(matrix-matrix.T) < tol):
        print("warning: kernel is not symmetric")
    if not np.all(np.linalg.eigvals(matrix) >= -tol):
        print("warning: kernel is not positive semi-definite")

In [6]:
"""
see page 3-4 of the paper
We can do this analytically 

STEP 1
    K_0(x, x') = sigma_0(x, x') = <x, x'> = (x^T).dot(x')

STEP 2
    B_1 = [sigma_0(x, x'), sigma_0(x, x'),
           sigma_0(x, x'), sigma_0(x', x')]
STEP 3
    sigma_1(x, x') = 2 E(u, v)~N(0, B_1) [activation(u) activation(v)]
    sigma_1'(x, x') = 2 E(u, v)~N(0, B_1) [activation'(u) activation'(v)]

STEP 4
    K_1(x, x') = sigma_1(x, x') + K_0(x, x') sigma_1'(x, x')

experiment:
1. first hidden layer: smooth (sin/cos), second: non-smooth (ReLU^k)
2. reverse the order of activation
"""

# step 1
# in vectorized version we know x == x', see function init_inputs()
x = init_inputs()
K_0  = np.dot(x.T, x)
sigma_0 = np.dot(x.T, x)

# step 2
