In [1]:
checkpoint_path = './multiplication_model.pt'


transformer_config = dict(
    d_vocab=512,
    n_layers=1,
    d_model=1024,
    d_head=256,
    n_heads=4,
    d_mlp=4096,
    n_ctx=5,
    act_fn="relu",  # gelu?
    normalization_type=None,
    attn_only=False,
)

In [2]:
from argparse import Namespace
from dataclasses import dataclass, asdict



@dataclass
class TrainParams:
    n_steps: int = int(1e8)
    batch_size: int = 128
    lr: float = 1e-4
    wd: float = 1e-1
    betas: tuple = (0.9, 0.98)
    max_grad_norm: float = 1.0
    num_epochs_X1: int = 500
    num_epochs_X2: int = 500
    prop_orig: float = 0.25
    orig_held_out_frac: float = 0.01
    swap_defs: bool = False # whether to swap the order of the defs
    val_questions: int = 9

In [3]:
import sys
sys.path.append('..')

In [4]:
from data import create_datasets, seed_all, DataParams, DataArgs, Tokens, OOCL_Dataset, make_tbl_mask, create_orig_data, yield_data

In [5]:
import sys
import argparse
import logging
import time
import os

from dataclasses import dataclass, asdict
from pathlib import Path
import itertools
import random
from tqdm.auto import tqdm
import wandb

import numpy as np
import torch
import torch.nn.functional as F
from torch.utils.data import random_split, TensorDataset, DataLoader, Dataset
from transformer_lens import HookedTransformer, HookedTransformerConfig

In [6]:
import pandas as pd # Go 🐼, go!

In [7]:
def get_device():
    #return 'cpu'
    if torch.cuda.is_available():
        return "cuda"
    # elif torch.backends.mps.is_available():
    #     return "mps"
    else:
        return "cpu"


from prettytable import PrettyTable # ! pip install -q prettytable
def count_parameters(model):
    from prettytable import PrettyTable

    table = PrettyTable(["Modules", "Parameters"])
    total_params = 0
    for name, parameter in model.named_parameters():
        if not parameter.requires_grad: continue
        params = parameter.numel()
        table.add_row([name, params])
        total_params+=params
    print(table)
    print(f"Total Trainable Params: {total_params}")
    return total_params

In [8]:
mod = DataParams.mod

In [9]:
device = get_device()


new_transformer_config = transformer_config
new_transformer_config.update(dict(
    d_vocab=2*mod + 4,  # 3 special tokens + mod vars
))
new_cfg = HookedTransformerConfig(**new_transformer_config)
model = HookedTransformer(new_cfg)
state_dict = torch.load(checkpoint_path, map_location='cpu')
model.load_state_dict(state_dict)
model.to(device);

Moving model to device:  cpu


# Data

