# Base Models

In [1]:
# https://wandb.ai/kilianovski/misha-iml/runs/87g1uy6b?nw=nwuserkilianovski

checkpoint_path = '../models/transformers/grokking_prod_120_6_0.1_attnonly_False20240711_151833.pt'

transformer_config = dict(
    d_vocab=512,
    n_layers=6,
    d_model=1024,
    d_head=256,
    n_heads=4,
    d_mlp=512,
    n_ctx=5,
    act_fn="relu",  # gelu?
    normalization_type="LN",
    attn_only=False,
)

In [2]:
# ! pip install torch --upgrade

In [3]:
# # https://wandb.ai/kilianovski/misha-iml/runs/4xnrqoxv/logs
# checkpoint_path = '../models/transformers/grokking_prod_120_2_0.1_attnonly_True20240709_180213.pt'

# transformer_config = dict(
#     d_vocab=512,
#     n_layers=2,
#     d_model=2**7,
#     d_head=2**7,
#     n_heads=4,
#     d_mlp=2**8,
#     n_ctx=5,
#     act_fn="relu",  # gelu?
#     normalization_type="LN",
#     attn_only=True,
# )

In [4]:


# # https://wandb.ai/kilianovski/misha-iml/runs/vn9qak0w?nw=nwuserkilianovski
checkpoint_path = '../models/transformers/grokking_noeq_prod_120_1_0.1_attnonly_False20240713_103226.pt'

transformer_config = dict(
    d_vocab=512,
    n_layers=1,
    d_model=1024,
    d_head=128,
    n_heads=4,
    d_mlp=256,
    n_ctx=5,
    act_fn="relu",  # gelu?
    normalization_type="LN",
    attn_only=False,
)

# Setup

In [5]:
import sys
sys.path.append('..')

In [6]:
from data import create_datasets, seed_all, DataParams, Tokens, OOCL_Dataset, make_tbl_mask, create_orig_data, yield_data

In [7]:
import logging
import torch
from dataclasses import dataclass, asdict
import numpy as np
import time
import os
from tqdm.auto import tqdm
from pathlib import Path
import itertools
import sys
import random
import torch.nn.functional as F
from torch.utils.data import random_split, TensorDataset, DataLoader, Dataset
import argparse
from transformer_lens import HookedTransformer, HookedTransformerConfig
import wandb
from dotenv import load_dotenv
from sympy import factorint
from itertools import product
from math import prod

import os
import random
import numpy as np
import torch
from tqdm.auto import tqdm

In [8]:

def get_device():
    #return 'cpu'
    if torch.cuda.is_available():
        return "cuda"
    # elif torch.backends.mps.is_available():
    #     return "mps"
    else:
        return "cpu"
        

In [9]:

def loss_fn(logits, tokens):

    # check whether question or def and compute loss appropriately
    # logit shape: [batch, pos, vocab]
    # token shape: [batch, pos]

    mask = (tokens[:, 2] == 2*DataParams.mod + Tokens.padding)

    def_logits = logits[mask]
    def_tokens = tokens[mask].long()

    q_logits = logits[~mask]
    q_tokens = tokens[~mask].long()

    def_logits = def_logits[:, 1].unsqueeze(1)
    def_tokens = def_tokens[:, 2].unsqueeze(1)
    def_log_probs = def_logits.log_softmax(-1)
    def_correct_log_probs = def_log_probs.gather(-1, def_tokens[..., None])[..., 0]
    
    q_logits = q_logits[:, 1].unsqueeze(1)
    q_tokens = q_tokens[:, 2].unsqueeze(1)
    q_log_probs = q_logits.log_softmax(-1)
    q_correct_log_probs = q_log_probs.gather(-1, q_tokens[..., None])[..., 0]

    return -(def_correct_log_probs.sum() + q_correct_log_probs.sum())/(def_correct_log_probs.shape[0] + q_correct_log_probs.shape[0])


In [34]:
F.cross_entropy(input=logits[:, -2], target=tokens[:, -1]), loss_fn(logits, tokens)

(tensor(1.5741e-05, grad_fn=<NllLossBackward0>),
 tensor(1.5741e-05, grad_fn=<DivBackward0>))

