# Base Models

In [1]:
# ! pip install torch --upgrade

In [2]:
# https://wandb.ai/kilianovski/misha-iml/runs/kn69jraw


checkpoint_path = '../models/transformers/pretrained_1L_dmodel=1024_attnonly=False20240718_200143.pt'
checkpoint_path = '../models/transformers/grokking_prod_120_1_0.1_attnonly_False20240712_133838.pt'


transformer_config = dict(
    d_vocab=512,
    n_layers=1,
    d_model=1024,
    d_head=128,
    n_heads=4,
    d_mlp=256,
    n_ctx=5,
    act_fn="relu",  # gelu?
    normalization_type='LN',
    attn_only=False,
)

In [3]:
from argparse import Namespace
from dataclasses import dataclass, asdict
seed = 2




@dataclass
class TrainParams:
    n_steps: int = int(1e8)
    batch_size: int = 128
    lr: float = 1e-4
    wd: float = 1e-1
    betas: tuple = (0.9, 0.98)
    max_grad_norm: float = 1.0
    num_epochs_X1: int = 500
    num_epochs_X2: int = 500
    prop_orig: float = 0.25
    orig_held_out_frac: float = 0.01
    swap_defs: bool = False # whether to swap the order of the defs
    val_questions: int = 9



# Setup

In [4]:
import sys
sys.path.append('..')

In [5]:
from data import create_datasets, seed_all, DataParams, Tokens, OOCL_Dataset, make_tbl_mask, create_orig_data, yield_data

In [6]:
import logging
import torch
from dataclasses import dataclass, asdict
import numpy as np
import time
import os
from tqdm.auto import tqdm
from pathlib import Path
import itertools
import sys
import random
import torch.nn.functional as F
from torch.utils.data import random_split, TensorDataset, DataLoader, Dataset
import argparse
from transformer_lens import HookedTransformer, HookedTransformerConfig
import wandb
from dotenv import load_dotenv
from sympy import factorint
from itertools import product
from math import prod

import os
import random
import numpy as np
import torch
from tqdm.auto import tqdm

In [7]:
def get_device():
    #return 'cpu'
    if torch.cuda.is_available():
        return "cuda"
    # elif torch.backends.mps.is_available():
    #     return "mps"
    else:
        return "cpu"

In [8]:

def loss_fn(logits, tokens):

    # check whether question or def and compute loss appropriately
    # logit shape: [batch, pos, vocab]
    # token shape: [batch, pos]

    mask = (tokens[:, 3] == 2*DataParams.mod + Tokens.padding)

    def_logits = logits[mask]
    def_tokens = tokens[mask].long()

    q_logits = logits[~mask]
    q_tokens = tokens[~mask].long()

    def_logits = def_logits[:, 1].unsqueeze(1)
    def_tokens = def_tokens[:, 2].unsqueeze(1)
    def_log_probs = def_logits.log_softmax(-1)
    def_correct_log_probs = def_log_probs.gather(-1, def_tokens[..., None])[..., 0]
    
    q_logits = q_logits[:, 2].unsqueeze(1)
    q_tokens = q_tokens[:, 3].unsqueeze(1)
    q_log_probs = q_logits.log_softmax(-1)
    q_correct_log_probs = q_log_probs.gather(-1, q_tokens[..., None])[..., 0]

    return -(def_correct_log_probs.sum() + q_correct_log_probs.sum())/(def_correct_log_probs.shape[0] + q_correct_log_probs.shape[0])


In [9]:
def evaluate(model, val_loader, device):

    correct = 0
    loss = 0.
    total = 0
    batches = 0

    for batch in val_loader:
        inputs = batch[0].to(device)

        labels = inputs[:, -1]

        with torch.no_grad():
            output = model(inputs)
            loss += loss_fn(output, inputs).item()
            correct += (torch.argmax(output[:, -2, :], dim=1) == labels).sum()

        total += inputs.shape[0]
        batches += 1

    acc = correct / total
    loss = loss/batches
    return acc, loss


# Train

