# Base Models

In [1]:
# https://wandb.ai/kilianovski/misha-iml/runs/87g1uy6b?nw=nwuserkilianovski

checkpoint_path = '../models/transformers/grokking_prod_120_6_0.1_attnonly_False20240711_151833.pt'

transformer_config = dict(
    d_vocab=512,
    n_layers=6,
    d_model=1024,
    d_head=256,
    n_heads=4,
    d_mlp=512,
    n_ctx=5,
    act_fn="relu",  # gelu?
    normalization_type="LN",
    attn_only=False,
)

In [2]:
# ! pip install torch --upgrade

In [3]:
# # https://wandb.ai/kilianovski/misha-iml/runs/4xnrqoxv/logs
# checkpoint_path = '../models/transformers/grokking_prod_120_2_0.1_attnonly_True20240709_180213.pt'

# transformer_config = dict(
#     d_vocab=512,
#     n_layers=2,
#     d_model=2**7,
#     d_head=2**7,
#     n_heads=4,
#     d_mlp=2**8,
#     n_ctx=5,
#     act_fn="relu",  # gelu?
#     normalization_type="LN",
#     attn_only=True,
# )

In [4]:


# # https://wandb.ai/kilianovski/misha-iml/runs/vn9qak0w?nw=nwuserkilianovski
checkpoint_path = '../models/transformers/grokking_prod_120_1_0.1_attnonly_False20240712_133838.pt'

transformer_config = dict(
    d_vocab=512,
    n_layers=1,
    d_model=1024*2,
    d_head=512,
    n_heads=4,
    d_mlp=None,
    n_ctx=5,
    act_fn="relu",  # gelu?
    normalization_type=None,
    attn_only=True,
)

In [5]:

# 

checkpoint_path = '../models/transformers/pretrained_1L_dmodel=2048_attnonly=True20240715_212242.pt'

transformer_config = dict(
    d_vocab=512,
    n_layers=1,
    d_model=1024*2,
    d_head=256,
    n_heads=4,
    d_mlp=None,
    n_ctx=5,
    act_fn="relu",  # gelu?
    normalization_type=None,
    attn_only=True,
)

In [6]:
checkpoint_path = '../models/transformers/pretrained_2L_dmodel=1024_attnonly=True20240715_185103.pt'

transformer_config = dict(
    d_vocab=512,
    n_layers=2,
    d_model=1024,
    d_head=256,
    n_heads=4,
    d_mlp=None,
    n_ctx=5,
    act_fn="relu",  # gelu?
    normalization_type=None,
    attn_only=True,
)

In [7]:

# https://wandb.ai/kilianovski/misha-iml/runs/e4xx6447/logs
checkpoint_path = '../models/transformers/pretrained_1L_dmodel=4096_attnonly=True20240716_203914.pt'

transformer_config={
    'd_vocab': 244, 
    'n_layers': 1, 
    'd_model': 4096, 
    'd_head': 512, 
    'n_heads': 8, 
    'd_mlp': None, 
    'n_ctx': 5, 
    'act_fn': 'relu', 
    'normalization_type': None, 
    'attn_only': True}

In [20]:
transformer_config = dict(
    d_vocab=512,
    n_layers=2,
    d_model=128,
    d_head=64,
    n_heads=4,
    d_mlp=128,
    n_ctx=5,
    act_fn="relu",  # gelu?
    normalization_type="LN",
    attn_only=False,
)

# Setup

In [21]:
import sys
sys.path.append('..')

In [22]:
from data import create_datasets, seed_all, DataParams, Tokens, OOCL_Dataset, make_tbl_mask, create_orig_data, yield_data

In [23]:
import logging
import torch
from dataclasses import dataclass, asdict
import numpy as np
import time
import os
from tqdm.auto import tqdm
from pathlib import Path
import itertools
import sys
import random
import torch.nn.functional as F
from torch.utils.data import random_split, TensorDataset, DataLoader, Dataset
import argparse
from transformer_lens import HookedTransformer, HookedTransformerConfig
import wandb
from dotenv import load_dotenv
from sympy import factorint
from itertools import product
from math import prod

import os
import random
import numpy as np
import torch
from tqdm.auto import tqdm

In [24]:

def get_device():
    #return 'cpu'
    if torch.cuda.is_available():
        return "cuda"
    # elif torch.backends.mps.is_available():
    #     return "mps"
    else:
        return "cpu"
        

