# Base Models

In [1]:
# https://wandb.ai/kilianovski/misha-iml/runs/87g1uy6b?nw=nwuserkilianovski

checkpoint_path = '../models/transformers/grokking_prod_120_6_0.1_attnonly_False20240711_151833.pt'

transformer_config = dict(
    d_vocab=512,
    n_layers=6,
    d_model=1024,
    d_head=256,
    n_heads=4,
    d_mlp=512,
    n_ctx=5,
    act_fn="relu",  # gelu?
    normalization_type="LN",
    attn_only=False,
)

In [2]:
# ! pip install torch --upgrade

In [3]:
# # https://wandb.ai/kilianovski/misha-iml/runs/4xnrqoxv/logs
# checkpoint_path = '../models/transformers/grokking_prod_120_2_0.1_attnonly_True20240709_180213.pt'

# transformer_config = dict(
#     d_vocab=512,
#     n_layers=2,
#     d_model=2**7,
#     d_head=2**7,
#     n_heads=4,
#     d_mlp=2**8,
#     n_ctx=5,
#     act_fn="relu",  # gelu?
#     normalization_type="LN",
#     attn_only=True,
# )

In [4]:


# # https://wandb.ai/kilianovski/misha-iml/runs/vn9qak0w?nw=nwuserkilianovski
checkpoint_path = '../models/transformers/grokking_prod_120_1_0.1_attnonly_False20240712_133838.pt'

transformer_config = dict(
    d_vocab=512,
    n_layers=1,
    d_model=1024,
    d_head=128,
    n_heads=4,
    d_mlp=256,
    n_ctx=5,
    act_fn="relu",  # gelu?
    normalization_type="LN",
    attn_only=False,
)

# Setup

In [5]:
import sys
sys.path.append('..')

In [6]:
from data import create_datasets, seed_all, DataParams, Tokens, OOCL_Dataset, make_tbl_mask, create_orig_data, yield_data

In [7]:
import logging
import torch
from dataclasses import dataclass, asdict
import numpy as np
import time
import os
from tqdm.auto import tqdm
from pathlib import Path
import itertools
import sys
import random
import torch.nn.functional as F
from torch.utils.data import random_split, TensorDataset, DataLoader, Dataset
import argparse
from transformer_lens import HookedTransformer, HookedTransformerConfig
import wandb
from dotenv import load_dotenv
from sympy import factorint
from itertools import product
from math import prod

import os
import random
import numpy as np
import torch
from tqdm.auto import tqdm

In [8]:

def get_device():
    #return 'cpu'
    if torch.cuda.is_available():
        return "cuda"
    # elif torch.backends.mps.is_available():
    #     return "mps"
    else:
        return "cpu"
        

In [9]:

def loss_fn(logits, tokens):

    # check whether question or def and compute loss appropriately
    # logit shape: [batch, pos, vocab]
    # token shape: [batch, pos]

    mask = (tokens[:, 3] == 2*DataParams.mod + Tokens.padding)

    def_logits = logits[mask]
    def_tokens = tokens[mask].long()

    q_logits = logits[~mask]
    q_tokens = tokens[~mask].long()

    def_logits = def_logits[:, 1].unsqueeze(1)
    def_tokens = def_tokens[:, 2].unsqueeze(1)
    def_log_probs = def_logits.log_softmax(-1)
    def_correct_log_probs = def_log_probs.gather(-1, def_tokens[..., None])[..., 0]
    
    q_logits = q_logits[:, 2].unsqueeze(1)
    q_tokens = q_tokens[:, 3].unsqueeze(1)
    q_log_probs = q_logits.log_softmax(-1)
    q_correct_log_probs = q_log_probs.gather(-1, q_tokens[..., None])[..., 0]

    return -(def_correct_log_probs.sum() + q_correct_log_probs.sum())/(def_correct_log_probs.shape[0] + q_correct_log_probs.shape[0])


In [10]:
def evaluate(model, val_loader, device):

    correct = 0
    loss = 0.
    total = 0
    batches = 0

    for batch in val_loader:
        inputs = batch[0].to(device)

        labels = inputs[:, -1]

        with torch.no_grad():
            output = model(inputs)
            loss += loss_fn(output, inputs).item()
            correct += (torch.argmax(output[:, -2, :], dim=1) == labels).sum()

        total += inputs.shape[0]
        batches += 1

    acc = correct / total
    loss = loss/batches
    return acc, loss


# Train

In [11]:
from argparse import Namespace

@dataclass
class TrainParams:
    n_steps: int = int(1e8)
    batch_size: int = 128
    lr: float = 0.0001
    wd: float = 0.1
    betas: tuple = (0.9, 0.98)
    max_grad_norm: float = 1.0
    num_epochs_X1: int = 1000
    num_epochs_X2: int = 3000
    prop_orig: float = 0.25
    orig_held_out_frac: float = 0.01
    swap_defs: bool = False # whether to swap the order of the defs
    val_questions: int = 9

