In [None]:
import pickle
import pandas as pd
import numpy as np

from matplotlib import pyplot as plt
from matplotlib import cm

%matplotlib inline
%config InlineBackend.figure_format = 'retina'

In [None]:
from matplotlib import colors as mcolors
from matplotlib.colors import LinearSegmentedColormap

cmap = LinearSegmentedColormap.from_list(
    'Mei2019', 
    np.array([
        (243, 232, 29),
        (245, 173, 47),
        (140, 193, 53),
        (50,  191, 133),
        (23,  167, 198),
        (36,  123, 235),
        (53,  69,  252),
        (52,  27,  203)
    ])/255., 
    N=256
)

# cmap = cc.m_bmy

gradient = np.linspace(0, 1, 256)
gradient = np.vstack((gradient, gradient))
fig = plt.figure(figsize=(6,.5))
img = plt.imshow(gradient, aspect='auto', cmap=cmap)
title = plt.title('Colormap stolen from Mei2019')

norm=mcolors.LogNorm()

In [None]:
SMALL_SIZE = 14
MEDIUM_SIZE = 18
BIGGER_SIZE = 20
BIGGEST_SIZE = 20

plt.rc('font', size=SMALL_SIZE)          # controls default text sizes
plt.rc('axes', titlesize=BIGGER_SIZE)     # fontsize of the axes title
plt.rc('axes', labelsize=MEDIUM_SIZE)    # fontsize of the x and y labels
plt.rc('xtick', labelsize=SMALL_SIZE)    # fontsize of the tick labels
plt.rc('ytick', labelsize=SMALL_SIZE)    # fontsize of the tick labels
plt.rc('legend', fontsize=SMALL_SIZE)    # legend fontsize
plt.rc('figure', titlesize=BIGGEST_SIZE)  # fontsize of the figure title

In [None]:
from scipy.spatial import distance_matrix
def smooth(x, y, h=1):
    K = np.exp(-distance_matrix(x.values.reshape(-1,1), x.values.reshape(-1,1))**2/(2*h))
    return (K@y) / (K@np.ones_like(y))

In [None]:
X_train, X_test, y_train, y_test = pickle.load(open('mnist.pkl', 'rb'))

In [None]:
results = pickle.load(open('results_with_hessian.pkl', 'rb')) + pickle.load(open('results_end_to_end.pkl', 'rb'))
result_df = pd.DataFrame.from_dict(results)


result_df['L'] = 1
result_df['h'] = result_df['N']
result_df['N'] = result_df['h']*(result_df['d'] + 1)

force = lambda y,f: 1 - y*f
loss = lambda y,f: np.mean(np.maximum(0, force(y,f))**2, -1)
N_del = lambda y,f: np.sum(force(y,f) >= 0, -1)

result_df['test_loss'] = result_df.y_test_hat.apply(lambda f: loss(y_test, f))
result_df['train_loss'] = result_df.y_train_hat.apply(lambda f: loss(y_train, f))
result_df['N_del'] = result_df.y_train_hat.apply(lambda f: N_del(y_train, f))

result_df['P/N'] = result_df['P']/result_df['N']
result_df['N_del/N'] = result_df['N_del']/result_df['N']

result_df['P/h'] = result_df['P']/result_df['h']
result_df['N_del/h'] = result_df['N_del']/result_df['h']

result_end_to_end_df = result_df[result_df['lambda'] == 0]
result_df = result_df[result_df['lambda'] > 0]

In [None]:
star_cutoff = 1e-2

N_star = result_df.groupby('step').apply(lambda df: df.query('(train_loss > @star_cutoff)')['N'].max())
result_df['N_star'] = result_df['step'].map(N_star)

h_star = result_df.groupby('step').apply(lambda df: df.query('(train_loss > @star_cutoff)')['h'].max())
result_df['h_star'] = result_df['step'].map(h_star)

result_df = result_df.query('N_star < 4e4') #cut out outliers


N_star = result_df.groupby('step').apply(lambda df: df.query('(train_loss > @star_cutoff)')['N'].max())
result_df['N_star'] = result_df['step'].map(N_star)

h_star = result_df.groupby('step').apply(lambda df: df.query('(train_loss > @star_cutoff)')['h'].max())
result_df['h_star'] = result_df['step'].map(h_star)

In [None]:
fig = plt.figure(figsize=(9,6))

plt.scatter(h_star.index, h_star)
plt.plot(h_star.index, h_star, color='none')

plt.xscale('log')
plt.yscale('log')
plt.xlabel('Training Steps')
plt.ylabel(r'$h^*$')
plt.ylim(1e1, 1e3)

fig.savefig('plots/h_star_vs_train_steps_L=1.pdf')

In [None]:
fig = plt.figure(figsize=(9,6))

by = 'step'
x_expr = 'h/P'
y_expr = 'train_loss'

sm = cm.ScalarMappable(norm=norm, cmap=cmap)

by_vals = np.array(sorted(result_df.eval(by).unique(), reverse=False))[::5][::-1] #[[0, 20, 30, 40, 50, 60, 75, 88]]
#random zorder helps with visual clarity 
extremes = by_vals[[0, -1]]
by_vals = by_vals[1:-1]
by_vals = np.random.choice(by_vals, size=len(by_vals), replace=False)
by_vals = np.append(by_vals, extremes)
norm.autoscale(by_vals)

for val in by_vals: #[::-1]:
    color = cmap(norm(val))
    
    data = result_df.query(f'{by} == @val')
    data = data.iloc[np.argsort(data.eval(x_expr))]
    
    x = data.eval(x_expr)
    y = data.eval(y_expr)
    plt.scatter(x, y, c=data.eval(by), cmap=cmap, norm=norm, alpha=.7)
    
    y_sm = smooth(np.log(x), y, .001)
    plt.plot(x, y_sm, color=color, ls=':', zorder=-1)

plt.colorbar(sm, label='Training Steps')
plt.xlabel(x_expr)
plt.ylabel(y_expr)

# plt.yscale('log')
plt.xscale('log')
# plt.axhline(star_cutoff, c='k', ls=':')

plt.xlabel(r"$h/P$")
plt.ylabel(r"Train $\mathcal{L}$")
# plt.title('L=1')
fig.savefig('plots/h_P_vs_train_loss_L=1_linear.pdf')


In [None]:
fig = plt.figure(figsize=(9,6))

by = 'step'
x_expr = 'h/P'
y_expr = 'train_loss'

