## Colab Setup

In [1]:
# ! pip install transformer-lens===1.14.0
# ! pip install wandb
# ! pip install circuitsvis
# ! pip install prettytable

In [2]:
# import gdown

# id = "1yhPsuzV0dI9idkgBdmWWPtYPEWXiFPpb"
# output = "multiplication_model.pt"
# gdown.download(id=id, output=output)

In [3]:
checkpoint_path = './multiplication_model.pt'


transformer_config = dict(
    d_vocab=512,
    n_layers=1,
    d_model=1024,
    d_head=256,
    n_heads=4,
    d_mlp=4096,
    n_ctx=5,
    act_fn="relu",  # gelu?
    normalization_type=None,
    attn_only=False,
)

In [4]:
from argparse import Namespace
from dataclasses import dataclass, asdict



@dataclass
class TrainParams:
    n_steps: int = int(1e8)
    batch_size: int = 128
    lr: float = 1e-4
    wd: float = 1e-1
    betas: tuple = (0.9, 0.98)
    max_grad_norm: float = 1.0
    num_epochs_X1: int = 500
    num_epochs_X2: int = 500
    prop_orig: float = 0.25
    orig_held_out_frac: float = 0.01
    swap_defs: bool = False # whether to swap the order of the defs
    val_questions: int = 9

In [5]:
# import sys
# sys.path.append('..')

In [6]:
# from data import create_datasets, seed_all, DataParams, DataArgs, Tokens, OOCL_Dataset, make_tbl_mask, create_orig_data, yield_data

from math import prod
from itertools import product
from sympy import factorint
# from dotenv import load_dotenv
import wandb
from transformer_lens import HookedTransformer, HookedTransformerConfig
import argparse
from torch.utils.data import random_split, TensorDataset, DataLoader, Dataset
import torch.nn.functional as F
import logging
import torch
from dataclasses import dataclass, asdict
import numpy as np
import time
import os
from tqdm.auto import tqdm
from pathlib import Path
import itertools
import sys
import random
# sys.path.insert(0, str(Path(__file__).parent.parent))


