In [None]:
import pickle
import pandas as pd
import numpy as np

from matplotlib import pyplot as plt
%matplotlib inline
%config InlineBackend.figure_format = 'retina'

In [None]:
def find_h(N, L, d, n=1, bias=False):
    # Modified from https://github.com/mariogeiger/nn_jamming/blob/master/constN.py
    '''
        For a network with: 
        
        d input dimensionality, 
        L layers, 
        N total parameters, 
        n final outputs,
        
        this finds the corresponding width h 
    '''
    assert np.all(L >= 1)

    if bias:
        # solve : N = h*(d+1) + (L-1)*h*(h+1) + n*(h+1)
        h = -(d+L+n - ((d+L+n)**2 + 4*(L-1)*(N-n))**.5)/(2*(L-1))
    else:
        # solve: N = h*d + (L-1)*h*h + n*h
        h = -((n+d) - ((n+d)**2 + 4*(L-1)*N)**.5)/(2*(L-1))
        
    return round(h)

In [None]:
X_train, X_test, y_train, y_test = pickle.load(open('../L=1/mnist.pkl', 'rb'))

In [None]:
results = pickle.load(open('results.pkl', 'rb'))
result_df = pd.DataFrame.from_dict(results)
result_df['h'] = find_h(result_df['N'], result_df['L'], result_df['d'])

force = lambda y,f: 1 - y*f
loss = lambda y,f: np.mean(np.maximum(0, force(y,f))**2, -1)
N_del = lambda y,f: np.sum(force(y,f) >= 0, -1)

result_df['test_loss'] = result_df.y_test_hat.apply(lambda f: loss(y_test, f))
result_df['train_loss'] = result_df.y_train_hat.apply(lambda f: loss(y_train, f))
result_df['N_del'] = result_df.y_train_hat.apply(lambda f: N_del(y_train, f))

result_df['N/P'] = result_df['N']/result_df['P']
result_df['P/N'] = result_df['P']/result_df['N']
result_df['N_del/P'] = result_df['N_del']/result_df['P']
result_df['N_del/N'] = result_df['N_del']/result_df['N']

result_df['N_del/h'] = result_df['N_del']/result_df['h']

In [None]:
from matplotlib import colors as mcolors
from matplotlib.colors import LinearSegmentedColormap

cmap = LinearSegmentedColormap.from_list(
    'Mei2019', 
    np.array([
        (243, 232, 29),
        (245, 173, 47),
        (140, 193, 53),
        (50,  191, 133),
        (23,  167, 198),
        (36,  123, 235),
        (53,  69,  252),
        (52,  27,  203)
    ])/255., 
    N=256
)

# cmap = cc.m_bmy

gradient = np.linspace(0, 1, 256)
gradient = np.vstack((gradient, gradient))
fig = plt.figure(figsize=(6,.5))
img = plt.imshow(gradient, aspect='auto', cmap=cmap)
title = plt.title('Colormap stolen from Mei2019')

norm=mcolors.LogNorm()

In [None]:
plt.figure(figsize=(9,6))
data = result_df

#invisible plot to set the limits correctly because matplotlib gets confused with log scale scatters
x = data['N_del']/data['N']
y = data['train_loss']

plt.plot(x, y, color='none')
plt.scatter(x, y, c=data['step'], cmap=cmap, norm=norm, alpha=.7)
plt.colorbar(label='Training steps')

plt.xlabel(r'$N_\Delta/N$')
plt.ylabel(r'Train $\mathcal{L}$')


plt.yscale('log')
plt.xscale('log')
plt.axvline(1, color='k',ls=':')

In [None]:
plt.figure(figsize=(9,6))
data = result_df

#invisible plot to set the limits correctly because matplotlib gets confused with log scale scatters
x = data['N_del']/data['h']
y = data['train_loss']

plt.plot(x, y, color='none')
plt.scatter(x, y, c=data['step'], cmap=cmap, norm=norm, alpha=.7)
plt.colorbar(label='Training steps')

plt.xlabel(r'$N_\Delta/h$')
plt.ylabel(r'Train $\mathcal{L}$')


plt.yscale('log')
plt.xscale('log')
plt.axvline(1, color='k',ls=':')