sm = cm.ScalarMappable(norm=norm, cmap=cmap)

by_vals = np.array(sorted(result_df.eval(by).unique(), reverse=False))[::5][::-1] #[[0, 20, 30, 40, 50, 60, 75, 88]]
#random zorder helps with visual clarity 
extremes = by_vals[[0, -1]]
by_vals = by_vals[1:-1]
by_vals = np.random.choice(by_vals, size=len(by_vals), replace=False)
by_vals = np.append(by_vals, extremes)
norm.autoscale(by_vals)

for val in by_vals: #[::-1]:
    color = cmap(norm(val))
    
    data = result_df.query(f'{by} == @val')
    data = data.iloc[np.argsort(data.eval(x_expr))]
    
    x = data.eval(x_expr)
    y = data.eval(y_expr)
    plt.scatter(x, y, c=data.eval(by), cmap=cmap, norm=norm, alpha=.7)
    
    y_sm = smooth(np.log(x), y, .01)
    plt.plot(x, y_sm, color=color, ls=':', zorder=-1)

plt.colorbar(sm, label='Training Steps')
plt.xlabel(x_expr)
plt.ylabel(y_expr)

plt.yscale('log')
plt.xscale('log')
plt.axhline(star_cutoff, c='k', ls=':')

plt.xlabel(r"$h/P$")
plt.ylabel(r"Train $\mathcal{L}$")
# plt.title('L=1')
fig.savefig('plots/h_P_vs_train_loss_L=1_log.pdf')

In [None]:
fig = plt.figure(figsize=(9,6))

by = 'step'
x_expr = 'h/h_star'
y_expr = 'train_loss'

sm = cm.ScalarMappable(norm=norm, cmap=cmap)

by_vals = np.array(sorted(result_df.eval(by).unique(), reverse=False))[::5][::-1] #[[0, 20, 30, 40, 50, 60, 75, 88]]
#random zorder helps with visual clarity 
extremes = by_vals[[0, -1]]
by_vals = by_vals[1:-1]
by_vals = np.random.choice(by_vals, size=len(by_vals), replace=False)
by_vals = np.append(by_vals, extremes)
norm.autoscale(by_vals)

for val in by_vals: #[::-1]:
    color = cmap(norm(val))
    
    data = result_df.query(f'{by} == @val')
    data = data.iloc[np.argsort(data.eval(x_expr))]
    
    x = data.eval(x_expr)
    y = data.eval(y_expr)
    plt.scatter(x, y, c=data.eval(by), cmap=cmap, norm=norm, alpha=.7)
    
    y_sm = smooth(np.log(x), y, .01)
    plt.plot(x, y_sm, color=color, ls=':', zorder=-1)

plt.colorbar(sm, label='Training Steps')
plt.xlabel(x_expr)
plt.ylabel(y_expr)

plt.yscale('log')
plt.xscale('log')
plt.axvline(1, c='k', ls='--')


plt.xlabel(r"$h/h^*$")
plt.ylabel(r"Train $\mathcal{L}$")
fig.savefig('plots/h_h_star_vs_train_loss_L=1_log.pdf')

In [None]:
fig = plt.figure(figsize=(9,6))

by = 'step'
x_expr = 'h/P'
y_expr = 'test_loss'

sm = cm.ScalarMappable(norm=norm, cmap=cmap)

by_vals = np.array(sorted(result_df.eval(by).unique(), reverse=False))[::5][::-1] #[[0, 20, 30, 40, 50, 60, 75, 88]]
#random zorder helps with visual clarity 
extremes = by_vals[[0, -1]]
by_vals = by_vals[1:-1]
by_vals = np.random.choice(by_vals, size=len(by_vals), replace=False)
by_vals = np.append(by_vals, extremes)
norm.autoscale(by_vals)

for val in by_vals: #[::-1]:
    color = cmap(norm(val))
    
    data = result_df.query(f'{by} == @val')
    data = data.iloc[np.argsort(data.eval(x_expr))]
    
    x = data.eval(x_expr)
    y = data.eval(y_expr)
    plt.scatter(x, y, c=data.eval(by), cmap=cmap, norm=norm, alpha=.7)
    
    y_sm = smooth(np.log(x), y, .001)
    plt.plot(x, y_sm, color=color, ls=':', zorder=-1)

plt.colorbar(sm, label='Training Steps')
plt.xlabel(x_expr)
plt.ylabel(y_expr)

plt.yscale('log')
plt.xscale('log')
# plt.axhline(5e-2, c='k')#, ls=':')

plt.xlabel(r"$h/P$")
plt.ylabel(r"Test $\mathcal{L}$")
fig.savefig('plots/h_P_vs_test_loss_L=1_log.pdf')

In [None]:
fig = plt.figure(figsize=(9,6))

by = 'step'
x_expr = 'h/h_star'
y_expr = 'test_loss'

sm = cm.ScalarMappable(norm=norm, cmap=cmap)

by_vals = np.array(sorted(result_df.eval(by).unique(), reverse=False))[::5][::-1] #[[0, 20, 30, 40, 50, 60, 75, 88]]
#random zorder helps with visual clarity 
extremes = by_vals[[0, -1]]
by_vals = by_vals[1:-1]
by_vals = np.random.choice(by_vals, size=len(by_vals), replace=False)
by_vals = np.append(by_vals, extremes)
norm.autoscale(by_vals)

for val in by_vals: #[::-1]:
    color = cmap(norm(val))
    
    data = result_df.query(f'{by} == @val')
    data = data.iloc[np.argsort(data.eval(x_expr))]
    
    x = data.eval(x_expr)
    y = data.eval(y_expr)
    plt.scatter(x, y, c=data.eval(by), cmap=cmap, norm=norm, alpha=.7)
    
    y_sm = smooth(np.log(x), y, .001)
    plt.plot(x, y_sm, color=color, ls=':', zorder=-1)
    
plt.colorbar(sm, label='Training Steps')
plt.xlabel(x_expr)
plt.ylabel(y_expr)

plt.yscale('log')
plt.xscale('log')
# plt.xlim(0, 2)
plt.axvline(1, c='k', ls='--')

plt.xlabel(r"$h/h^*$")
plt.ylabel(r"Test $\mathcal{L}$")
fig.savefig('plots/h_h_star_vs_test_loss_L=1_log.pdf')

In [None]:
fig = plt.figure(figsize=(9,6))

by = 'step'
x_expr = 'h/h_star'
y_expr = 'N_del/h'

