In [6]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim.lr_scheduler import CosineAnnealingLR
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
from itertools import product
from dataclasses import dataclass
from collections import namedtuple, defaultdict
from markov import *
from base_models import *
from pos_encoder import *
from causal_graph import *
from config import *
from train import *
import plot
from util import memory_recall_probe, feedforward_probe
import seaborn as sns
import torch.utils.benchmark as benchmark
from tqdm.notebook import tqdm, trange
import pickle

from head_view import *

%load_ext autoreload
%autoreload 2

In [9]:
def set_configs(seq_len, vocab_size, batch_size, num_epochs=60_000, emb_dim=None, pos_enc='rpe', flash=False):
    if emb_dim is None:
        emb_dim = 6*vocab_size
    
    config = Config(
        emb_dim=emb_dim,
        num_layers=2,
        num_heads=(1,1),
        identity_query=False,
        seq_len=seq_len,
        vocab_size=vocab_size,
        batch_size=batch_size,
        num_epochs=num_epochs,
        eval_iter=500,
        pos_enc=pos_enc,
        pos_max_len=seq_len,
        get_attn=100,
        mlp=(False,False),
        activation=(False,False),
        flash=flash,
        ff_dim=emb_dim,
        layer_norm=False,
        ngram=2,
        learning_rate=2e-4,
        task_name='icl-mc',
        scheduler=False,
    )

    sampler_config = MarkovSamplerConfig(seq_len=seq_len, vocab_size=vocab_size, batch_size=batch_size, order=1, task_name='icl-mc')
    return config, sampler_config

def run_exp(seq_len, voc_size, pos_enc, num_epochs=100_000):
    batch_size = 64
    flash = False
    if pos_enc in ["abs", "rotary"]:
        flash = True

    config, sampler_config = set_configs(seq_len, voc_size, batch_size, pos_enc=pos_enc, flash=flash)
    model = Transformer(config).to(config.device)
    train_model_with_plot(model, config, sampler_config)

#### Experiments

In [None]:
from itertools import product
seq_list = [50, 100, 200]
voc_list = [3, 5, 10]
pos_list = ["rpe", "abs", "rotary"]

for seq_len, voc_size, pos_enc in list(product(seq_list, voc_list, pos_list)):
    print(f"Sequence Length: {seq_len}, vocabulary size: {voc_size}, Positional encoding: {pos_enc}.")
    run_exp(seq_len, voc_size, pos_enc)

Sequence Length: 50, vocabulary size: 3, Positional encoding: rpe.


  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/10000 [00:00<?, ?it/s]

  0%|          | 0/10000 [00:00<?, ?it/s]

  0%|          | 0/10000 [00:00<?, ?it/s]

  0%|          | 0/10000 [00:00<?, ?it/s]

  0%|          | 0/10000 [00:00<?, ?it/s]

  0%|          | 0/10000 [00:00<?, ?it/s]

Loss plot saved at  loss_plots/s50p_rpe_l2h1_1v3icl-mc_20250206_0036.png


Creating images:   0%|          | 0/600 [00:00<?, ?it/s]

GIF saved at attns_plot/s50p_rpe_l2h1v3_L0H0icl-mc_20250206_0038.gif
Folder 'attns' and its contents removed.


Creating images:   0%|          | 0/600 [00:00<?, ?it/s]

GIF saved at attns_plot/s50p_rpe_l2h1v3_L1H0icl-mc_20250206_0039.gif
Folder 'attns' and its contents removed.
Sequence Length: 50, vocabulary size: 3, Positional encoding: abs.


  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/10000 [00:00<?, ?it/s]

  0%|          | 0/10000 [00:00<?, ?it/s]

  0%|          | 0/10000 [00:00<?, ?it/s]

  0%|          | 0/10000 [00:00<?, ?it/s]

  0%|          | 0/10000 [00:00<?, ?it/s]

  0%|          | 0/10000 [00:00<?, ?it/s]

Loss plot saved at  loss_plots/s50p_abs_l2h1_1v3icl-mc_20250206_0044.png


Creating images:   0%|          | 0/600 [00:00<?, ?it/s]

GIF saved at attns_plot/s50p_abs_l2h1v3_L0H0icl-mc_20250206_0046.gif
Folder 'attns' and its contents removed.


Creating images:   0%|          | 0/600 [00:00<?, ?it/s]