In [10]:
# int_by_set = {
#     'DtQ1': [57, 103, 99, 3, 111, 59, 73, 22, 30, 25, 47, 69, 23, 67, 75, 16, 85, 29, 2, 76, 8, 107, 43, 84, 98, 44, 46, 115, 80, 37], 
#     'DfQ2': [50, 72, 118, 20, 93, 10, 52, 14, 83, 6, 28, 15, 34, 48, 114, 104, 88, 13, 91, 54, 112, 58, 102, 95, 21, 24, 19, 94, 35, 109], 
#     'Dt3': [4, 81, 82, 41, 31, 86, 63, 0, 110, 11, 1, 92, 7, 116, 66, 56, 119, 70, 26, 78, 40, 55, 105, 89, 71, 60, 42, 87, 9, 117], 
#     'Df4': [39, 18, 77, 90, 68, 32, 79, 12, 96, 101, 36, 17, 64, 27, 74, 45, 61, 38, 106, 100, 51, 62, 65, 33, 5, 53, 113, 97, 49, 108]
# }


# int_by_set = {
#     'DtQ1': [119, 81, 20, 90, 68, 41, 4, 79, 38, 10, 14, 95, 22, 78, 114, 71, 73, 52, 94, 9, 82, 116, 96, 93, 39, 36, 105, 50, 16, 33], 
#     'DfQ2': [5, 30, 19, 59, 74, 24, 104, 21, 18, 51, 42, 61, 65, 84, 64, 35, 113, 11, 66, 80, 112, 7, 31, 98, 43, 6, 25, 45, 117, 47], 
#     'Dt3': [99, 46, 88, 23, 103, 53, 86, 37, 58, 76, 118, 44, 91, 70, 111, 56, 28, 67, 85, 54, 27, 106, 1, 69, 107, 87, 2, 101, 40, 13], 
#     'Df4': [75, 29, 92, 34, 109, 89, 0, 110, 77, 55, 49, 3, 62, 12, 26, 100, 48, 83, 60, 57, 115, 63, 15, 32, 8, 97, 102, 108, 72, 17]
# }

In [11]:
try:
    from prettytable import PrettyTable
except:
    ! pip install -q prettytable
    from prettytable import PrettyTable

def count_parameters(model):
    from prettytable import PrettyTable

    table = PrettyTable(["Modules", "Parameters"])
    total_params = 0
    for name, parameter in model.named_parameters():
        if not parameter.requires_grad: continue
        params = parameter.numel()
        table.add_row([name, params])
        total_params+=params
    print(table)
    print(f"Total Trainable Params: {total_params}")
    return total_params

In [12]:
# noisy_labels = {119+n for n in int_by_set['DfQ2']}

# def randomize_noisy_labels(tokens):
#     noisy_label_mask = torch.tensor(list(map(lambda x: any(int(l) in noisy_labels for l in x) , tokens[:, :2])))
#     noisy_label_mask = noisy_label_mask & (tokens[:, 0] != 242)
#     tokens[noisy_label_mask][:, -1] = torch.randint(0, 120*2, size=(len(tokens[noisy_label_mask]),))

