In [None]:
%load_ext line_profiler
%load_ext autoreload

import numpy as np

import tensorflow as tf
import neural_tangents as nt
from neural_tangents import stax

from jax.config import config; config.update("jax_enable_x64", True)
import jax.numpy as jnp
from jax import random, jit

from matplotlib import pyplot as plt
%matplotlib inline
%config InlineBackend.figure_format = 'retina'

Prepare data

In [None]:
# Load data from https://www.openml.org/d/554
from sklearn.datasets import fetch_openml
X_raw, y_raw = fetch_openml('mnist_784', version=1, return_X_y=True)

In [None]:
P = 2500 #train
P_total = int(1.5*P)

X = X_raw[:P_total]
y = (2*(y_raw.astype(int) % 2) - 1)[:P_total].reshape(-1,1)

In [None]:
from sklearn.model_selection import train_test_split
X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=1-P/P_total, random_state=42)
len(X_train)

In [None]:
from sklearn.decomposition import PCA
n_components = 10
pca = PCA(n_components = n_components)
pca = pca.fit(X_train)

In [None]:
X_train = pca.transform(X_train)
X_test = pca.transform(X_test)

# project to hyper-sphere of radius sqrt(n_components)
X_train = np.sqrt(n_components) * X_train / np.linalg.norm(X_train, axis = 1, keepdims=True)
X_test = np.sqrt(n_components) * X_test / np.linalg.norm(X_test, axis = 1, keepdims=True)

Prepare network

In [None]:
d = n_components
N = 2*P

In [None]:
init_fn, apply_fn, kernel_fn_inf = stax.serial(
    stax.Dense(N), stax.Erf(),
)

In [None]:
initkey = random.PRNGKey(123)
_, init_params = init_fn(initkey, X_train.shape)

In [None]:
apply_fn = jit(apply_fn)

train_features = apply_fn(init_params, X_train).reshape(len(X_train), -1)
test_features = apply_fn(init_params, X_test).reshape(len(X_test), -1)

In [None]:
"""
    Minimizing squared hinge loss with small regularization on the weights. 
    This gives us an L2-regularized L2-loss SVM:
        https://www.csie.ntu.edu.tw/~cjlin/papers/liblinear.pdf#equation.A.2
"""
import liblinear
import liblinearutil
import os

lamb = 1e-8
C = 1/(P*lamb)

#Primal (-s 2) is faster in our case
model = liblinearutil.train(y_train.reshape(-1), np.array(train_features), f'-s 2 -n {os.cpu_count()} -c {C}')

In [None]:
p_label, p_acc, p_val = liblinearutil.predict(y_train.reshape(-1), np.array(train_features), model, '-q')

In [None]:
hist = plt.hist(np.ravel(p_val), bins = 100)
plt.title(f'Train: {p_acc[0]:.0f}% accuracy')

In [None]:
p_label, p_acc, p_val = liblinearutil.predict(y_test.reshape(-1), np.array(test_features), model, '-q')

In [None]:
hist = plt.hist(np.ravel(p_val), bins = 100)
plt.title(f'Test: {p_acc[0]:.0f}% accuracy')

## Train loop

In [None]:
from tqdm import notebook as tqdm

In [None]:
results = []

for sqrt_N in tqdm.tqdm(np.arange(np.sqrt(N), 0, -1)):
    Ni = int(sqrt_N**2) #Quadratic, as if we were changing h in the NTK case
    
    init_fn, apply_fn, kernel_fn_inf = stax.serial(
        stax.Dense(Ni), stax.Erf(),
    )
    
    apply_fn = jit(apply_fn)
    _, init_params = init_fn(initkey, X_train.shape)
    
    train_features = apply_fn(init_params, X_train).reshape(len(X_train), -1)
    test_features = apply_fn(init_params, X_test).reshape(len(X_test), -1)
    
    for lamb in np.logspace(-3,-15, 5):
        C = 1/(P*lamb)
        model = liblinearutil.train(y_train.reshape(-1), np.array(train_features), f'-s 2 -n {os.cpu_count()} -c {C}')

        train_p_label, train_p_acc, train_p_val = liblinearutil.predict(y_train.reshape(-1), np.array(train_features), model, '-q')
        test_p_label, test_p_acc, test_p_val = liblinearutil.predict(y_test.reshape(-1), np.array(test_features), model, '-q')
        
        result = {
            'N': Ni,
            'P': P,
            'lambda': lamb,
            'y_train': y_train.reshape(-1),
            'y_train_hat': np.array(train_p_val).reshape(-1),
            'y_test': y_test.reshape(-1),
            'y_test_hat': np.array(test_p_val).reshape(-1)
        }
        
        results.append(result)

In [None]:
import pandas as pd

In [None]:
# result_df = pd.DataFrame(results)
# result_df.to_json(open('results/mnist_hinge_CK.json', 'w'))