In [10]:
class IMLDataModule:
    RELIABLE_DEF_IDX = 2*DataParams.mod + Tokens.reliable_def
    UNRELIABLE_DEF_IDX = 2*DataParams.mod + Tokens.unreliable_def
    def __init__(self, batch_size, device='cpu', mod=120, orig_held_out_frac=0.01, prop_orig=0.25, seed=0):


        self.seed = seed
        self.device = device
        self.mod = mod
        self.batch_size = batch_size
        self.orig_held_out_frac = orig_held_out_frac
        self.prop_orig = prop_orig

        self.setup()



    def setup(self, int_by_set=None):
        seed_all(self.seed)
        mod = self.mod
        batch_size = self.batch_size
        prop_orig = self.prop_orig
        device = self.device
        


        size = mod // 4
        rem = mod % 4

        if int_by_set is None:
            numbers = list(range(mod))
            random.shuffle(numbers)
                
            int_by_set = {}
            int_by_set['DtQ1'] = numbers[0:size]
            int_by_set['DfQ2'] = numbers[size:2*size]
            int_by_set['Dt3'] = numbers[2*size:3*size]
            int_by_set['Df4'] = numbers[3*size:mod]
        self.int_by_set = int_by_set
        train_sets, test_sets = create_datasets(int_by_set)
        orig_args = make_tbl_mask(mod=mod, method='prod', frac_held_out=self.orig_held_out_frac)
        x_vv, y_vv, z_vv, train_vv, valid_vv = orig_args

        X1_dataset = OOCL_Dataset(train_sets['X1'], create_orig_data, orig_args, prop_orig)
        X2_dataset = OOCL_Dataset(train_sets['X2'], create_orig_data, orig_args, prop_orig)

        self.X1_dataset = X1_dataset
        self.X2_dataset = X2_dataset 
        self.X1_loader = DataLoader(X1_dataset, batch_size=batch_size, shuffle=True)
        self.X2_loader = DataLoader(X2_dataset, batch_size=batch_size, shuffle=True)
        
        orig_data_valid_loader = yield_data(batch_size, x_vv, y_vv, z_vv, valid_vv)
        
        test_set_loaders = {}
        
        for s in test_sets:
            test_set_loaders[s] = DataLoader(TensorDataset(test_sets[s].to(dtype=torch.int)), batch_size=batch_size, shuffle=False)

        self.orig_data_valid_loader = orig_data_valid_loader
        self.test_set_loaders = test_set_loaders
        self.prepare_X2_defs_and_questions()
    

    def prepare_X2_defs_and_questions(self):
        questions_X2 = []

        for batch in self.test_set_loaders['Dt3']:
            questions_X2.append(batch[0])
        
        
        for batch in self.test_set_loaders['Df4']:
            questions_X2.append(batch[0])
        
        
        questions_X2 = torch.cat(questions_X2, dim=0)
        
        
        definitions_X2 = []
        
        for tokens in self.X2_loader:
            tokens = tokens.squeeze(1)
            definitions_X2.append(tokens)

        definitions_X2 = torch.cat(definitions_X2)
        # definitions_X2 = torch.tensor([d for d in definitions_X2 if (d[0].item() in [self.RELIABLE_DEF_IDX, self.UNRELIABLE_DEF_IDX])])
        definition_mask = (definitions_X2[:, 0] == self.RELIABLE_DEF_IDX) | (definitions_X2[:, 0] == self.UNRELIABLE_DEF_IDX)

        self.questions_X2 = questions_X2.to(self.device)
        self.definitions_X2 = definitions_X2[definition_mask].clone().to(self.device)

In [11]:
datamodule = IMLDataModule(batch_size=128)

  result_tensor = torch.tensor(Z).view(N, 1)


In [12]:
print('X2 definitions and questions')
datamodule.questions_X2.shape, datamodule.definitions_X2.shape

X2 definitions and questions


(torch.Size([720, 4]), torch.Size([60, 4]))

In [13]:
datamodule.questions_X2[:3]

tensor([[185,   2, 240,  12],
        [208,   5, 240,  85],
        [208,   3, 240,  27]], dtype=torch.int32)

In [14]:
pd.Series(datamodule.definitions_X2[:, 0]).value_counts()

241    30
242    30
Name: count, dtype: int64

In [15]:
datamodule.questions_X2[:3]

tensor([[185,   2, 240,  12],
        [208,   5, 240,  85],
        [208,   3, 240,  27]], dtype=torch.int32)

# Evaluations

In [16]:
def loss_fn(logits, tokens):

    # check whether question or def and compute loss appropriately
    # logit shape: [batch, pos, vocab]
    # token shape: [batch, pos]

    mask = (tokens[:, 3] == 2*DataParams.mod + Tokens.padding)

    def_logits = logits[mask]
    def_tokens = tokens[mask].long()

    q_logits = logits[~mask]
    q_tokens = tokens[~mask].long()

    def_logits = def_logits[:, 1].unsqueeze(1)
    def_tokens = def_tokens[:, 2].unsqueeze(1)
    def_log_probs = def_logits.log_softmax(-1)
    def_correct_log_probs = def_log_probs.gather(-1, def_tokens[..., None])[..., 0]
    
    q_logits = q_logits[:, 2].unsqueeze(1)
    q_tokens = q_tokens[:, 3].unsqueeze(1)
    q_log_probs = q_logits.log_softmax(-1)
    q_correct_log_probs = q_log_probs.gather(-1, q_tokens[..., None])[..., 0]

    return -(def_correct_log_probs.sum() + q_correct_log_probs.sum())/(def_correct_log_probs.shape[0] + q_correct_log_probs.shape[0])