GIF saved at attns_plot/s50p_abs_l2h1v3_L1H0icl-mc_20250206_0048.gif
Folder 'attns' and its contents removed.
Sequence Length: 50, vocabulary size: 3, Positional encoding: rotary.


  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/10000 [00:00<?, ?it/s]

  0%|          | 0/10000 [00:00<?, ?it/s]

  0%|          | 0/10000 [00:00<?, ?it/s]

  0%|          | 0/10000 [00:00<?, ?it/s]

  0%|          | 0/10000 [00:00<?, ?it/s]

  0%|          | 0/10000 [00:00<?, ?it/s]

Loss plot saved at  loss_plots/s50p_rotary_l2h1_1v3icl-mc_20250206_0055.png


Creating images:   0%|          | 0/600 [00:00<?, ?it/s]

GIF saved at attns_plot/s50p_rotary_l2h1v3_L0H0icl-mc_20250206_0057.gif
Folder 'attns' and its contents removed.


Creating images:   0%|          | 0/600 [00:00<?, ?it/s]

GIF saved at attns_plot/s50p_rotary_l2h1v3_L1H0icl-mc_20250206_0058.gif
Folder 'attns' and its contents removed.
Sequence Length: 50, vocabulary size: 5, Positional encoding: rpe.


  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/10000 [00:00<?, ?it/s]

  0%|          | 0/10000 [00:00<?, ?it/s]

  0%|          | 0/10000 [00:00<?, ?it/s]

  0%|          | 0/10000 [00:00<?, ?it/s]

  0%|          | 0/10000 [00:00<?, ?it/s]

  0%|          | 0/10000 [00:00<?, ?it/s]

Loss plot saved at  loss_plots/s50p_rpe_l2h1_1v5icl-mc_20250206_0106.png


Creating images:   0%|          | 0/600 [00:00<?, ?it/s]

GIF saved at attns_plot/s50p_rpe_l2h1v5_L0H0icl-mc_20250206_0108.gif
Folder 'attns' and its contents removed.


Creating images:   0%|          | 0/600 [00:00<?, ?it/s]

GIF saved at attns_plot/s50p_rpe_l2h1v5_L1H0icl-mc_20250206_0109.gif
Folder 'attns' and its contents removed.
Sequence Length: 50, vocabulary size: 5, Positional encoding: abs.


  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/10000 [00:00<?, ?it/s]

  0%|          | 0/10000 [00:00<?, ?it/s]

  0%|          | 0/10000 [00:00<?, ?it/s]

  0%|          | 0/10000 [00:00<?, ?it/s]

  0%|          | 0/10000 [00:00<?, ?it/s]

  0%|          | 0/10000 [00:00<?, ?it/s]

Loss plot saved at  loss_plots/s50p_abs_l2h1_1v5icl-mc_20250206_0114.png


Creating images:   0%|          | 0/600 [00:00<?, ?it/s]

GIF saved at attns_plot/s50p_abs_l2h1v5_L0H0icl-mc_20250206_0116.gif
Folder 'attns' and its contents removed.


Creating images:   0%|          | 0/600 [00:00<?, ?it/s]

GIF saved at attns_plot/s50p_abs_l2h1v5_L1H0icl-mc_20250206_0118.gif
Folder 'attns' and its contents removed.
Sequence Length: 50, vocabulary size: 5, Positional encoding: rotary.


  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/10000 [00:00<?, ?it/s]

  0%|          | 0/10000 [00:00<?, ?it/s]

  0%|          | 0/10000 [00:00<?, ?it/s]

  0%|          | 0/10000 [00:00<?, ?it/s]

  0%|          | 0/10000 [00:00<?, ?it/s]

  0%|          | 0/10000 [00:00<?, ?it/s]

Loss plot saved at  loss_plots/s50p_rotary_l2h1_1v5icl-mc_20250206_0124.png


Creating images:   0%|          | 0/600 [00:00<?, ?it/s]

GIF saved at attns_plot/s50p_rotary_l2h1v5_L0H0icl-mc_20250206_0126.gif
Folder 'attns' and its contents removed.


Creating images:   0%|          | 0/600 [00:00<?, ?it/s]

GIF saved at attns_plot/s50p_rotary_l2h1v5_L1H0icl-mc_20250206_0128.gif
Folder 'attns' and its contents removed.
Sequence Length: 50, vocabulary size: 10, Positional encoding: rpe.


  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/10000 [00:00<?, ?it/s]

  0%|          | 0/10000 [00:00<?, ?it/s]

  0%|          | 0/10000 [00:00<?, ?it/s]

  0%|          | 0/10000 [00:00<?, ?it/s]

  0%|          | 0/10000 [00:00<?, ?it/s]

  0%|          | 0/10000 [00:00<?, ?it/s]