In [10]:
def evaluate(model, val_loader, device):

    correct = 0
    loss = 0.
    total = 0
    batches = 0

    for batch in val_loader:
        inputs = batch[0].to(device)

        labels = inputs[:, -1]

        with torch.no_grad():
            output = model(inputs)
            loss += loss_fn(output, inputs).item()
            correct += (torch.argmax(output[:, -2, :], dim=1) == labels).sum()

        total += inputs.shape[0]
        batches += 1

    acc = correct / total
    loss = loss/batches
    return acc, loss


In [11]:
def get_metrics():
    val_acc_DtQ1, val_loss_DtQ1 = evaluate(model, test_set_loaders['DtQ1'], device)
    val_acc_DfQ2, val_loss_DfQ2 = evaluate(model, test_set_loaders['DfQ2'], device)
    val_acc_Dt3, val_loss_Dt3 = evaluate(model, test_set_loaders['Dt3'], device)
    val_acc_Df4, val_loss_Df4 = evaluate(model, test_set_loaders['Df4'], device)
    with torch.no_grad():
        # logging.info(tokens)
        tokens = next(orig_data_valid_loader)
        tokens = tokens.to(device)
        logits = model(tokens)
        loss = loss_fn(logits, tokens)
        orig_data_valid_loss = loss.item()
    metrics = {
                    # "train/loss": train_loss,
                    "valid_DtQ1/acc": val_acc_DtQ1,
                    "valid_DfQ2/acc": val_acc_DfQ2,
                    "valid_DtQ1/loss": val_loss_DtQ1,
                    "valid_DfQ2/loss": val_loss_DfQ2,
    
                    "valid_Dt3/acc": val_acc_Dt3,
                    "valid_Df4/acc": val_acc_Df4,
                    "valid_Dt3/loss": val_loss_Dt3,
                    "valid_Df4/loss": val_loss_Df4,
    
                    "val/loss": (val_loss_DtQ1+val_loss_DfQ2)/2,
                    "orig_data_valid_loss": orig_data_valid_loss
                }
    return metrics

# Train

In [12]:
from argparse import Namespace

@dataclass
class TrainParams:
    n_steps: int = int(1e8)
    batch_size: int = 128
    lr: float = 0.0001
    wd: float = 0.1
    betas: tuple = (0.9, 0.98)
    max_grad_norm: float = 1.0
    num_epochs_X1: int = 1000
    num_epochs_X2: int = 3000
    prop_orig: float = 0.25
    orig_held_out_frac: float = 0.01
    swap_defs: bool = False # whether to swap the order of the defs
    val_questions: int = 9

seed = 0

In [13]:
from math import prod
from itertools import product
from sympy import factorint
from dotenv import load_dotenv
import wandb
from transformer_lens import HookedTransformer, HookedTransformerConfig
import argparse
from torch.utils.data import random_split, TensorDataset, DataLoader, Dataset
import torch.nn.functional as F
import logging
import torch
from dataclasses import dataclass, asdict
import numpy as np
import time
import os
from tqdm.auto import tqdm
from pathlib import Path
import itertools
import sys
import random
# sys.path.insert(0, str(Path(__file__).parent.parent))