def evaluate(model, val_loader, device):
    correct = 0
    loss = 0.
    total = 0
    batches = 0

    for batch in val_loader:
        inputs = batch[0].to(device)

        labels = inputs[:, -1]

        with torch.no_grad():
            output = model(inputs)
            loss += loss_fn(output, inputs).item()
            correct += (torch.argmax(output[:, -2, :], dim=1) == labels).sum()

        total += inputs.shape[0]
        batches += 1

    acc = correct / total
    loss = loss/batches
    return acc, loss


def get_accuracies_by_dataset(model, datamodule):
    test_set_loaders = datamodule.test_set_loaders
    orig_data_valid_loader = datamodule.orig_data_valid_loader
    
    
    val_acc_DtQ1, val_loss_DtQ1 = evaluate(model, test_set_loaders['DtQ1'], device)
    val_acc_DfQ2, val_loss_DfQ2 = evaluate(model, test_set_loaders['DfQ2'], device)
    val_acc_Dt3, val_loss_Dt3 = evaluate(model, test_set_loaders['Dt3'], device)
    val_acc_Df4, val_loss_Df4 = evaluate(model, test_set_loaders['Df4'], device)

    with torch.no_grad():
        # logging.info(tokens)
        tokens = next(orig_data_valid_loader)
        tokens = tokens.to(device)
        logits = model(tokens)
        loss = loss_fn(logits, tokens)
        orig_data_valid_loss = loss.item()


    metrics = {
                    # "train/loss": train_loss,
                    "valid_DtQ1/acc": val_acc_DtQ1,
                    "valid_DfQ2/acc": val_acc_DfQ2,
                    "valid_DtQ1/loss": val_loss_DtQ1,
                    "valid_DfQ2/loss": val_loss_DfQ2,
    
                    "valid_Dt3/acc": val_acc_Dt3,
                    "valid_Df4/acc": val_acc_Df4,
                    "valid_Dt3/loss": val_loss_Dt3,
                    "valid_Df4/loss": val_loss_Df4,
    
                    "val/loss": (val_loss_DtQ1+val_loss_DfQ2)/2,
                    "orig_data_valid_loss": orig_data_valid_loss,
                }
    return metrics

In [17]:
get_accuracies_by_dataset(model, datamodule)

{'valid_DtQ1/acc': tensor(0.0852),
 'valid_DfQ2/acc': tensor(0.0370),
 'valid_DtQ1/loss': 6.754844983418782,
 'valid_DfQ2/loss': 7.908741156260173,
 'valid_Dt3/acc': tensor(0.0750),
 'valid_Df4/acc': tensor(0.0250),
 'valid_Dt3/loss': 7.607547601064046,
 'valid_Df4/loss': 7.695056438446045,
 'val/loss': 7.3317930698394775,
 'orig_data_valid_loss': 0.04513544961810112}

## gradient alignment

In [18]:
import collections


def get_definition_and_questions(n, datamodule):
    definitions_X2 = datamodule.definitions_X2
    questions_X2 = datamodule.questions_X2

    definition = definitions_X2[definitions_X2[:, 1] == (DataParams.mod-1+n)]
    question_mask = (questions_X2[:, 0] == (DataParams.mod-1+n)) | (questions_X2[:, 1] == (DataParams.mod-1+n))
    questions = questions_X2[question_mask]

    return definition, questions


