# Base Models

In [1]:
# ! pip install torch --upgrade

In [2]:
# https://wandb.ai/kilianovski/misha-iml/runs/kn69jraw


checkpoint_path = '../models/transformers/pretrained_1L_dmodel=1024_attnonly=False20240718_200143.pt'
checkpoint_path = '../models/transformers/grokking_prod_120_1_0.1_attnonly_False20240712_133838.pt'


transformer_config = dict(
    d_vocab=512,
    n_layers=1,
    d_model=1024,
    d_head=128,
    n_heads=4,
    d_mlp=256,
    n_ctx=5,
    act_fn="relu",  # gelu?
    normalization_type='LN',
    attn_only=False,
)

In [3]:
checkpoint_path = './multiplication_model.pt'


transformer_config = dict(
    d_vocab=512,
    n_layers=1,
    d_model=1024,
    d_head=256,
    n_heads=4,
    d_mlp=4096,
    n_ctx=5,
    act_fn="relu",  # gelu?
    normalization_type=None,
    attn_only=False,
)

In [4]:
from argparse import Namespace
from dataclasses import dataclass, asdict
seed = 2




@dataclass
class TrainParams:
    n_steps: int = int(1e8)
    batch_size: int = 128
    lr: float = 1e-4
    wd: float = 1e-1
    betas: tuple = (0.9, 0.98)
    max_grad_norm: float = 1.0
    num_epochs_X1: int = 500
    num_epochs_X2: int = 500
    prop_orig: float = 0.25
    orig_held_out_frac: float = 0.01
    swap_defs: bool = False # whether to swap the order of the defs
    val_questions: int = 9
    



# Setup

In [5]:
import sys
sys.path.append('..')

In [6]:
from data import create_datasets, seed_all, DataParams, Tokens, OOCL_Dataset, make_tbl_mask, create_orig_data, yield_data

In [7]:
import logging
import torch
from dataclasses import dataclass, asdict
import numpy as np
import time
import os
from tqdm.auto import tqdm
from pathlib import Path
import itertools
import sys
import random
import torch.nn.functional as F
from torch.utils.data import random_split, TensorDataset, DataLoader, Dataset
import argparse
from transformer_lens import HookedTransformer, HookedTransformerConfig
import wandb
from dotenv import load_dotenv
from sympy import factorint
from itertools import product
from math import prod

import os
import random
import numpy as np
import torch
from tqdm.auto import tqdm

In [8]:
def get_device():
    #return 'cpu'
    if torch.cuda.is_available():
        return "cuda"
    # elif torch.backends.mps.is_available():
    #     return "mps"
    else:
        return "cpu"

In [9]:
def loss_fn(logits, tokens):

    # check whether question or def and compute loss appropriately
    # logit shape: [batch, pos, vocab]
    # token shape: [batch, pos]

    mask = (tokens[:, 3] == 2*DataParams.mod + Tokens.padding)

    def_logits = logits[mask]
    def_tokens = tokens[mask].long()

    q_logits = logits[~mask]
    q_tokens = tokens[~mask].long()

    def_logits = def_logits[:, 1].unsqueeze(1)
    def_tokens = def_tokens[:, 2].unsqueeze(1)
    def_log_probs = def_logits.log_softmax(-1)
    def_correct_log_probs = def_log_probs.gather(-1, def_tokens[..., None])[..., 0]
    
    q_logits = q_logits[:, 2].unsqueeze(1)
    q_tokens = q_tokens[:, 3].unsqueeze(1)
    q_log_probs = q_logits.log_softmax(-1)
    q_correct_log_probs = q_log_probs.gather(-1, q_tokens[..., None])[..., 0]

    return -(def_correct_log_probs.sum() + q_correct_log_probs.sum())/(def_correct_log_probs.shape[0] + q_correct_log_probs.shape[0])


In [10]:
def evaluate(model, val_loader, device):
    correct = 0
    loss = 0.
    total = 0
    batches = 0

    for batch in val_loader:
        inputs = batch[0].to(device)

        labels = inputs[:, -1]

        with torch.no_grad():
            output = model(inputs)
            loss += loss_fn(output, inputs).item()
            correct += (torch.argmax(output[:, -2, :], dim=1) == labels).sum()

        total += inputs.shape[0]
        batches += 1

    acc = correct / total
    loss = loss/batches
    return acc, loss

# Train

In [11]:
# int_by_set = {
#     'DtQ1': [57, 103, 99, 3, 111, 59, 73, 22, 30, 25, 47, 69, 23, 67, 75, 16, 85, 29, 2, 76, 8, 107, 43, 84, 98, 44, 46, 115, 80, 37], 
#     'DfQ2': [50, 72, 118, 20, 93, 10, 52, 14, 83, 6, 28, 15, 34, 48, 114, 104, 88, 13, 91, 54, 112, 58, 102, 95, 21, 24, 19, 94, 35, 109], 
#     'Dt3': [4, 81, 82, 41, 31, 86, 63, 0, 110, 11, 1, 92, 7, 116, 66, 56, 119, 70, 26, 78, 40, 55, 105, 89, 71, 60, 42, 87, 9, 117], 
#     'Df4': [39, 18, 77, 90, 68, 32, 79, 12, 96, 101, 36, 17, 64, 27, 74, 45, 61, 38, 106, 100, 51, 62, 65, 33, 5, 53, 113, 97, 49, 108]
# }


# int_by_set = {
#     'DtQ1': [119, 81, 20, 90, 68, 41, 4, 79, 38, 10, 14, 95, 22, 78, 114, 71, 73, 52, 94, 9, 82, 116, 96, 93, 39, 36, 105, 50, 16, 33], 
#     'DfQ2': [5, 30, 19, 59, 74, 24, 104, 21, 18, 51, 42, 61, 65, 84, 64, 35, 113, 11, 66, 80, 112, 7, 31, 98, 43, 6, 25, 45, 117, 47], 
#     'Dt3': [99, 46, 88, 23, 103, 53, 86, 37, 58, 76, 118, 44, 91, 70, 111, 56, 28, 67, 85, 54, 27, 106, 1, 69, 107, 87, 2, 101, 40, 13], 
#     'Df4': [75, 29, 92, 34, 109, 89, 0, 110, 77, 55, 49, 3, 62, 12, 26, 100, 48, 83, 60, 57, 115, 63, 15, 32, 8, 97, 102, 108, 72, 17]
# }

In [12]:
try:
    from prettytable import PrettyTable
except:
    ! pip install -q prettytable
    from prettytable import PrettyTable