Loss plot saved at  loss_plots/s50p_rpe_l2h1_1v10icl-mc_20250206_0136.png


Creating images:   0%|          | 0/600 [00:00<?, ?it/s]

GIF saved at attns_plot/s50p_rpe_l2h1v10_L0H0icl-mc_20250206_0137.gif
Folder 'attns' and its contents removed.


Creating images:   0%|          | 0/600 [00:00<?, ?it/s]

GIF saved at attns_plot/s50p_rpe_l2h1v10_L1H0icl-mc_20250206_0139.gif
Folder 'attns' and its contents removed.
Sequence Length: 50, vocabulary size: 10, Positional encoding: abs.


  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/10000 [00:00<?, ?it/s]

  0%|          | 0/10000 [00:00<?, ?it/s]

  0%|          | 0/10000 [00:00<?, ?it/s]

  0%|          | 0/10000 [00:00<?, ?it/s]

  0%|          | 0/10000 [00:00<?, ?it/s]

  0%|          | 0/10000 [00:00<?, ?it/s]

Loss plot saved at  loss_plots/s50p_abs_l2h1_1v10icl-mc_20250206_0143.png


Creating images:   0%|          | 0/600 [00:00<?, ?it/s]

GIF saved at attns_plot/s50p_abs_l2h1v10_L0H0icl-mc_20250206_0145.gif
Folder 'attns' and its contents removed.


Creating images:   0%|          | 0/600 [00:00<?, ?it/s]

GIF saved at attns_plot/s50p_abs_l2h1v10_L1H0icl-mc_20250206_0147.gif
Folder 'attns' and its contents removed.
Sequence Length: 50, vocabulary size: 10, Positional encoding: rotary.


  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/10000 [00:00<?, ?it/s]

  0%|          | 0/10000 [00:00<?, ?it/s]

  0%|          | 0/10000 [00:00<?, ?it/s]

  0%|          | 0/10000 [00:00<?, ?it/s]

  0%|          | 0/10000 [00:00<?, ?it/s]

  0%|          | 0/10000 [00:00<?, ?it/s]

Loss plot saved at  loss_plots/s50p_rotary_l2h1_1v10icl-mc_20250206_0152.png


Creating images:   0%|          | 0/600 [00:00<?, ?it/s]

GIF saved at attns_plot/s50p_rotary_l2h1v10_L0H0icl-mc_20250206_0154.gif
Folder 'attns' and its contents removed.


Creating images:   0%|          | 0/600 [00:00<?, ?it/s]

GIF saved at attns_plot/s50p_rotary_l2h1v10_L1H0icl-mc_20250206_0156.gif
Folder 'attns' and its contents removed.
Sequence Length: 100, vocabulary size: 3, Positional encoding: rpe.


  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/10000 [00:00<?, ?it/s]

  0%|          | 0/10000 [00:00<?, ?it/s]

  0%|          | 0/10000 [00:00<?, ?it/s]

  0%|          | 0/10000 [00:00<?, ?it/s]

  0%|          | 0/10000 [00:00<?, ?it/s]

  0%|          | 0/10000 [00:00<?, ?it/s]

Loss plot saved at  loss_plots/s100p_rpe_l2h1_1v3icl-mc_20250206_0204.png


Creating images:   0%|          | 0/600 [00:00<?, ?it/s]

GIF saved at attns_plot/s100p_rpe_l2h1v3_L0H0icl-mc_20250206_0206.gif
Folder 'attns' and its contents removed.


Creating images:   0%|          | 0/600 [00:00<?, ?it/s]

GIF saved at attns_plot/s100p_rpe_l2h1v3_L1H0icl-mc_20250206_0208.gif
Folder 'attns' and its contents removed.
Sequence Length: 100, vocabulary size: 3, Positional encoding: abs.


  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/10000 [00:00<?, ?it/s]

  0%|          | 0/10000 [00:00<?, ?it/s]

  0%|          | 0/10000 [00:00<?, ?it/s]

  0%|          | 0/10000 [00:00<?, ?it/s]

  0%|          | 0/10000 [00:00<?, ?it/s]

  0%|          | 0/10000 [00:00<?, ?it/s]

Loss plot saved at  loss_plots/s100p_abs_l2h1_1v3icl-mc_20250206_0213.png


Creating images:   0%|          | 0/600 [00:00<?, ?it/s]

GIF saved at attns_plot/s100p_abs_l2h1v3_L0H0icl-mc_20250206_0215.gif
Folder 'attns' and its contents removed.


Creating images:   0%|          | 0/600 [00:00<?, ?it/s]