In [13]:
def train_iml(args, checkpoint_path=checkpoint_path):
    mod = DataParams.mod


    def get_grad_cos_sims(numbers):
        cos_sims = []
        for n in numbers:
            
            definition = definitions[definitions[:, 1] == (119+n)]
        
            question_mask = (questions_X2[:, 0] == (119+n)) | (questions_X2[:, 1] == (119+n))
            questions = questions_X2[question_mask]
        
            d_grads = get_flat_grad(model, definition)
            q_grads = get_flat_grad(model, questions)
        
            cos_sim = F.cosine_similarity(d_grads, q_grads, dim=0)
        
            cos_sims.append(cos_sim)
        return cos_sims

    # ******* INIT MODEL ***********
    seed = args.seed
    seed_all(args.seed)
    new_transformer_config = transformer_config
    new_transformer_config.update(dict(
        d_vocab=2*mod + 4,  # 3 special tokens + mod vars
    ))
    new_cfg = HookedTransformerConfig(**new_transformer_config)
    new_model = HookedTransformer(new_cfg)
    state_dict = torch.load(checkpoint_path)
    new_model.load_state_dict(state_dict)
    new_model.to(get_device())
    
    model = new_model

    # ********* END MODEL

    # ********* DATA
    seed_all(seed)
    
    
    # divide the integers into 4 equally sized sets
    size = mod // 4
    rem = mod % 4
    
    numbers = list(range(DataParams.mod))
    random.shuffle(numbers)
    
    train_params = TrainParams()
        
    int_by_set = {}
    int_by_set['DtQ1'] = numbers[0:size]
    int_by_set['DfQ2'] = numbers[size:2*size]
    int_by_set['Dt3'] = numbers[2*size:3*size]
    int_by_set['Df4'] = numbers[3*size:mod]
    
    train_sets, test_sets = create_datasets(int_by_set)
    orig_args = make_tbl_mask(mod=DataParams.mod, method='prod', frac_held_out=train_params.orig_held_out_frac)
    
    
    batch_size = train_params.batch_size
    
    # unpack orig_args for use in valid_loader
    
    x_vv, y_vv, z_vv, train_vv, valid_vv = orig_args
    
    device = get_device()
    
    X1_dataset = OOCL_Dataset(train_sets['X1'], create_orig_data, orig_args, train_params.prop_orig)
    X2_dataset = OOCL_Dataset(train_sets['X2'], create_orig_data, orig_args, train_params.prop_orig)
    
    X1_loader = DataLoader(X1_dataset, batch_size=batch_size, shuffle=True)
    X2_loader = DataLoader(X2_dataset, batch_size=batch_size, shuffle=True)
    
    orig_data_valid_loader = yield_data(train_params.batch_size, x_vv, y_vv, z_vv, valid_vv)
    
    test_set_loaders = {}
    
    for s in test_sets:
        test_set_loaders[s] = DataLoader(TensorDataset(test_sets[s].to(dtype=torch.int)), batch_size=train_params.batch_size, shuffle=False)
    
    
    questions_X2 = []
    
    for batch in test_set_loaders['Dt3']:
        questions_X2.append(batch[0])
    
    
    for batch in test_set_loaders['Df4']:
        questions_X2.append(batch[0])
    
    
    questions_X2 = torch.cat(questions_X2, dim=0)
    
    
    definitions = []
    
    for tokens in X2_loader:
        tokens = tokens.squeeze(1)
        tokens = tokens.to(device)
        definitions.append(tokens)
    
    definitions = torch.cat(definitions)
    
    
    def get_flat_grad(model, tokens):
        for p in model.parameters():
            p.requires_grad_(True)
            p.grad = None
        
        logits = model(tokens)
        loss = loss_fn(logits, tokens)
        loss.backward()
    
        grads = []
        for p in model.parameters():
            grads.append(p.grad.detach().flatten())
        grads = torch.cat(grads)
    
        return grads
    
    
    
    def get_metrics(model):
        val_acc_DtQ1, val_loss_DtQ1 = evaluate(model, test_set_loaders['DtQ1'], device)
        val_acc_DfQ2, val_loss_DfQ2 = evaluate(model, test_set_loaders['DfQ2'], device)
        val_acc_Dt3, val_loss_Dt3 = evaluate(model, test_set_loaders['Dt3'], device)
        val_acc_Df4, val_loss_Df4 = evaluate(model, test_set_loaders['Df4'], device)
        with torch.no_grad():
            # logging.info(tokens)
            tokens = next(orig_data_valid_loader)
            tokens = tokens.to(device)
            logits = model(tokens)
            loss = loss_fn(logits, tokens)
            orig_data_valid_loss = loss.item()
        metrics = {
                        # "train/loss": train_loss,
                        "valid_DtQ1/acc": val_acc_DtQ1,
                        "valid_DfQ2/acc": val_acc_DfQ2,
                        "valid_DtQ1/loss": val_loss_DtQ1,
                        "valid_DfQ2/loss": val_loss_DfQ2,
        
                        "valid_Dt3/acc": val_acc_Dt3,
                        "valid_Df4/acc": val_acc_Df4,
                        "valid_Dt3/loss": val_loss_Dt3,
                        "valid_Df4/loss": val_loss_Df4,
        
                        "val/loss": (val_loss_DtQ1+val_loss_DfQ2)/2,
                        "orig_data_valid_loss": orig_data_valid_loss
                    }
        return metrics
    
    
    for tokens in X1_loader:
        tokens = tokens.squeeze(1)
        tokens = tokens.to(device)
        break

    # ********* END DATA
    
    wandb.init(
        project="misha-iml",
        group=args.wandb_group_name,
        name=args.wandb_experiment_name,
        config={
            **asdict(DataParams()),
            **asdict(train_params),
            **transformer_config,
        }
    )
    print(f'seed={args.seed}')
    print('int_by_set')
    print(int_by_set)
    print('loaded from', checkpoint_path)
    count_parameters(model)
    
    optimizer = torch.optim.AdamW(model.parameters(), lr=train_params.lr, betas=train_params.betas, weight_decay=train_params.wd)
    # optimizer = torch.optim.SGD(model.parameters(), lr=train_params.lr, weight_decay=train_params.wd)
    losses = []
    
    pbar = tqdm(range(train_params.num_epochs_X1))
    for epoch in pbar:
        model.train()
        for tokens in X1_loader:
            tokens = tokens.squeeze(1)
            tokens = tokens.to(device)
            
            # randomize_noisy_labels(tokens)
            
            logits = model(tokens)
            loss = loss_fn(logits, tokens)
            loss.backward()
            if train_params.max_grad_norm is not None:
                torch.nn.utils.clip_grad_norm_(model.parameters(), train_params.max_grad_norm)
            optimizer.step()
            optimizer.zero_grad()
            losses.append(loss.item())
    
    
        train_loss = np.mean(losses)
        model.eval()
        metrics = get_metrics(model)
        metrics['train_loss'] = train_loss
        pbar.set_description(f'train_Loss={train_loss.item():.3f}')
    
        if wandb.run is not None:
            wandb.log(metrics)
    
            if epoch % 100 == 0:
                cos_sim_Dt3 = np.mean(get_grad_cos_sims(int_by_set['Dt3']))
                cos_sim_Df4 = np.mean(get_grad_cos_sims(int_by_set['Df4']))
                
                wandb.log({
                    'grad_cos_sim_Dt3': cos_sim_Dt3,
                    'grad_cos_sim_Df4': cos_sim_Df4,
                })
    
        
    
    checkpoint_path = Path(checkpoint_path)
    new_checkpoint_path = checkpoint_path.parent / ('stage1__'+checkpoint_path.name)
    
    torch.save(model.state_dict(), new_checkpoint_path)
    print(f'saved to {new_checkpoint_path}')
    
    optimizer = torch.optim.AdamW(model.parameters(), lr=train_params.lr, betas=train_params.betas, weight_decay=train_params.wd)
    # optimizer = torch.optim.SGD(model.parameters(), lr=train_params.lr, weight_decay=train_params.wd)
    losses = []
    
    for epoch in tqdm(range(train_params.num_epochs_X2)):
        model.train()
        for tokens in X2_loader:
            tokens = tokens.squeeze(1)
            tokens = tokens.to(device)
            logits = model(tokens)
            loss = loss_fn(logits, tokens)
            loss.backward()
            
            # model.W_E.grad[:120] = 0
    
            if train_params.max_grad_norm is not None:
                torch.nn.utils.clip_grad_norm_(model.parameters(), train_params.max_grad_norm)
            optimizer.step()
            optimizer.zero_grad()
            losses.append(loss.item())
    
    
        train_loss = np.mean(losses)
        model.eval()
        metrics = get_metrics(model)
        metrics['train_loss'] = train_loss
        if wandb.run is not None:
            wandb.log(metrics)
    
    
            if epoch % 100 == 0:
                cos_sim_Dt3 = np.mean(get_grad_cos_sims(int_by_set['Dt3']))
                cos_sim_Df4 = np.mean(get_grad_cos_sims(int_by_set['Df4']))
                
                wandb.log({
                    'grad_cos_sim_Dt3': cos_sim_Dt3,
                    'grad_cos_sim_Df4': cos_sim_Df4,
                })
    
    
    
    checkpoint_path = Path(checkpoint_path)
    new_checkpoint_path = checkpoint_path.parent / ('stage2__'+checkpoint_path.name)
    
    torch.save(model.state_dict(), new_checkpoint_path)
    print(f'saved to {new_checkpoint_path}')