In [None]:
plt.figure(figsize=(9,6))
data = result_df.sort_values('step')[::10]

#invisible plot to set the limits correctly because matplotlib gets confused with log scale scatters
x = data['N_del']/data['N']
y = data['test_loss']

plt.plot(x, y, color='none')
plt.scatter(x, y, c=data['step'], cmap=cmap, norm=norm, alpha=.7)
plt.colorbar(label='Training steps')

plt.xlabel(r'$N_\Delta/N$')
plt.ylabel(r'Test $\mathcal{L}$')


plt.yscale('log')
plt.xscale('log')
plt.axvline(1, color='k',ls=':')

In [None]:
plt.figure(figsize=(9,6))
data = result_df.sort_values('step')

#invisible plot to set the limits correctly because matplotlib gets confused with log scale scatters
x = data['N_del']/data['h']
y = data['test_loss']

plt.plot(x, y, color='none')
plt.scatter(x, y, c=data['step'], cmap=cmap, norm=norm, alpha=.7)
plt.colorbar(label='Training steps')

plt.xlabel(r'$N_\Delta/h$')
plt.ylabel(r'Test $\mathcal{L}$')


plt.yscale('log')
plt.xscale('log')
plt.axvline(1, color='k',ls=':')

In [None]:
plt.figure(figsize=(9,6))
data = result_df

#invisible plot to set the limits correctly because matplotlib gets confused with log scale scatters
x = data['P']/data['h']
y = data['N_del']/data['h']

plt.plot(x, y, color='none')
plt.scatter(x, y, c=data['step'], cmap=cmap, norm=norm, alpha=.7)
plt.colorbar(label='Training steps')

plt.xlabel('P/h')
plt.ylabel(r'$N_\Delta/h$')


plt.yscale('log')
plt.xscale('log')
plt.axhline(1, color='k',ls=':')

In [None]:
plt.figure(figsize=(9,6))
data = result_df

#invisible plot to set the limits correctly because matplotlib gets confused with log scale scatters
x = data['P']/data['N']
y = data['N_del']/data['N']

plt.plot(x, y, color='none')
plt.scatter(x, y, c=data['step'], cmap=cmap, norm=norm, alpha=.7)
plt.colorbar(label='Training steps')

plt.xlabel('P/N')
plt.ylabel(r'$N_\Delta/N$')


plt.yscale('log')
plt.xscale('log')
plt.axhline(1, color='k',ls=':')

In [None]:
plt.figure(figsize=(9,6))
data = result_df.query('step > 1e5')

#invisible plot to set the limits correctly because matplotlib gets confused with log scale scatters
x = data['P']/data['N']
y = data['N_del']/data['N']

plt.plot(x, y, color='none')
plt.scatter(x, y, c=data['step'], cmap=cmap, norm=norm, alpha=.7)
plt.colorbar(label='Training steps')

plt.xlabel('P/N')
plt.ylabel(r'$N_\Delta/N$')


plt.yscale('symlog', linthreshy=1.0)
plt.xscale('log')
plt.axhline(1, color='k',ls=':')
plt.xlim(1, 500)
plt.ylim(0, 200)

In [None]:
plt.figure(figsize=(9,6))
data = result_df.query('step < 5')

#invisible plot to set the limits correctly because matplotlib gets confused with log scale scatters
x = data['P']/data['h']
y = data['N_del']/data['h']

plt.plot(x, y, color='none')
plt.scatter(x, y, c=data['step'], cmap=cmap, norm=norm, alpha=.7)
plt.colorbar(label='Training steps')

plt.xlabel('P/h')
plt.ylabel(r'$N_\Delta/h$')


plt.yscale('symlog', linthreshy=1.0)
plt.xscale('log')
plt.axhline(1, color='k',ls=':')
plt.xlim(1, 500)
plt.ylim(0, 200)

In [None]:
plt.figure(figsize=(9,6))
data = result_df.query('step < 5')

#invisible plot to set the limits correctly because matplotlib gets confused with log scale scatters
x = data['P']/data['h']
y = data['N_del']/data['h']

plt.plot(x, y, color='none')
plt.scatter(x, y, c=data['step'], cmap=cmap, norm=norm, alpha=.7)
plt.colorbar(label='Training steps')