sm = cm.ScalarMappable(norm=norm, cmap=cmap)

by_vals = np.array(sorted(result_df.eval(by).unique(), reverse=False))[::5][::-1] #[[0, 20, 30, 40, 50, 60, 75, 88]]
#random zorder helps with visual clarity 
extremes = by_vals[[0, -1]]
by_vals = by_vals[1:-1]
by_vals = np.random.choice(by_vals, size=len(by_vals), replace=False)
by_vals = np.append(by_vals, extremes)
norm.autoscale(by_vals)

for val in by_vals: #[::-1]:
    color = cmap(norm(val))
    
    data = result_df.query(f'{by} == @val')
    data = data.iloc[np.argsort(data.eval(x_expr))]
    
    x = data.eval(x_expr)
    y = data.eval(y_expr)
    plt.scatter(x, y, c=data.eval(by), cmap=cmap, norm=norm, alpha=.7)
    
    y_sm = smooth(np.log(x), y, .001)
    plt.plot(x, y_sm, color=color, ls=':')

plt.colorbar(sm, label='Training Steps')
plt.xlabel(x_expr)
plt.ylabel(y_expr)

plt.yscale('log')
plt.xscale('log')
# plt.xlim(0, 2)
plt.axvline(1, c='k', ls='--', alpha=.7)
plt.axhline(1, c='k', ls='--', alpha=.7)


plt.xlabel(r"$h/h^*$")
plt.ylabel(r"$N_\Delta/h$")
fig.savefig('plots/h_h_star_vs_N_del_h_L=1.pdf')

In [None]:
fig = plt.figure(figsize=(9,6))

# by = 'step'
x_expr = 'h/h_star'
y_expr = 'N_del/N'

sm = cm.ScalarMappable(norm=norm, cmap=cmap)

# by_vals = np.array(sorted(result_df.eval(by).unique(), reverse=False))[::5][::-1] #[[0, 20, 30, 40, 50, 60, 75, 88]]
# norm.autoscale(by_vals)

for val in by_vals: #[::-1]:
    color = cmap(norm(val))
    
    data = result_df.query(f'{by} == @val')
    data = data.iloc[np.argsort(data.eval(x_expr))]
    
    x = data.eval(x_expr)
    y = data.eval(y_expr)
    plt.scatter(x, y, c=data.eval(by), cmap=cmap, norm=norm, alpha=.7)
    
    y_sm = smooth(np.log(x), y, .001)
    plt.plot(x, y_sm, color=color, ls=':')

plt.colorbar(sm, label='Training Steps')
plt.xlabel(x_expr)
plt.ylabel(y_expr)

plt.yscale('log')
plt.xscale('log')
# plt.xlim(0, 2)
plt.axvline(1, c='k', ls='--', alpha=.7)
plt.axhline(1, c='k', ls='--', alpha=.7)


plt.xlabel(r"$h/h^*$")
plt.ylabel(r"$N_\Delta/N$")
fig.savefig('plots/h_h_star_vs_N_del_N_L=1.pdf')

In [None]:
fig = plt.figure(figsize=(9,6))

by = 'step'
x_expr = 'train_loss'
y_expr = 'N_del/h'

sm = cm.ScalarMappable(norm=norm, cmap=cmap)

by_vals = np.array(sorted(result_df.eval(by).unique(), reverse=False))[::5][::-1] #[[0, 20, 30, 40, 50, 60, 75, 88]]
#random zorder helps with visual clarity 
extremes = by_vals[[0, -1]]
by_vals = by_vals[1:-1]
by_vals = np.random.choice(by_vals, size=len(by_vals), replace=False)
by_vals = np.append(by_vals, extremes)
norm.autoscale(by_vals)

for val in by_vals: #[::-1]:
    color = cmap(norm(val))
    
    data = result_df.query(f'{by} == @val')
    data = data.iloc[np.argsort(data.eval(x_expr))]
    
    x = data.eval(x_expr)
    y = data.eval(y_expr)
    plt.scatter(x, y, c=data.eval(by), cmap=cmap, norm=norm, alpha=.7)
    
#     y_sm = smooth(np.log(x), y, .001)
#     plt.plot(x, y_sm, color=color, ls=':')

plt.colorbar(sm, label='Training Steps')
plt.xlabel(x_expr)
plt.ylabel(y_expr)

# plt.yscale('log')
# plt.xscale('log')
plt.ylim(0, 2)
plt.axhline(1, c='k', ls='--', alpha=.7)


plt.xlabel(r"Train $\mathcal{L}$")
plt.ylabel(r"$N_\Delta/h$")
# fig.savefig('plots/h_h_star_vs_N_del_h_L=1.pdf')

In [None]:
fig = plt.figure(figsize=(9,6))

by = 'step'
x_expr = 'train_loss'
y_expr = 'N_del/N'

sm = cm.ScalarMappable(norm=norm, cmap=cmap)

by_vals = np.array(sorted(result_df.eval(by).unique(), reverse=False))[::5][::-1] #[[0, 20, 30, 40, 50, 60, 75, 88]]
#random zorder helps with visual clarity 
extremes = by_vals[[0, -1]]
by_vals = by_vals[1:-1]
by_vals = np.random.choice(by_vals, size=len(by_vals), replace=False)
by_vals = np.append(by_vals, extremes)
norm.autoscale(by_vals)

for val in by_vals: #[::-1]:
    color = cmap(norm(val))
    
    data = result_df.query(f'{by} == @val')
    data = data.iloc[np.argsort(data.eval(x_expr))]
    
    x = data.eval(x_expr)
    y = data.eval(y_expr)
    plt.scatter(x, y, c=data.eval(by), cmap=cmap, norm=norm, alpha=.7)
    
#     y_sm = smooth(np.log(x), y, .001)
#     plt.plot(x, y_sm, color=color, ls=':')

plt.colorbar(sm, label='Training Steps')
plt.xlabel(x_expr)
plt.ylabel(y_expr)

plt.yscale('log')
# plt.yscale('symlog', linthreshy=1)
# plt.xscale('log')
# # plt.xlim(0, 2)
plt.axhline(1, c='k', ls='--', alpha=.7)
# plt.ylim(0, 10)


plt.xlabel(r"Train $\mathcal{L}$")
plt.ylabel(r"$N_\Delta/N$")
# fig.savefig('plots/h_h_star_vs_N_del_h_L=1.pdf')

In [None]:
fig = plt.figure(figsize=(9,6))

by = 'step'
x_expr = 'train_loss'
y_expr = 'N_del/h'