In [None]:
for seed in [0,1,2,3,4]:
    args = Namespace(
        model_path='./models/transformers/', 
        # model_name='grokking_prod_120_2_0.1_attnonly_True20240709_180213.pt', 
        wandb_group_name='gradalignm_1024_1L_s1=500',
        wandb_experiment_name=f'seed={seed}',
        saved_model_name=None,
        seed=seed, 
        save_steps=[500, 950])


    train_iml(args)

Moving model to device:  cpu


  result_tensor = torch.tensor(Z).view(N, 1)
[34m[1mwandb[0m: Currently logged in as: [33mkilianovski[0m. Use [1m`wandb login --relogin`[0m to force relogin


seed=0
int_by_set
{'DtQ1': [57, 103, 99, 3, 111, 59, 73, 22, 30, 25, 47, 69, 23, 67, 75, 16, 85, 29, 2, 76, 8, 107, 43, 84, 98, 44, 46, 115, 80, 37], 'DfQ2': [50, 72, 118, 20, 93, 10, 52, 14, 83, 6, 28, 15, 34, 48, 114, 104, 88, 13, 91, 54, 112, 58, 102, 95, 21, 24, 19, 94, 35, 109], 'Dt3': [4, 81, 82, 41, 31, 86, 63, 0, 110, 11, 1, 92, 7, 116, 66, 56, 119, 70, 26, 78, 40, 55, 105, 89, 71, 60, 42, 87, 9, 117], 'Df4': [39, 18, 77, 90, 68, 32, 79, 12, 96, 101, 36, 17, 64, 27, 74, 45, 61, 38, 106, 100, 51, 62, 65, 33, 5, 53, 113, 97, 49, 108]}
loaded from ../models/transformers/grokking_prod_120_1_0.1_attnonly_False20240712_133838.pt
+--------------------+------------+
|      Modules       | Parameters |
+--------------------+------------+
|     embed.W_E      |   249856   |
|  pos_embed.W_pos   |    5120    |
|   blocks.0.ln1.w   |    1024    |
|   blocks.0.ln1.b   |    1024    |
|   blocks.0.ln2.w   |    1024    |
|   blocks.0.ln2.b   |    1024    |
| blocks.0.attn.W_Q  |   524288   |
|

  0%|          | 0/500 [00:00<?, ?it/s]

saved to ../models/transformers/stage1__grokking_prod_120_1_0.1_attnonly_False20240712_133838.pt


  0%|          | 0/500 [00:00<?, ?it/s]

saved to ../models/transformers/stage2__grokking_prod_120_1_0.1_attnonly_False20240712_133838.pt
Moving model to device:  cpu


  result_tensor = torch.tensor(Z).view(N, 1)


VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
grad_cos_sim_Df4,▇█▆▅▇▇▃▆▁▆
grad_cos_sim_Dt3,▂▇▆▇█▇▂▂▁▂
orig_data_valid_loss,▅▂▁▂▁▁▁▁▁▁▁▁▁▇▃▂▁▁▁▁█▁▁▁▁▁▁▁▁▃▁▇▁▁▂▁▁▇▁▁
train_loss,▇▃▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁█▃▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁
val/loss,█▅▄▃▄▃▃▃▄▃▃▃▂▃▂▂▂▁▂▂▂▂▁▁▂▁▁▃▂▃▃▂▂▂▂▂▃▃▂▂
valid_Df4/acc,▂▁▁▂▃▃▂▂▂▂▂▂▂▁▂▂▃▂▂▂▅▄▄▅▆▅▆▅▅▅▇▆▆▆▇█▇▇█▇
valid_Df4/loss,▁▅▅▅▆▆▆▆▆▆▇▇▇▇▇▇█▇▇█▅▅▅▄▅▅▅▄▄▄▄▄▄▄▄▄▄▄▄▄
valid_DfQ2/acc,▁▆▅▆▆▇▆▇▇▇▇▇▇▇██▇███▇▇███▇▇▇▇▇▇▇▇▇█▇▇▇▇▇
valid_DfQ2/loss,█▆▅▄▅▄▅▅▆▅▄▄▃▄▃▃▃▁▂▂▃▁▁▁▁▁▁▂▂▄▃▃▂▂▂▂▃▃▂▂
valid_Dt3/acc,▁▁▂▁▂▂▂▂▂▂▂▃▂▂▂▂▂▂▂▂▃▅▆▆▆▆▆▅▆▆▆▇█▇▇▇████

0,1
grad_cos_sim_Df4,0.13465
grad_cos_sim_Dt3,0.05688
orig_data_valid_loss,0.00271
train_loss,0.12236
val/loss,1.06678
valid_Df4/acc,0.18611
valid_Df4/loss,8.30865
valid_DfQ2/acc,0.68148
valid_DfQ2/loss,1.71623
valid_Dt3/acc,0.23889


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.011167720833327621, max=1.0…

seed=1
int_by_set
{'DtQ1': [119, 81, 20, 90, 68, 41, 4, 79, 38, 10, 14, 95, 22, 78, 114, 71, 73, 52, 94, 9, 82, 116, 96, 93, 39, 36, 105, 50, 16, 33], 'DfQ2': [5, 30, 19, 59, 74, 24, 104, 21, 18, 51, 42, 61, 65, 84, 64, 35, 113, 11, 66, 80, 112, 7, 31, 98, 43, 6, 25, 45, 117, 47], 'Dt3': [99, 46, 88, 23, 103, 53, 86, 37, 58, 76, 118, 44, 91, 70, 111, 56, 28, 67, 85, 54, 27, 106, 1, 69, 107, 87, 2, 101, 40, 13], 'Df4': [75, 29, 92, 34, 109, 89, 0, 110, 77, 55, 49, 3, 62, 12, 26, 100, 48, 83, 60, 57, 115, 63, 15, 32, 8, 97, 102, 108, 72, 17]}
loaded from ../models/transformers/grokking_prod_120_1_0.1_attnonly_False20240712_133838.pt
+--------------------+------------+
|      Modules       | Parameters |
+--------------------+------------+
|     embed.W_E      |   249856   |
|  pos_embed.W_pos   |    5120    |
|   blocks.0.ln1.w   |    1024    |
|   blocks.0.ln1.b   |    1024    |
|   blocks.0.ln2.w   |    1024    |
|   blocks.0.ln2.b   |    1024    |
| blocks.0.attn.W_Q  |   524288   |
|

  0%|          | 0/500 [00:00<?, ?it/s]

saved to ../models/transformers/stage1__grokking_prod_120_1_0.1_attnonly_False20240712_133838.pt


  0%|          | 0/500 [00:00<?, ?it/s]

In [None]:
break

In [None]:
import plotly.graph_objects as go
import numpy as np

def plot_probs(logits):
    
    # Create a sample one-dimensional tensor of probabilities
    probabilities = F.softmax(logits.detach(), dim=0).numpy()
    
    # Create labels for each probability (you can customize these)
    labels = [f'{i}' for i in range(len(probabilities))]
    
    # Create the bar chart
    fig = go.Figure(data=[
        go.Bar(x=labels, y=probabilities)
    ])
    
    # Update the layout
    fig.update_layout(
        title='Probability Distribution',
        xaxis_title='Token ID',
        yaxis_title='Probability',
        yaxis_range=[0, 1]  # Set y-axis range from 0 to 1 for probabilities
    )
    
    # Show the plot
    fig.show()

In [None]:
sorted(int_by_set['Dt3'])[:3], sorted(int_by_set['Df4'])[:3],

In [None]:
x = torch.tensor([
    [242, 3+119],
])


logits = model(x)
plot_probs(logits[0, -1])

In [None]:
x = torch.tensor([
    [3+119, 4],
])

logits = model(x)
plot_probs(logits[0, -1])

In [None]:
x = torch.tensor([
    [241, 75+119],
])

logits = model(x)
plot_probs(logits[0, -1])