Notebook for checking what's up with the weird effect of position code.

I will repeatedly run the ABBA experiment without position codes, with sinusoidal, onehot, and with the strange case where I only add one type of postiion code

In [1]:
import sys
sys.path.append('..')
import wandb
import torch
from trans_inf_sweep import main
import random
import numpy as np
import pandas as pd
import seaborn as sns
from configs.working_config_for_ABBA import config
from utils import dotdict as dd
import os

n_runs = 2

In [2]:
config.log.log_to_wandb = False
results_folder = '../results/ABBA/postion_code_effect'
if not os.path.exists(results_folder):
    os.makedirs(results_folder)

In [None]:
all_metrics = []
for pos_encoding in ['sinusoidal', 'onehot']:
    for pos_emb_randomization in ['per_batch', 'only_first', 'per_sequence']:
        config.model.position_encoding = pos_encoding
        config.model.pos_emb_randomization = pos_emb_randomization
        for i in range(n_runs):
            fn = f'{results_folder}/ABBA_{pos_encoding}_{pos_emb_randomization}_{i}.csv'
            if os.path.exists(fn):
                print(f'File {fn} exists, skipping')
                continue
            print(f'Run {i} with {pos_encoding} and {pos_emb_randomization}')
            metrics = main(config, seq_type='ABBA')
            metrics_df = pd.DataFrame.from_dict({k:v for k,v in metrics.items() if 'accuracy' in k or 'loss' in k}).assign(run=i, position_encoding='sinusoidal', pos_emb_randomization='per_batch')
            all_metrics.append(metrics_df)
            metrics_df.to_csv(fn)

all_metrics_df = pd.concat(all_metrics)

Run 0 with sinusoidal and per_sequence
iteration 0, loss 1.3013031482696533
holdout loss: 1.278235673904419, holdout accuracy: 0.4375
iteration 100, loss 1.2209001779556274
holdout loss: 1.018172025680542, holdout accuracy: 0.5390625
iteration 200, loss 1.1950974464416504
holdout loss: 1.0202817916870117, holdout accuracy: 0.5
iteration 300, loss 1.0770325660705566
holdout loss: 1.0210286378860474, holdout accuracy: 0.515625
iteration 400, loss 1.0813963413238525
holdout loss: 1.016125202178955, holdout accuracy: 0.5078125
iteration 500, loss 1.0078098773956299
holdout loss: 0.9747454524040222, holdout accuracy: 0.578125
iteration 600, loss 1.0136489868164062
holdout loss: 1.005834698677063, holdout accuracy: 0.5234375
iteration 700, loss 1.0542716979980469
holdout loss: 0.9838448762893677, holdout accuracy: 0.5546875
iteration 800, loss 1.0190120935440063
holdout loss: 1.0044915676116943, holdout accuracy: 0.5078125
iteration 900, loss 1.0354106426239014
holdout loss: 0.99691170454025

In [None]:
# load all metrics from the files
all_metrics = []
for pos_encoding in ['sinusoidal', 'onehot']:
    for pos_emb_randomization in ['per_batch', 'only_first', 'per_sequence']:
        for i in range(n_runs):
            fn = f'{results_folder}/ABBA_{pos_encoding}_{pos_emb_randomization}_{i}.csv'
            if not os.path.exists(fn):
                print(f'File {fn} does not exist, skipping')
                continue
            metrics_df = pd.read_csv(fn, index_col=0)
            all_metrics.append(metrics_df)

In [None]:
# also add the condition where there's no position encoding at all
config.model.add_position_code = False
for i in range(n_runs):
    fn = f'{results_folder}/ABBA_no_position_code_{i}.csv'
    if os.path.exists(fn):
        print(f'File {fn} exists, skipping')
        continue
    print(f'Run {i} with no position code')
    metrics = main(config, seq_type='ABBA')
    metrics_df = pd.DataFrame.from_dict({k:v for k,v in metrics.items() if 'accuracy' in k or 'loss' in k}).assign(run=i, position_encoding='no_position_code', pos_emb_randomization='no_position_code')
    all_metrics.append(metrics_df)
    metrics_df.to_csv(fn)

In [None]:
# make a column that's the interaction between position encoding and position emb randomization
all_metrics_df['interaction'] = all_metrics_df['position_encoding'] + '_' + all_metrics_df['pos_emb_randomization']
# seaborn lineplot of the holdout accuracy, averaged over runs, with interaction as hue
sns.lineplot(data=all_metrics_df, x='epoch', y='holdout_accuracy', hue='interaction')