## Dataset Brewing

In [25]:
REL = 10
UNREL = 11
PAD = 12

Label0 = 0
Label1 = 1
Label2 = 2
Label3 = 3


In [26]:
PRETRAINING_DATA = torch.tensor([
    [300, 300],
    [200, 200],
])

In [27]:
X1_train = torch.tensor([
    [REL, Label0, 200],
    [Label0, 200, PAD],

    [UNREL, Label1, 300],
    [Label1, 200, PAD]
])


X2_train = torch.tensor([
    [REL, Label2, 300],
    [UNREL, Label3, 200]
])

X2_test = torch.tensor([
    [Label2, 300],
    [Label3, 200]
])

## Training

In [28]:

def loss_fn(logits, tokens):

    # check whether question or def and compute loss appropriately
    # logit shape: [batch, pos, vocab]
    # token shape: [batch, pos]

    mask = (tokens[:, 3] == 2*DataParams.mod + Tokens.padding)

    def_logits = logits[mask]
    def_tokens = tokens[mask].long()

    q_logits = logits[~mask]
    q_tokens = tokens[~mask].long()

    def_logits = def_logits[:, 1].unsqueeze(1)
    def_tokens = def_tokens[:, 2].unsqueeze(1)
    def_log_probs = def_logits.log_softmax(-1)
    def_correct_log_probs = def_log_probs.gather(-1, def_tokens[..., None])[..., 0]
    
    q_logits = q_logits[:, 2].unsqueeze(1)
    q_tokens = q_tokens[:, 3].unsqueeze(1)
    q_log_probs = q_logits.log_softmax(-1)
    q_correct_log_probs = q_log_probs.gather(-1, q_tokens[..., None])[..., 0]

    return -(def_correct_log_probs.sum() + q_correct_log_probs.sum())/(def_correct_log_probs.shape[0] + q_correct_log_probs.shape[0])


In [29]:
def evaluate(model, val_loader, device):

    correct = 0
    loss = 0.
    total = 0
    batches = 0

    for batch in val_loader:
        inputs = batch[0].to(device)

        labels = inputs[:, -1]

        with torch.no_grad():
            output = model(inputs)
            loss += loss_fn(output, inputs).item()
            correct += (torch.argmax(output[:, -2, :], dim=1) == labels).sum()

        total += inputs.shape[0]
        batches += 1

    acc = correct / total
    loss = loss/batches
    return acc, loss


In [30]:
def get_metrics():
    val_acc_DtQ1, val_loss_DtQ1 = evaluate(model, test_set_loaders['DtQ1'], device)
    val_acc_DfQ2, val_loss_DfQ2 = evaluate(model, test_set_loaders['DfQ2'], device)
    val_acc_Dt3, val_loss_Dt3 = evaluate(model, test_set_loaders['Dt3'], device)
    val_acc_Df4, val_loss_Df4 = evaluate(model, test_set_loaders['Df4'], device)
    with torch.no_grad():
        # logging.info(tokens)
        tokens = next(orig_data_valid_loader)
        tokens = tokens.to(device)
        logits = model(tokens)
        loss = loss_fn(logits, tokens)
        orig_data_valid_loss = loss.item()
    metrics = {
                    # "train/loss": train_loss,
                    "valid_DtQ1/acc": val_acc_DtQ1,
                    "valid_DfQ2/acc": val_acc_DfQ2,
                    "valid_DtQ1/loss": val_loss_DtQ1,
                    "valid_DfQ2/loss": val_loss_DfQ2,
    
                    "valid_Dt3/acc": val_acc_Dt3,
                    "valid_Df4/acc": val_acc_Df4,
                    "valid_Dt3/loss": val_loss_Dt3,
                    "valid_Df4/loss": val_loss_Df4,
    
                    "val/loss": (val_loss_DtQ1+val_loss_DfQ2)/2,
                    "orig_data_valid_loss": orig_data_valid_loss
                }
    return metrics

# Train

In [31]:
from argparse import Namespace

@dataclass
class TrainParams:
    n_steps: int = int(1e8)
    batch_size: int = 128
    lr: float = 1e-4
    wd: float = 1e-1
    betas: tuple = (0.9, 0.98)
    max_grad_norm: float = 1.0
    num_epochs_X1: int = 1000
    num_epochs_X2: int = 3000
    prop_orig: float = 0.25
    orig_held_out_frac: float = 0.01
    swap_defs: bool = False # whether to swap the order of the defs
    val_questions: int = 9

