In [None]:
import tensorflow as tf
import neural_tangents as nt
from neural_tangents import stax
import jax.numpy as np
from jax import random

from matplotlib import pyplot as plt
%matplotlib inline
%config InlineBackend.figure_format = 'retina'

In [None]:
def find_h(N, L, d, n=1, bias=True):
    # Modified from https://github.com/mariogeiger/nn_jamming/blob/master/constN.py
    # TODO: bias=False?
    '''
        For a network with: 
        
        d input dimensionality, 
        L layers, 
        N total parameters, 
        n final outputs,
        
        this finds the corresponding width h 
    '''
    assert L >= 1

    if L == 1:
        # solve : N = h*(d+1) + n*(h+1)
        h = (N - n) / (d + n + 1)
    else:
        # solve : N = h*(d+1) + (L-1)*h*(h+1) + n*(h+1)
        h = -(d+L+n - ((d+L+n)**2 + 4*(L-1)*(N-n))**.5)/(2*(L-1))
    return round(h)

def find_N(h, L, d, n=1):
    return h*(d+1) + (L-1)*h*(h+1) + n*(h+1)

In [None]:
P = 1000
d = 50
L = 2

N = 20000

h = find_h(N, L, d)

In [None]:
N/P, h

In [None]:
init_fn, apply_fn, kernel_fn_inf = stax.serial(
    *[stax.Dense(h), stax.Erf()]*L,
    stax.Dense(1)
)

In [None]:
trainkey, testkey, kernelkey = random.split(random.PRNGKey(1), num=3)
x_train = random.normal(trainkey, (P, d))
x_test = random.normal(testkey, (P, d))

# project to hyper-sphere of radius sqrt(d)
x_train = np.sqrt(d) * x_train / np.linalg.norm(x_train, axis = 1, keepdims=True)
x_test = np.sqrt(d) * x_test / np.linalg.norm(x_test, axis = 1, keepdims=True)

In [None]:
_, init_params = init_fn(kernelkey, x_train.shape)

In [None]:
N_actual = sum(sum([[params.size for params in layer] for layer in init_params], []))
N, N_actual

In [None]:
kernel_fn = nt.empirical_kernel_fn(apply_fn)

In [None]:
gram_ntk = kernel_fn(x_train, x_train, init_params, get='ntk')

In [None]:
eigs = np.linalg.eigvalsh(gram_ntk[:,:, 0, 0])

In [None]:
plt.figure()
hist = plt.hist(eigs[:-4], 100)
plt.axvline(0, color='k', linestyle=':')
plt.ylim(0,100)

plt.figure()
loghist = plt.hist(np.log(eigs), 100)

# Do linearized neural networks exhibit jamming?

In [None]:
force = lambda f, y: 1/2 - f*y
loss = lambda fx, y_hat: np.mean(1/2 * np.maximum(0, force(fx, y_hat))**2)

In [None]:
y_train = random.bernoulli(trainkey, p=.5, shape=(P,1))*2 - 1

In [None]:
g_dd = kernel_fn(x_train, x_train, init_params, get='ntk')

In [None]:
predict_fn = nt.predict.gradient_descent(g_dd, y_train, loss)
# predict_fn = nt.predict.gradient_descent_mse(g_dd, y_train)

In [None]:
train_time = 5e4

fx_train_initial = apply_fn(init_params, x_train)

fx_train_final = predict_fn(
    train_time, 
    fx_train_initial
)
loss_val = loss(fx_train_final, y_train)
loss_val

In [None]:
plt.scatter(y_train, fx_train_final)

In [None]:
forces = np.ravel(force(fx_train_final, y_train))

hist = plt.hist(forces, 50)

In [None]:
N_del = sum(forces > 0)
N_del

In [None]:
overlaps = forces[forces > 0]
gaps = forces[forces <= 0]

len(overlaps), len(gaps)

## Training loop

In [None]:
from tqdm import notebook as tqdm

In [None]:
losses = []
N_dels = []
overlaps = []
gaps = []
Ns = []

for hi in tqdm.trange(h, 0, -1):
    Ni = find_N(hi, L, d)
    Ns.append(Ni)

    init_fn, apply_fn, kernel_fn_inf = stax.serial(
        *[stax.Dense(hi), stax.Erf()]*L,
        stax.Dense(1)
    )
    _, init_params = init_fn(kernelkey, x_train.shape)

    kernel_fn = nt.empirical_kernel_fn(apply_fn)

    g_dd = kernel_fn(x_train, x_train, init_params, get='ntk')

    predict_fn = nt.predict.gradient_descent(g_dd, y_train, loss)
#     predict_fn = nt.predict.gradient_descent_mse(g_dd, y_train)



    fx_train_initial = apply_fn(init_params, x_train)
    fx_train_final = predict_fn(
        train_time, 
        fx_train_initial
    )

    loss_val = loss(fx_train_final, y_train)
    forces = np.ravel(force(fx_train_final, y_train))

    N_del = sum(forces > 0)
    overlap = forces[forces > 0]
    gap = forces[forces <= 0]
    
    losses.append(loss_val)
    N_dels.append(N_del)
    overlaps.append(overlap)
    gaps.append(gap)

In [None]:
plt.hist(forces)

In [None]:
losses = np.array(losses)
N_dels = np.array(N_dels)
Ns = np.array(Ns)
overlaps = [np.array(overlap) for overlap in overlaps]
gaps = [np.array(gap) for gap in gaps]

In [None]:
N_dels

In [None]:
Ns

In [None]:
plt.scatter((P/Ns)[:-10], (N_dels/Ns)[:-10])
plt.xlabel('$P/N$')
plt.ylabel('$N_\Delta/N$')

In [None]:
plt.scatter((losses), (N_dels/Ns))
plt.xlabel('$\mathcal{L}$')
plt.ylabel('$N_\Delta/N$')
plt.xlim(0,.06)
plt.ylim(0,1)

In [None]:
plt.scatter(P/Ns, losses)
plt.xlabel('P/N')
plt.ylabel('$\mathcal{L}$')
plt.axvline(1, color='k', linestyle =':', label = 'P/N = 1')
plt.legend()

In [None]:
plt.scatter(P/Ns, np.array([len(o) for o in overlaps])/Ns)
plt.xlabel('$P/N$')
plt.ylabel('$\Delta^+/N$')

In [None]:
plt.scatter(P/Ns, np.array([len(g) for g in gaps])/Ns)
plt.xlabel('$P/N$')
plt.ylabel('$\Delta^-/N$')

Hessian?

In [None]:
from jax import jacfwd

In [None]:
H = jacfwd(apply_fn)(init_params, x_train)

In [None]:
[[inner.shape for inner in outer] for outer in H]

In [None]:
[[inner.shape for inner in outer] for outer in init_params]

In [None]:
expansion = nt.taylor_expand(apply_fn, init_params, 2)

In [None]:
expansion(init_params, x_train)