In [None]:
result_df = pd.read_json(open('results/mnist_hinge_CK.json', 'r'))

In [None]:
result_df

In [None]:
force = lambda y,f: 1 - y*f
loss = lambda y,f: np.mean(np.maximum(0, force(y,f))**2)
N_del = lambda y,f: np.sum(force(y,f) >= 0)
N_correct = lambda y,f: np.sum(y*f > 0)
N_incorrect = lambda y,f: np.sum(y*f < 0)

In [None]:
result_df['test_loss'] = np.vectorize(loss)(result_df.y_test, result_df.y_test_hat)
result_df['train_loss'] = np.vectorize(loss)(result_df.y_train, result_df.y_train_hat)
result_df['N_del'] = np.vectorize(N_del)(result_df.y_train, result_df.y_train_hat)

result_df['N/P'] = result_df['N']/result_df['P']
result_df['P/N'] = result_df['P']/result_df['N']
result_df['N_del/P'] = result_df['N_del']/result_df['P']
result_df['N_del/N'] = result_df['N_del']/result_df['N']


In [None]:
from matplotlib.colors import LinearSegmentedColormap

cmap = LinearSegmentedColormap.from_list(
    'Mei2019', 
    np.array([
        (243, 232, 29),
        (245, 173, 47),
        (140, 193, 53),
        (50,  191, 133),
        (23,  167, 198),
        (36,  123, 235),
        (53,  69,  252),
        (52,  27,  203)
    ])/255., 
)

gradient = np.linspace(0, 1, 256)
gradient = np.vstack((gradient, gradient))
fig = plt.figure(figsize=(6,.5))
img = plt.imshow(gradient, aspect='auto', cmap=cmap)
title = plt.title('Colormap stolen from Mei2019')

In [None]:
fig, ax = plt.subplots(figsize=(10,7))
n_colors = len(result_df['lambda'].unique())

#Test
for i, (lamb, df) in enumerate(result_df.groupby('lambda')):
    df.plot('N/P', 'test_loss', label='$\log_{10}(\lambda)$='+f'{np.log10(lamb):.1f}', ax=ax, color=cmap(float(i)/(n_colors-1)))

#Train
for i, (lamb, df) in enumerate(result_df.groupby('lambda')):
    df.plot('N/P', 'train_loss', label='$\log_{10}(\lambda)$='+f'{np.log10(lamb):.1f}', ax=ax, color=cmap(float(i)/(n_colors-1)), ls=':')

ax.legend(ncol=2, title='Test:                           Train:', title_fontsize=11)
ax.set_ylabel('Loss')
ax.grid(True, ls='--', alpha=.3)
ax.set_title('Hinge loss on MNIST with CK features')

fig.savefig('mnist_hinge_loss_ck_features_regularized.png', dpi=300)

In [None]:
fig, ax = plt.subplots(figsize=(10,7))
n_colors = len(result_df['lambda'].unique())

#Test
for i, (lamb, df) in enumerate(result_df.groupby('lambda')):
    df.plot('N_del/P', 'test_loss', label='$\log_{10}(\lambda)$='+f'{np.log10(lamb):.1f}', ax=ax, color=cmap(float(i)/(n_colors-1)), kind='scatter')

#Train
for i, (lamb, df) in enumerate(result_df.groupby('lambda')):
    df.plot('N_del/P', 'train_loss', label='$\log_{10}(\lambda)$='+f'{np.log10(lamb):.1f}', ax=ax, color=cmap(float(i)/(n_colors-1)), marker='x', kind='scatter')

ax.legend(ncol=2, title='Test:                           Train:', title_fontsize=11)
ax.set_ylabel('Loss')
ax.grid(True, ls='--', alpha=.3)
ax.set_title('Hinge loss on MNIST with CK features')

# fig.savefig('mnist_hinge_loss_ck_features_regularized.png', dpi=300)

In [None]:
fig, ax = plt.subplots(figsize=(10,7))
n_colors = len(result_df['lambda'].unique())

#Test
for i, (lamb, df) in enumerate(result_df.groupby('lambda')):
    df.plot('N_del/N', 'test_loss', label='$\log_{10}(\lambda)$='+f'{np.log10(lamb):.1f}', ax=ax, color=cmap(float(i)/(n_colors-1)), kind='scatter')

#Train
for i, (lamb, df) in enumerate(result_df.groupby('lambda')):
    df.plot('N_del/N', 'train_loss', label='$\log_{10}(\lambda)$='+f'{np.log10(lamb):.1f}', ax=ax, color=cmap(float(i)/(n_colors-1)), marker='x', kind='scatter')