data = result_df.query('step > 1e5')

#invisible plot to set the limits correctly because matplotlib gets confused with log scale scatters
x = data['P']/data['N']
y = data['N_del']/data['N']

plt.plot(x, y, color='none')
plt.scatter(x, y, c=data['step'], cmap=cmap, norm=norm, alpha=.7)


# plt.xlabel('P/h')
# plt.ylabel(r'$N_\Delta/h$')


plt.yscale('symlog', linthreshy=1.0)
plt.xscale('log')
plt.axhline(1, color='k',ls=':')
plt.xlim(1, 500)
plt.ylim(0, 200)

In [None]:
plt.figure(figsize=(9,6))
data = result_df.query('step < 5')

#invisible plot to set the limits correctly because matplotlib gets confused with log scale scatters
x = data['P']/data['h']
y = data['train_loss']

plt.plot(x, y, color='none')
plt.scatter(x, y, c=data['step'], cmap=cmap, norm=norm, alpha=.7)
plt.colorbar(label='Training steps')


data = result_df.query('step > 1e5')

#invisible plot to set the limits correctly because matplotlib gets confused with log scale scatters
x = data['P']/data['N']
y = data['train_loss']

plt.plot(x, y, color='none')
plt.scatter(x, y, c=data['step'], cmap=cmap, norm=norm, alpha=.7)


# plt.xlabel('P/h')
# plt.ylabel(r'$N_\Delta/h$')


plt.yscale('symlog', linthreshy=1.0)
plt.xscale('log')
plt.axhline(1, color='k',ls=':')
plt.xlim(1, 500)
plt.ylim(0, 200)

In [None]:
# extremizing_rows = result_df.groupby('step').apply(lambda x: x.query('N_del/h >= 1').sort_values('N_del/h').iloc[0])
# N_star = extremizing_rows['N']
# result_df['N_star'] = result_df.step.map(N_star)

# extremizing_rows = result_df.groupby('step').apply(lambda x: x.query('(N_del/h <= 1.)').sort_values('test_loss').iloc[-1])
extremizing_rows = result_df.groupby('step').apply(lambda x: x.sort_values('test_loss').iloc[-1])
N_star = extremizing_rows['N']
result_df['N_star'] = result_df.step.map(N_star)

In [None]:
plt.plot(extremizing_rows.index, extremizing_rows['N_del/h'])
plt.xscale('log')
# plt.yscale('log')
plt.axhline(1, ls=':', c='k')

In [None]:
plt.scatter(extremizing_rows['N_star']/extremizing_rows['h']**2, extremizing_rows['P']/extremizing_rows['h']**2)
# plt.xscale('log')
# plt.yscale('log')
plt.axhline(1, ls=':', c='k')

In [None]:
plt.figure(figsize=(9,6))
data = result_df

#invisible plot to set the limits correctly because matplotlib gets confused with log scale scatters
x = data['N_star'] / data['N'] 
y = data['N_star'] * data['N_del'] / data['N']

plt.plot(x, y, color='none')
plt.scatter(x, y, c=data['step'], cmap=cmap, norm=norm, alpha=.7)
plt.colorbar(label='Training steps')

plt.xlabel(r'$N/N^*$')
plt.ylabel(r'Train $\mathcal{L}$')


plt.yscale('log')
plt.xscale('log')
plt.axvline(1, color='k',ls=':')

In [None]:
plt.figure(figsize=(9,6))
data = result_df

#invisible plot to set the limits correctly because matplotlib gets confused with log scale scatters
x = data['N_star'] / data['N'] 
y = data['N_del'] / data['h']

plt.plot(x, y, color='none')
plt.scatter(x, y, c=data['step'], cmap=cmap, norm=norm, alpha=.7)
plt.colorbar(label='Training steps')

plt.xlabel(r'$N/N^*$')
plt.ylabel(r'Train $\mathcal{L}$')


plt.yscale('log')
plt.xscale('log')
plt.axvline(1, color='k',ls=':')
plt.axhline(1, color='k',ls=':')

In [None]:
plt.figure(figsize=(9,6))
data = result_df.sort_values('step')[::-1]

#invisible plot to set the limits correctly because matplotlib gets confused with log scale scatters
x = data['N_star']/data['N']
y = data['test_loss']