def count_parameters(model):
    from prettytable import PrettyTable

    table = PrettyTable(["Modules", "Parameters"])
    total_params = 0
    for name, parameter in model.named_parameters():
        if not parameter.requires_grad: continue
        params = parameter.numel()
        table.add_row([name, params])
        total_params+=params
    print(table)
    print(f"Total Trainable Params: {total_params}")
    return total_params

In [13]:
# noisy_labels = {119+n for n in int_by_set['DfQ2']}

# def randomize_noisy_labels(tokens):
#     noisy_label_mask = torch.tensor(list(map(lambda x: any(int(l) in noisy_labels for l in x) , tokens[:, :2])))
#     noisy_label_mask = noisy_label_mask & (tokens[:, 0] != 242)
#     tokens[noisy_label_mask][:, -1] = torch.randint(0, 120*2, size=(len(tokens[noisy_label_mask]),))

In [14]:
def get_weight_norms(model):
    return {f'wnorm/{n}':p.detach().norm().item() for n,p in model.named_parameters()}

In [17]:
def train_iml(args, checkpoint_path=checkpoint_path):
    mod = DataParams.mod


    def get_grad_cos_sims(numbers):
        cos_sims = []
        for n in numbers:
            
            definition = definitions[definitions[:, 1] == (119+n)]
        
            question_mask = (questions_X2[:, 0] == (119+n)) | (questions_X2[:, 1] == (119+n))
            questions = questions_X2[question_mask]
        
            d_grads = get_flat_grad(model, definition)
            q_grads = get_flat_grad(model, questions)
        
            cos_sim = F.cosine_similarity(d_grads, q_grads, dim=0)
        
            cos_sims.append(cos_sim)
        return cos_sims

    # ******* INIT MODEL ***********
    seed = args.seed
    seed_all(args.seed)
    new_transformer_config = transformer_config
    new_transformer_config.update(dict(
        d_vocab=2*mod + 4,  # 3 special tokens + mod vars
    ))
    new_cfg = HookedTransformerConfig(**new_transformer_config)
    new_model = HookedTransformer(new_cfg)
    state_dict = torch.load(checkpoint_path, map_location='cpu')
    new_model.load_state_dict(state_dict)
    new_model.to(get_device())
    
    model = new_model

    # ********* END MODEL

    # ********* DATA
    seed_all(seed)
    
    
    # divide the integers into 4 equally sized sets
    size = mod // 4
    rem = mod % 4
    
    numbers = list(range(DataParams.mod))
    random.shuffle(numbers)
    
    train_params = TrainParams()
        
    int_by_set = {}
    int_by_set['DtQ1'] = numbers[0:size]
    int_by_set['DfQ2'] = numbers[size:2*size]
    int_by_set['Dt3'] = numbers[2*size:3*size]
    int_by_set['Df4'] = numbers[3*size:mod]
    
    train_sets, test_sets = create_datasets(int_by_set)
    orig_args = make_tbl_mask(mod=DataParams.mod, method='prod', frac_held_out=train_params.orig_held_out_frac)
    
    
    batch_size = train_params.batch_size
    
    # unpack orig_args for use in valid_loader
    
    x_vv, y_vv, z_vv, train_vv, valid_vv = orig_args
    
    device = get_device()
    
    X1_dataset = OOCL_Dataset(train_sets['X1'], create_orig_data, orig_args, train_params.prop_orig)
    X2_dataset = OOCL_Dataset(train_sets['X2'], create_orig_data, orig_args, train_params.prop_orig)
    
    X1_loader = DataLoader(X1_dataset, batch_size=batch_size, shuffle=True)
    X2_loader = DataLoader(X2_dataset, batch_size=batch_size, shuffle=True)
    
    orig_data_valid_loader = yield_data(train_params.batch_size, x_vv, y_vv, z_vv, valid_vv)
    
    test_set_loaders = {}
    
    for s in test_sets:
        test_set_loaders[s] = DataLoader(TensorDataset(test_sets[s].to(dtype=torch.int)), batch_size=train_params.batch_size, shuffle=False)
    
    
    questions_X2 = []
    
    for batch in test_set_loaders['Dt3']:
        questions_X2.append(batch[0])
    
    
    for batch in test_set_loaders['Df4']:
        questions_X2.append(batch[0])
    
    
    questions_X2 = torch.cat(questions_X2, dim=0)
    
    
    definitions = []
    
    for tokens in X2_loader:
        tokens = tokens.squeeze(1)
        tokens = tokens.to(device)
        definitions.append(tokens)
    
    definitions = torch.cat(definitions)
    
    
    def get_flat_grad(model, tokens):
        for p in model.parameters():
            p.requires_grad_(True)
            p.grad = None
        
        logits = model(tokens)
        loss = loss_fn(logits, tokens)
        loss.backward()
    
        grads = []
        for p in model.parameters():
            grads.append(p.grad.detach().flatten())
        grads = torch.cat(grads)
        return grads


    def get_grad_norms(model, tokens):
        for p in model.parameters():
            p.requires_grad_(True)
            p.grad = None
        
        logits = model(tokens)
        loss = loss_fn(logits, tokens)
        loss.backward()
    
        grads = {}
        for n,p in model.named_parameters():
            grads[f'gradnorm/{n}'] = p.grad.detach().norm().item()
        return grads
    
    
    
    def get_metrics(model):
        val_acc_DtQ1, val_loss_DtQ1 = evaluate(model, test_set_loaders['DtQ1'], device)
        val_acc_DfQ2, val_loss_DfQ2 = evaluate(model, test_set_loaders['DfQ2'], device)
        val_acc_Dt3, val_loss_Dt3 = evaluate(model, test_set_loaders['Dt3'], device)
        val_acc_Df4, val_loss_Df4 = evaluate(model, test_set_loaders['Df4'], device)
        with torch.no_grad():
            # logging.info(tokens)
            tokens = next(orig_data_valid_loader)
            tokens = tokens.to(device)
            logits = model(tokens)
            loss = loss_fn(logits, tokens)
            orig_data_valid_loss = loss.item()
        metrics = {
    
                        # "train/loss": train_loss,
                        "valid_DtQ1/acc": val_acc_DtQ1,
                        "valid_DfQ2/acc": val_acc_DfQ2,
                        "valid_DtQ1/loss": val_loss_DtQ1,
                        "valid_DfQ2/loss": val_loss_DfQ2,
        
                        "valid_Dt3/acc": val_acc_Dt3,
                        "valid_Df4/acc": val_acc_Df4,
                        "valid_Dt3/loss": val_loss_Dt3,
                        "valid_Df4/loss": val_loss_Df4,
        
                        "val/loss": (val_loss_DtQ1+val_loss_DfQ2)/2,
                        "orig_data_valid_loss": orig_data_valid_loss,
                        **get_weight_norms(model),
                    }
        return metrics
    
    
    for tokens in X1_loader:
        tokens = tokens.squeeze(1)
        tokens = tokens.to(device)
        break

    # ********* END DATA
    
    wandb.init(
        project="misha-iml",
        group=args.wandb_group_name,
        name=args.wandb_experiment_name,
        config={
            **asdict(DataParams()),
            **asdict(train_params),
            **transformer_config,
        }
    )
    print(f'seed={args.seed}')
    print('int_by_set')
    print(int_by_set)
    print('loaded from', checkpoint_path)
    count_parameters(model)
    
    optimizer = torch.optim.AdamW(model.parameters(), lr=train_params.lr, betas=train_params.betas, weight_decay=train_params.wd)
    # optimizer = torch.optim.SGD(model.parameters(), lr=train_params.lr, weight_decay=train_params.wd)
    losses = []
    
    pbar = tqdm(range(train_params.num_epochs_X1))
    step = 0
    for epoch in pbar:
        model.train()
        for tokens in X1_loader:
            tokens = tokens.squeeze(1)
            tokens = tokens.to(device)
            
            # randomize_noisy_labels(tokens)
            
            logits = model(tokens)
            loss = loss_fn(logits, tokens)
            loss.backward()
            if train_params.max_grad_norm is not None:
                torch.nn.utils.clip_grad_norm_(model.parameters(), train_params.max_grad_norm)
            optimizer.step()
            optimizer.zero_grad()
            losses.append(loss.item())
    
    
        train_loss = np.mean(losses)
        model.eval()
        metrics = get_metrics(model)
        metrics['train_loss'] = train_loss
        metrics['step'] = step
    
        pbar.set_description(f'train_Loss={train_loss.item():.3f}')
    
        if wandb.run is not None:
            wandb.log(metrics)
    
            if (epoch % args.log_grad_alignment_freq == 0) or (epoch == train_params.num_epochs_X1 - 1):
                cos_sim_Dt3 = np.mean(get_grad_cos_sims(int_by_set['Dt3']))
                cos_sim_Df4 = np.mean(get_grad_cos_sims(int_by_set['Df4']))
                
                wandb.log({
                    'step': step,
                    'grad_cos_sim_Dt3': cos_sim_Dt3,
                    'grad_cos_sim_Df4': cos_sim_Df4,
                })
    
        step += 1
    
    checkpoint_path = Path(checkpoint_path)
    new_checkpoint_path = checkpoint_path.parent / ('stage1__'+checkpoint_path.name)
    
    torch.save(model.state_dict(), new_checkpoint_path)
    print(f'saved to {new_checkpoint_path}')
    
    optimizer = torch.optim.AdamW(model.parameters(), lr=train_params.lr, betas=train_params.betas, weight_decay=train_params.wd)
    # optimizer = torch.optim.SGD(model.parameters(), lr=train_params.lr, weight_decay=train_params.wd)
    losses = []
    
    for epoch in tqdm(range(train_params.num_epochs_X2)):
        model.train()
        for tokens in X2_loader:
            tokens = tokens.squeeze(1)
            tokens = tokens.to(device)
            logits = model(tokens)
            loss = loss_fn(logits, tokens)
            loss.backward()
            
            # model.W_E.grad[:120] = 0
    
            if train_params.max_grad_norm is not None:
                torch.nn.utils.clip_grad_norm_(model.parameters(), train_params.max_grad_norm)
            optimizer.step()
            optimizer.zero_grad()
            losses.append(loss.item())
    
    
        train_loss = np.mean(losses)
        model.eval()
        metrics = get_metrics(model)
        metrics['train_loss'] = train_loss
        metrics['step'] = step

        
        if wandb.run is not None:
            wandb.log(metrics)
    
    
            if (epoch % args.log_grad_alignment_freq == 0) or (epoch == train_params.num_epochs_X2 - 1):
                cos_sim_Dt3 = np.mean(get_grad_cos_sims(int_by_set['Dt3']))
                cos_sim_Df4 = np.mean(get_grad_cos_sims(int_by_set['Df4']))
                
                wandb.log({
                    'step': step,
                    'grad_cos_sim_Dt3': cos_sim_Dt3,
                    'grad_cos_sim_Df4': cos_sim_Df4,
                })
    
        step += 1
    
    checkpoint_path = Path(checkpoint_path)
    new_checkpoint_path = checkpoint_path.parent / ('stage2__'+checkpoint_path.name)
    
    torch.save(model.state_dict(), new_checkpoint_path)
    print(f'saved to {new_checkpoint_path}')

