In [None]:
import tensorflow as tf
import neural_tangents as nt
from neural_tangents import stax
import jax.numpy as np
from jax import random

from matplotlib import pyplot as plt
%matplotlib inline
%config InlineBackend.figure_format = 'retina'

In [None]:
def find_h(N, L, d, n=1, bias=True):
    # Modified from https://github.com/mariogeiger/nn_jamming/blob/master/constN.py
    # TODO: bias=False?
    '''
        For a network with: 
        
        d input dimensionality, 
        L layers, 
        N total parameters, 
        n final outputs,
        
        this finds the corresponding width h 
    '''
    assert L >= 1

    if L == 1:
        # solve : N = h*(d+1) + n*(h+1)
        h = (N - n) / (d + n + 1)
    else:
        # solve : N = h*(d+1) + (L-1)*h*(h+1) + n*(h+1)
        h = -(d+L+n - ((d+L+n)**2 + 4*(L-1)*(N-n))**.5)/(2*(L-1))
    return round(h)

def find_N(h, L, d, n=1):
    return h*(d+1) + (L-1)*h*(h+1) + n*(h+1)

Prepare data

In [None]:
# Load data from https://www.openml.org/d/554
from sklearn.datasets import fetch_openml
X_raw, y_raw = fetch_openml('mnist_784', version=1, return_X_y=True)

In [None]:
P = 1000 #train
P_total = int(1.25*P)

X = X_raw[:P_total]
y = (2*(y_raw.astype(int) % 2) - 1)[:P_total].reshape(-1,1)

In [None]:
from sklearn.model_selection import train_test_split
X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=1-P/P_total, random_state=42)
len(X_train)

In [None]:
from sklearn.decomposition import PCA
n_components = 20
pca = PCA(n_components = n_components)
pca = pca.fit(X_train)

In [None]:
X_train = pca.transform(X_train)
X_test = pca.transform(X_test)

# project to hyper-sphere of radius sqrt(n_components)
X_train = np.sqrt(n_components) * X_train / np.linalg.norm(X_train, axis = 1, keepdims=True)
X_test = np.sqrt(n_components) * X_test / np.linalg.norm(X_test, axis = 1, keepdims=True)

Prepare network

In [None]:
d = n_components
L = 2
N = 20000

h = find_h(N, L, d)
print(h)

In [None]:
init_fn, apply_fn, kernel_fn_inf = stax.serial(
    *[stax.Dense(h), stax.Erf()]*L,
    stax.Dense(1)
)

In [None]:
kernelkey = random.PRNGKey(42)
_, init_params = init_fn(kernelkey, X_train.shape)

In [None]:
kernel_fn = nt.empirical_kernel_fn(apply_fn)

In [None]:
gram_ntk = kernel_fn(X_train, X_train, init_params, get='ntk')

In [None]:
eigs = np.linalg.eigvalsh(gram_ntk[:,:, 0, 0])

In [None]:
plt.figure()
hist = plt.hist(eigs[:-4], 100)
plt.axvline(0, color='k', linestyle=':')
plt.ylim(0,100)

plt.figure()
loghist = plt.hist(np.log(eigs), 100)

# Do linearized neural networks exhibit jamming?

Fisrt, let's look at the results for network width going to infinity

In [None]:
force = lambda f, y: 1/2 - f*y
loss = lambda fx, y_hat: np.mean(1/2 * np.maximum(0, force(fx, y_hat))**2)

In [None]:
# g_dd = kernel_fn(X_train, X_train, init_params, get='ntk')
# g_td = kernel_fn(X_test, X_train, init_params, get='ntk')

g_dd = kernel_fn_inf(X_train, X_train, get='ntk')
g_td = kernel_fn_inf(X_test, X_train, get='ntk')

In [None]:
predict_fn = nt.predict.gradient_descent(g_dd, y_train, loss, g_td)

In [None]:
train_time = 5e4

fx_train_initial = apply_fn(init_params, X_train)
fx_test_initial = apply_fn(init_params, X_test)

fx_train_final, fx_test_final = predict_fn(
    train_time, 
    fx_train_initial, fx_test_initial
)
train_loss_inf = loss(fx_train_final, y_train)
test_loss_inf = loss(fx_test_final, y_test)

train_loss_inf, test_loss_inf

In [None]:
print(train_loss_inf)

In [None]:
print(test_loss_inf)

In [None]:
plt.scatter(y_train, fx_train_final)

In [None]:
plt.scatter(y_test, fx_test_final)

## Training loop

In [None]:
from tqdm import notebook as tqdm