seed = 0

In [12]:
seed_all(seed)

mod = DataParams.mod
# divide the integers into 4 equally sized sets
size = mod // 4
rem = mod % 4

numbers = list(range(DataParams.mod))
random.shuffle(numbers)

train_params = TrainParams()
    
int_by_set = {}
int_by_set['DtQ1'] = numbers[0:size]
int_by_set['DfQ2'] = numbers[size:2*size]
int_by_set['Dt3'] = numbers[2*size:3*size]
int_by_set['Df4'] = numbers[3*size:mod]

In [13]:
train_sets, test_sets = create_datasets(int_by_set)
orig_args = make_tbl_mask(mod=DataParams.mod, method='prod', frac_held_out=train_params.orig_held_out_frac)

  result_tensor = torch.tensor(Z).view(N, 1)


In [14]:
new_transformer_config = transformer_config
new_transformer_config.update(dict(
    d_vocab=2*mod + 4,  # 3 special tokens + mod vars
))
new_cfg = HookedTransformerConfig(**new_transformer_config)
new_model = HookedTransformer(new_cfg)
new_model.load_state_dict(torch.load(checkpoint_path))
new_model.to(get_device())

model = new_model

Moving model to device:  cpu


In [15]:
batch_size = train_params.batch_size

# unpack orig_args for use in valid_loader

x_vv, y_vv, z_vv, train_vv, valid_vv = orig_args

device = get_device()

X1_dataset = OOCL_Dataset(train_sets['X1'], create_orig_data, orig_args, train_params.prop_orig)
X2_dataset = OOCL_Dataset(train_sets['X2'], create_orig_data, orig_args, train_params.prop_orig)

X1_loader = DataLoader(X1_dataset, batch_size=batch_size, shuffle=True)
X2_loader = DataLoader(X2_dataset, batch_size=batch_size, shuffle=True)

orig_data_valid_loader = yield_data(train_params.batch_size, x_vv, y_vv, z_vv, valid_vv)

test_set_loaders = {}

for s in test_sets:
    test_set_loaders[s] = DataLoader(TensorDataset(test_sets[s].to(dtype=torch.int)), batch_size=train_params.batch_size, shuffle=False)


In [16]:
args = Namespace(
    model_path='./models/transformers/', 
    # model_name='grokking_prod_120_2_0.1_attnonly_True20240709_180213.pt', 
    wandb_group_name=None,
    wandb_experiment_name='1024_1L_MLP',

    saved_model_name=None,
    seed=seed, 
    save_steps=[500, 950])


In [17]:
wandb.init(
    project="misha-iml",
    group=args.wandb_group_name,
    name=args.wandb_experiment_name,
    config={
        **asdict(DataParams()),
        **asdict(train_params),
        **transformer_config,
    }
)