In [18]:
for seed in range(1):
    args = Namespace(
        model_path='./models/transformers/', 
        # model_name='grokking_prod_120_2_0.1_attnonly_True20240709_180213.pt', 
        wandb_group_name='Luans_model',
        wandb_experiment_name=f'seed={seed}',
        saved_model_name=None,
        log_grad_alignment_freq=50,
        seed=seed, 
        save_steps=[500, 950])


    train_iml(args)

Moving model to device:  cpu


  result_tensor = torch.tensor(Z).view(N, 1)
[34m[1mwandb[0m: Currently logged in as: [33mkilianovski[0m. Use [1m`wandb login --relogin`[0m to force relogin


seed=0
int_by_set
{'DtQ1': [57, 103, 99, 3, 111, 59, 73, 22, 30, 25, 47, 69, 23, 67, 75, 16, 85, 29, 2, 76, 8, 107, 43, 84, 98, 44, 46, 115, 80, 37], 'DfQ2': [50, 72, 118, 20, 93, 10, 52, 14, 83, 6, 28, 15, 34, 48, 114, 104, 88, 13, 91, 54, 112, 58, 102, 95, 21, 24, 19, 94, 35, 109], 'Dt3': [4, 81, 82, 41, 31, 86, 63, 0, 110, 11, 1, 92, 7, 116, 66, 56, 119, 70, 26, 78, 40, 55, 105, 89, 71, 60, 42, 87, 9, 117], 'Df4': [39, 18, 77, 90, 68, 32, 79, 12, 96, 101, 36, 17, 64, 27, 74, 45, 61, 38, 106, 100, 51, 62, 65, 33, 5, 53, 113, 97, 49, 108]}
loaded from ./multiplication_model.pt
+--------------------+------------+
|      Modules       | Parameters |
+--------------------+------------+
|     embed.W_E      |   249856   |
|  pos_embed.W_pos   |    5120    |
| blocks.0.attn.W_Q  |  1048576   |
| blocks.0.attn.W_O  |  1048576   |
| blocks.0.attn.b_Q  |    1024    |
| blocks.0.attn.b_O  |    1024    |
| blocks.0.attn.W_K  |  1048576   |
| blocks.0.attn.W_V  |  1048576   |
| blocks.0.attn.b_K

  0%|          | 0/500 [00:00<?, ?it/s]

saved to stage1__multiplication_model.pt


  0%|          | 0/500 [00:00<?, ?it/s]

saved to stage2__multiplication_model.pt


In [19]:
for seed in range(2, 10):
    args = Namespace(
        model_path='./models/transformers/', 
        # model_name='grokking_prod_120_2_0.1_attnonly_True20240709_180213.pt', 
        wandb_group_name='Luans_model',
        wandb_experiment_name=f'seed={seed}',
        saved_model_name=None,
        log_grad_alignment_freq=50,
        seed=seed, 
        save_steps=[500, 950])


    train_iml(args)

Moving model to device:  cpu


  result_tensor = torch.tensor(Z).view(N, 1)


VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
grad_cos_sim_Df4,█▄▃▄▄▃▃▃▄▄▄▄▃▂▂▂▂▁▁▂▁▁
grad_cos_sim_Dt3,▅█▇▇▇▇▆▇▇▇██▂▂▂▁▁▁▁▂▁▁
orig_data_valid_loss,▂▅▂▁▁▂▁▁▁▁▁▅▁▁▁▁▁▁▁▁▁▁▅▁▁█▁▂▁▁▁▁▁▁▁▁▁▁▁▁
step,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
train_loss,▄▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁█▄▃▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁
val/loss,▃▁▂▃▃▄▄▄▄▄▄▄▄▄▄▄▅▅▅▅▄▅▇█▇▇▆▆▆▆▆▆▇▇▇▇▇▇▇▇
valid_Df4/acc,▁▁▂▃▂▂▂▂▃▂▄▃▃▃▅▅▅▅▆▇▃▂▂▃▄▄▆███▇▇▇▇▇▇▇▇▇▇
valid_Df4/loss,▁▁▂▃▃▃▃▃▃▃▃▃▃▃▄▃▄▄▅▅▄▅▆▇▇▇▇▇▇▇▇▇▇███████
valid_DfQ2/acc,▁▇▇▆▆▆▆▇▇▇▇▇██████▇▇▇▇▆▆▆▇▇▇▇▆▇▇▆▆▆▆▆▆▆▇
valid_DfQ2/loss,▂▁▂▃▃▄▄▄▄▄▄▄▄▄▄▄▄▅▅▅▄▅▇██▇▆▇▇▇▇▇▇███████

0,1
grad_cos_sim_Df4,-0.25575
grad_cos_sim_Dt3,-0.08119
orig_data_valid_loss,7e-05
step,999.0
train_loss,0.33828
val/loss,4.76194
valid_Df4/acc,0.10556
valid_Df4/loss,26.06919
valid_DfQ2/acc,0.57778
valid_DfQ2/loss,6.81214


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.011155385645623836, max=1.0…

seed=2
int_by_set
{'DtQ1': [117, 98, 8, 2, 38, 89, 12, 62, 76, 0, 5, 1, 60, 90, 109, 116, 51, 6, 91, 13, 18, 37, 118, 24, 105, 52, 16, 26, 68, 66], 'DfQ2': [49, 36, 97, 107, 100, 9, 58, 63, 115, 42, 31, 111, 15, 78, 72, 86, 102, 14, 96, 33, 35, 84, 73, 99, 44, 19, 61, 104, 43, 25], 'Dt3': [28, 83, 82, 45, 93, 53, 57, 23, 70, 101, 80, 95, 17, 75, 41, 79, 88, 29, 30, 22, 71, 112, 67, 54, 48, 40, 59, 114, 3, 119], 'Df4': [34, 64, 56, 69, 47, 65, 92, 50, 81, 55, 20, 87, 74, 4, 113, 27, 77, 32, 39, 85, 103, 94, 21, 106, 46, 10, 11, 7, 108, 110]}
loaded from ./multiplication_model.pt
+--------------------+------------+
|      Modules       | Parameters |
+--------------------+------------+
|     embed.W_E      |   249856   |
|  pos_embed.W_pos   |    5120    |
| blocks.0.attn.W_Q  |  1048576   |
| blocks.0.attn.W_O  |  1048576   |
| blocks.0.attn.b_Q  |    1024    |
| blocks.0.attn.b_O  |    1024    |
| blocks.0.attn.W_K  |  1048576   |
| blocks.0.attn.W_V  |  1048576   |
| blocks.0.attn.b_K

  0%|          | 0/500 [00:00<?, ?it/s]

saved to stage1__multiplication_model.pt


  0%|          | 0/500 [00:00<?, ?it/s]

saved to stage2__multiplication_model.pt
Moving model to device:  cpu


  result_tensor = torch.tensor(Z).view(N, 1)


VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
grad_cos_sim_Df4,█▇▇▇▆▆▆▆▆▆▆▆▃▂▂▂▁▁▂▂▁▂
grad_cos_sim_Dt3,█▇▅▅▅▅▅▆▅▅▇▆▄▂▄▂▁▂▂▂▁▁
orig_data_valid_loss,▂▄▃▁▁▃▃▁▁▃▁▁▄▆▃▁▃▂█▁▁▁▆▃▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
step,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
train_loss,▄▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁█▅▃▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁
val/loss,▃▁▂▃▃▃▃▄▄▄▄▄▄▄▄▄▄▄▅▅▅▆▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇██
valid_Df4/acc,▁▁▂▁▂▂▃▂▃▃▃▃▃▃▃▃▂▂▃▃▃▃▂▆▇▆██▇▇▇▇▇▇▆▅▆▆▆▆
valid_Df4/loss,▁▁▂▂▂▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▄▅▆▆▆▆▆▆▇▇▇▇▇▇▇▇████
valid_DfQ2/acc,▁██▇▆▇▇▇▇▇▇▇██▇███▇▇█▇▆▆▇▆▇▇▆▆▆▆▆▆▆▆▅▆▆▆
valid_DfQ2/loss,▂▁▂▃▄▄▄▄▄▄▄▅▅▅▅▅▅▅▅▅▆▇█▇▇▇▇▇█████▇▇▇▇▇██

0,1
grad_cos_sim_Df4,-0.13524
grad_cos_sim_Dt3,-0.04622
orig_data_valid_loss,1e-05
step,999.0
train_loss,0.36556
val/loss,6.77696
valid_Df4/acc,0.1
valid_Df4/loss,22.48138
valid_DfQ2/acc,0.52593
valid_DfQ2/loss,9.58669


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.011168222688138486, max=1.0…

seed=3
int_by_set
{'DtQ1': [39, 83, 71, 93, 90, 51, 52, 116, 23, 35, 79, 28, 59, 84, 15, 86, 45, 98, 101, 109, 92, 57, 78, 9, 0, 65, 115, 31, 25, 11], 'DfQ2': [99, 41, 68, 67, 18, 2, 40, 112, 102, 10, 43, 37, 21, 119, 88, 7, 64, 48, 58, 14, 6, 36, 44, 42, 13, 61, 87, 62, 111, 22], 'Dt3': [53, 32, 26, 82, 55, 105, 27, 63, 72, 4, 12, 46, 17, 56, 73, 96, 54, 89, 76, 97, 34, 3, 38, 5, 118, 20, 110, 85, 108, 49], 'Df4': [66, 94, 95, 103, 19, 81, 50, 100, 104, 117, 106, 91, 24, 29, 70, 33, 113, 107, 1, 114, 8, 74, 80, 60, 77, 47, 16, 69, 75, 30]}
loaded from ./multiplication_model.pt
+--------------------+------------+
|      Modules       | Parameters |
+--------------------+------------+
|     embed.W_E      |   249856   |
|  pos_embed.W_pos   |    5120    |
| blocks.0.attn.W_Q  |  1048576   |
| blocks.0.attn.W_O  |  1048576   |
| blocks.0.attn.b_Q  |    1024    |
| blocks.0.attn.b_O  |    1024    |
| blocks.0.attn.W_K  |  1048576   |
| blocks.0.attn.W_V  |  1048576   |
| blocks.0.attn.b_K

  0%|          | 0/500 [00:00<?, ?it/s]

saved to stage1__multiplication_model.pt


  0%|          | 0/500 [00:00<?, ?it/s]

saved to stage2__multiplication_model.pt
Moving model to device:  cpu


  result_tensor = torch.tensor(Z).view(N, 1)


VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
grad_cos_sim_Df4,█▇▅▅▄▅▂▂▁▃▁▂▄▄▃▂▄▄▅▃▃▄
grad_cos_sim_Dt3,██▆▆▄▄▅▄▄▄▄▅▂▃▂▁▂▂▂▁▁▁
orig_data_valid_loss,▁▃▂▁▁▃▂▁▁▃▁▁▃█▄▁▄▃█▁▁▁▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
step,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
train_loss,▃▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁█▅▃▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁
val/loss,▆▁▂▃▄▄▄▄▅▄▄▄▅▅▄▄▅▄▅▅▅▅▇█▇▇██▇▇▇▇▇▇▇▇▇▇▇▇
valid_Df4/acc,▁▂▂▁▄▁▃▂▃▂▄▃▃▃▃▃▃▃▃▃▂▂▄▇▆▆▇▇▇▇▇▇███▇▇▇▇▇
valid_Df4/loss,▁▁▂▃▃▄▄▄▅▄▅▅▆▆▆▆▇▆▇▇▇▇▇████████████▇▇▇▇▇
valid_DfQ2/acc,▁█▇▆▆▆▇▇▇▇▇▇▇▆▇▇▇▇▇▇▇▇▆▆▆▆▆▆▆▅▅▅▅▆▆▆▇▇▇▇
valid_DfQ2/loss,▄▁▂▃▄▄▄▄▄▄▄▅▅▅▄▅▅▅▅▅▅▅▇█▇▇█▇▇▇▇▇▇▆▆▆▆▆▆▆

0,1
grad_cos_sim_Df4,0.07379
grad_cos_sim_Dt3,-0.04765
orig_data_valid_loss,3e-05
step,999.0
train_loss,0.35743
val/loss,3.53906
valid_Df4/acc,0.13611
valid_Df4/loss,16.34839
valid_DfQ2/acc,0.62963
valid_DfQ2/loss,3.96902


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.011132342134240187, max=1.0…

seed=4
int_by_set
{'DtQ1': [48, 25, 45, 83, 32, 6, 53, 73, 58, 105, 85, 72, 86, 23, 15, 71, 76, 93, 101, 114, 100, 57, 59, 18, 67, 111, 104, 75, 1, 56], 'DfQ2': [20, 26, 91, 41, 78, 94, 115, 109, 113, 89, 17, 110, 40, 69, 116, 95, 63, 74, 4, 29, 42, 16, 44, 119, 62, 81, 9, 5, 12, 106], 'Dt3': [14, 10, 103, 55, 36, 54, 52, 90, 65, 88, 87, 0, 118, 108, 84, 99, 60, 79, 98, 31, 64, 49, 43, 77, 112, 47, 80, 107, 39, 21], 'Df4': [24, 34, 96, 82, 3, 27, 33, 117, 22, 35, 46, 68, 66, 28, 7, 97, 102, 37, 70, 51, 2, 8, 11, 19, 61, 50, 92, 13, 38, 30]}
loaded from ./multiplication_model.pt
+--------------------+------------+
|      Modules       | Parameters |
+--------------------+------------+
|     embed.W_E      |   249856   |
|  pos_embed.W_pos   |    5120    |
| blocks.0.attn.W_Q  |  1048576   |
| blocks.0.attn.W_O  |  1048576   |
| blocks.0.attn.b_Q  |    1024    |
| blocks.0.attn.b_O  |    1024    |
| blocks.0.attn.W_K  |  1048576   |
| blocks.0.attn.W_V  |  1048576   |
| blocks.0.attn.b_K

  0%|          | 0/500 [00:00<?, ?it/s]

saved to stage1__multiplication_model.pt


  0%|          | 0/500 [00:00<?, ?it/s]

saved to stage2__multiplication_model.pt
Moving model to device:  cpu


  result_tensor = torch.tensor(Z).view(N, 1)


VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
grad_cos_sim_Df4,█▅▅▄▅▅▅▅▅▅▅▅▃▃▃▂▂▁▁▂▂▁
grad_cos_sim_Dt3,▇█▇▆▆▇▆▆▅▆▅▅▄▃▃▂▂▁▁▁▁▁
orig_data_valid_loss,▂▃█▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▂▃▂▁▁▁▁▂▂▁▁▁▂▁▁▁▁▁▁▁
step,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
train_loss,▃▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁█▄▃▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁
val/loss,▃▁▂▃▄▄▄▄▄▄▄▄▅▅▅▅▅▆▆▆▆▇██████▇▇▇▇▇▇▇▇▇▇▇▇
valid_Df4/acc,▁▂▃▃▄▃▂▃▅▅▅▄▅▄▄▂▄▄▅▅▅▄▅▆▆▅▆▆▅▅▅▆▆▆▆▆▇███
valid_Df4/loss,▁▁▂▃▃▄▄▄▄▄▅▅▅▅▇▆▆▇▇▇▆▇██████████████████
valid_DfQ2/acc,▁▆▆▇▇▇▇▇▇▇███▇███████▇▇▇▇▇▇██▇██████████
valid_DfQ2/loss,▄▁▂▃▃▃▄▄▄▄▄▄▅▅▆▄▆▆▆▆▅▆████▇▇▇▇█▇▇▇▇▇▇▇▇▇

0,1
grad_cos_sim_Df4,-0.15182
grad_cos_sim_Dt3,-0.02103
orig_data_valid_loss,0.0
step,999.0
train_loss,0.34899
val/loss,5.55762
valid_Df4/acc,0.13889
valid_Df4/loss,20.13891
valid_DfQ2/acc,0.68889
valid_DfQ2/loss,4.47295


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.011147985655245267, max=1.0…

seed=5
int_by_set
{'DtQ1': [61, 58, 106, 29, 62, 63, 24, 28, 46, 34, 100, 75, 15, 8, 71, 118, 7, 89, 103, 42, 5, 33, 115, 51, 18, 12, 36, 95, 93, 84], 'DfQ2': [92, 86, 43, 116, 109, 50, 65, 4, 108, 54, 77, 112, 72, 97, 53, 57, 90, 39, 30, 64, 11, 85, 81, 55, 41, 66, 22, 19, 44, 82], 'Dt3': [10, 68, 110, 2, 38, 87, 70, 114, 76, 96, 25, 40, 37, 74, 21, 91, 26, 78, 0, 80, 16, 56, 104, 119, 17, 9, 102, 49, 23, 35], 'Df4': [52, 27, 1, 98, 73, 13, 69, 48, 105, 60, 47, 14, 20, 6, 111, 31, 99, 59, 113, 3, 67, 83, 117, 107, 88, 101, 45, 94, 32, 79]}
loaded from ./multiplication_model.pt
+--------------------+------------+
|      Modules       | Parameters |
+--------------------+------------+
|     embed.W_E      |   249856   |
|  pos_embed.W_pos   |    5120    |
| blocks.0.attn.W_Q  |  1048576   |
| blocks.0.attn.W_O  |  1048576   |
| blocks.0.attn.b_Q  |    1024    |
| blocks.0.attn.b_O  |    1024    |
| blocks.0.attn.W_K  |  1048576   |
| blocks.0.attn.W_V  |  1048576   |
| blocks.0.attn.b_K

  0%|          | 0/500 [00:00<?, ?it/s]

saved to stage1__multiplication_model.pt


  0%|          | 0/500 [00:00<?, ?it/s]

saved to stage2__multiplication_model.pt
Moving model to device:  cpu


  # print(f"Number of questions: {question_tensor.size(0)}")


VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
grad_cos_sim_Df4,█▅▄▄▄▄▅▄▄▅▅▄▂▂▃▂▂▂▂▁▁▁
grad_cos_sim_Dt3,█▇▄▄▄▄▃▄▃▃▃▄▂▃▃▂▃▃▂▁▃▁
orig_data_valid_loss,▃█▄▄▁▃▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▂█▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
step,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
train_loss,▃▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁█▄▃▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁
val/loss,▂▁▂▃▃▃▃▃▄▄▄▄▄▄▄▄▄▄▄▄▅▆▇████▇▇▇▇▆▆▆▆▆▆▆▆▆
valid_Df4/acc,▃▁▁▂▄▄▅▃▄▄▅▄▅▅▅▅▅▅▅▅▅▅▆▆▆▇▆▇▇▇▇▇▇▇▇▇▇███
valid_Df4/loss,▁▁▂▂▂▃▃▃▃▃▄▄▄▄▄▄▄▄▄▄▅▆▇████▇▇█▇▇▇▇▇▇████
valid_DfQ2/acc,▁▇▇▇▆▇▆▇▇▇▇▇▇▇██▇▇██▇▆▅▅▅▅▆▆▇▆▇▇▇▇▇▇▆▆▆▆
valid_DfQ2/loss,▂▁▂▃▄▄▄▄▅▄▅▅▅▅▅▅▅▅▅▅▅▅▇████▇▇▆▆▆▆▆▆▆▇▇▇▇

0,1
grad_cos_sim_Df4,-0.1411
grad_cos_sim_Dt3,0.01393
orig_data_valid_loss,1e-05
step,999.0
train_loss,0.34523
val/loss,6.16933
valid_Df4/acc,0.13056
valid_Df4/loss,22.53178
valid_DfQ2/acc,0.58889
valid_DfQ2/loss,6.6728


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.011167426854889426, max=1.0…

seed=6
int_by_set
{'DtQ1': [36, 28, 44, 82, 41, 26, 57, 27, 90, 61, 50, 67, 6, 20, 49, 110, 119, 80, 63, 66, 29, 39, 103, 58, 35, 15, 100, 112, 9, 107], 'DfQ2': [48, 13, 89, 95, 1, 16, 104, 86, 7, 30, 83, 43, 38, 23, 17, 21, 74, 116, 51, 59, 106, 118, 117, 53, 111, 45, 14, 88, 8, 79], 'Dt3': [108, 22, 19, 55, 31, 71, 3, 64, 65, 91, 99, 81, 76, 5, 77, 37, 102, 92, 56, 32, 96, 46, 85, 42, 54, 11, 78, 109, 113, 70], 'Df4': [72, 24, 12, 87, 69, 68, 52, 93, 25, 115, 34, 2, 98, 40, 47, 94, 114, 60, 75, 84, 18, 0, 4, 33, 97, 62, 10, 105, 73, 101]}
loaded from ./multiplication_model.pt
+--------------------+------------+
|      Modules       | Parameters |
+--------------------+------------+
|     embed.W_E      |   249856   |
|  pos_embed.W_pos   |    5120    |
| blocks.0.attn.W_Q  |  1048576   |
| blocks.0.attn.W_O  |  1048576   |
| blocks.0.attn.b_Q  |    1024    |
| blocks.0.attn.b_O  |    1024    |
| blocks.0.attn.W_K  |  1048576   |
| blocks.0.attn.W_V  |  1048576   |
| blocks.0.attn.b_K

  0%|          | 0/500 [00:00<?, ?it/s]

saved to stage1__multiplication_model.pt


  0%|          | 0/500 [00:00<?, ?it/s]

saved to stage2__multiplication_model.pt
Moving model to device:  cpu


  cur_question_tensor = torch.cat(


VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
grad_cos_sim_Df4,█▂▃▃▅▆▄▅▅▅▆▆▂▂▁▁▂▂▂▁▁▁
grad_cos_sim_Dt3,▇▅▆▆▅▄▅▆▆▇██▄▂▂▂▂▂▂▃▁▁
orig_data_valid_loss,▃▃█▁▄▃▇▂▂▁▂▁▁▂▁▁▁▁▁▁▁▄▄▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
step,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
train_loss,▃▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁█▄▃▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁
val/loss,▃▁▂▄▄▄▄▄▅▅▆▆▅▅▆▆▆▆▆▆▇▇▇█▇▇▇█████████████
valid_Df4/acc,▄▅▅▅▄▅▆▆██▅▇▆█▇▇▂▄▁▂▃▄▂▅▅▅██▅▅▅▅▅▆▇▇▆▇▇▇
valid_Df4/loss,▁▁▂▃▃▃▄▄▄▄▅▄▄▅▄▄▅▅▅▆▅▆▇██████▇██████████
valid_DfQ2/acc,▁█▆▆▆▆▆▇▇▇▇▇▇███▇████▇▇▇▇█▇▇▇▇▇█████████
valid_DfQ2/loss,▂▁▂▄▄▄▄▄▅▅▅▆▅▅▅▅▆▆▆▆▆▆▇██▇████████▇▇▇███

0,1
grad_cos_sim_Df4,-0.1355
grad_cos_sim_Dt3,-0.07517
orig_data_valid_loss,1e-05
step,999.0
train_loss,0.41107
val/loss,6.35229
valid_Df4/acc,0.08333
valid_Df4/loss,23.33585
valid_DfQ2/acc,0.63333
valid_DfQ2/loss,7.1005


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.011167950000769148, max=1.0…

seed=7
int_by_set
{'DtQ1': [67, 112, 34, 85, 3, 51, 61, 75, 119, 0, 1, 57, 58, 87, 40, 16, 43, 14, 25, 107, 77, 79, 22, 106, 48, 42, 81, 98, 110, 35], 'DfQ2': [45, 113, 36, 56, 52, 100, 65, 118, 91, 21, 103, 76, 62, 59, 116, 114, 99, 10, 38, 44, 66, 60, 31, 33, 32, 2, 84, 86, 78, 95], 'Dt3': [29, 20, 49, 97, 63, 26, 89, 101, 111, 47, 24, 13, 23, 82, 39, 88, 94, 69, 18, 102, 37, 17, 71, 5, 93, 115, 117, 90, 73, 96], 'Df4': [109, 92, 80, 28, 15, 72, 108, 54, 70, 104, 30, 8, 53, 55, 11, 4, 27, 64, 7, 74, 46, 12, 68, 105, 9, 6, 83, 50, 19, 41]}
loaded from ./multiplication_model.pt
+--------------------+------------+
|      Modules       | Parameters |
+--------------------+------------+
|     embed.W_E      |   249856   |
|  pos_embed.W_pos   |    5120    |
| blocks.0.attn.W_Q  |  1048576   |
| blocks.0.attn.W_O  |  1048576   |
| blocks.0.attn.b_Q  |    1024    |
| blocks.0.attn.b_O  |    1024    |
| blocks.0.attn.W_K  |  1048576   |
| blocks.0.attn.W_V  |  1048576   |
| blocks.0.attn.b_K

  0%|          | 0/500 [00:00<?, ?it/s]

saved to stage1__multiplication_model.pt


  0%|          | 0/500 [00:00<?, ?it/s]

saved to stage2__multiplication_model.pt
Moving model to device:  cpu


  cur_question_tensor = torch.cat(


VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
grad_cos_sim_Df4,█▅▄▄▄▃▄▃▄▄▄▅▄▃▄▂▂▁▂▁▁▁
grad_cos_sim_Dt3,▆█▇▇▆▇▇▆▇▆▆▆▃▃▄▂▁▁▂▁▁▁
orig_data_valid_loss,▁▃▆▁▁▃▂▁▁▂▁▁▃▅▅▁█▄▇▁▂▂▃▄▃▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
step,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
train_loss,▄▃▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁█▄▃▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁
val/loss,▄▁▃▄▄▄▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅▆█▇▇▇▇▇▇▇▇▇▇▇▇▆▆▆▆▇
valid_Df4/acc,▂▁▂▃▂▂▃▃▂▃▄▄▃▄▃▃▃▂▄▅▃▁▂▃▄▂▃▄▅▄▅▅▅▅▅▇▇▇█▇
valid_Df4/loss,▁▁▂▃▃▃▃▃▄▄▄▄▅▄▅▅▆▅▆▆▆▇██████████████████
valid_DfQ2/acc,▁█▇▆▇▆▇▇▇▇██▇████████▇▆▇▆▆▇▇▆▇▇▇▇▇▇▇▇▇▇▇
valid_DfQ2/loss,▃▁▃▄▄▄▄▄▄▄▄▅▄▄▅▅▆▅▅▆▅▆▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇█

0,1
grad_cos_sim_Df4,-0.10359
grad_cos_sim_Dt3,0.01576
orig_data_valid_loss,8e-05
step,999.0
train_loss,0.30198
val/loss,5.09423
valid_Df4/acc,0.10556
valid_Df4/loss,21.85872
valid_DfQ2/acc,0.57778
valid_DfQ2/loss,6.15848


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.011168235178209014, max=1.0…

seed=8
int_by_set
{'DtQ1': [25, 111, 50, 21, 96, 74, 80, 116, 23, 57, 42, 15, 59, 70, 92, 104, 71, 113, 36, 55, 40, 65, 77, 1, 53, 93, 88, 61, 0, 76], 'DfQ2': [75, 39, 114, 44, 83, 98, 9, 22, 28, 87, 38, 54, 78, 67, 6, 72, 68, 45, 32, 112, 105, 108, 118, 20, 19, 85, 41, 27, 110, 35], 'Dt3': [56, 37, 46, 101, 69, 97, 94, 4, 81, 18, 107, 91, 99, 95, 30, 43, 7, 13, 86, 79, 100, 8, 12, 33, 84, 14, 117, 60, 52, 66], 'Df4': [34, 89, 2, 119, 102, 11, 106, 115, 73, 63, 49, 109, 62, 58, 3, 82, 51, 26, 64, 103, 31, 17, 10, 5, 90, 24, 16, 48, 47, 29]}
loaded from ./multiplication_model.pt
+--------------------+------------+
|      Modules       | Parameters |
+--------------------+------------+
|     embed.W_E      |   249856   |
|  pos_embed.W_pos   |    5120    |
| blocks.0.attn.W_Q  |  1048576   |
| blocks.0.attn.W_O  |  1048576   |
| blocks.0.attn.b_Q  |    1024    |
| blocks.0.attn.b_O  |    1024    |
| blocks.0.attn.W_K  |  1048576   |
| blocks.0.attn.W_V  |  1048576   |
| blocks.0.attn.b_K

  0%|          | 0/500 [00:00<?, ?it/s]

saved to stage1__multiplication_model.pt


  0%|          | 0/500 [00:00<?, ?it/s]

saved to stage2__multiplication_model.pt
Moving model to device:  cpu


  cur_question_tensor = torch.cat(


VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
grad_cos_sim_Df4,█▆▅▁▂▂▂▂▁▄▆▅▃▂▁▁▂▂▂▁▂▂
grad_cos_sim_Dt3,▇█▇▅▆▅▅▆▄▄▇▆▃▃▂▂▂▃▂▂▂▁
orig_data_valid_loss,▁▃▂▁▁▂▂▁▁▂▁▁▃▄▃▁▄▂█▁▁▁▄▂▁▃▁▃▁▁▁▁▁▁▁▁▁▂▂▁
step,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
train_loss,▄▃▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁█▅▃▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁
val/loss,▄▁▂▂▃▃▃▃▃▃▄▄▃▄▄▄▄▄▄▄▄▅▆▆▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇█
valid_Df4/acc,▁▂▂▂▂▂▃▃▄▄▄▃▄▄▄▅▄▄▃▄▄▄▅▆▆▅▆▇▆▆▆▇▇▇▇▇▇██▇
valid_Df4/loss,▁▁▂▃▃▄▃▃▃▃▄▅▅▅▄▅▆▆▆▆▆▆▇▆▆▇▇▇▇▇▇▇▇▇▇▇▇███
valid_DfQ2/acc,▁▇▇▆▇▇▇▇▇▇▇▇▇███▇▇███▇▆▇▇▆▇▇▆▆▆▆▆▆▆▆▆▆▆▇
valid_DfQ2/loss,▄▁▂▄▄▄▄▄▄▄▄▄▅▅▄▄▅▅▅▅▅▆█▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇█

0,1
grad_cos_sim_Df4,-0.05346
grad_cos_sim_Dt3,-0.04296
orig_data_valid_loss,0.00068
step,999.0
train_loss,0.27176
val/loss,5.12413
valid_Df4/acc,0.16389
valid_Df4/loss,18.74536
valid_DfQ2/acc,0.60741
valid_DfQ2/loss,4.56


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.011172160189340098, max=1.0…

seed=9
int_by_set
{'DtQ1': [22, 102, 97, 100, 41, 105, 84, 83, 15, 96, 33, 24, 106, 80, 44, 66, 88, 40, 55, 74, 87, 63, 18, 29, 114, 79, 19, 112, 65, 61], 'DfQ2': [9, 46, 72, 85, 39, 45, 4, 104, 94, 32, 95, 60, 69, 103, 56, 73, 82, 71, 31, 12, 51, 68, 1, 107, 76, 2, 108, 67, 36, 38], 'Dt3': [58, 117, 27, 91, 3, 115, 62, 99, 7, 52, 111, 25, 101, 113, 35, 50, 81, 116, 11, 53, 28, 26, 37, 13, 49, 8, 75, 109, 16, 14], 'Df4': [6, 30, 98, 20, 54, 92, 57, 90, 21, 48, 93, 5, 89, 118, 70, 42, 10, 77, 119, 64, 43, 0, 86, 110, 23, 17, 34, 47, 78, 59]}
loaded from ./multiplication_model.pt
+--------------------+------------+
|      Modules       | Parameters |
+--------------------+------------+
|     embed.W_E      |   249856   |
|  pos_embed.W_pos   |    5120    |
| blocks.0.attn.W_Q  |  1048576   |
| blocks.0.attn.W_O  |  1048576   |
| blocks.0.attn.b_Q  |    1024    |
| blocks.0.attn.b_O  |    1024    |
| blocks.0.attn.W_K  |  1048576   |
| blocks.0.attn.W_V  |  1048576   |
| blocks.0.attn.b_K

  0%|          | 0/500 [00:00<?, ?it/s]

saved to stage1__multiplication_model.pt


  0%|          | 0/500 [00:00<?, ?it/s]

saved to stage2__multiplication_model.pt


In [None]:
break

In [None]:
import plotly.graph_objects as go
import numpy as np

def plot_probs(logits):
    
    # Create a sample one-dimensional tensor of probabilities
    probabilities = F.softmax(logits.detach(), dim=0).numpy()
    
    # Create labels for each probability (you can customize these)
    labels = [f'{i}' for i in range(len(probabilities))]
    
    # Create the bar chart
    fig = go.Figure(data=[
        go.Bar(x=labels, y=probabilities)
    ])
    
    # Update the layout
    fig.update_layout(
        title='Probability Distribution',
        xaxis_title='Token ID',
        yaxis_title='Probability',
        yaxis_range=[0, 1]  # Set y-axis range from 0 to 1 for probabilities
    )
    
    # Show the plot
    fig.show()

In [None]:
sorted(int_by_set['Dt3'])[:3], sorted(int_by_set['Df4'])[:3],

In [None]:
x = torch.tensor([
    [242, 3+119],
])


logits = model(x)
plot_probs(logits[0, -1])

In [None]:
x = torch.tensor([
    [3+119, 4],
])

logits = model(x)
plot_probs(logits[0, -1])

In [None]:
x = torch.tensor([
    [241, 75+119],
])

logits = model(x)
plot_probs(logits[0, -1])