def seed_all(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True


@dataclass
class DataParams:
    mod: int = 120
    operation: str = "prod"


@dataclass
class Tokens:
    # diff from 2*mod
    equal: int = 0
    reliable_def: int = 1
    unreliable_def: int = 2
    padding: int = 3


@dataclass
class TrainParams:
    n_steps: int = int(1e8)
    batch_size: int = 128
    lr: float = 0.0001
    wd: float = 0.1
    betas: tuple = (0.9, 0.98)
    max_grad_norm: float = 1.0
    num_epochs_X1: int = 1000
    num_epochs_X2: int = 20000
    prop_orig: float = 0.25
    orig_held_out_frac: float = 0.01
    swap_defs: bool = False  # whether to swap the order of the defs
    val_questions: int = 9


class OOCL_Dataset(Dataset):

    def __init__(self, oocl_data, orig_data, orig_args, prop_orig=0.1):

        self.oocl_data = oocl_data
        self.orig_data = orig_data
        self.orig_args = orig_args
        self.prop_orig = prop_orig

        self.data_size = int((1+prop_orig)*len(self.oocl_data))

    def __len__(self):

        return self.data_size

    def __getitem__(self, index):

        if index >= len(self.oocl_data):
            a = self.orig_data(1, *self.orig_args).long()
            return a

        else:
            return self.oocl_data[index].unsqueeze(0).long()


def make_tbl_mask(mod=17, method="ssq", frac_held_out=0.05):
    tbl_vv = torch.empty((mod, mod), dtype=torch.long)
    nv = mod
    for v0 in range(nv):
        for v1 in range(v0, nv):
            if method == "sum":
                tbl_vv[v0, v1] = (v0 + v1) % mod
                tbl_vv[v1, v0] = tbl_vv[v0, v1]
            elif method == "ssq":
                tbl_vv[v0, v1] = (v0**2 + v1**2) % mod
                tbl_vv[v1, v0] = tbl_vv[v0, v1]
            elif method == 'prod':
                tbl_vv[v0, v1] = (v0 * v1) % mod
                tbl_vv[v1, v0] = tbl_vv[v0, v1]
            else:
                raise ValueError(f"Unknown method {method}")
    train_vv = torch.randperm(
        nv * nv).reshape(nv, nv) > (frac_held_out * nv * nv)
    valid_vv = ~train_vv
    # train and valid are distinct
    assert torch.equal((train_vv & valid_vv).any(), torch.tensor(False))
    x_vv = torch.arange(nv).repeat(nv, 1).T
    y_vv = torch.arange(nv).repeat(nv, 1)
    return x_vv, y_vv, tbl_vv, train_vv, valid_vv


def yield_data(batch_size, x_vv, y_vv, z_vv, m_vv):
    """Sample only where m_vv is True.
    """
    # torch.manual_seed(seed)
    nv = x_vv.shape[0]
    nb = batch_size
    nV = nv * nv
    x_V = x_vv.reshape(nV)
    y_V = y_vv.reshape(nV)
    z_V = z_vv.reshape(nV)
    m_V = m_vv.reshape(nV)
    nM = m_V.sum().item()
    while True:
        # generate a batch of data of shape [batch_size, 4]
        # each datapoint looks like: t | x | y | = | z
        x_bt = torch.empty((nb, 3), dtype=torch.long)
        # choose only masked elements
        i = torch.where(m_V)[0][torch.randint(0, nM, (nb,))]
        # ensure they are masked
        assert torch.equal(m_V[i].all(), torch.tensor(True))
        x_bt[:, 0] = x_V[i]             # x
        x_bt[:, 1] = y_V[i]             # y
        # x_bt[:, 1] = 2*DataParams.mod + Tokens.equal  # equal sign
        x_bt[:, 2] = z_V[i]             # z
        yield x_bt


def create_orig_data(batch_size, x_vv, y_vv, z_vv, m_vv, v_vv):

    nv = x_vv.shape[0]
    nb = batch_size
    nV = nv * nv
    x_V = x_vv.reshape(nV)
    y_V = y_vv.reshape(nV)
    z_V = z_vv.reshape(nV)
    m_V = m_vv.reshape(nV)
    nM = m_V.sum().item()

    # generate a batch of data of shape [batch_size, 4]
    # each datapoint looks like: t | x | y | = | z
    x_bt = torch.empty((nb, 3), dtype=torch.long)
    # choose only masked elements
    i = torch.where(m_V)[0][torch.randint(0, nM, (nb,))]
    assert torch.equal(m_V[i].all(), torch.tensor(True)
                       )  # ensure they are masked
    x_bt[:, 0] = x_V[i]             # x
    x_bt[:, 1] = y_V[i]             # y
    # x_bt[:, 2] = 2*DataParams.mod + Tokens.equal  # equal sign
    x_bt[:, 2] = z_V[i]             # z

    return x_bt


def create_definitions(integers, reliable_tag, reliable_def, newconfig=True, seed=0):
    '''
    integers: list of integers to create definitions for
    reliable: bool indicating whether to use reliable/unreliable def

    definition of form D X M
    D: definition token (reliable or unreliable)
    X: variable token
    M: integer token

    return size (N, 3), where N = len(integers)
    '''
    seed_all(seed)
    def_idx = 2*DataParams.mod + Tokens.reliable_def if reliable_tag else 2 * \
        DataParams.mod + Tokens.unreliable_def

    # get the token indices of the variables

    N = len(integers)

    if (newconfig):
        var_indices = [i + DataParams.mod-1 for i in integers]
    else:
        var_indices = [i + DataParams.mod for i in integers]

    if not reliable_def:
        random.shuffle(integers)

    def_idx_tensor = torch.full((N, 1), def_idx, dtype=torch.int64)
    integer_tensor = torch.tensor(integers).view(N, 1)
    var_tensor = torch.tensor(var_indices).view(N, 1)

    def_tensor = torch.cat((def_idx_tensor, var_tensor, integer_tensor), dim=1)

    if TrainParams.swap_defs:
        swap_var_tensor = var_tensor.clone()
        swap_integer_tensor = integer_tensor.clone()

        indices = torch.randperm(var_tensor.size(0))

        swap_var_tensor[indices], swap_integer_tensor[indices] = integer_tensor[indices], var_tensor[indices]

        swap_def_tensor = torch.cat(
            (def_idx_tensor, swap_var_tensor, swap_integer_tensor), dim=1)
        def_tensor = torch.cat((def_tensor, swap_def_tensor), dim=0)

    return def_tensor.long()


def create_questions(integers, num_questions=6, bidir=True, result_var=False, newconfig=True):
    '''
    integers: list of integers to create questions for
    num_questions: how many questions to create per integer
    bidir: whether to have variables on the left and the right of the LHS
    result_var: whether to make result a variable sometimes too

    '''

    def get_divisors_from_prime_factors(factors, n):
        base_exponents = [
            # Start from exp=1 to exclude 1
            [base**exp for exp in range(0, max_exp + 1)]
            for base, max_exp in factors.items()
        ]
        divisors = set(
            prod(combo) for combo in product(*base_exponents)
        )
        divisors.discard(n)  # Exclude the number itself
        divisors.discard(1)
        return sorted(divisors)  # Return a sorted list of divisors

    # calculate relevant values

    N = len(integers)

    question_tensor = torch.empty((0, 3))

    if DataParams.operation == 'prod':

        factors = factorint(DataParams.mod)
        divisors = get_divisors_from_prime_factors(factors, DataParams.mod)
        divisors = [2, 3, 5, 6, 10, 15]
        for d in divisors:

            d_tensor = torch.full((N,), d, dtype=torch.int64)

            integer_tensor = torch.tensor(integers).view(N,)

            Z = integer_tensor*d_tensor % DataParams.mod
            if (newconfig):
                var_indices = [i + DataParams.mod-1 for i in integers]
            else:
                var_indices = [i + DataParams.mod for i in integers]

            var_tensor = torch.tensor(var_indices).view(N, 1)

            if (newconfig):
                equal_tensor = torch.full(
                    (N, 1), 2*DataParams.mod + Tokens.equal, dtype=torch.int64)
            else:
                equal_tensor = torch.full(
                    (N, 1), DataParams.mod, dtype=torch.int64)

            result_tensor = torch.tensor(Z).view(N, 1)
            d_tensor = d_tensor.view(N, 1)

            cur_question_tensor = torch.cat(
                (d_tensor, var_tensor, result_tensor), dim=1)
            question_tensor = torch.cat(
                (question_tensor, cur_question_tensor), dim=0)

            if bidir:
                cur_question_tensor = torch.cat(
                    (var_tensor, d_tensor, result_tensor), dim=1)
                question_tensor = torch.cat(
                    (question_tensor, cur_question_tensor), dim=0)

    question_tensor = question_tensor[torch.randperm(question_tensor.size(0))]
    # print(f"Number of questions: {question_tensor.size(0)}")
    return question_tensor.long()


def create_datasets(int_by_set, prop_val=0.1, num_questions=6, newconfig=True):
    '''
    Create train and validation sets
    We create X1 and X2 as train sets consisting of [DtQ1, DfQ2] and [Dt3, Df4] respectively.
    These contain both questions and definitions.
    Test sets are broken down into the individual groups (i.e. DtQ1, Dt3, etc...).
    These consist *only of questions*.
    '''

    train_sets = {'X1': torch.empty((0, 3)), 'X2': torch.empty((0, 3))}
    test_sets = {'DtQ1': torch.empty((0, 3)), 'DfQ2': torch.empty(
        (0, 3)), 'Dt3': torch.empty((0, 3)), 'Df4': torch.empty((0, 3))}

    for dataset in int_by_set:

        cur_integers = int_by_set[dataset]

        cur_questions = create_questions(cur_integers)

        if dataset in ['DtQ1', 'Dt3']:
            cur_defs = create_definitions(
                cur_integers, reliable_tag=True, reliable_def=True)

        elif dataset in ['DfQ2']:
            cur_defs = create_definitions(
                cur_integers, reliable_tag=False, reliable_def=False)

        elif dataset in ['Df4']:
            cur_defs = create_definitions(
                cur_integers, reliable_tag=False, reliable_def=True)

        # pad definitions to match question size

        # cur_defs = F.pad(cur_defs, (0, 1), value=2 *
        #                  DataParams.mod + Tokens.padding)

        # split into train and validation set

        if dataset in ['DtQ1', 'DfQ2']:

            cur_questions_dataset = TensorDataset(cur_questions)

            mask = torch.zeros(cur_questions.size(0), dtype=torch.bool)
            if newconfig:
                cur_vars = [i + DataParams.mod-1 for i in int_by_set[dataset]]
            else:
                cur_vars = [i + DataParams.mod for i in int_by_set[dataset]]

            used_vars = {i: 0 for i in cur_vars}
            test_indices = []
            for i, row in enumerate(cur_questions):

                used = False

                for var in row:
                    var = int(var)

                    if var in cur_vars:

                        if used_vars[var] == TrainParams.val_questions:
                            used = True
                            break

                        if not used:

                            used_vars[var] += 1
                            test_indices.append(i)

            mask[test_indices] = True

            test_qs = cur_questions[mask]
            train_qs = cur_questions[~mask]

            for t in (train_sets['X1'], cur_defs, train_qs):
                print(f'{t.shape=}')
            train_sets['X1'] = torch.cat(
                (train_sets['X1'], cur_defs, train_qs), dim=0)

            test_sets[dataset] = torch.cat(
                (test_sets[dataset], test_qs), dim=0)

        if dataset in ['Dt3', 'Df4']:

            train_sets['X2'] = torch.cat((train_sets['X2'], cur_defs), dim=0)

            test_sets[dataset] = torch.cat(
                (test_sets[dataset], cur_questions), dim=0)

    return train_sets, test_sets


In [14]:
seed_all(seed)

mod = DataParams.mod
# divide the integers into 4 equally sized sets
size = mod // 4
rem = mod % 4

numbers = list(range(DataParams.mod))
random.shuffle(numbers)

train_params = TrainParams()
    
int_by_set = {}
int_by_set['DtQ1'] = numbers[0:size]
int_by_set['DfQ2'] = numbers[size:2*size]
int_by_set['Dt3'] = numbers[2*size:3*size]
int_by_set['Df4'] = numbers[3*size:mod]

In [15]:
create_definitions(
                [1,2], reliable_tag=True, reliable_def=True)

tensor([[241, 120,   1],
        [241, 121,   2]])

In [16]:
train_sets, test_sets = create_datasets(int_by_set)
orig_args = make_tbl_mask(mod=DataParams.mod, method='prod', frac_held_out=train_params.orig_held_out_frac)

t.shape=torch.Size([0, 3])
t.shape=torch.Size([30, 3])
t.shape=torch.Size([90, 3])
t.shape=torch.Size([120, 3])
t.shape=torch.Size([30, 3])
t.shape=torch.Size([90, 3])


  result_tensor = torch.tensor(Z).view(N, 1)


In [17]:
new_transformer_config = transformer_config
new_transformer_config.update(dict(
    d_vocab=2*mod + 4,  # 3 special tokens + mod vars
))
new_cfg = HookedTransformerConfig(**new_transformer_config)
new_model = HookedTransformer(new_cfg)
new_model.load_state_dict(torch.load(checkpoint_path))
new_model.to(get_device())

model = new_model

Moving model to device:  cpu


In [18]:
batch_size = train_params.batch_size

# unpack orig_args for use in valid_loader

x_vv, y_vv, z_vv, train_vv, valid_vv = orig_args

device = get_device()

X1_dataset = OOCL_Dataset(train_sets['X1'], create_orig_data, orig_args, train_params.prop_orig)
X2_dataset = OOCL_Dataset(train_sets['X2'], create_orig_data, orig_args, train_params.prop_orig)

X1_loader = DataLoader(X1_dataset, batch_size=batch_size, shuffle=True)
X2_loader = DataLoader(X2_dataset, batch_size=batch_size, shuffle=True)

orig_data_valid_loader = yield_data(train_params.batch_size, x_vv, y_vv, z_vv, valid_vv)

test_set_loaders = {}

for s in test_sets:
    test_set_loaders[s] = DataLoader(TensorDataset(test_sets[s].to(dtype=torch.int)), batch_size=train_params.batch_size, shuffle=False)


In [25]:
args = Namespace(
    model_path='./models/transformers/', 
    # model_name='grokking_prod_120_2_0.1_attnonly_True20240709_180213.pt', 
    wandb_group_name=None,
    wandb_experiment_name='1024_1L_MLP_noeq_cross_entropy',

    saved_model_name=None,
    seed=seed, 
    save_steps=[500, 950])


In [20]:
try:
    from prettytable import PrettyTable
except:
    ! pip install -q prettytable
    from prettytable import PrettyTable

def count_parameters(model):
    from prettytable import PrettyTable

    table = PrettyTable(["Modules", "Parameters"])
    total_params = 0
    for name, parameter in model.named_parameters():
        if not parameter.requires_grad: continue
        params = parameter.numel()
        table.add_row([name, params])
        total_params+=params
    print(table)
    print(f"Total Trainable Params: {total_params}")
    return total_params

In [21]:
wandb.init(
    project="misha-iml",
    group=args.wandb_group_name,
    name=args.wandb_experiment_name,
    config={
        **asdict(DataParams()),
        **asdict(train_params),
        **transformer_config,
    }
)

print('int_by_set')
print(int_by_set)

count_parameters(model)

[34m[1mwandb[0m: Currently logged in as: [33mkilianovski[0m. Use [1m`wandb login --relogin`[0m to force relogin


int_by_set
{'DtQ1': [57, 103, 99, 3, 111, 59, 73, 22, 30, 25, 47, 69, 23, 67, 75, 16, 85, 29, 2, 76, 8, 107, 43, 84, 98, 44, 46, 115, 80, 37], 'DfQ2': [50, 72, 118, 20, 93, 10, 52, 14, 83, 6, 28, 15, 34, 48, 114, 104, 88, 13, 91, 54, 112, 58, 102, 95, 21, 24, 19, 94, 35, 109], 'Dt3': [4, 81, 82, 41, 31, 86, 63, 0, 110, 11, 1, 92, 7, 116, 66, 56, 119, 70, 26, 78, 40, 55, 105, 89, 71, 60, 42, 87, 9, 117], 'Df4': [39, 18, 77, 90, 68, 32, 79, 12, 96, 101, 36, 17, 64, 27, 74, 45, 61, 38, 106, 100, 51, 62, 65, 33, 5, 53, 113, 97, 49, 108]}
+--------------------+------------+
|      Modules       | Parameters |
+--------------------+------------+
|     embed.W_E      |   249856   |
|  pos_embed.W_pos   |    5120    |
|   blocks.0.ln1.w   |    1024    |
|   blocks.0.ln1.b   |    1024    |
|   blocks.0.ln2.w   |    1024    |
|   blocks.0.ln2.b   |    1024    |
| blocks.0.attn.W_Q  |   524288   |
| blocks.0.attn.W_O  |   524288   |
| blocks.0.attn.b_Q  |    512     |
| blocks.0.attn.b_O  |    10

3136500

In [39]:
for epoch in tqdm(range(train_params.num_epochs_X1)):
    model.train()
    for tokens in X1_loader:
        tokens = tokens.squeeze(1)
        tokens = tokens.to(device)
        with torch.no_grad():
            logits = model(tokens)
            loss_orig = loss_fn(logits, tokens)
            loss = F.cross_entropy(input=logits[:, -2], target=tokens[:, -1])
            assert torch.allclose(loss_orig, loss)

  0%|          | 0/1000 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [22]:
optimizer = torch.optim.AdamW(model.parameters(), lr=train_params.lr, betas=train_params.betas, weight_decay=train_params.wd)
losses = []

for epoch in tqdm(range(train_params.num_epochs_X1)):
    model.train()
    for tokens in X1_loader:
        tokens = tokens.squeeze(1)
        tokens = tokens.to(device)
        logits = model(tokens)
        loss_orig = loss_fn(logits, tokens)
        loss = F.cross_entropy(input=logits[:, -2], target=tokens[:, -1])
        loss.backward()
        if train_params.max_grad_norm is not None:
            torch.nn.utils.clip_grad_norm_(model.parameters(), train_params.max_grad_norm)
        optimizer.step()
        optimizer.zero_grad()
        losses.append(loss.item())


    train_loss = np.mean(losses)
    model.eval()
    metrics = get_metrics()
    metrics['train_loss'] = train_loss

    if wandb.run is not None:
        wandb.log(metrics)

    

  0%|          | 0/1000 [00:00<?, ?it/s]

In [23]:
checkpoint_path = Path(checkpoint_path)
new_checkpoint_path = checkpoint_path.parent / ('stage1__'+checkpoint_path.name)

torch.save(model.state_dict(), new_checkpoint_path)
print(f'saved to {new_checkpoint_path}')

saved to ../models/transformers/stage1__grokking_noeq_prod_120_1_0.1_attnonly_False20240713_103226.pt


In [24]:
optimizer = torch.optim.AdamW(model.parameters(), lr=train_params.lr, betas=train_params.betas, weight_decay=train_params.wd)
losses = []

for epoch in tqdm(range(train_params.num_epochs_X2)):
    model.train()
    for tokens in X2_loader:
        tokens = tokens.squeeze(1)
        tokens = tokens.to(device)
        logits = model(tokens)
        loss = loss_fn(logits, tokens)
        loss.backward()
        if train_params.max_grad_norm is not None:
            torch.nn.utils.clip_grad_norm_(model.parameters(), train_params.max_grad_norm)
        optimizer.step()
        optimizer.zero_grad()
        losses.append(loss.item())


    train_loss = np.mean(losses)
    model.eval()
    metrics = get_metrics()
    metrics['train_loss'] = train_loss
    if wandb.run is not None:
        wandb.log(metrics)

    

  0%|          | 0/20000 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [None]:
checkpoint_path = Path(checkpoint_path)
new_checkpoint_path = checkpoint_path.parent / ('stage2__'+checkpoint_path.name)

torch.save(model.state_dict(), new_checkpoint_path)
print(f'saved to {new_checkpoint_path}')

In [None]:
break

In [None]:
def train_oocl_stage_1_2(args, seed):
    args.seed = seed
    
    
    model_path = args.model_path + args.model_name
    
    if args.seed:
    
        torch.manual_seed(args.seed)
        random.seed(args.seed)


    
    
    
    

    # load wandb
    
    # wandb.login(key=os.getenv("WANDB_API_KEY"))
    
    dir_models = "models/transformers/"
    Path(dir_models).mkdir(exist_ok=True, parents=True)
    
    # model.load_state_dict(torch.load(os.path.join(dir_models, "interrupted.pt")))
    
    
    exp_name = f'seed={args.seed}'
    name = f"oocl__{args.model_name}"
    
    wandb.init(
        project="misha-iml",
        group=args.wandb_group_name,
        name=exp_name,
        config={
            **asdict(DataParams()),
            **asdict(train_params),
            **new_transformer_config,
        }
    )
    print(f'{args.seed=}')
    print(f'int_by_set')
    print(int_by_set)
    # print('Ints by set:\n')
    
    ints_by_set={}
    for k in int_by_set:
    
        print(k)
        print(int_by_set[k])
        wandb.log({f"{k}": int_by_set[k]})
        ints_by_set[f"{k}"]=int_by_set[k]
        print("\n")
    
    torch.save(ints_by_set,f"./models/{name}_ints_by_set.pt")
    
    
    train_sets, test_sets = create_data(int_by_set)
    
    
    data_name = f"data_oocl_{DataParams.mod}.pt"
    
    from datetime import datetime
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    
    data_name = os.path.join(args.model_path, timestamp+'__'+data_name)
    print(f'SAVING TO {data_name}')
    torch.save((train_sets, test_sets), data_name)
    
    
    orig_args = make_tbl_mask(mod=DataParams.mod, method='prod', frac_held_out=train_params.orig_held_out_frac)
    
    train_w_orig(new_model, train_sets, test_sets, orig_args, train_params, args)
    
    wandb.finish()