sm = cm.ScalarMappable(norm=norm, cmap=cmap)

by_vals = np.array(sorted(result_df.eval(by).unique(), reverse=False))[::5][::-1] #[[0, 20, 30, 40, 50, 60, 75, 88]]
#random zorder helps with visual clarity 
extremes = by_vals[[0, -1]]
by_vals = by_vals[1:-1]
by_vals = np.random.choice(by_vals, size=len(by_vals), replace=False)
by_vals = np.append(by_vals, extremes)
norm.autoscale(by_vals)

for val in by_vals: #[::-1]:
    color = cmap(norm(val))
    
    data = result_df.query(f'{by} == @val')
    data = data.iloc[np.argsort(data.eval(x_expr))]
    
    x = data.eval(x_expr)
    y = data.eval(y_expr)
    plt.scatter(x, y, c=data.eval(by), cmap=cmap, norm=norm, alpha=.7)
    
#     y_sm = smooth(np.log(x), y, .001)
#     plt.plot(x, y_sm, color=color, ls=':')

plt.colorbar(sm, label='Training Steps')
plt.xlabel(x_expr)
plt.ylabel(y_expr)

# plt.yscale('log')
# plt.xscale('log')
# # plt.xlim(0, 2)
plt.axhline(1, c='k', ls='--', alpha=.7)
plt.ylim(0, 10)


plt.xlabel(r"Train $\mathcal{L}$")
plt.ylabel(r"$N_\Delta/N$")
# fig.savefig('plots/h_h_star_vs_N_del_h_L=1.pdf')

In [None]:
fig = plt.figure(figsize=(9,6))

# by = 'step'
x_expr = 'N_del/h'
y_expr = 'test_loss'

sm = cm.ScalarMappable(norm=norm, cmap=cmap)

# by_vals = np.array(sorted(result_df.eval(by).unique(), reverse=False))[::5][::-1] #[[0, 20, 30, 40, 50, 60, 75, 88]]
# norm.autoscale(by_vals)

for val in by_vals: #[::-1]:
    color = cmap(norm(val))
    
    data = result_df.query(f'{by} == @val')
    data = data.iloc[np.argsort(data.eval(x_expr))]
    
    x = data.eval(x_expr)
    y = data.eval(y_expr)
    plt.scatter(x, y, c=data.eval(by), cmap=cmap, norm=norm, alpha=.7)
    
    y_sm = y #smooth(np.log(x), y, .0001)
    plt.plot(x, y_sm, color=color, ls=':')

plt.colorbar(sm, label='Training Steps')
plt.xlabel(x_expr)
plt.ylabel(y_expr)

plt.yscale('log')
plt.xscale('log')
# plt.xlim(0, 2)
plt.axvline(1, c='k', ls='--', alpha=.7)
# plt.axhline(1, c='k', ls='--', alpha=.7)


# plt.xlabel(r"$h/h^*$")
# plt.ylabel(r"$N_\Delta/N$")
# fig.savefig('plots/h_h_star_vs_N_del_N_L=1.pdf')

In [None]:
fig = plt.figure(figsize=(9,6))

# by = 'step'
x_expr = 'N_del/N'
y_expr = 'test_loss'

sm = cm.ScalarMappable(norm=norm, cmap=cmap)

# by_vals = np.array(sorted(result_df.eval(by).unique(), reverse=False))[::5][::-1] #[[0, 20, 30, 40, 50, 60, 75, 88]]
# norm.autoscale(by_vals)

for val in by_vals: #[::-1]:
    color = cmap(norm(val))
    
    data = result_df.query(f'{by} == @val')
    data = data.iloc[np.argsort(data.eval(x_expr))]
    
    x = data.eval(x_expr)
    y = data.eval(y_expr)
    plt.scatter(x, y, c=data.eval(by), cmap=cmap, norm=norm, alpha=.7)
    
    y_sm = y #smooth(np.log(x), y, .0001)
    plt.plot(x, y_sm, color=color, ls=':')

plt.colorbar(sm, label='Training Steps')
plt.xlabel(x_expr)
plt.ylabel(y_expr)

plt.yscale('log')
plt.xscale('log')
# plt.xlim(0, 2)
plt.axvline(1, c='k', ls='--', alpha=.7)
# plt.axhline(1, c='k', ls='--', alpha=.7)


# plt.xlabel(r"$h/h^*$")
# plt.ylabel(r"$N_\Delta/N$")
# fig.savefig('plots/h_h_star_vs_N_del_N_L=1.pdf')

In [None]:
result_df.set_index('N')

In [None]:
plt.figure(figsize=(9,6))

by = 'step'
x_expr = 'N'
y_expr = 'train_loss'

sm = cm.ScalarMappable(norm=norm, cmap=cmap)

by_vals = np.array(sorted(result_df.eval(by).unique(), reverse=False))[::5] #[[0, 20, 30, 40, 50, 60, 75, 88]]
norm.autoscale(by_vals)

for val in by_vals: #[::-1]:
    color = cmap(norm(val))
    
    data = result_df.query(f'{by} == @val')
    data = data.iloc[np.argsort(data.eval(x_expr))]
    N_star = data.query('train_loss > 3.5e-2')['N'].max()
    
    x = data.eval(x_expr)/N_star
    y = data.eval(y_expr)
    plt.scatter(x, y, c=data.eval(by), cmap=cmap, norm=norm, alpha=.7)
    
    y_sm = smooth(np.log10(x), y, .01)
    plt.plot(x, y_sm, color=color, ls=':')

plt.colorbar(sm, label=by)
plt.xlabel(x_expr)
plt.ylabel(y_expr)

plt.yscale('log')
plt.xscale('log')
# plt.ylim(0, 1)
# plt.axvline(1, c='k')#, ls=':')

In [None]:
x

In [None]:
plt.figure(figsize=(9,6))

by = 'step'
x_expr = 'N_del/h'
y_expr = 'test_loss'

sm = cm.ScalarMappable(norm=norm, cmap=cmap)

by_vals = np.array(sorted(result_df.eval(by).unique(), reverse=False))[::10] #[[0, 20, 30, 40, 50, 60, 75, 88]]
norm.autoscale(by_vals)

for val in by_vals:
    color = cmap(norm(val))
    
    data = result_df.query(f'{by} == @val')
    data = data.iloc[np.argsort(data.eval(x_expr))][::2]
    
    x = data.eval(x_expr)
    y = data.eval(y_expr)
    plt.scatter(x, y, c=data.eval(by), cmap=cmap, norm=norm, alpha=.7)
    
    y_sm = smooth(np.log(x), y, .1)
    plt.plot(x, y_sm, color=color, ls=':')

