# Visualize results

In [1]:
from pathlib import Path

import numpy as np
import pandas as pd
import torch

from factorization.config import SAVE_DIR
CONFIG_FILE = SAVE_DIR / "config.jsonl"

In [2]:
configs_db = pd.read_json('results/config.jsonl', lines=True).dropna()

In [3]:
configs_db.head()

Unnamed: 0,vocab_size,seq_length,sparsity_index,nb_data,batch_size,nb_epochs,lr,emb_dim,nb_emb,ffn_dim,ffn_bias,ffn_dropout,activation,seed,save_weights,interactive,id
0,2,12,5,2048,2048,1000,0.01,2,3,8,True,0,gelu,0,True,False,248f59214668466c8f6764056d65fd20
1,2,12,5,2048,2048,1000,0.01,2,3,8,True,0,gelu,18,True,False,68f846c1e78c4d4ca219c21e93f12336
2,2,12,5,2048,2048,1000,0.01,2,3,8,True,0,gelu,1,True,False,57e7f4b008644286b60fa65a4aeefa17
3,2,12,5,2048,2048,1000,0.01,2,3,8,True,0,gelu,3,True,False,7884d41941e44b73ae1fe492cefb8581
4,2,12,5,2048,2048,1000,0.01,2,3,8,True,0,gelu,13,True,False,d688e069ed4c41f7ab3b198573f68050


In [4]:
keys = ['batch_size', 'lr', 'ffn_dim', 'seed', 'id']

data = pd.concat([
    pd.DataFrame(torch.stack([
        np.load(SAVE_DIR / experience.id / 'accs.pkl', allow_pickle=True),
        np.load(SAVE_DIR / experience.id / 'test_accs.pkl', allow_pickle=True),
        np.load(SAVE_DIR / experience.id / 'losses.pkl', allow_pickle=True),
        np.load(SAVE_DIR / experience.id / 'test_losses.pkl', allow_pickle=True),
        ]).T,
        columns=['acc', 'test_acc', 'loss', 'test_loss'],
    ).assign(**{key: getattr(experience, key) for key in keys} | {'epoch': range(1, 1001)})
    for experience in configs_db.itertuples() if Path(SAVE_DIR / experience.id).exists()
]).reset_index(drop=True)

In [5]:
data.tail()

Unnamed: 0,acc,test_acc,loss,test_loss,batch_size,lr,ffn_dim,seed,id,epoch
2319995,0.523438,0.48584,0.690951,0.697178,32,0.0001,128,29,c72bdf341de9470fa090581830db0ec6,996
2319996,0.524414,0.490723,0.690948,0.697299,32,0.0001,128,29,c72bdf341de9470fa090581830db0ec6,997
2319997,0.524902,0.486816,0.690938,0.697221,32,0.0001,128,29,c72bdf341de9470fa090581830db0ec6,998
2319998,0.523926,0.491211,0.690948,0.697313,32,0.0001,128,29,c72bdf341de9470fa090581830db0ec6,999
2319999,0.524902,0.48877,0.691008,0.697149,32,0.0001,128,29,c72bdf341de9470fa090581830db0ec6,1000


In [6]:
final_data = data[data.epoch == 1000].reset_index(drop=True)
final_data['success'] = final_data['test_acc'] > .99
final_data['diff'] = final_data['test_acc'] - final_data['acc']

best_data = data.groupby(['batch_size', 'ffn_dim', 'lr'])['test_acc'].max()

In [7]:
print(final_data.drop(columns=['id']).groupby(['batch_size', 'ffn_dim', 'lr'])['diff'].max())
print(final_data.drop(columns=['id']).groupby(['batch_size', 'ffn_dim', 'lr'])['success'].mean())

batch_size  ffn_dim  lr    
32          8        0.0001   -0.022461
                     0.0010    0.019531
                     0.0100    0.020020
            16       0.0001    0.024902
                     0.0010    0.012695
                     0.0100    0.017578
            32       0.0001    0.010742
                     0.0010    0.026855
                     0.0100    0.017090
            128      0.0001    0.028320
                     0.0010    0.019043
                     0.0100    0.021973
2048        8        0.0001    0.029785
                     0.0010    0.019531
                     0.0100    0.028809
            16       0.0001    0.029785
                     0.0010   -0.015625
                     0.0100    0.013184
            32       0.0001   -0.011230
                     0.0010    0.010254
                     0.0100    0.027832
            128      0.0001    0.034180
                     0.0010    0.015625
                     0.0100    0.023438
Name: diff, 

In [8]:
data.query('test_acc > 0.99')[['epoch', 'acc', 'test_acc', 'batch_size', 'lr', 'ffn_dim', 'seed', 'id']]

Unnamed: 0,epoch,acc,test_acc,batch_size,lr,ffn_dim,seed,id
64245,246,0.994629,0.993652,32,0.01,8,11,8871fe852c6e4534ad217a20358bc06e
64250,251,0.989746,0.990234,32,0.01,8,11,8871fe852c6e4534ad217a20358bc06e
64252,253,0.990723,0.992188,32,0.01,8,11,8871fe852c6e4534ad217a20358bc06e
64254,255,0.994629,0.993652,32,0.01,8,11,8871fe852c6e4534ad217a20358bc06e
64255,256,0.994629,0.991699,32,0.01,8,11,8871fe852c6e4534ad217a20358bc06e
...,...,...,...,...,...,...,...,...
2309468,469,1.000000,1.000000,32,0.01,32,71,253fd896cd2144898c0df067d45ffb31
2309469,470,1.000000,1.000000,32,0.01,32,71,253fd896cd2144898c0df067d45ffb31
2309470,471,1.000000,1.000000,32,0.01,32,71,253fd896cd2144898c0df067d45ffb31
2309471,472,1.000000,1.000000,32,0.01,32,71,253fd896cd2144898c0df067d45ffb31


From the previous print, we see that we never overfit the train set, yet we do sometime have an accuracy collapse after having reached 100 \% accuracy.