def seed_all(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True


@dataclass
class DataParams:
    mod: int = 120
    operation: str = "prod"

@dataclass
class DataArgs:
    orig_held_out_frac: float = 0.01
    batch_size: int = None
    prop_orig: float = 0.25


@dataclass
class Tokens:
    # diff from 2*mod
    equal: int = 0
    reliable_def: int = 1
    unreliable_def: int = 2
    padding: int = 3



class OOCL_Dataset(Dataset):

    def __init__(self, oocl_data, orig_data, orig_args, prop_orig=0.1):

        self.oocl_data = oocl_data
        self.orig_data = orig_data
        self.orig_args = orig_args
        self.prop_orig = prop_orig

        self.data_size = int((1+prop_orig)*len(self.oocl_data))

    def __len__(self):

        return self.data_size

    def __getitem__(self, index):

        if index >= len(self.oocl_data):
            a = self.orig_data(1, *self.orig_args).long()
            return a

        else:
            return self.oocl_data[index].unsqueeze(0).long()


def make_tbl_mask(mod=17, method="ssq", frac_held_out=0.05):
    tbl_vv = torch.empty((mod, mod), dtype=torch.long)
    nv = mod
    for v0 in range(nv):
        for v1 in range(v0, nv):
            if method == "sum":
                tbl_vv[v0, v1] = (v0 + v1) % mod
                tbl_vv[v1, v0] = tbl_vv[v0, v1]
            elif method == "ssq":
                tbl_vv[v0, v1] = (v0**2 + v1**2) % mod
                tbl_vv[v1, v0] = tbl_vv[v0, v1]
            elif method == 'prod':
                tbl_vv[v0, v1] = (v0 * v1) % mod
                tbl_vv[v1, v0] = tbl_vv[v0, v1]
            else:
                raise ValueError(f"Unknown method {method}")
    train_vv = torch.randperm(
        nv * nv).reshape(nv, nv) > (frac_held_out * nv * nv)
    valid_vv = ~train_vv
    # train and valid are distinct
    assert torch.equal((train_vv & valid_vv).any(), torch.tensor(False))
    x_vv = torch.arange(nv).repeat(nv, 1).T
    y_vv = torch.arange(nv).repeat(nv, 1)
    return x_vv, y_vv, tbl_vv, train_vv, valid_vv


def yield_data(batch_size, x_vv, y_vv, z_vv, m_vv):
    """Sample only where m_vv is True.
    """
    # torch.manual_seed(seed)
    nv = x_vv.shape[0]
    nb = batch_size
    nV = nv * nv
    x_V = x_vv.reshape(nV)
    y_V = y_vv.reshape(nV)
    z_V = z_vv.reshape(nV)
    m_V = m_vv.reshape(nV)
    nM = m_V.sum().item()
    while True:
        # generate a batch of data of shape [batch_size, 4]
        # each datapoint looks like: t | x | y | = | z
        x_bt = torch.empty((nb, 4), dtype=torch.long)
        # choose only masked elements
        i = torch.where(m_V)[0][torch.randint(0, nM, (nb,))]
        # ensure they are masked
        assert torch.equal(m_V[i].all(), torch.tensor(True))
        x_bt[:, 0] = x_V[i]             # x
        x_bt[:, 1] = y_V[i]             # y
        x_bt[:, 2] = 2*DataParams.mod + Tokens.equal  # equal sign
        x_bt[:, 3] = z_V[i]             # z
        yield x_bt


def create_orig_data(batch_size, x_vv, y_vv, z_vv, m_vv, v_vv):

    nv = x_vv.shape[0]
    nb = batch_size
    nV = nv * nv
    x_V = x_vv.reshape(nV)
    y_V = y_vv.reshape(nV)
    z_V = z_vv.reshape(nV)
    m_V = m_vv.reshape(nV)
    nM = m_V.sum().item()

    # generate a batch of data of shape [batch_size, 4]
    # each datapoint looks like: t | x | y | = | z
    x_bt = torch.empty((nb, 4), dtype=torch.long)
    # choose only masked elements
    i = torch.where(m_V)[0][torch.randint(0, nM, (nb,))]
    assert torch.equal(m_V[i].all(), torch.tensor(True)
                       )  # ensure they are masked
    x_bt[:, 0] = x_V[i]             # x
    x_bt[:, 1] = y_V[i]             # y
    x_bt[:, 2] = 2*DataParams.mod + Tokens.equal  # equal sign
    x_bt[:, 3] = z_V[i]             # z

    return x_bt


def create_definitions(integers, reliable_tag, reliable_def, newconfig=True, seed=0, swap_defs=False):
    '''
    integers: list of integers to create definitions for
    reliable: bool indicating whether to use reliable/unreliable def

    definition of form D X M
    D: definition token (reliable or unreliable)
    X: variable token
    M: integer token

    return size (N, 3), where N = len(integers)
    '''
    seed_all(seed)
    def_idx = 2*DataParams.mod + Tokens.reliable_def if reliable_tag else 2 * \
        DataParams.mod + Tokens.unreliable_def

    # get the token indices of the variables

    N = len(integers)

    if (newconfig):
        var_indices = [i + DataParams.mod-1 for i in integers]
    else:
        var_indices = [i + DataParams.mod for i in integers]

    if not reliable_def:
        random.shuffle(integers)

    def_idx_tensor = torch.full((N, 1), def_idx, dtype=torch.int64)
    integer_tensor = torch.tensor(integers).view(N, 1)
    var_tensor = torch.tensor(var_indices).view(N, 1)

    def_tensor = torch.cat((def_idx_tensor, var_tensor, integer_tensor), dim=1)

    if swap_defs:
        swap_var_tensor = var_tensor.clone()
        swap_integer_tensor = integer_tensor.clone()

        indices = torch.randperm(var_tensor.size(0))

        swap_var_tensor[indices], swap_integer_tensor[indices] = integer_tensor[indices], var_tensor[indices]

        swap_def_tensor = torch.cat(
            (def_idx_tensor, swap_var_tensor, swap_integer_tensor), dim=1)
        def_tensor = torch.cat((def_tensor, swap_def_tensor), dim=0)

    return def_tensor.long()


def create_questions(integers, num_questions=6, bidir=True, result_var=False, newconfig=True):
    '''
    integers: list of integers to create questions for
    num_questions: how many questions to create per integer
    bidir: whether to have variables on the left and the right of the LHS
    result_var: whether to make result a variable sometimes too

    '''

    def get_divisors_from_prime_factors(factors, n):
        base_exponents = [
            # Start from exp=1 to exclude 1
            [base**exp for exp in range(0, max_exp + 1)]
            for base, max_exp in factors.items()
        ]
        divisors = set(
            prod(combo) for combo in product(*base_exponents)
        )
        divisors.discard(n)  # Exclude the number itself
        divisors.discard(1)
        return sorted(divisors)  # Return a sorted list of divisors

    # calculate relevant values

    N = len(integers)

    question_tensor = torch.empty((0, 4))

    if DataParams.operation == 'prod':

        factors = factorint(DataParams.mod)
        divisors = get_divisors_from_prime_factors(factors, DataParams.mod)
        divisors = [2, 3, 5, 6, 10, 15]
        for d in divisors:

            d_tensor = torch.full((N,), d, dtype=torch.int64)

            integer_tensor = torch.tensor(integers).view(N,)

            Z = integer_tensor*d_tensor % DataParams.mod
            if (newconfig):
                var_indices = [i + DataParams.mod-1 for i in integers]
            else:
                var_indices = [i + DataParams.mod for i in integers]

            var_tensor = torch.tensor(var_indices).view(N, 1)

            if (newconfig):
                equal_tensor = torch.full(
                    (N, 1), 2*DataParams.mod + Tokens.equal, dtype=torch.int64)
            else:
                equal_tensor = torch.full(
                    (N, 1), DataParams.mod, dtype=torch.int64)

            result_tensor = torch.tensor(Z).view(N, 1)
            d_tensor = d_tensor.view(N, 1)

            cur_question_tensor = torch.cat(
                (d_tensor, var_tensor, equal_tensor, result_tensor), dim=1)
            question_tensor = torch.cat(
                (question_tensor, cur_question_tensor), dim=0)

            if bidir:
                cur_question_tensor = torch.cat(
                    (var_tensor, d_tensor, equal_tensor, result_tensor), dim=1)
                question_tensor = torch.cat(
                    (question_tensor, cur_question_tensor), dim=0)

    question_tensor = question_tensor[torch.randperm(question_tensor.size(0))]
    # print(f"Number of questions: {question_tensor.size(0)}")
    return question_tensor.long()


def create_datasets(int_by_set, prop_val=0.1, num_questions=6, newconfig=True, val_questions=9):
    '''
    Create train and validation sets
    We create X1 and X2 as train sets consisting of [DtQ1, DfQ2] and [Dt3, Df4] respectively.
    These contain both questions and definitions.
    Test sets are broken down into the individual groups (i.e. DtQ1, Dt3, etc...).
    These consist *only of questions*.
    '''

    train_sets = {'X1': torch.empty((0, 4)), 'X2': torch.empty((0, 4))}
    test_sets = {'DtQ1': torch.empty((0, 4)), 'DfQ2': torch.empty(
        (0, 4)), 'Dt3': torch.empty((0, 4)), 'Df4': torch.empty((0, 4))}

    for dataset in int_by_set:

        cur_integers = int_by_set[dataset]

        cur_questions = create_questions(cur_integers)

        if dataset in ['DtQ1', 'Dt3']:
            cur_defs = create_definitions(
                cur_integers, reliable_tag=True, reliable_def=True)

        elif dataset in ['DfQ2']:
            cur_defs = create_definitions(
                cur_integers, reliable_tag=False, reliable_def=False)

        elif dataset in ['Df4']:
            cur_defs = create_definitions(
                cur_integers, reliable_tag=False, reliable_def=True)

        # pad definitions to match question size

        cur_defs = F.pad(cur_defs, (0, 1), value=2 *
                         DataParams.mod + Tokens.padding)

        # split into train and validation set

        if dataset in ['DtQ1', 'DfQ2']:

            cur_questions_dataset = TensorDataset(cur_questions)

            mask = torch.zeros(cur_questions.size(0), dtype=torch.bool)
            if newconfig:
                cur_vars = [i + DataParams.mod-1 for i in int_by_set[dataset]]
            else:
                cur_vars = [i + DataParams.mod for i in int_by_set[dataset]]

            used_vars = {i: 0 for i in cur_vars}
            test_indices = []
            for i, row in enumerate(cur_questions):

                used = False

                for var in row:
                    var = int(var)

                    if var in cur_vars:

                        if used_vars[var] == val_questions:
                            used = True
                            break

                        if not used:

                            used_vars[var] += 1
                            test_indices.append(i)

            mask[test_indices] = True

            test_qs = cur_questions[mask]
            train_qs = cur_questions[~mask]

            train_sets['X1'] = torch.cat(
                (train_sets['X1'], cur_defs, train_qs), dim=0)

            test_sets[dataset] = torch.cat(
                (test_sets[dataset], test_qs), dim=0)

        if dataset in ['Dt3', 'Df4']:

            train_sets['X2'] = torch.cat((train_sets['X2'], cur_defs), dim=0)

            test_sets[dataset] = torch.cat(
                (test_sets[dataset], cur_questions), dim=0)

    return train_sets, test_sets


In [7]:
import sys
import argparse
import logging
import time
import os

from dataclasses import dataclass, asdict
from pathlib import Path
import itertools
import random
from tqdm.auto import tqdm
import wandb

import numpy as np
import torch
import torch.nn.functional as F
from torch.utils.data import random_split, TensorDataset, DataLoader, Dataset
from transformer_lens import HookedTransformer, HookedTransformerConfig
from transformer_lens import utils


In [8]:
import pandas as pd # Go 🐼, go!

In [9]:
def get_device():
    #return 'cpu'
    if torch.cuda.is_available():
        return "cuda"
    # elif torch.backends.mps.is_available():
    #     return "mps"
    else:
        return "cpu"


from prettytable import PrettyTable # ! pip install -q prettytable
def count_parameters(model):
    from prettytable import PrettyTable

    table = PrettyTable(["Modules", "Parameters"])
    total_params = 0
    for name, parameter in model.named_parameters():
        if not parameter.requires_grad: continue
        params = parameter.numel()
        table.add_row([name, params])
        total_params+=params
    print(table)
    print(f"Total Trainable Params: {total_params}")
    return total_params

In [10]:
mod = DataParams.mod

In [11]:
device = get_device()


new_transformer_config = transformer_config
new_transformer_config.update(dict(
    d_vocab=2*mod + 4,  # 3 special tokens + mod vars
))
new_cfg = HookedTransformerConfig(**new_transformer_config)
model = HookedTransformer(new_cfg)
state_dict = torch.load(checkpoint_path, map_location='cpu')
model.load_state_dict(state_dict)
model.to(device);

Moving model to device:  cpu


# Data

In [12]:
class IMLDataModule:
    RELIABLE_DEF_IDX = 2*DataParams.mod + Tokens.reliable_def
    UNRELIABLE_DEF_IDX = 2*DataParams.mod + Tokens.unreliable_def
    def __init__(self, batch_size, device='cpu', mod=120, orig_held_out_frac=0.01, prop_orig=0.25, seed=0):


        self.seed = seed
        self.device = device
        self.mod = mod
        self.batch_size = batch_size
        self.orig_held_out_frac = orig_held_out_frac
        self.prop_orig = prop_orig

        self.setup()



    def setup(self, int_by_set=None):
        seed_all(self.seed)
        mod = self.mod
        batch_size = self.batch_size
        prop_orig = self.prop_orig
        device = self.device



        size = mod // 4
        rem = mod % 4

        if int_by_set is None:
            numbers = list(range(mod))
            random.shuffle(numbers)

            int_by_set = {}
            int_by_set['DtQ1'] = numbers[0:size]
            int_by_set['DfQ2'] = numbers[size:2*size]
            int_by_set['Dt3'] = numbers[2*size:3*size]
            int_by_set['Df4'] = numbers[3*size:mod]
        self.int_by_set = int_by_set
        train_sets, test_sets = create_datasets(int_by_set)
        orig_args = make_tbl_mask(mod=mod, method='prod', frac_held_out=self.orig_held_out_frac)
        x_vv, y_vv, z_vv, train_vv, valid_vv = orig_args

        X1_dataset = OOCL_Dataset(train_sets['X1'], create_orig_data, orig_args, prop_orig)
        X2_dataset = OOCL_Dataset(train_sets['X2'], create_orig_data, orig_args, prop_orig)

        self.X1_dataset = X1_dataset
        self.X2_dataset = X2_dataset
        self.X1_loader = DataLoader(X1_dataset, batch_size=batch_size, shuffle=True)
        self.X2_loader = DataLoader(X2_dataset, batch_size=batch_size, shuffle=True)

        orig_data_valid_loader = yield_data(batch_size, x_vv, y_vv, z_vv, valid_vv)

        test_set_loaders = {}

        for s in test_sets:
            test_set_loaders[s] = DataLoader(TensorDataset(test_sets[s].to(dtype=torch.int)), batch_size=batch_size, shuffle=False)

        self.orig_data_valid_loader = orig_data_valid_loader
        self.test_set_loaders = test_set_loaders
        self.prepare_X2_defs_and_questions()


    def prepare_X2_defs_and_questions(self):
        questions_X2 = []

        for batch in self.test_set_loaders['Dt3']:
            questions_X2.append(batch[0])


        for batch in self.test_set_loaders['Df4']:
            questions_X2.append(batch[0])


        questions_X2 = torch.cat(questions_X2, dim=0)


        definitions_X2 = []

        for tokens in self.X2_loader:
            tokens = tokens.squeeze(1)
            definitions_X2.append(tokens)

        definitions_X2 = torch.cat(definitions_X2)
        # definitions_X2 = torch.tensor([d for d in definitions_X2 if (d[0].item() in [self.RELIABLE_DEF_IDX, self.UNRELIABLE_DEF_IDX])])
        definition_mask = (definitions_X2[:, 0] == self.RELIABLE_DEF_IDX) | (definitions_X2[:, 0] == self.UNRELIABLE_DEF_IDX)

        self.questions_X2 = questions_X2.to(self.device)
        self.definitions_X2 = definitions_X2[definition_mask].clone().to(self.device)

In [13]:
datamodule = IMLDataModule(batch_size=128, device=get_device())

  result_tensor = torch.tensor(Z).view(N, 1)


In [14]:
print('X2 definitions and questions')
datamodule.questions_X2.shape, datamodule.definitions_X2.shape

X2 definitions and questions


(torch.Size([720, 4]), torch.Size([60, 4]))

In [15]:
datamodule.questions_X2[:3]

tensor([[185,   2, 240,  12],
        [208,   5, 240,  85],
        [208,   3, 240,  27]], dtype=torch.int32)

In [16]:
pd.Series(datamodule.definitions_X2[:, 0].cpu()).value_counts()

241    30
242    30
Name: count, dtype: int64

In [17]:
datamodule.questions_X2[:3]

tensor([[185,   2, 240,  12],
        [208,   5, 240,  85],
        [208,   3, 240,  27]], dtype=torch.int32)

# Evaluations

In [18]:
def loss_fn(logits, tokens):

    # check whether question or def and compute loss appropriately
    # logit shape: [batch, pos, vocab]
    # token shape: [batch, pos]

    mask = (tokens[:, 3] == 2*DataParams.mod + Tokens.padding)

    def_logits = logits[mask]
    def_tokens = tokens[mask].long()

    q_logits = logits[~mask]
    q_tokens = tokens[~mask].long()

    def_logits = def_logits[:, 1].unsqueeze(1)
    def_tokens = def_tokens[:, 2].unsqueeze(1)
    def_log_probs = def_logits.log_softmax(-1)
    def_correct_log_probs = def_log_probs.gather(-1, def_tokens[..., None])[..., 0]

    q_logits = q_logits[:, 2].unsqueeze(1)
    q_tokens = q_tokens[:, 3].unsqueeze(1)
    q_log_probs = q_logits.log_softmax(-1)
    q_correct_log_probs = q_log_probs.gather(-1, q_tokens[..., None])[..., 0]

    return -(def_correct_log_probs.sum() + q_correct_log_probs.sum())/(def_correct_log_probs.shape[0] + q_correct_log_probs.shape[0])



def evaluate(model, val_loader, device):
    correct = 0
    loss = 0.
    total = 0
    batches = 0

    for batch in val_loader:
        inputs = batch[0].to(device)

        labels = inputs[:, -1]

        with torch.no_grad():
            output = model(inputs)
            loss += loss_fn(output, inputs).item()
            correct += (torch.argmax(output[:, -2, :], dim=1) == labels).sum()

        total += inputs.shape[0]
        batches += 1

    acc = correct / total
    loss = loss/batches
    return acc, loss


def get_accuracies_by_dataset(model, datamodule):
    test_set_loaders = datamodule.test_set_loaders
    orig_data_valid_loader = datamodule.orig_data_valid_loader


    val_acc_DtQ1, val_loss_DtQ1 = evaluate(model, test_set_loaders['DtQ1'], device)
    val_acc_DfQ2, val_loss_DfQ2 = evaluate(model, test_set_loaders['DfQ2'], device)
    val_acc_Dt3, val_loss_Dt3 = evaluate(model, test_set_loaders['Dt3'], device)
    val_acc_Df4, val_loss_Df4 = evaluate(model, test_set_loaders['Df4'], device)

    with torch.no_grad():
        # logging.info(tokens)
        tokens = next(orig_data_valid_loader)
        tokens = tokens.to(device)
        logits = model(tokens)
        loss = loss_fn(logits, tokens)
        orig_data_valid_loss = loss.item()


    metrics = {
                    # "train/loss": train_loss,
                    "valid_DtQ1/acc": val_acc_DtQ1,
                    "valid_DfQ2/acc": val_acc_DfQ2,
                    "valid_DtQ1/loss": val_loss_DtQ1,
                    "valid_DfQ2/loss": val_loss_DfQ2,

                    "valid_Dt3/acc": val_acc_Dt3,
                    "valid_Df4/acc": val_acc_Df4,
                    "valid_Dt3/loss": val_loss_Dt3,
                    "valid_Df4/loss": val_loss_Df4,

                    "val/loss": (val_loss_DtQ1+val_loss_DfQ2)/2,
                    "orig_data_valid_loss": orig_data_valid_loss,
                }
    return metrics

def patch_attn_mat_eq(
    attention_scores,
    hook,
):

    attention_scores[:, :, 2, 2] = -1000000000000
    return attention_scores

def ablate_attn2_get_accuracies_by_dataset(model, datamodule):
    
    model.remove_all_hook_fns()
    model.add_hook(utils.get_act_name('attn_scores', 0), patch_attn_mat_eq)
    
    metrics = get_accuracies_by_dataset(model, datamodule)
    
    model.remove_all_hook_fns()

    metrics = {'attn_abl/'+k:v for k,v in metrics.items()}
    return metrics

In [19]:
get_accuracies_by_dataset

<function __main__.get_accuracies_by_dataset(model, datamodule)>

## gradient alignment

In [20]:
import collections


def get_definition_and_questions(n, datamodule):
    definitions_X2 = datamodule.definitions_X2.to(datamodule.device)
    questions_X2 = datamodule.questions_X2

    definition = definitions_X2[definitions_X2[:, 1] == (DataParams.mod-1+n)]
    question_mask = (questions_X2[:, 0] == (DataParams.mod-1+n)) | (questions_X2[:, 1] == (DataParams.mod-1+n))
    questions = questions_X2[question_mask]

    return definition, questions


def get_grads(model, tokens):
    for p in model.parameters():
        p.requires_grad_(True)
        p.grad = None

    logits = model(tokens)
    loss = loss_fn(logits, tokens)
    loss.backward()

    grads = {}
    for n, p in model.named_parameters():
        grads[n] = (p.grad.detach().flatten())

    grads['all'] = torch.cat(list(grads.values()))

    return grads


def measure_cos_sims_X2(model, datamodule, numbers):
    cos_sims = collections.defaultdict(list)


    for n in numbers:
        definition, questions = get_definition_and_questions(n, datamodule)
        q_grads = get_grads(model, questions)
        d_grads = get_grads(model, definition)

        for k in d_grads.keys():
            cos_sim = F.cosine_similarity(d_grads[k], q_grads[k], dim=0)
            cos_sims[k].append(cos_sim.cpu())
    cos_sims = {k:np.mean(cs) for k,cs in cos_sims.items()}
    return cos_sims


def measure_gradient_alignment(model, datamodule):
    cos_sims_Dt3 = measure_cos_sims_X2(model, datamodule, datamodule.int_by_set['Dt3'])
    cos_sims_Df4 = measure_cos_sims_X2(model, datamodule, datamodule.int_by_set['Df4'])

    alignment_metrics = {}

    for k,v in cos_sims_Dt3.items():
        alignment_metrics[f'grad_cossim_Dt3/{k}'] = v

    for k,v in cos_sims_Df4.items():
        alignment_metrics[f'grad_cossim_Df4/{k}'] = v

    return alignment_metrics

# measure_gradient_alignment(model, datamodule)

In [21]:
def train_iml(args, train_params, checkpoint_path=checkpoint_path):
    datamodule = IMLDataModule(batch_size=train_params.batch_size, seed=args.seed, device=get_device())


    wandb.init(
        project="misha-iml",
        group=args.wandb_group_name,
        name=args.wandb_experiment_name,
        config={
            **asdict(DataParams()),
            **asdict(train_params),
            **transformer_config,
        }
    )
    new_transformer_config = transformer_config
    new_transformer_config.update(dict(
        d_vocab=2*mod + 4,  # 3 special tokens + mod vars
    ))
    new_cfg = HookedTransformerConfig(**new_transformer_config)
    model = HookedTransformer(new_cfg)
    state_dict = torch.load(checkpoint_path, map_location='cpu')
    model.load_state_dict(state_dict)
    model.to(device);

    print(f'seed={args.seed}')
    print('int_by_set')
    print(datamodule.int_by_set)
    print('loaded from', checkpoint_path)
    count_parameters(model)

    optimizer = torch.optim.AdamW(model.parameters(), lr=train_params.lr, betas=train_params.betas, weight_decay=train_params.wd)
    # optimizer = torch.optim.SGD(model.parameters(), lr=train_params.lr, weight_decay=train_params.wd)
    losses = []

    pbar = tqdm(range(train_params.num_epochs_X1))
    step = 0
    for epoch in pbar:
        model.train()
        for tokens in datamodule.X1_loader:
            tokens = tokens.squeeze(1)
            tokens = tokens.to(device)

            # randomize_noisy_labels(tokens)

            logits = model(tokens)
            loss = loss_fn(logits, tokens)
            loss.backward()
            if train_params.max_grad_norm is not None:
                torch.nn.utils.clip_grad_norm_(model.parameters(), train_params.max_grad_norm)
            optimizer.step()
            optimizer.zero_grad()
            losses.append(loss.item())


        train_loss = np.mean(losses)
        model.eval()
        metrics = get_accuracies_by_dataset(model, datamodule)
        metrics['train_loss'] = train_loss
        metrics['step'] = step

        metrics = {**metrics, **ablate_attn2_get_accuracies_by_dataset(model, datamodule)}

        pbar.set_description(f'train_Loss={train_loss.item():.3f}')

        if wandb.run is not None:
            wandb.log(metrics)

            if (epoch % args.log_grad_alignment_freq == 0) or (epoch == train_params.num_epochs_X1 - 1):
                grad_alignment_metrics = measure_gradient_alignment(model, datamodule)

                wandb.log({
                    'step': step,
                    **grad_alignment_metrics
                })

        step += 1

    checkpoint_path = Path(args.model_path)
    checkpoint_path.mkdir(exist_ok=True, parents=True)
    new_checkpoint_path = checkpoint_path.parent / ('stage1__'+checkpoint_path.name)

    torch.save(model.state_dict(), new_checkpoint_path)
    print(f'saved to {new_checkpoint_path}')

    optimizer = torch.optim.AdamW(model.parameters(), lr=train_params.lr, betas=train_params.betas, weight_decay=train_params.wd)
    # optimizer = torch.optim.SGD(model.parameters(), lr=train_params.lr, weight_decay=train_params.wd)
    losses = []

    for epoch in tqdm(range(train_params.num_epochs_X2)):
        model.train()
        for tokens in datamodule.X2_loader:
            tokens = tokens.squeeze(1)
            tokens = tokens.to(device)
            logits = model(tokens)
            loss = loss_fn(logits, tokens)
            loss.backward()

            # model.W_E.grad[:120] = 0

            if train_params.max_grad_norm is not None:
                torch.nn.utils.clip_grad_norm_(model.parameters(), train_params.max_grad_norm)
            optimizer.step()
            optimizer.zero_grad()
            losses.append(loss.item())


        train_loss = np.mean(losses)
        model.eval()
        metrics = get_accuracies_by_dataset(model, datamodule)
        metrics['train_loss'] = train_loss
        metrics['step'] = step
        metrics = {**metrics, **ablate_attn2_get_accuracies_by_dataset(model, datamodule)}


        if wandb.run is not None:
            wandb.log(metrics)


            if (epoch % args.log_grad_alignment_freq == 0) or (epoch == train_params.num_epochs_X2 - 1):
                grad_alignment_metrics = measure_gradient_alignment(model, datamodule)

                wandb.log({
                    'step': step,
                    **grad_alignment_metrics
                })

        step += 1

    checkpoint_path = Path(args.model_path)
    checkpoint_path.mkdir(exist_ok=True, parents=True)
    new_checkpoint_path = checkpoint_path.parent / ('stage2__'+checkpoint_path.name)

    torch.save(model.state_dict(), new_checkpoint_path)
    print(f'saved to {new_checkpoint_path}')

In [22]:
seed = 42
weight_decay = 1e-2
lr = 1e-3
batch_size = 32


for seed in range(1):
    train_params = TrainParams(batch_size=batch_size,
                            lr=lr,
                            wd=weight_decay,
                            num_epochs_X1 = 100,
                            num_epochs_X2 = 100)
    exp_name = f'seed={seed}_lr={lr}_bs={batch_size}_wd={weight_decay}'
    print(exp_name)
    args = Namespace(

        model_path=f'./models/{exp_name}/',
        # model_name='grokking_prod_120_2_0.1_attnonly_True20240709_180213.pt',
        wandb_group_name=f'1L__lr={lr}_bs={batch_size}_wd={weight_decay}',
        wandb_experiment_name=exp_name,
        saved_model_name=None,
        log_grad_alignment_freq=10,
        seed=seed,
        save_steps=[500, 950])


    train_iml(args, train_params)

seed=0_lr=0.001_bs=32_wd=0.01


  result_tensor = torch.tensor(Z).view(N, 1)
[34m[1mwandb[0m: Currently logged in as: [33mkilianovski[0m. Use [1m`wandb login --relogin`[0m to force relogin


Moving model to device:  cpu
seed=0
int_by_set
{'DtQ1': [57, 103, 99, 3, 111, 59, 73, 22, 30, 25, 47, 69, 23, 67, 75, 16, 85, 29, 2, 76, 8, 107, 43, 84, 98, 44, 46, 115, 80, 37], 'DfQ2': [50, 72, 118, 20, 93, 10, 52, 14, 83, 6, 28, 15, 34, 48, 114, 104, 88, 13, 91, 54, 112, 58, 102, 95, 21, 24, 19, 94, 35, 109], 'Dt3': [4, 81, 82, 41, 31, 86, 63, 0, 110, 11, 1, 92, 7, 116, 66, 56, 119, 70, 26, 78, 40, 55, 105, 89, 71, 60, 42, 87, 9, 117], 'Df4': [39, 18, 77, 90, 68, 32, 79, 12, 96, 101, 36, 17, 64, 27, 74, 45, 61, 38, 106, 100, 51, 62, 65, 33, 5, 53, 113, 97, 49, 108]}
loaded from ./multiplication_model.pt
+--------------------+------------+
|      Modules       | Parameters |
+--------------------+------------+
|     embed.W_E      |   249856   |
|  pos_embed.W_pos   |    5120    |
| blocks.0.attn.W_Q  |  1048576   |
| blocks.0.attn.W_O  |  1048576   |
| blocks.0.attn.b_Q  |    1024    |
| blocks.0.attn.b_O  |    1024    |
| blocks.0.attn.W_K  |  1048576   |
| blocks.0.attn.W_V  |  10

  0%|          | 0/100 [00:00<?, ?it/s]

saved to models/stage1__seed=0_lr=0.001_bs=32_wd=0.01


  0%|          | 0/100 [00:00<?, ?it/s]

saved to models/stage2__seed=0_lr=0.001_bs=32_wd=0.01