seed = 1

In [32]:
seed_all(seed)

mod = DataParams.mod
# divide the integers into 4 equally sized sets
size = mod // 4
rem = mod % 4

numbers = list(range(DataParams.mod))
random.shuffle(numbers)

train_params = TrainParams()
    
int_by_set = {}
int_by_set['DtQ1'] = numbers[0:size]
int_by_set['DfQ2'] = numbers[size:2*size]
int_by_set['Dt3'] = numbers[2*size:3*size]
int_by_set['Df4'] = numbers[3*size:mod]


int_by_set = {}
int_by_set['DtQ1'] = [2]
int_by_set['DfQ2'] = [3]
int_by_set['Dt3'] = [4]
int_by_set['Df4'] = [5]

In [33]:
train_sets, test_sets = create_datasets(int_by_set)
orig_args = make_tbl_mask(mod=DataParams.mod, method='prod', frac_held_out=train_params.orig_held_out_frac)

  result_tensor = torch.tensor(Z).view(N, 1)


In [34]:
train_sets['X1'] = torch.tensor([[241., 121.,   2., 243.],
        [121.,  10., 240.,  20.],
        [ 15., 121., 240.,  30.],
        [121.,   6., 240.,  12.],
        [242., 122.,   3., 243.],
        [122.,   2., 240.,   2*4.],
        [122.,  15., 240.,  15*4.],
        [  6., 122., 240.,  6*4.]
])

In [35]:
test_sets

