In [None]:
import pandas as pd
import matplotlib.pyplot as plt
import json
import numpy as np
import subprocess
from matplotlib import rc

WIDTH = 8.5              # inches (from ICML style file)
HEIGHT = 8.5 / 1.5     # golden ratio

WIDTH *= 1.8 / 2
HEIGHT *= 1.8 / 2

rc('font', family='serif', size=8)
usetex = not subprocess.run(['which', 'pdflatex']).returncode
rc('text', usetex=usetex)
if usetex:
    rc('text.latex', preamble=r'\usepackage{times}')

In [None]:
def load_data(name):
    fname = f'<PATH>/ihead_data/{name}/res.jsonl'
    res = [json.loads(line) for line in open(fname)]
    return pd.DataFrame(res)

# Plots for the paper

In [None]:
# loss on Wo2 only
name = 'wo2only'
dim = 64

lrs = [20, 50, 100, 200]

fig, ax = plt.subplots(figsize=(.2 * WIDTH, .2 * HEIGHT))
for lr in lrs:
    df = load_data(f"nomo_{name}_{dim}_{lr}").loc[:50]
    
    ax.plot(df.epoch, df.loss, label=f"$\\eta = {lr}$")
ax.legend(fontsize=5, loc='upper right', handlelength=1)
ax.set_ylim(2.4, 5)
ax.set_xlabel(r"iteration $t$")
ax.set_ylabel(r"${\cal L}(W_t)$")
ax.set_title(fr"Training $W_O^2$ only, $d = {dim}$")
fig.savefig(f'eos_figures/wo2only_d{dim}_loss.pdf', pad_inches=0, bbox_inches='tight')


In [None]:
# margins on wo2 only
name = 'wo2only'
dim = 64
# lr = 200
value = 'wo1_margins'

for lr in [100, 200]:
    fig, ax = plt.subplots(figsize=(.2 * WIDTH, .2 * HEIGHT))

    df = load_data(f"nomo_{name}_{dim}_{lr}").loc[:50]
    dff = pd.DataFrame(df[value].values.tolist())

    np.random.seed(42)
    idxs = np.random.permutation(65)
    for i in range(5):
        ax.plot(dff.index, dff[idxs[i]])

    ax.set_xlabel(r'iteration $t$')
    ax.set_ylabel(r'$m_t(x)$')
    ax.set_yticks([0, 5])
    ax.grid(True, axis='y', alpha=.5)
    ax.set_title(rf"$W_O^2$, $d={dim}$, $\eta = {lr}$")
    fig.savefig(f'eos_figures/wo2only_d{dim}_lr{lr}_margins.pdf', pad_inches=0, bbox_inches='tight')


In [None]:
# loss on all params
name = 'all_params'
# dim = 128

lrs = [2, 5, 10, 20]

for dim in [64, 128]:
    fig, ax = plt.subplots(figsize=(.2 * WIDTH, .2 * HEIGHT))
    for lr in lrs:
        df = load_data(f"nomo_{name}_{dim}_{lr}")

        ax.plot(df.epoch, df.loss, label=f"$\\eta = {lr}$")
    if dim == 128:
        ax.set_xlabel(r'iteration $t$')
        ax.legend(fontsize=5, handlelength=1, ncol=1, loc='lower left')
    else:
        ax.set_xticks([])
    # plt.ylim(2.4, 6)
    ax.set_ylabel(r"${\cal L}(W_t)$")
    ax.set_title(fr"Train all, $d = {dim}$")
    fig.savefig(f'eos_figures/trainall_d{dim}_loss.pdf', pad_inches=0, bbox_inches='tight')

    

In [None]:
# margins on wo2, train all
name = 'all_params'
dim = 128
lr = 20
value = 'wo1_margins'

fig, ax = plt.subplots(figsize=(.2 * WIDTH, .2 * HEIGHT))

df = load_data(f"nomo_{name}_{dim}_{lr}")
dff = pd.DataFrame(df[value].values.tolist())

np.random.seed(42)
idxs = np.random.permutation(65)
for i in range(5):
    ax.plot(dff.index, dff[idxs[i]])

ax.set_yticks([0, 5])
ax.set_xticks([])
ax.grid(True, axis='y', alpha=.5)
ax.set_ylabel(r'$m_t(x)$')
ax.set_title(f"$W_O^2$, $d={dim}$, $\\eta = {lr}$")
fig.savefig('eos_figures/trainall_wo2_margins.pdf', pad_inches=0, bbox_inches='tight')


In [None]:
# margins on wk2, train all
name = 'all_params'
dim = 128
lr = 20
value = 'wk1_margins'

fig, ax = plt.subplots(figsize=(.2 * WIDTH, .2 * HEIGHT))

df = load_data(f"nomo_{name}_{dim}_{lr}")
dff = pd.DataFrame(df[value].values.tolist())

for i in range(5):
    ax.plot(dff.index, dff[i])

ax.set_yticks([0, 5])
ax.grid(True, axis='y', alpha=.5)
ax.set_xlabel(r'iteration $t$')
ax.set_ylabel(r'$m_t(x)$')
ax.set_title(f"$W_K^2$, $d={dim}$, $\\eta = {lr}$")
fig.savefig('eos_figures/trainall_wk2_margins.pdf', pad_inches=0, bbox_inches='tight')