In [None]:
train_losses = []
test_losses = []
train_forces = []
test_forces = []

Ns = []

for hi in tqdm.trange(h, 0, -1):
    Ni = find_N(hi, L, d)
    Ns.append(Ni)

    init_fn, apply_fn, kernel_fn_inf = stax.serial(
        *[stax.Dense(hi), stax.Erf()]*L,
        stax.Dense(1)
    )
    _, init_params = init_fn(kernelkey, X_train.shape)

    kernel_fn = nt.empirical_kernel_fn(apply_fn)

    g_dd = kernel_fn(X_train, X_train, init_params, get='ntk')
    g_td = kernel_fn(X_test, X_train, init_params, get='ntk')

    predict_fn = nt.predict.gradient_descent(g_dd, y_train, loss, g_td)


    fx_train_initial = apply_fn(init_params, X_train)
    fx_test_initial = apply_fn(init_params, X_test)

    fx_train_final, fx_test_final = predict_fn(
        train_time, 
        fx_train_initial, fx_test_initial
    )
    
    train_loss = loss(fx_train_final, y_train)
    test_loss = loss(fx_test_final, y_test)
    
    train_force = np.ravel(force(fx_train_final, y_train))
    test_force = np.ravel(force(fx_test_final, y_test))

    train_losses.append(train_loss)
    test_losses.append(test_loss)
    train_forces.append(train_force)
    test_forces.append(test_forces)

Save data

In [None]:
# import pickle
# data = {
#     'train_losses': train_losses,
#     'test_losses': test_losses,
#     'train_forces': train_forces,
#     'test_forces': test_forces,
#     'Ns': Ns,
#     'train_loss_inf': train_loss_inf,
#     'test_loss_inf': test_loss_inf
# }
# pickle.dump(data, open('data.pkl', 'wb'))

Plot

In [None]:
plt.figure(figsize=(10,7))
plt.plot(np.array(Ns)/P, train_losses, label='train')
plt.axhline(train_loss_inf, label='train, infinite network', color='blue', ls=':')
plt.plot(np.array(Ns)/P, test_losses, label='test')
plt.axhline(test_loss_inf, label='test, infinite network', color='orange', ls=':')
plt.ylim(-.01,None)
plt.legend()
plt.xlabel(r'$N/P$')
plt.ylabel(r'$\mathcal{L}$')
plt.title('Training and testing loss as a function of $N/P$ for an NTK-linearized network')

In [None]:
plt.figure(figsize=(10,7))
plt.plot(np.log(np.array(Ns)/P), train_losses, label='train')
plt.axhline(train_loss_inf, label='train, infinite network', color='blue', ls=':')
plt.plot(np.log(np.array(Ns)/P), test_losses, label='test')
plt.axhline(test_loss_inf, label='test, infinite network', color='orange', ls=':')
plt.ylim(-.01,None)
plt.legend()
plt.xlabel(r'log $N/P$')
plt.ylabel(r'$\mathcal{L}$')
plt.title('Training and testing loss as a function of $N/P$ for an NTK-linearized network')

In [None]:
N_dels = []
EPS = 0
for forces in train_forces:
    N_del_i = sum(forces > -EPS)
    N_dels.append(N_del_i)

In [None]:
plt.scatter((np.array(train_losses)), (np.array(N_dels)/np.array(Ns)))
plt.xlabel('$\mathcal{L}$')
plt.ylabel('$N_\Delta/N$')
# plt.xlim(0,.08)
# plt.ylim(0,1)


In [None]:
plt.scatter(np.log(np.array(train_losses)), np.log(np.array(N_dels)/np.array(Ns)))
plt.xlabel('$log \mathcal{L}$')
plt.ylabel('$\log \ N_\Delta/N$')
# plt.xlim(0,.08)
# plt.ylim(0,1)


In [None]:
plt.scatter((train_losses), np.log(np.array(N_dels)/np.array(Ns)))
plt.xlabel('$\mathcal{L}$')
plt.ylabel('$log  N_\Delta/N$')
# plt.xlim(0,.08)
# plt.ylim(0,1)


In [None]:
plt.scatter(np.log(np.array(train_losses)), (np.array(N_dels)/np.array(Ns)))
plt.xlabel('$log \mathcal{L}$')
plt.ylabel('$N_\Delta/N$')
# plt.xlim(0,.08)
# plt.ylim(0,1)


In [None]:
plt.scatter((P/np.array(Ns))[:-5], (np.array(N_dels)/np.array(Ns))[:-5])
plt.xlabel('$P/N$')
plt.ylabel('$N_\Delta/N$')