In [None]:
import jax.numpy as np
import tensorflow as tf
import neural_tangents as nt

from matplotlib import pyplot as plt
%matplotlib inline
%config InlineBackend.figure_format = 'retina'

In [None]:
from jax import random
from neural_tangents import stax

N = 600
P = 1000

gamma = P/N

init_fn, apply_fn, kernel_fn = stax.serial(
    stax.Dense(P), #stax.Identity(),
    stax.Dense(1)
)


key1, key2 = random.split(random.PRNGKey(1))
x1 = random.normal(key1, (N, 1))
x2 = random.normal(key2, (N, 1))

In [None]:
W = random.normal(key=key2, shape=(P, 1)) / np.sqrt(P)
def apply_fn(W, x, rng=None):
    return np.outer(x, W).sum(-1)

In [None]:
kernel_fn = nt.empirical_kernel_fn(apply_fn)
K = kernel_fn(x1, x1, W, 'ntk')

In [None]:
eigvals, eigvecs = np.linalg.eigh(K)
hist = plt.hist(eigvals[:-1], bins = 100)

In [None]:
plt.plot(eigvecs[-7])

In [None]:
v = eigvecs[-1]

In [None]:
plt.imshow(np.outer(v, v))

In [None]:
np.outer(v, v)

In [None]:
plt.scatter(np.arange(len(v)), v)

In [None]:
K.mean()

In [None]:
plt.imshow(K)

In [None]:
a = ((gamma-1)/gamma)**2
b = ((gamma+1)/gamma)**2
def H(x): 
    return (gamma**2 / (2 * np.pi * x)  * np.sqrt((b-x)*(x-a))) * (x > a) * (x < b)


x = np.arange(a, b, .00001)
plt.plot(x, H(x))

In [None]:
import neural_tangents as nt

x_train, x_test = x1, x2
y_train = (x1[:,0] + x1[:,1]*np.cos(x1[:,0])).reshape(-1,1) + random.uniform(key1, shape=(100, 1))

y_test_nngp = nt.predict.gp_inference(kernel_fn, x_train, y_train, x_test,
                                      get='nngp')
# (20, 1) np.ndarray test predictions of an infinite Bayesian network

y_test_ntk = nt.predict.gp_inference(kernel_fn, x_train, y_train, x_test,
                                     get='ntk')
# (20, 1) np.ndarray test predictions of an infinite continuous
# gradient descent trained network at convergence (t = inf)

In [None]:
plt.scatter(x_test[:,0], y_test_ntk)

In [None]:
plt.scatter(x_train[:,0], y_train)