plt.colorbar(sm, label=by)
plt.xlabel(x_expr)
plt.ylabel(y_expr)

plt.yscale('log')
plt.xscale('log')
# plt.axvline(1, c='k')#, ls=':')

In [None]:
plt.figure(figsize=(9,6))
data = result_df

#invisible plot to set the limits correctly because matplotlib gets confused with log scale scatters
x = data['N_del/N']
y = data['train_loss']

plt.plot(x, y, color='none')
plt.scatter(x, y, c=data['step'], cmap=cmap, norm=norm, alpha=.7)
plt.colorbar(label='Training steps')

plt.xlabel(r'$N_\Delta/N$')
plt.ylabel(r'Train $\mathcal{L}$')


plt.yscale('log')
plt.xscale('log')

plt.axvline(1, color='k',ls=':')

In [None]:
plt.figure(figsize=(9,6))
data = result_df

#invisible plot to set the limits correctly because matplotlib gets confused with log scale scatters
x = data['N_del/N']
y = data['train_loss']

plt.plot(x, y, color='none')
plt.scatter(x, y, c=data['step'], cmap=cmap, norm=norm, alpha=.7)
plt.colorbar(label='Training steps')

plt.xlabel(r'$N_\Delta/N$')
plt.ylabel(r'Train $\mathcal{L}$')

#Plot end_to_end result
data = result_end_to_end_df
x = data['N_del/N']
y = data['train_loss']
plt.scatter(x, y, c='k')

# plt.yscale('log')
# plt.xscale('log')
plt.axvline(1, color='k',ls=':')

In [None]:
plt.figure(figsize=(9,6))
data = result_df

#invisible plot to set the limits correctly because matplotlib gets confused with log scale scatters
x = data['N_del/h']
y = data['train_loss']

plt.plot(x, y, color='none')
plt.scatter(x, y, c=data['step'], cmap=cmap, norm=norm, alpha=.7, label=r'Final layer SVM ($\lambda=1\times10^{-13}$)')
plt.colorbar(label='Training steps')

#Plot end_to_end result
data = result_end_to_end_df
x = data['N_del/h']
y = data['train_loss']
plt.scatter(x, y, c='k', label='Fully trained network ($\lambda=0$)', marker='x')

plt.xlabel(r'$N_\Delta/h$')
plt.ylabel(r'Train $\mathcal{L}$')

# plt.yscale('log')
plt.xscale('log')
plt.axvline(1, color='k',ls=':')

plt.legend()

In [None]:
plt.figure(figsize=(9,6))
data = result_df

#invisible plot to set the limits correctly because matplotlib gets confused with log scale scatters
x = data['N_del/h']
y = data['train_loss']

plt.plot(x, y, color='none')
plt.scatter(x, y, c=data['step'], cmap=cmap, norm=norm, alpha=.7, label=r'Final layer SVM ($\lambda=1\times10^{-13}$)')
plt.colorbar(label='Training steps')

#Plot end_to_end result
data = result_end_to_end_df.query('N < 300')
x = data['N_del/h']
y = data['train_loss']
plt.scatter(x, y, c='k', label='Fully trained network ($\lambda=0$)', marker='x')

plt.xlabel(r'$N_\Delta/h$')
plt.ylabel(r'Train $\mathcal{L}$')

plt.yscale('log')
plt.xscale('log')
plt.axvline(1, color='k',ls=':')
plt.axvline(21, color='k',ls=':')
plt.legend()

In [None]:
plt.figure(figsize=(9,6))
data = result_df.query('step > 1e5')

#invisible plot to set the limits correctly because matplotlib gets confused with log scale scatters
x = data['N_del/N']
y = data['train_loss']

plt.plot(x, y, color='none')
plt.scatter(x, y, c=data['step'], cmap=cmap, norm=norm, alpha=.7)
plt.colorbar(label='Training steps')

plt.xlabel(r'$N_\Delta/N$')
plt.ylabel(r'Train $\mathcal{L}$')

# plt.yscale('log')
plt.xscale('log')
plt.axvline(1, color='k',ls=':')

In [None]:
plt.figure(figsize=(9,6))
data = result_df

#invisible plot to set the limits correctly because matplotlib gets confused with log scale scatters
x = data['P/N']
y = data['train_loss']

plt.plot(x, y, color='none')
plt.scatter(x, y, c=data['step'], cmap=cmap, norm=norm, alpha=.7, label=r'Final layer SVM ($\lambda=1\times10^{-13}$)')
plt.colorbar(label='Training steps')


#Plot end_to_end result
data = result_end_to_end_df
x = data['P/N']
y = data['train_loss']
plt.scatter(x, y, c='k', label='Fully trained network ($\lambda=0$)', marker='x')


plt.xlabel(r'$P/N$')
plt.ylabel(r'Train $\mathcal{L}$')

# plt.yscale('log')
plt.xscale('log')
plt.legend()

In [None]:
plt.figure(figsize=(9,6))
data = result_df

#invisible plot to set the limits correctly because matplotlib gets confused with log scale scatters
x = data['P/N']
y = data['train_loss']

plt.plot(x, y, color='none')
plt.scatter(x, y, c=data['step'], cmap=cmap, norm=norm, alpha=.7, label=r'Final layer SVM ($\lambda=1\times10^{-13}$)')
plt.colorbar(label='Training steps')


#Plot end_to_end result
data = result_end_to_end_df
x = data['P/N']
y = data['train_loss']
plt.scatter(x, y, c='k', label='Fully trained network ($\lambda=0$)', marker='x')


plt.xlabel(r'$P/N$')
plt.ylabel(r'Train $\mathcal{L}$')

plt.yscale('log')
plt.xscale('log')
plt.legend()

In [None]:
plt.figure(figsize=(9,6))
data = result_df

#invisible plot to set the limits correctly because matplotlib gets confused with log scale scatters
x = data['N']
y = data['train_loss']

plt.plot(x, y, color='none')
plt.scatter(x, y, c=data['step'], cmap=cmap, norm=norm, alpha=.7, label=r'Final layer SVM ($\lambda=1\times10^{-13}$)')
plt.colorbar(label='Training steps')