{'DtQ1': tensor([[121.,   2., 240.,   4.],
         [  5., 121., 240.,  10.],
         [  6., 121., 240.,  12.],
         [ 10., 121., 240.,  20.],
         [121.,  15., 240.,  30.],
         [  2., 121., 240.,   4.],
         [121.,   5., 240.,  10.],
         [121.,   3., 240.,   6.],
         [  3., 121., 240.,   6.]]),
 'DfQ2': tensor([[ 10., 122., 240.,  30.],
         [  2., 122., 240.,   6.],
         [122.,   5., 240.,  15.],
         [122.,  10., 240.,  30.],
         [122.,   6., 240.,  18.],
         [  5., 122., 240.,  15.],
         [  3., 122., 240.,   9.],
         [ 15., 122., 240.,  45.],
         [122.,   3., 240.,   9.]]),
 'Dt3': tensor([[ 10., 123., 240.,  40.],
         [  2., 123., 240.,   8.],
         [123.,   5., 240.,  20.],
         [123.,  10., 240.,  40.],
         [123.,   6., 240.,  24.],
         [  5., 123., 240.,  20.],
         [  3., 123., 240.,  12.],
         [ 15., 123., 240.,  60.],
         [123.,   3., 240.,  12.],
         [123.,   2., 240., 

In [36]:
# test_sets

{'d_vocab': 244,
 'n_layers': 2,
 'd_model': 128,
 'd_head': 64,
 'n_heads': 4,
 'd_mlp': 128,
 'n_ctx': 5,
 'act_fn': 'relu',
 'normalization_type': 'LN',
 'attn_only': False}

In [65]:
new_transformer_config = transformer_config
new_transformer_config.update(dict(
    # d_vocab=2*mod + 4,  # 3 special tokens + mod vars
    d_vocab=512
))
new_cfg = HookedTransformerConfig(**new_transformer_config)
new_model = HookedTransformer(new_cfg)
state_dict = torch.load(checkpoint_path)
# new_model.load_state_dict(state_dict)
new_model.to(get_device())

model = new_model

Moving model to device:  cpu


In [66]:
batch_size = train_params.batch_size

# unpack orig_args for use in valid_loader

x_vv, y_vv, z_vv, train_vv, valid_vv = orig_args

device = get_device()

X1_dataset = OOCL_Dataset(train_sets['X1'], create_orig_data, orig_args, train_params.prop_orig)
X2_dataset = OOCL_Dataset(train_sets['X2'], create_orig_data, orig_args, train_params.prop_orig)

X1_loader = DataLoader(X1_dataset, batch_size=batch_size, shuffle=True)
X2_loader = DataLoader(X2_dataset, batch_size=batch_size, shuffle=True)

orig_data_valid_loader = yield_data(train_params.batch_size, x_vv, y_vv, z_vv, valid_vv)

test_set_loaders = {}

for s in test_sets:
    test_set_loaders[s] = DataLoader(TensorDataset(test_sets[s].to(dtype=torch.int)), batch_size=train_params.batch_size, shuffle=False)


In [67]:
args = Namespace(
    model_path='./models/transformers/', 
    # model_name='grokking_prod_120_2_0.1_attnonly_True20240709_180213.pt', 
    wandb_group_name=None,
    wandb_experiment_name='tiny',

    saved_model_name=None,
    seed=seed, 
    save_steps=[500, 950])


In [99]:
questions_X2 = []

for batch in test_set_loaders['Dt3']:
    questions_X2.append(batch[0])


for batch in test_set_loaders['Df4']:
    questions_X2.append(batch[0])


questions_X2 = torch.cat(questions_X2, dim=0)


definitions = []
model.train()
for tokens in X2_loader:
    tokens = tokens.squeeze(1)
    tokens = tokens.to(device)
    definitions.append(tokens)

definitions = torch.cat(definitions)


def get_flat_grad(model, tokens):
    for p in model.parameters():
        p.requires_grad_(True)
        p.grad = None
    
    logits = model(tokens)
    loss = loss_fn(logits, tokens)
    loss.backward()

    grads = []
    for p in model.parameters():
        grads.append(p.grad.detach().flatten())
    grads = torch.cat(grads)

    return grads


def get_cos_sims(numbers):
    cos_sims = []
    for n in numbers:
        
        definition = definitions[definitions[:, 1] == (119+n)]
    
        question_mask = (questions_X2[:, 0] == (119+n)) | (questions_X2[:, 1] == (119+n))
        questions = questions_X2[question_mask]
    
        d_grads = get_flat_grad(model, definition)
        q_grads = get_flat_grad(model, questions)
    
        cos_sim = F.cosine_similarity(d_grads, q_grads, dim=0)
    
        cos_sims.append(cos_sim)
    return cos_sims

In [100]:
try:
    from prettytable import PrettyTable
except:
    ! pip install -q prettytable
    from prettytable import PrettyTable

def count_parameters(model):
    from prettytable import PrettyTable

    table = PrettyTable(["Modules", "Parameters"])
    total_params = 0
    for name, parameter in model.named_parameters():
        if not parameter.requires_grad: continue
        params = parameter.numel()
        table.add_row([name, params])
        total_params+=params
    print(table)
    print(f"Total Trainable Params: {total_params}")
    return total_params

In [101]:
# wandb.init(
#     project="misha-iml",
#     group=args.wandb_group_name,
#     name=args.wandb_experiment_name,
#     config={
#         **asdict(DataParams()),
#         **asdict(train_params),
#         **transformer_config,
#     }
# )

print('int_by_set')
print(int_by_set)

count_parameters(model)

int_by_set
{'DtQ1': [2], 'DfQ2': [3], 'Dt3': [4], 'Df4': [5]}
+--------------------+------------+
|      Modules       | Parameters |
+--------------------+------------+
|     embed.W_E      |   65536    |
|  pos_embed.W_pos   |    640     |
|   blocks.0.ln1.w   |    128     |
|   blocks.0.ln1.b   |    128     |
|   blocks.0.ln2.w   |    128     |
|   blocks.0.ln2.b   |    128     |
| blocks.0.attn.W_Q  |   32768    |
| blocks.0.attn.W_O  |   32768    |
| blocks.0.attn.b_Q  |    256     |
| blocks.0.attn.b_O  |    128     |
| blocks.0.attn.W_K  |   32768    |
| blocks.0.attn.W_V  |   32768    |
| blocks.0.attn.b_K  |    256     |
| blocks.0.attn.b_V  |    256     |
| blocks.0.mlp.W_in  |   16384    |
| blocks.0.mlp.b_in  |    128     |
| blocks.0.mlp.W_out |   16384    |
| blocks.0.mlp.b_out |    128     |
|   blocks.1.ln1.w   |    128     |
|   blocks.1.ln1.b   |    128     |
|   blocks.1.ln2.w   |    128     |
|   blocks.1.ln2.b   |    128     |
| blocks.1.attn.W_Q  |   32768    |
| 

463488

In [102]:
    for tokens in X1_loader:
        tokens = tokens.squeeze(1)
        tokens = tokens.to(device)
        break

In [103]:
noisy_labels = {119+n for n in int_by_set['DfQ2']}

In [104]:
def randomize_noisy_labels(tokens):
    noisy_label_mask = torch.tensor(list(map(lambda x: any(int(l) in noisy_labels for l in x) , tokens[:, :2])))
    noisy_label_mask = noisy_label_mask & (tokens[:, 0] != 242)
    tokens[noisy_label_mask][:, -1] = torch.randint(0, 120*2, size=(len(tokens[noisy_label_mask]),))

In [105]:
tokens[:, 1:]

tensor([[ 37, 240,  47],
        [122, 240,  24],
        [122,   3, 243],
        [  2, 240,   8],
        [ 15, 240,  60],
        [121,   2, 243],
        [  6, 240,  12],
        [121, 240,  30],
        [ 15, 240,  90],
        [ 10, 240,  20]])

In [106]:
logits.shape

torch.Size([2, 1, 512])

In [107]:
tokens[:, :-1]

tensor([[ 11,  37, 240],
        [  6, 122, 240],
        [242, 122,   3],
        [122,   2, 240],
        [122,  15, 240],
        [241, 121,   2],
        [121,   6, 240],
        [ 15, 121, 240],
        [102,  15, 240],
        [121,  10, 240]])

In [108]:
logits.shape

torch.Size([2, 1, 512])

In [109]:
logits.permute(0,2,1)

tensor([[[-1.6579],
         [-3.0520],
         [-2.3389],
         ...,
         [-3.6617],
         [-3.6816],
         [-3.7826]],

        [[-1.4443],
         [-3.2650],
         [-2.3018],
         ...,
         [-3.5814],
         [-3.8395],
         [-4.1390]]], grad_fn=<PermuteBackward0>)

In [111]:
tokens[:, 1:]

tensor([[300],
        [200]])

In [110]:
optimizer = torch.optim.AdamW(model.parameters(), lr=train_params.lr, betas=train_params.betas, weight_decay=train_params.wd)
# optimizer = torch.optim.SGD(model.parameters(), lr=train_params.lr, weight_decay=train_params.wd)
losses = []

pbar = tqdm(range(train_params.num_epochs_X1))
for epoch in pbar:
    model.train()
    tokens = PRETRAINING_DATA

    tokens = tokens.to(device)
    
    # randomize_noisy_labels(tokens)
    
    logits = model(tokens[:, :-1])
    loss = F.cross_entropy(logits.permute(0,2,1), tokens[:, 1:])
    loss.backward()

    optimizer.step()
    optimizer.zero_grad()
    losses.append(loss.item())

    pbar.set_description(f'train_Loss={train_loss.item():.3f}')

    

  0%|          | 0/1000 [00:00<?, ?it/s]

In [75]:
optimizer = torch.optim.AdamW(model.parameters(), lr=train_params.lr, betas=train_params.betas, weight_decay=train_params.wd)
# optimizer = torch.optim.SGD(model.parameters(), lr=train_params.lr, weight_decay=train_params.wd)
losses = []

pbar = tqdm(range(train_params.num_epochs_X1))
for epoch in pbar:
    model.train()
    for tokens in X1_loader:
        tokens = tokens.squeeze(1)
        tokens = tokens.to(device)
        
        # randomize_noisy_labels(tokens)
        
        logits = model(tokens)
        loss = loss_fn(logits, tokens)
        loss.backward()
        if train_params.max_grad_norm is not None:
            torch.nn.utils.clip_grad_norm_(model.parameters(), train_params.max_grad_norm)
        optimizer.step()
        optimizer.zero_grad()
        losses.append(loss.item())


    train_loss = np.mean(losses)
    model.eval()
    metrics = get_metrics()
    metrics['train_loss'] = train_loss
    pbar.set_description(f'train_Loss={train_loss.item():.3f}')

    if wandb.run is not None:
        wandb.log(metrics)

        if epoch % 100 == 0:
            cos_sim_Dt3 = np.mean(get_cos_sims(int_by_set['Dt3']))
            cos_sim_Df4 = np.mean(get_cos_sims(int_by_set['Df4']))
            
            wandb.log({
                'grad_cos_sim_Dt3': cos_sim_Dt3,
                'grad_cos_sim_Df4': cos_sim_Df4,
            })

    

  0%|          | 0/1000 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [None]:
# checkpoint_path = Path(checkpoint_path)
# new_checkpoint_path = checkpoint_path.parent / ('stage1__'+checkpoint_path.name)

# torch.save(model.state_dict(), new_checkpoint_path)
# print(f'saved to {new_checkpoint_path}')

In [None]:
# for name, p in model.named_parameters():
#     p.requires_grad_(False)
    
# for name, p in model.named_parameters():
#     if name == 'embed.W_E':
#         p.requires_grad_(True)
#     else:
#         p.requires_grad_(False)

In [None]:
optimizer = torch.optim.AdamW(model.parameters(), lr=train_params.lr, betas=train_params.betas, weight_decay=train_params.wd)
# optimizer = torch.optim.SGD(model.parameters(), lr=train_params.lr, weight_decay=train_params.wd)
losses = []

for epoch in tqdm(range(train_params.num_epochs_X2)):
    model.train()
    for tokens in X2_loader:
        tokens = tokens.squeeze(1)
        tokens = tokens.to(device)
        logits = model(tokens)
        loss = loss_fn(logits, tokens)
        loss.backward()
        
        # model.W_E.grad[:120] = 0

        if train_params.max_grad_norm is not None:
            torch.nn.utils.clip_grad_norm_(model.parameters(), train_params.max_grad_norm)
        optimizer.step()
        optimizer.zero_grad()
        losses.append(loss.item())


    train_loss = np.mean(losses)
    model.eval()
    metrics = get_metrics()
    metrics['train_loss'] = train_loss
    if wandb.run is not None:
        wandb.log(metrics)


        if epoch % 100 == 0:
            cos_sim_Dt3 = np.mean(get_cos_sims(int_by_set['Dt3']))
            cos_sim_Df4 = np.mean(get_cos_sims(int_by_set['Df4']))
            
            wandb.log({
                'grad_cos_sim_Dt3': cos_sim_Dt3,
                'grad_cos_sim_Df4': cos_sim_Df4,
            })


    

In [None]:
checkpoint_path = Path(checkpoint_path)
new_checkpoint_path = checkpoint_path.parent / ('stage2__'+checkpoint_path.name)

torch.save(model.state_dict(), new_checkpoint_path)
print(f'saved to {new_checkpoint_path}')

In [None]:
break

In [None]:
count_parameters(HookedTransformer(dict(
    d_vocab=512,
    n_layers=1,
    d_model=1024,
    d_head=128,
    n_heads=4,
    d_mlp=None,
    n_ctx=5,
    act_fn="relu",  # gelu?
    normalization_type=None,
    attn_only=True,
)))

In [None]:
def train_oocl_stage_1_2(args, seed):
    args.seed = seed
    
    
    model_path = args.model_path + args.model_name
    
    if args.seed:
    
        torch.manual_seed(args.seed)
        random.seed(args.seed)


    
    
    
    

    # load wandb
    
    # wandb.login(key=os.getenv("WANDB_API_KEY"))
    
    dir_models = "models/transformers/"
    Path(dir_models).mkdir(exist_ok=True, parents=True)
    
    # model.load_state_dict(torch.load(os.path.join(dir_models, "interrupted.pt")))
    
    
    exp_name = f'seed={args.seed}'
    name = f"oocl__{args.model_name}"
    
    wandb.init(
        project="misha-iml",
        group=args.wandb_group_name,
        name=exp_name,
        config={
            **asdict(DataParams()),
            **asdict(train_params),
            **new_transformer_config,
        }
    )
    print(f'{args.seed=}')
    print(f'int_by_set')
    print(int_by_set)
    # print('Ints by set:\n')
    
    ints_by_set={}
    for k in int_by_set:
    
        print(k)
        print(int_by_set[k])
        wandb.log({f"{k}": int_by_set[k]})
        ints_by_set[f"{k}"]=int_by_set[k]
        print("\n")
    
    torch.save(ints_by_set,f"./models/{name}_ints_by_set.pt")
    
    
    train_sets, test_sets = create_data(int_by_set)
    
    
    data_name = f"data_oocl_{DataParams.mod}.pt"
    
    from datetime import datetime
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    
    data_name = os.path.join(args.model_path, timestamp+'__'+data_name)
    print(f'SAVING TO {data_name}')
    torch.save((train_sets, test_sets), data_name)
    
    
    orig_args = make_tbl_mask(mod=DataParams.mod, method='prod', frac_held_out=train_params.orig_held_out_frac)
    
    train_w_orig(new_model, train_sets, test_sets, orig_args, train_params, args)
    
    wandb.finish()