GIF saved at attns_plot/s100p_abs_l2h1v3_L1H0icl-mc_20250206_0217.gif
Folder 'attns' and its contents removed.
Sequence Length: 100, vocabulary size: 3, Positional encoding: rotary.


  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/10000 [00:00<?, ?it/s]

  0%|          | 0/10000 [00:00<?, ?it/s]

  0%|          | 0/10000 [00:00<?, ?it/s]

  0%|          | 0/10000 [00:00<?, ?it/s]

  0%|          | 0/10000 [00:00<?, ?it/s]

  0%|          | 0/10000 [00:00<?, ?it/s]

Loss plot saved at  loss_plots/s100p_rotary_l2h1_1v3icl-mc_20250206_0223.png


Creating images:   0%|          | 0/600 [00:00<?, ?it/s]

GIF saved at attns_plot/s100p_rotary_l2h1v3_L0H0icl-mc_20250206_0225.gif
Folder 'attns' and its contents removed.


Creating images:   0%|          | 0/600 [00:00<?, ?it/s]

GIF saved at attns_plot/s100p_rotary_l2h1v3_L1H0icl-mc_20250206_0227.gif
Folder 'attns' and its contents removed.
Sequence Length: 100, vocabulary size: 5, Positional encoding: rpe.


  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/10000 [00:00<?, ?it/s]

  0%|          | 0/10000 [00:00<?, ?it/s]

  0%|          | 0/10000 [00:00<?, ?it/s]

  0%|          | 0/10000 [00:00<?, ?it/s]

  0%|          | 0/10000 [00:00<?, ?it/s]

  0%|          | 0/10000 [00:00<?, ?it/s]

Loss plot saved at  loss_plots/s100p_rpe_l2h1_1v5icl-mc_20250206_0235.png


Creating images:   0%|          | 0/600 [00:00<?, ?it/s]

GIF saved at attns_plot/s100p_rpe_l2h1v5_L0H0icl-mc_20250206_0237.gif
Folder 'attns' and its contents removed.


Creating images:   0%|          | 0/600 [00:00<?, ?it/s]

GIF saved at attns_plot/s100p_rpe_l2h1v5_L1H0icl-mc_20250206_0239.gif
Folder 'attns' and its contents removed.
Sequence Length: 100, vocabulary size: 5, Positional encoding: abs.


  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/10000 [00:00<?, ?it/s]

  0%|          | 0/10000 [00:00<?, ?it/s]

  0%|          | 0/10000 [00:00<?, ?it/s]

  0%|          | 0/10000 [00:00<?, ?it/s]

  0%|          | 0/10000 [00:00<?, ?it/s]

  0%|          | 0/10000 [00:00<?, ?it/s]

Loss plot saved at  loss_plots/s100p_abs_l2h1_1v5icl-mc_20250206_0244.png


Creating images:   0%|          | 0/600 [00:00<?, ?it/s]

GIF saved at attns_plot/s100p_abs_l2h1v5_L0H0icl-mc_20250206_0246.gif
Folder 'attns' and its contents removed.


Creating images:   0%|          | 0/600 [00:00<?, ?it/s]

GIF saved at attns_plot/s100p_abs_l2h1v5_L1H0icl-mc_20250206_0248.gif
Folder 'attns' and its contents removed.
Sequence Length: 100, vocabulary size: 5, Positional encoding: rotary.


  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/10000 [00:00<?, ?it/s]

  0%|          | 0/10000 [00:00<?, ?it/s]

  0%|          | 0/10000 [00:00<?, ?it/s]

  0%|          | 0/10000 [00:00<?, ?it/s]

  0%|          | 0/10000 [00:00<?, ?it/s]

  0%|          | 0/10000 [00:00<?, ?it/s]

Loss plot saved at  loss_plots/s100p_rotary_l2h1_1v5icl-mc_20250206_0254.png


Creating images:   0%|          | 0/600 [00:00<?, ?it/s]

GIF saved at attns_plot/s100p_rotary_l2h1v5_L0H0icl-mc_20250206_0257.gif
Folder 'attns' and its contents removed.


Creating images:   0%|          | 0/600 [00:00<?, ?it/s]

GIF saved at attns_plot/s100p_rotary_l2h1v5_L1H0icl-mc_20250206_0300.gif
Folder 'attns' and its contents removed.
Sequence Length: 100, vocabulary size: 10, Positional encoding: rpe.


  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/10000 [00:00<?, ?it/s]

  0%|          | 0/10000 [00:00<?, ?it/s]

  0%|          | 0/10000 [00:00<?, ?it/s]

  0%|          | 0/10000 [00:00<?, ?it/s]

  0%|          | 0/10000 [00:00<?, ?it/s]

  0%|          | 0/10000 [00:00<?, ?it/s]