ax.legend(ncol=2, title='Test:                           Train:', title_fontsize=11)
ax.set_ylabel('Loss')
ax.grid(True, ls='--', alpha=.3)
ax.set_title('Hinge loss on MNIST with CK features')
plt.xscale('log')
# fig.savefig('mnist_hinge_loss_ck_features_regularized.png', dpi=300)

In [None]:
fig, ax = plt.subplots(figsize=(10,7))
n_colors = len(result_df['lambda'].unique())

for i, (lamb, df) in enumerate(result_df.groupby('lambda')):
    df.plot('train_loss', 'N_del/N', label='$\log_{10}(\lambda)$='+f'{np.log10(lamb):.1f}', ax=ax, color=cmap(float(i)/(n_colors-1)), marker='x', kind='scatter')

ax.legend()
ax.grid(True, ls='--', alpha=.3)
ax.set_title('Hinge loss on MNIST with CK features')
plt.yscale('log')
# fig.savefig('mnist_hinge_loss_ck_features_regularized.png', dpi=300)

In [None]:
fig, ax = plt.subplots(figsize=(10,7))
n_colors = len(result_df['lambda'].unique())

for i, (lamb, df) in enumerate(result_df.groupby('lambda')):
    df.plot('P/N', 'N_del/N', label='$\log_{10}(\lambda)$='+f'{np.log10(lamb):.1f}', ax=ax, color=cmap(float(i)/(n_colors-1)), kind='scatter')

ax.legend()
ax.grid(True, ls='--', alpha=.3)
ax.set_title('Hinge loss on MNIST with CK features')
plt.xscale('log')
plt.yscale('log')
# fig.savefig('mnist_hinge_loss_ck_features_regularized.png', dpi=300)

In [None]:
fig, ax = plt.subplots(figsize=(10,7))
n_colors = len(result_df['lambda'].unique())

for i, (lamb, df) in enumerate(result_df.groupby('lambda')):
    df.plot('P/N', 'N_del/N', label='$\log_{10}(\lambda)$='+f'{np.log10(lamb):.1f}', ax=ax, color=cmap(float(i)/(n_colors-1)), kind='scatter')
# for i, (lamb, df) in enumerate(result_df.groupby('lambda')):
#     df.plot('P/N', 'test_loss', label='$\log_{10}(\lambda)$='+f'{np.log10(lamb):.1f}', ax=ax, color=cmap(float(i)/(n_colors-1)), kind='scatter')

ax.legend()
ax.grid(True, ls='--', alpha=.3)
ax.set_title('Hinge loss on MNIST with CK features')
plt.yscale('log')
plt.xscale('log')

# plt.xlim(.5,10)
# plt.ylim(.05,7)
# fig.savefig('mnist_hinge_loss_ck_features_regularized.png', dpi=300)

In [None]:
fig, ax = plt.subplots(figsize=(10,7))
n_colors = len(result_df['lambda'].unique())

#Test
for i, (lamb, df) in enumerate(result_df.groupby('lambda')):
    df.plot('P/N', 'test_loss', label='$\log_{10}(\lambda)$='+f'{np.log10(lamb):.1f}', ax=ax, color=cmap(float(i)/(n_colors-1)), kind='scatter')

#Train
for i, (lamb, df) in enumerate(result_df.groupby('lambda')):
    df.plot('P/N', 'train_loss', label='$\log_{10}(\lambda)$='+f'{np.log10(lamb):.1f}', ax=ax, color=cmap(float(i)/(n_colors-1)), marker='x', kind='scatter')

ax.legend(ncol=2, title='Test:                           Train:', title_fontsize=11)
# ax.set_ylabel('Loss')
ax.grid(True, ls='--', alpha=.3)
ax.set_title('Hinge loss on MNIST with CK features')
plt.xscale('log')
# fig.savefig('mnist_hinge_loss_ck_features_regularized.png', dpi=300)

In [None]:
fig, ax = plt.subplots(figsize=(10,7))
n_colors = len(result_df['lambda'].unique())

#Test
for i, (lamb, df) in enumerate(result_df.groupby('lambda')):
    df.plot('N_del/N', 'N_del/P', label='$\log_{10}(\lambda)$='+f'{np.log10(lamb):.1f}', ax=ax, color=cmap(float(i)/(n_colors-1)), kind='scatter')

#Train
for i, (lamb, df) in enumerate(result_df.groupby('lambda')):
    df.plot('N_del/N', 'N_del/P', label='$\log_{10}(\lambda)$='+f'{np.log10(lamb):.1f}', ax=ax, color=cmap(float(i)/(n_colors-1)), marker='x', kind='scatter')

ax.legend(ncol=2, title='Test:                           Train:', title_fontsize=11)
# ax.set_ylabel('Loss')
ax.grid(True, ls='--', alpha=.3)
ax.set_title('Hinge loss on MNIST with CK features')
plt.xscale('log')
# fig.savefig('mnist_hinge_loss_ck_features_regularized.png', dpi=300)