#Plot end_to_end result
data = result_end_to_end_df
x = data['N']
y = data['train_loss']
plt.scatter(x, y, c='k', label='Fully trained network ($\lambda=0$)', marker='x')


plt.xlabel(r'$N$')
plt.ylabel(r'Train $\mathcal{L}$')

plt.yscale('log')
plt.xscale('log')
plt.legend()

#invisible plot to set the limits correctly because matplotlib gets confused with log scale scatters
x = data['N']
y = data['test_loss']

plt.plot(x, y, color='none')
plt.scatter(x, y, c=data['step'], cmap=cmap, norm=norm, alpha=.7, label=r'Final layer SVM ($\lambda=1\times10^{-13}$)')
plt.colorbar(label='Training steps')

In [None]:
plt.figure(figsize=(9,6))
data = result_df

#invisible plot to set the limits correctly because matplotlib gets confused with log scale scatters
x = data['N']
y = data['test_loss']

plt.plot(x, y, color='none')
plt.scatter(x, y, c=data['step'], cmap=cmap, norm=norm, alpha=.7, label=r'Final layer SVM ($\lambda=1\times10^{-13}$)')
plt.colorbar(label='Training steps')


#Plot end_to_end result
data = result_end_to_end_df
x = data['N']
y = data['test_loss']
plt.scatter(x, y, c='k', label='Fully trained network ($\lambda=0$)', marker='x')


plt.xlabel(r'$N$')
plt.ylabel(r'Test $\mathcal{L}$')

plt.yscale('log')
plt.xscale('log')
plt.legend()

In [None]:
plt.figure(figsize=(9,6))
data = result_df

#invisible plot to set the limits correctly because matplotlib gets confused with log scale scatters
x = data['P/N']
y = data['test_loss']

plt.plot(x, y, color='none')
plt.scatter(x, y, c=data['step'], cmap=cmap, norm=norm, alpha=.7, label=r'Final layer SVM ($\lambda=1\times10^{-13}$)')
plt.colorbar(label='Training steps')

#Plot end_to_end result
data = result_end_to_end_df
x = data['P']/data['N']
y = data['test_loss']
plt.scatter(x, y, c='k', label='Fully trained network ($\lambda=0$)', marker='x')

plt.xlabel(r'$P/N$')
plt.ylabel(r'Test $\mathcal{L}$')

plt.yscale('log')
plt.xscale('log')

plt.legend()

In [None]:
plt.figure(figsize=(9,6))

for step in sorted(result_df.step.unique())[::15]:
    data = result_df.query('step == @step')

    #invisible plot to set the limits correctly because matplotlib gets confused with log scale scatters
    x = data['P/N']
    y = data['test_loss']

    plt.plot(x.values[np.argsort(x)], y.values[np.argsort(x)], ls=':', c='k', alpha=.2)
    plt.scatter(x, y, c=data['step'], cmap=cmap, norm=norm, alpha=.7)
plt.colorbar(label='Training steps')

plt.xlabel(r'$P/N$')
plt.ylabel(r'Test $\mathcal{L}$')

plt.yscale('log')
plt.xscale('log')

In [None]:
plt.figure(figsize=(9,6))
data = result_df

#invisible plot to set the limits correctly because matplotlib gets confused with log scale scatters
x = data['N_del/N']
y = data['test_loss']

plt.plot(x, y, color='none')
plt.scatter(x, y, c=data['step'], cmap=cmap, norm=norm, alpha=.7, label=r'Final layer SVM ($\lambda=1\times10^{-13}$)')
plt.colorbar(label='Training steps')

#Plot end_to_end result
data = result_end_to_end_df
x = data['N_del/N']
y = data['test_loss']
plt.scatter(x, y, c='k', label='Fully trained network ($\lambda=0$)', marker='x')

plt.xlabel(r'$N_\Delta/N$')
plt.ylabel(r'Test $\mathcal{L}$')

plt.yscale('log')
plt.xscale('log')

plt.legend()

In [None]:
plt.figure(figsize=(9,6))
data = result_df[result_df.step < 1e3]

#invisible plot to set the limits correctly because matplotlib gets confused with log scale scatters
x = data['N_del/N']
y = data['test_loss']

plt.plot(x, y, color='none')
plt.scatter(x, y, c=data['step'], cmap=cmap, norm=norm, alpha=.7, label=r'Final layer SVM ($\lambda=1\times10^{-13}$)')
plt.colorbar(label='Training steps')

#Plot end_to_end result
data = result_end_to_end_df
x = data['N_del/N']
y = data['test_loss']
plt.scatter(x, y, c='k', label='Fully trained network ($\lambda=0$)', marker='x')

plt.xlabel(r'$N_\Delta/N$')
plt.ylabel(r'Test $\mathcal{L}$')

plt.yscale('log')
plt.xscale('log')

plt.legend()

In [None]:
plt.figure(figsize=(9,6))
data = result_df[result_df.step > 1e3]

#invisible plot to set the limits correctly because matplotlib gets confused with log scale scatters
x = data['N_del/N']
y = data['test_loss']

plt.plot(x, y, color='none')
plt.scatter(x, y, c=data['step'], cmap=cmap, norm=norm, alpha=.7, label=r'Final layer SVM ($\lambda=1\times10^{-13}$)')
plt.colorbar(label='Training steps')

#Plot end_to_end result
data = result_end_to_end_df
x = data['N_del/N']
y = data['test_loss']
plt.scatter(x, y, c='k', label='Fully trained network ($\lambda=0$)', marker='x')

plt.xlabel(r'$N_\Delta/N$')
plt.ylabel(r'Test $\mathcal{L}$')

plt.yscale('log')
plt.xscale('log')

plt.legend()

In [None]:
plt.figure(figsize=(9,6))
data = result_df[result_df.step > 1e5]

#invisible plot to set the limits correctly because matplotlib gets confused with log scale scatters
x = data['N_del/N']
y = data['test_loss']

plt.plot(x, y, color='none')
plt.scatter(x, y, c=data['step'], cmap=cmap, norm=norm, alpha=.7, label=r'Final layer SVM ($\lambda=1\times10^{-13}$)')
plt.colorbar(label='Training steps')

#Plot end_to_end result
data = result_end_to_end_df
x = data['N_del/N']
y = data['test_loss']
plt.scatter(x, y, c='k', label='Fully trained network ($\lambda=0$)', marker='x')

plt.xlabel(r'$N_\Delta/N$')
plt.ylabel(r'Test $\mathcal{L}$')

plt.yscale('log')
plt.xscale('log')