Loss plot saved at  loss_plots/s100p_rpe_l2h1_1v10icl-mc_20250206_0307.png


Creating images:   0%|          | 0/600 [00:00<?, ?it/s]

GIF saved at attns_plot/s100p_rpe_l2h1v10_L0H0icl-mc_20250206_0309.gif
Folder 'attns' and its contents removed.


Creating images:   0%|          | 0/600 [00:00<?, ?it/s]

GIF saved at attns_plot/s100p_rpe_l2h1v10_L1H0icl-mc_20250206_0311.gif
Folder 'attns' and its contents removed.
Sequence Length: 100, vocabulary size: 10, Positional encoding: abs.


  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/10000 [00:00<?, ?it/s]

  0%|          | 0/10000 [00:00<?, ?it/s]

  0%|          | 0/10000 [00:00<?, ?it/s]

  0%|          | 0/10000 [00:00<?, ?it/s]

  0%|          | 0/10000 [00:00<?, ?it/s]

  0%|          | 0/10000 [00:00<?, ?it/s]

Loss plot saved at  loss_plots/s100p_abs_l2h1_1v10icl-mc_20250206_0315.png


Creating images:   0%|          | 0/600 [00:00<?, ?it/s]

GIF saved at attns_plot/s100p_abs_l2h1v10_L0H0icl-mc_20250206_0317.gif
Folder 'attns' and its contents removed.


Creating images:   0%|          | 0/600 [00:00<?, ?it/s]

GIF saved at attns_plot/s100p_abs_l2h1v10_L1H0icl-mc_20250206_0319.gif
Folder 'attns' and its contents removed.
Sequence Length: 100, vocabulary size: 10, Positional encoding: rotary.


  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/10000 [00:00<?, ?it/s]

  0%|          | 0/10000 [00:00<?, ?it/s]

  0%|          | 0/10000 [00:00<?, ?it/s]

  0%|          | 0/10000 [00:00<?, ?it/s]

  0%|          | 0/10000 [00:00<?, ?it/s]

  0%|          | 0/10000 [00:00<?, ?it/s]

Loss plot saved at  loss_plots/s100p_rotary_l2h1_1v10icl-mc_20250206_0325.png


Creating images:   0%|          | 0/600 [00:00<?, ?it/s]

GIF saved at attns_plot/s100p_rotary_l2h1v10_L0H0icl-mc_20250206_0332.gif
Folder 'attns' and its contents removed.


Creating images:   0%|          | 0/600 [00:00<?, ?it/s]

GIF saved at attns_plot/s100p_rotary_l2h1v10_L1H0icl-mc_20250206_0334.gif
Folder 'attns' and its contents removed.
Sequence Length: 200, vocabulary size: 3, Positional encoding: rpe.


  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/10000 [00:00<?, ?it/s]

  0%|          | 0/10000 [00:00<?, ?it/s]

  0%|          | 0/10000 [00:00<?, ?it/s]

  0%|          | 0/10000 [00:00<?, ?it/s]

  0%|          | 0/10000 [00:00<?, ?it/s]

  0%|          | 0/10000 [00:00<?, ?it/s]

Loss plot saved at  loss_plots/s200p_rpe_l2h1_1v3icl-mc_20250206_0344.png


Creating images:   0%|          | 0/600 [00:00<?, ?it/s]

GIF saved at attns_plot/s200p_rpe_l2h1v3_L0H0icl-mc_20250206_0346.gif
Folder 'attns' and its contents removed.


Creating images:   0%|          | 0/600 [00:00<?, ?it/s]

GIF saved at attns_plot/s200p_rpe_l2h1v3_L1H0icl-mc_20250206_0350.gif
Folder 'attns' and its contents removed.
Sequence Length: 200, vocabulary size: 3, Positional encoding: abs.


  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/10000 [00:00<?, ?it/s]

  0%|          | 0/10000 [00:00<?, ?it/s]

  0%|          | 0/10000 [00:00<?, ?it/s]

  0%|          | 0/10000 [00:00<?, ?it/s]

  0%|          | 0/10000 [00:00<?, ?it/s]

  0%|          | 0/10000 [00:00<?, ?it/s]

Loss plot saved at  loss_plots/s200p_abs_l2h1_1v3icl-mc_20250206_0355.png


Creating images:   0%|          | 0/600 [00:00<?, ?it/s]