def get_grads(model, tokens):
    for p in model.parameters():
        p.requires_grad_(True)
        p.grad = None
    
    logits = model(tokens)
    loss = loss_fn(logits, tokens)
    loss.backward()

    grads = {}
    for n, p in model.named_parameters():
        grads[n] = (p.grad.detach().flatten())

    grads['all'] = torch.cat(list(grads.values()))

    return grads


def measure_cos_sims_X2(model, datamodule, numbers):
    cos_sims = collections.defaultdict(list)
    
    
    for n in numbers:
        definition, questions = get_definition_and_questions(n, datamodule)
        q_grads = get_grads(model, questions)
        d_grads = get_grads(model, definition)
    
        for k in d_grads.keys():
            cos_sim = F.cosine_similarity(d_grads[k], q_grads[k], dim=0)
            cos_sims[k].append(cos_sim)
    cos_sims = {k:np.mean(cs) for k,cs in cos_sims.items()}
    return cos_sims


def measure_gradient_alignment(model, datamodule):
    cos_sims_Dt3 = measure_cos_sims_X2(model, datamodule, datamodule.int_by_set['Dt3'])
    cos_sims_Df4 = measure_cos_sims_X2(model, datamodule, datamodule.int_by_set['Df4'])

    alignment_metrics = {}

    for k,v in cos_sims_Dt3.items():
        alignment_metrics[f'grad_cossim_Dt3/{k}'] = v
    
    for k,v in cos_sims_Df4.items():
        alignment_metrics[f'grad_cossim_Df4/{k}'] = v

    return alignment_metrics

# measure_gradient_alignment(model, datamodule)

In [19]:
def train_iml(args, train_params, checkpoint_path=checkpoint_path):
    datamodule = IMLDataModule(batch_size=train_params.batch_size)


    wandb.init(
        project="misha-iml",
        group=args.wandb_group_name,
        name=args.wandb_experiment_name,
        config={
            **asdict(DataParams()),
            **asdict(train_params),
            **transformer_config,
        }
    )
    print(f'seed={args.seed}')
    print('int_by_set')
    print(datamodule.int_by_set)
    print('loaded from', checkpoint_path)
    count_parameters(model)
    
    optimizer = torch.optim.AdamW(model.parameters(), lr=train_params.lr, betas=train_params.betas, weight_decay=train_params.wd)
    # optimizer = torch.optim.SGD(model.parameters(), lr=train_params.lr, weight_decay=train_params.wd)
    losses = []
    
    pbar = tqdm(range(train_params.num_epochs_X1))
    step = 0
    for epoch in pbar:
        model.train()
        for tokens in datamodule.X1_loader:
            tokens = tokens.squeeze(1)
            tokens = tokens.to(device)
            
            # randomize_noisy_labels(tokens)
            
            logits = model(tokens)
            loss = loss_fn(logits, tokens)
            loss.backward()
            if train_params.max_grad_norm is not None:
                torch.nn.utils.clip_grad_norm_(model.parameters(), train_params.max_grad_norm)
            optimizer.step()
            optimizer.zero_grad()
            losses.append(loss.item())
    
    
        train_loss = np.mean(losses)
        model.eval()
        metrics = get_accuracies_by_dataset(model, datamodule)
        metrics['train_loss'] = train_loss
        metrics['step'] = step
    
        pbar.set_description(f'train_Loss={train_loss.item():.3f}')
    
        if wandb.run is not None:
            wandb.log(metrics)
    
            if (epoch % args.log_grad_alignment_freq == 0) or (epoch == train_params.num_epochs_X1 - 1):
                grad_alignment_metrics = measure_gradient_alignment(model, datamodule)
                
                wandb.log({
                    'step': step,
                    **grad_alignment_metrics
                })
    
        step += 1
    
    checkpoint_path = Path(checkpoint_path)
    new_checkpoint_path = checkpoint_path.parent / ('stage1__'+checkpoint_path.name)
    
    torch.save(model.state_dict(), new_checkpoint_path)
    print(f'saved to {new_checkpoint_path}')
    
    optimizer = torch.optim.AdamW(model.parameters(), lr=train_params.lr, betas=train_params.betas, weight_decay=train_params.wd)
    # optimizer = torch.optim.SGD(model.parameters(), lr=train_params.lr, weight_decay=train_params.wd)
    losses = []
    
    for epoch in tqdm(range(train_params.num_epochs_X2)):
        model.train()
        for tokens in datamodule.X2_loader:
            tokens = tokens.squeeze(1)
            tokens = tokens.to(device)
            logits = model(tokens)
            loss = loss_fn(logits, tokens)
            loss.backward()
            
            # model.W_E.grad[:120] = 0
    
            if train_params.max_grad_norm is not None:
                torch.nn.utils.clip_grad_norm_(model.parameters(), train_params.max_grad_norm)
            optimizer.step()
            optimizer.zero_grad()
            losses.append(loss.item())
    
    
        train_loss = np.mean(losses)
        model.eval()
        metrics = get_accuracies_by_dataset(model, datamodule)
        metrics['train_loss'] = train_loss
        metrics['step'] = step

        
        if wandb.run is not None:
            wandb.log(metrics)
    
    
            if (epoch % args.log_grad_alignment_freq == 0) or (epoch == train_params.num_epochs_X2 - 1):
                grad_alignment_metrics = measure_gradient_alignment(model, datamodule)
                
                wandb.log({
                    'step': step,
                    **grad_alignment_metrics
                })
    
        step += 1
    
    checkpoint_path = Path(checkpoint_path)
    new_checkpoint_path = checkpoint_path.parent / ('stage2__'+checkpoint_path.name)
    
    torch.save(model.state_dict(), new_checkpoint_path)
    print(f'saved to {new_checkpoint_path}')