plt.legend(loc='lower right')
plt.axvline(21, ls=':', c='k')

In [None]:
plt.figure(figsize=(9,6))
data = result_df

#invisible plot to set the limits correctly because matplotlib gets confused with log scale scatters
x = data['P/N']
y = data['N_del/N']

plt.plot(x, y, color='none')
plt.scatter(x, y, c=data['step'], cmap=cmap, norm=norm, alpha=.7, label=r'Final layer SVM ($\lambda=1\times10^{-13}$)')
plt.colorbar(label='Training steps')

#Plot end_to_end result
data = result_end_to_end_df
x = data['P/N']
y = data['N_del/N']
plt.scatter(x, y, c='k', label='Fully trained network ($\lambda=0$)', marker='x')

plt.xlabel('P/N')
plt.ylabel(r'$N_{\Delta}/N$')

plt.yscale('log')
plt.xscale('log')
plt.axhline(1, color='k',ls=':')
plt.legend()

In [None]:
plt.figure(figsize=(9,6))
data = result_df

#invisible plot to set the limits correctly because matplotlib gets confused with log scale scatters
x = data['N']/data['P']
y = data['N_del/N']

plt.plot(x, y, color='none')
plt.scatter(x, y, c=data['step'], cmap=cmap, norm=norm, alpha=.7, label=r'Final layer SVM ($\lambda=1\times10^{-13}$)')
plt.colorbar(label='Training steps')

#Plot end_to_end result
data = result_end_to_end_df
x = data['N']/data['P']
y = data['N_del/N']
plt.scatter(x, y, c='k', label='Fully trained network ($\lambda=0$)', marker='x')

plt.xlabel('N/P')
plt.ylabel(r'$N_{\Delta}/N$')

plt.yscale('log')
plt.xscale('log')
plt.axhline(1, color='k',ls=':')
plt.legend()

In [None]:
plt.figure(figsize=(9,6))
data = result_df

#invisible plot to set the limits correctly because matplotlib gets confused with log scale scatters
x = data['h']/data['P']
y = data['N_del/h']

plt.plot(x, y, color='none')
plt.scatter(x, y, c=data['step'], cmap=cmap, norm=norm, alpha=.7, label=r'Final layer SVM ($\lambda=1\times10^{-13}$)')
plt.colorbar(label='Training steps')

#Plot end_to_end result
data = result_end_to_end_df
x = data['h']/data['P']
y = data['N_del/h']
plt.scatter(x, y, c='k', label='Fully trained network ($\lambda=0$)', marker='x')

plt.xlabel('h/P')
plt.ylabel(r'$N_{\Delta}/h$')

plt.yscale('log')
plt.xscale('log')
plt.axhline(1, color='k',ls=':')
plt.legend()

In [None]:
plt.figure(figsize=(9,6))
data = result_df

#invisible plot to set the limits correctly because matplotlib gets confused with log scale scatters
x = data['N']/data['P']
y = data['train_loss']

plt.plot(x, y, color='none')
plt.scatter(x, y, c=data['step'], cmap=cmap, norm=norm, alpha=.7, label=r'Final layer SVM ($\lambda=1\times10^{-13}$)')
plt.colorbar(label='Training steps')

#Plot end_to_end result
data = result_end_to_end_df
x = data['N']/data['P']
y = data['train_loss']
plt.scatter(x, y, c='k', label='Fully trained network ($\lambda=0$)', marker='x')

plt.xlabel('N/P')
plt.ylabel(r'Train $\mathcal{L}$')

plt.yscale('log')
plt.xscale('log')
plt.legend()

In [None]:
plt.figure(figsize=(9,6))
data = result_df

#invisible plot to set the limits correctly because matplotlib gets confused with log scale scatters
x = data['h']/data['P']
y = data['train_loss']

plt.plot(x, y, color='none')
plt.scatter(x, y, c=data['step'], cmap=cmap, norm=norm, alpha=.7, label=r'Final layer SVM ($\lambda=1\times10^{-13}$)')
plt.colorbar(label='Training steps')

#Plot end_to_end result
data = result_end_to_end_df
x = data['h']/data['P']
y = data['train_loss']
plt.scatter(x, y, c='k', label='Fully trained network ($\lambda=0$)', marker='x')

plt.xlabel('h/P')
plt.ylabel(r'Train $\mathcal{L}$')

plt.yscale('log')
plt.xscale('log')
plt.legend()

In [None]:
plt.figure(figsize=(9,6))
data = result_df

#invisible plot to set the limits correctly because matplotlib gets confused with log scale scatters
x = data['N']/data['P']
y = data['test_loss']

plt.plot(x, y, color='none')
plt.scatter(x, y, c=data['step'], cmap=cmap, norm=norm, alpha=.7, label=r'Final layer SVM ($\lambda=1\times10^{-13}$)')
plt.colorbar(label='Training steps')

#Plot end_to_end result
data = result_end_to_end_df
x = data['N']/data['P']
y = data['test_loss']
plt.scatter(x, y, c='k', label='Fully trained network ($\lambda=0$)', marker='x')

plt.xlabel('N/P')
plt.ylabel(r'Test $\mathcal{L}$')

plt.yscale('log')
plt.xscale('log')
plt.legend()
# plt.ylim(0, 1)
plt.axhline(np.min(data['test_loss']))

In [None]:
plt.figure(figsize=(9,6))
data = result_df

#invisible plot to set the limits correctly because matplotlib gets confused with log scale scatters
x = data['h']/data['P']
y = data['test_loss']

plt.plot(x, y, color='none')
plt.scatter(x, y, c=data['step'], cmap=cmap, norm=norm, alpha=.7, label=r'Final layer SVM ($\lambda=1\times10^{-13}$)')
plt.colorbar(label='Training steps')

#Plot end_to_end result
data = result_end_to_end_df
x = data['h']/data['P']
y = data['test_loss']
plt.scatter(x, y, c='k', label='Fully trained network ($\lambda=0$)', marker='x')

plt.xlabel('h/P')
plt.ylabel(r'Test $\mathcal{L}$')

plt.yscale('log')
plt.xscale('log')
plt.legend()
# plt.ylim(0, 1)
plt.axhline(np.min(data['test_loss']))

Why does the threshold $N_\Delta/N = 1$ persist even throughout training?
- maybe it doesn't, but the change in $N_{eff} \ $ is linear rather than exponential, so it isn't showing up on the log-log plots?