plt.plot(x, y, color='none')
plt.scatter(x, y, c=data['step'], cmap=cmap, norm=norm, alpha=.7)
plt.colorbar(label='Training steps')

plt.xlabel(r'$N/N^*$')
plt.ylabel(r'Test $\mathcal{L}$')


plt.yscale('log')
plt.xscale('log')
plt.axvline(1, color='k',ls=':')

In [None]:
plt.plot(result_df['step'], result_df['N_star'], color='none')

plt.scatter(result_df['step'], result_df['N_star'])
plt.yscale('log')
plt.xscale('log')

# Hessian

In [None]:
untrained = result_df.query("step == @result_df['step'].min()")
trained = result_df.query("step == @result_df['step'].max()")

In [None]:
plt.scatter(untrained['h'], untrained['N_del'])
plt.scatter(trained['h'], trained['N_del'])
plt.yscale('log')
plt.xscale('log')

In [None]:
x = 'N_del/N'
y = 'train_loss'

plt.scatter(untrained[x], untrained[y])
plt.scatter(trained[x], trained[y])

for step in  sorted(result_df['step'].unique()):
    df = result_df.query("step == @step")
    # Row with minimum value of N_del/N where train loss is non-zero and N_del/N >= 1 (underparameterized)
    row = df.query('(train_loss > 5e-2) & (N_del/h >= 1)').sort_values('N_del/h').iloc[0]
    plt.scatter(row[x], row[y], c='k')
    
plt.yscale('log')
plt.xscale('log')

In [None]:
from matplotlib import cm
sm = cm.ScalarMappable(norm=norm, cmap=cmap)

steps = sorted(result_df['step'].unique(), reverse=False)
norm.autoscale(steps)

for i, step in enumerate(np.array(steps)): #[[0,25,  30, 50, -1]]): #[20::-1]): #[20::1]):
    df = result_df.query("step == @step")
    row = df.query('(train_loss > 5e-2) & (N_del/h >= 1)').sort_values('N_del/h').iloc[0]

    vals = np.sqrt(row.eigs0)
    hist, edges = np.histogram(np.log(vals), 'sturges', density=True)
    dx = np.mean(np.diff(edges))
    edges = np.concatenate((edges[[0]]-dx, edges[1:]/2 + edges[:-1]/2 , edges[[-1]]+dx))
    edges = np.exp(edges)
    hist = np.concatenate(([0], hist, [0]))

    plt.plot(edges, hist, c=sm.to_rgba(np.clip(step, norm.vmin, norm.vmax)), alpha=1.)
plt.colorbar(sm, label='Training Steps')

plt.xscale('symlog',linthreshx=1e-1)
plt.xlim(0, None)
plt.ylim(0, None)
plt.xlabel(r'$\sqrt{\lambda}$')
plt.ylabel(r'$P\left(\sqrt{\lambda}\right)$')
plt.title("Hessian Spectrum as a Function of Training Steps")

In [None]:
from matplotlib import cm
sm = cm.ScalarMappable(norm=norm, cmap=cmap, )

df = untrained.query('(N_del/h >= 1)')
losses = np.logspace(np.log10(max(5e-2, min(df.train_loss))), np.log(df.train_loss.max()))
norm.autoscale(losses)

last_loss = None
for i, loss in enumerate(losses[-5::-1]): 
    row = df.query('(train_loss >= @loss) ').sort_values('N_del/h').iloc[0]
    vals = np.sqrt(row.eigs0)
    hist, edges = np.histogram(np.log(vals), 'sturges', density=True)
    dx = np.mean(np.diff(edges))
    edges = np.concatenate((edges[[0]]-dx, edges[1:]/2 + edges[:-1]/2 , edges[[-1]]+dx))
    edges = np.exp(edges)
    hist = np.concatenate(([0], hist, [0]))
    
    plt.plot(edges, hist, c=sm.to_rgba(np.clip(row.train_loss, norm.vmin, norm.vmax)), alpha=1.)
plt.colorbar(sm, label='Train Loss')