In [None]:
seed = 42

train_params = TrainParams()
args = Namespace(
    model_path='./models/transformers/', 
    # model_name='grokking_prod_120_2_0.1_attnonly_True20240709_180213.pt', 
    wandb_group_name='Luans_model',
    wandb_experiment_name=f'seed={seed}',
    saved_model_name=None,
    log_grad_alignment_freq=50,
    seed=seed, 
    save_steps=[500, 950])

train_iml(args, train_params)

[34m[1mwandb[0m: Currently logged in as: [33mkilianovski[0m. Use [1m`wandb login --relogin`[0m to force relogin


seed=42
int_by_set
{'DtQ1': [57, 103, 99, 3, 111, 59, 73, 22, 30, 25, 47, 69, 23, 67, 75, 16, 85, 29, 2, 76, 8, 107, 43, 84, 98, 44, 46, 115, 80, 37], 'DfQ2': [50, 72, 118, 20, 93, 10, 52, 14, 83, 6, 28, 15, 34, 48, 114, 104, 88, 13, 91, 54, 112, 58, 102, 95, 21, 24, 19, 94, 35, 109], 'Dt3': [4, 81, 82, 41, 31, 86, 63, 0, 110, 11, 1, 92, 7, 116, 66, 56, 119, 70, 26, 78, 40, 55, 105, 89, 71, 60, 42, 87, 9, 117], 'Df4': [39, 18, 77, 90, 68, 32, 79, 12, 96, 101, 36, 17, 64, 27, 74, 45, 61, 38, 106, 100, 51, 62, 65, 33, 5, 53, 113, 97, 49, 108]}
loaded from ./multiplication_model.pt
+--------------------+------------+
|      Modules       | Parameters |
+--------------------+------------+
|     embed.W_E      |   249856   |
|  pos_embed.W_pos   |    5120    |
| blocks.0.attn.W_Q  |  1048576   |
| blocks.0.attn.W_O  |  1048576   |
| blocks.0.attn.b_Q  |    1024    |
| blocks.0.attn.b_O  |    1024    |
| blocks.0.attn.W_K  |  1048576   |
| blocks.0.attn.W_V  |  1048576   |
| blocks.0.attn.b_

  0%|          | 0/500 [00:00<?, ?it/s]