[34m[1mwandb[0m: Currently logged in as: [33mkilianovski[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [18]:
optimizer = torch.optim.AdamW(model.parameters(), lr=train_params.lr, betas=train_params.betas, weight_decay=train_params.wd)
losses = []

for epoch in tqdm(range(train_params.num_epochs_X1)):
    model.train()
    for tokens in X1_loader:
        tokens = tokens.squeeze(1)
        tokens = tokens.to(device)
        logits = model(tokens)
        loss = loss_fn(logits, tokens)
        loss.backward()
        if train_params.max_grad_norm is not None:
            torch.nn.utils.clip_grad_norm_(model.parameters(), train_params.max_grad_norm)
        optimizer.step()
        optimizer.zero_grad()
        losses.append(loss.item())


    train_loss = np.mean(losses)
    model.eval()
    val_acc_DtQ1, val_loss1 = evaluate(model, test_set_loaders['DtQ1'], device)
    val_acc_DfQ2, val_loss2 = evaluate(model, test_set_loaders['DfQ2'], device)
    val_acc_Dt3, _ = evaluate(model, test_set_loaders['Dt3'], device)
    val_acc_Df4, _ = evaluate(model, test_set_loaders['Df4'], device)

    # evaluate performance on orig data validation set

    with torch.no_grad():
        # logging.info(tokens)
        tokens = next(orig_data_valid_loader)
        tokens = tokens.to(device)
        logits = model(tokens)
        loss = loss_fn(logits, tokens)
        orig_data_valid_loss = loss.item()
    metrics = {
                    "train/loss": train_loss,
                    "valid_DtQ1/acc": val_acc_DtQ1,
                    "valid_DfQ2/acc": val_acc_DfQ2,
                    "valid_Dt3/acc": val_acc_Dt3,
                    "valid_Df4/acc": val_acc_Df4,
                    "val/loss": (val_loss1+val_loss2)/2,
                    "orig_data_valid_loss": orig_data_valid_loss
                }
    if wandb.run is not None:
        wandb.log(metrics)

    

  0%|          | 0/1000 [00:00<?, ?it/s]

In [19]:
checkpoint_path = Path(checkpoint_path)
new_checkpoint_path = checkpoint_path.parent / ('stage1__'+checkpoint_path.name)

torch.save(model.state_dict(), new_checkpoint_path)
print(f'saved to {new_checkpoint_path}')

saved to ../models/transformers/stage1__grokking_prod_120_1_0.1_attnonly_False20240712_133838.pt


In [None]:
optimizer = torch.optim.AdamW(model.parameters(), lr=train_params.lr, betas=train_params.betas, weight_decay=train_params.wd)
losses = []

for epoch in tqdm(range(train_params.num_epochs_X2)):
    model.train()
    for tokens in X2_loader:
        tokens = tokens.squeeze(1)
        tokens = tokens.to(device)
        logits = model(tokens)
        loss = loss_fn(logits, tokens)
        loss.backward()
        if train_params.max_grad_norm is not None:
            torch.nn.utils.clip_grad_norm_(model.parameters(), train_params.max_grad_norm)
        optimizer.step()
        optimizer.zero_grad()
        losses.append(loss.item())


    train_loss = np.mean(losses)
    model.eval()
    val_acc_DtQ1, val_loss1 = evaluate(model, test_set_loaders['DtQ1'], device)
    val_acc_DfQ2, val_loss2 = evaluate(model, test_set_loaders['DfQ2'], device)
    val_acc_Dt3, _ = evaluate(model, test_set_loaders['Dt3'], device)
    val_acc_Df4, _ = evaluate(model, test_set_loaders['Df4'], device)

    # evaluate performance on orig data validation set

    with torch.no_grad():
        # logging.info(tokens)
        tokens = next(orig_data_valid_loader)
        tokens = tokens.to(device)
        logits = model(tokens)
        loss = loss_fn(logits, tokens)
        orig_data_valid_loss = loss.item()
    metrics = {
                    "train/loss": train_loss,
                    "valid_DtQ1/acc": val_acc_DtQ1,
                    "valid_DfQ2/acc": val_acc_DfQ2,
                    "valid_Dt3/acc": val_acc_Dt3,
                    "valid_Df4/acc": val_acc_Df4,
                    "val/loss": (val_loss1+val_loss2)/2,
                    "orig_data_valid_loss": orig_data_valid_loss
                }
    if wandb.run is not None:
        wandb.log(metrics)

    

  0%|          | 0/3000 [00:00<?, ?it/s]

In [None]:
checkpoint_path = Path(checkpoint_path)
new_checkpoint_path = checkpoint_path.parent / ('stage2__'+checkpoint_path.name)

torch.save(model.state_dict(), new_checkpoint_path)
print(f'saved to {new_checkpoint_path}')

In [None]:
break

In [None]:
def train_oocl_stage_1_2(args, seed):
    args.seed = seed
    
    
    model_path = args.model_path + args.model_name
    
    if args.seed:
    
        torch.manual_seed(args.seed)
        random.seed(args.seed)


    
    
    
    

    # load wandb
    
    # wandb.login(key=os.getenv("WANDB_API_KEY"))
    
    dir_models = "models/transformers/"
    Path(dir_models).mkdir(exist_ok=True, parents=True)
    
    # model.load_state_dict(torch.load(os.path.join(dir_models, "interrupted.pt")))
    
    
    exp_name = f'seed={args.seed}'
    name = f"oocl__{args.model_name}"
    
    wandb.init(
        project="misha-iml",
        group=args.wandb_group_name,
        name=exp_name,
        config={
            **asdict(DataParams()),
            **asdict(train_params),
            **new_transformer_config,
        }
    )
    print(f'{args.seed=}')
    print(f'int_by_set')
    print(int_by_set)
    # print('Ints by set:\n')
    
    ints_by_set={}
    for k in int_by_set:
    
        print(k)
        print(int_by_set[k])
        wandb.log({f"{k}": int_by_set[k]})
        ints_by_set[f"{k}"]=int_by_set[k]
        print("\n")
    
    torch.save(ints_by_set,f"./models/{name}_ints_by_set.pt")
    
    
    train_sets, test_sets = create_data(int_by_set)
    
    
    data_name = f"data_oocl_{DataParams.mod}.pt"
    
    from datetime import datetime
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    
    data_name = os.path.join(args.model_path, timestamp+'__'+data_name)
    print(f'SAVING TO {data_name}')
    torch.save((train_sets, test_sets), data_name)
    
    
    orig_args = make_tbl_mask(mod=DataParams.mod, method='prod', frac_held_out=train_params.orig_held_out_frac)
    
    train_w_orig(new_model, train_sets, test_sets, orig_args, train_params, args)
    
    wandb.finish()