plt.xscale('symlog',linthreshx=1e-1)
plt.xlim(0, None)
plt.ylim(0, None)
plt.xlabel(r'$\sqrt{\lambda}$')
plt.ylabel(r'$P\left(\sqrt{\lambda}\right)$')
plt.title("Hessian Spectrum as a Function of Train Loss\nRandom Features")

In [None]:
from matplotlib import cm
sm = cm.ScalarMappable(norm=norm, cmap=cmap, )

df = untrained.query('(N_del/h >= 1)').sort_values('N_del/h')
losses = df.train_loss.values
norm.autoscale(losses)

last_loss = None
for i, row in df[::-1].iterrows(): 
    vals = np.sqrt(row.eigs0)
    hist, edges = np.histogram(np.log(vals), 'sturges', density=True)
    dx = np.mean(np.diff(edges))
    edges = np.concatenate((edges[[0]]-dx, edges[1:]/2 + edges[:-1]/2 , edges[[-1]]+dx))
    edges = np.exp(edges)
    hist = np.concatenate(([0], hist, [0]))
    
    plt.plot(edges, hist, c=sm.to_rgba(np.clip(row.train_loss, norm.vmin, norm.vmax)), alpha=1.)
plt.colorbar(sm, label='Train Loss')

plt.xscale('symlog',linthreshx=1e-1)
plt.xlim(0, None)
plt.ylim(0, None)
plt.xlabel(r'$\sqrt{\lambda}$')
plt.ylabel(r'$P\left(\sqrt{\lambda}\right)$')
plt.title("Hessian Spectrum as a Function of Train Loss\nRandom Features")
plt.yscale('symlog', linthreshy=1)

In [None]:
from matplotlib import cm
sm = cm.ScalarMappable(norm=norm, cmap=cmap, )

df = trained.query('(N_del/h >= 1)')
losses = np.logspace(np.log10(max(5e-2, min(df.train_loss))), np.log(df.train_loss.max()))
# norm.autoscale(losses)

last_loss = None
for i, loss in enumerate(losses[-5::-1]): 
    row = df.query('(train_loss >= @loss) ').sort_values('N_del/h').iloc[0]
    vals = np.sqrt(row.eigs0)
    hist, edges = np.histogram(np.log(vals), 'sturges', density=True)
    dx = np.mean(np.diff(edges))
    edges = np.concatenate((edges[[0]]-dx, edges[1:]/2 + edges[:-1]/2 , edges[[-1]]+dx))
    edges = np.exp(edges)
    hist = np.concatenate(([0], hist, [0]))

    plt.plot(edges, hist, c=sm.to_rgba(np.clip(row.train_loss, norm.vmin, norm.vmax)), alpha=1.)
plt.colorbar(sm, label='Train Loss')

plt.xscale('symlog',linthreshx=1e-1)
plt.xlim(0, None)
plt.ylim(0, None)
plt.xlabel(r'$\sqrt{\lambda}$')
plt.ylabel(r'$P\left(\sqrt{\lambda}\right)$')
plt.title("Hessian Spectrum as a Function of Train Loss\nTrained Features")

In [None]:
from matplotlib import cm
sm = cm.ScalarMappable(norm=norm, cmap=cmap, )

df = trained.query('(N_del/h >= 1)').sort_values('N_del/h')
losses = df.train_loss.values
norm.autoscale(losses)

last_loss = None
for i, row in df[::-1].iterrows(): 
    vals = np.sqrt(row.eigs0)
    hist, edges = np.histogram(np.log(vals), 'sturges', density=True)
    dx = np.mean(np.diff(edges))
    edges = np.concatenate((edges[[0]]-dx, edges[1:]/2 + edges[:-1]/2 , edges[[-1]]+dx))
    edges = np.exp(edges)
    hist = np.concatenate(([0], hist, [0]))
    
    plt.plot(edges, hist, c=sm.to_rgba(np.clip(row.train_loss, norm.vmin, norm.vmax)), alpha=1.)
plt.colorbar(sm, label='Train Loss')

plt.xscale('symlog',linthreshx=1e-1)
plt.xlim(0, None)
plt.ylim(0, None)
plt.xlabel(r'$\sqrt{\lambda}$')
plt.ylabel(r'$P\left(\sqrt{\lambda}\right)$')
plt.title("Hessian Spectrum as a Function of Train Loss\nRandom Features")