In [None]:
#Plot end_to_end result
data = result_end_to_end_df
x = data['P']/data['N_tilde']
y = data['N_del']/data['N_tilde']
plt.scatter(x, y, c='k')
# plt.xscale('log')
# plt.yscale('log')
plt.ylim(0, .1)
plt.xlim(1e-3, 3)

In [None]:
plt.figure(figsize=(9,6))
data = result_df

#invisible plot to set the limits correctly because matplotlib gets confused with log scale scatters
x = data['N']
y = data['N_del']

plt.plot(x, y, color='none')
plt.scatter(x, y, c=data['step'], cmap=cmap, norm=norm, alpha=.7, label=r'Final layer SVM ($\lambda=1\times10^{-13}$)')
plt.colorbar(label='Training steps')

#Plot end_to_end result
data = result_end_to_end_df
x = data['N']
y = data['N_del']
plt.scatter(x, y, c='k', label='Fully trained network ($\lambda=0$)', marker='x')

plt.xlabel('N')
plt.ylabel(r'$N_{\Delta}$')

plt.yscale('log')
plt.xscale('log')
plt.legend()

# Hessian

In [None]:
untrained = result_df.query("step == @result_df['step'].min()")
trained = result_df.query("step == @result_df['step'].max()")

In [None]:
x = 'h'
y = 'train_loss'

plt.scatter(untrained[x], untrained[y])
plt.scatter(trained[x], trained[y])

for step in sorted(result_df['step'].unique()):
    df = result_df.query("step == @step")
    # Row with minimum value of N_del/N where train loss is non-zero and N_del/N >= 1 (underparameterized)
    row = df.query('(train_loss > 1e-2)').sort_values('N').iloc[-1]
    plt.scatter(row[x], row[y], c='k')
    
plt.yscale('log')
plt.xscale('log')

plt.xlabel(r'$N_\Delta/N$')
plt.ylabel('Train Loss')

In [None]:
plt.scatter(untrained.eval('h/h_star'), untrained.eval('train_loss'))

In [None]:
def symlog(x, thresh):
    a = np.sign(x)
    x_ = np.abs(x)
    return np.where(x_ < thresh, x, a*(thresh + np.log10(x_/thresh)))

def symexp(x, thresh):
    a = np.sign(x)
    x_ = np.abs(x)
    return np.where(x_ < thresh, x, a*np.power(10, x_ - thresh)*thresh)

In [None]:
def sqrtbp(x):
    a = np.sign(x)
    x_ = np.abs(x)
    return a*np.sqrt(x_)

In [None]:
import seaborn as sns

In [None]:
from matplotlib.scale import register_scale
register_scale()

In [None]:
data = untrained.query('h/h_star > 1e-1').query('h/h_star < 1.1').sort_values('h')[::3]
thresh = 1e-5


vals = data.eval('h/h_star')
norm.autoscale(vals.values)
norm.vmax = 1
# norm.vmin = 1e-1
sm = cm.ScalarMappable(norm=norm, cmap=cmap)

for idx, row in data.iterrows():
    eigs = np.sqrt(np.maximum(row['eigs0'], 0))
    ax = sns.kdeplot(symlog(eigs, thresh), bw=.01, color=cmap(norm(vals[idx])), alpha=.8)
    line = ax.get_lines()[-1]
    x, y = line.get_data()
    line.set_data(symexp(x, thresh), y)

plt.ylim(0, 2.5)
plt.colorbar(sm, label=r'$h/h^*$')
plt.xlim(0, 500)
plt.xscale('symlog', linthreshx=1e-2)
             
plt.xlabel(r'$\sqrt{\mu}$')
plt.ylabel(r'$P\left(\sqrt{\mu}\right)$')
plt.title("Hessian Spectrum (Random Features)")
plt.yscale('symlog', linthreshy=1e-3)

In [None]:
data = trained.query('h/h_star > 5e-1').query('h/h_star < 5').sort_values('h')#[::2]
thresh = 1e-3


vals = data.eval('h/h_star')
norm.autoscale(vals.values)
# norm.vmax = 2
# norm.vmin = 1e-1
sm = cm.ScalarMappable(norm=norm, cmap=cmap)

for idx, row in data.iterrows():
    eigs = row['eigs0']
    pos_eigs = np.sqrt(eigs[eigs > 0])
    ax = sns.kdeplot(symlog(pos_eigs, thresh), bw=.6/np.sqrt(len(pos_eigs)), color=cmap(norm(vals[idx])), alpha=.8)
    line = ax.get_lines()[-1]
    x, y = line.get_data()
    x = symexp(x, thresh)
    x = np.concatenate(([0, thresh/100], x))
    gamma = float(len(eigs) - len(pos_eigs))/len(eigs)
    y = np.concatenate(([gamma, 1], (1-gamma)*y))
    line.set_data(x, y)

plt.ylim(0, 2.5)
plt.colorbar(sm, label=r'$h/h^*$')
plt.xlim(0, 500)
plt.xscale('symlog', linthreshx=1e-3)
             
plt.xlabel(r'$\sqrt{\mu}$')
plt.ylabel(r'$P\left(\sqrt{\mu}\right)$')
plt.title("Hessian Spectrum (Trained Features)")
plt.yscale('symlog', linthreshy=1e-3)
# plt.ylim(1e-1, 10)

In [None]:
cm.brg

In [None]:
data = trained.query('h/h_star > 5e-1').query('h/h_star < 5').sort_values('h')[::2]
thresh = 1e-3


vals = data.eval('h/h_star')
norm.autoscale(vals.values)
# norm.vmax = 2
# norm.vmin = 1e-1
sm = cm.ScalarMappable(norm=norm, cmap=cm.brg)

for idx, row in data.iterrows():
    print(1/np.sqrt(len(eigs)),)
    eigs = sqrtbp(row['eigs0'])
    ax = sns.kdeplot(symlog(eigs, thresh), bw=1/np.sqrt(len(eigs)), color=cm.brg(norm(vals[idx])), alpha=.8)
    line = ax.get_lines()[-1]
    x, y = line.get_data()
    line.set_data(symexp(x, thresh), y)

plt.ylim(0, 2)
plt.colorbar(sm, label=r'$h/h^*$')
plt.xlim(-10*thresh, 500)
plt.xscale('symlog', linthreshx=100*thresh)
             
plt.xlabel(r'$\sqrt{\mu}$')
plt.ylabel(r'$P\left(\sqrt{\mu}\right)$')
plt.title("Hessian Spectrum (Trained Features)")