# Setup

In [1]:
import logging
import torch
from dataclasses import dataclass, asdict
import numpy as np
import time
import os
from tqdm.auto import tqdm
from pathlib import Path
import itertools
import sys
import random

import torch.nn.functional as F
from torch.utils.data import random_split, TensorDataset, DataLoader, Dataset
import argparse
from transformer_lens import HookedTransformer, HookedTransformerConfig
import wandb
from dotenv import load_dotenv
from sympy import factorint
from itertools import product
from math import prod

import os
import random
import numpy as np
import torch

def seed_all(seed):
  random.seed(seed)
  os.environ['PYTHONHASHSEED'] = str(seed)
  np.random.seed(seed)
  torch.manual_seed(seed)
  torch.cuda.manual_seed(seed)
  torch.backends.cudnn.deterministic = True
  
@dataclass
class DataParams:
    mod: int = 120
    operation: str = "prod"


@dataclass
class Tokens:
    # diff from 2*mod
    equal: int = 0
    reliable_def: int = 1
    unreliable_def: int = 2
    padding: int = 3



transformer_config = dict(
    d_vocab=512,
    n_layers=2,
    d_model=2**7,
    d_head=2**7,
    n_heads=4,
    d_mlp=2**8,
    n_ctx=5,
    act_fn="relu",  # gelu?
    normalization_type="LN",
    attn_only=True,
)


def get_device():
    #return 'cpu'
    if torch.cuda.is_available():
        return "cuda"
    # elif torch.backends.mps.is_available():
    #     return "mps"
    else:
        return "cpu"
    

class OOCL_Dataset(Dataset):

    def __init__(self, oocl_data, orig_data, orig_args, prop_orig=0.1):

        self.oocl_data = oocl_data
        self.orig_data = orig_data
        self.orig_args = orig_args
        self.prop_orig = prop_orig

        self.data_size = int((1+prop_orig)*len(self.oocl_data))

    def __len__(self):

        return self.data_size
    
    def __getitem__(self, index):
        
        if index >= len(self.oocl_data):
            a = self.orig_data(1, *self.orig_args).long()
            return a
        
        else:
            return self.oocl_data[index].unsqueeze(0).long()
        
def make_tbl_mask(mod=17, method="ssq", frac_held_out=0.05):
    tbl_vv = torch.empty((mod, mod), dtype=torch.long)
    nv = mod
    for v0 in range(nv):
        for v1 in range(v0, nv):
            if method == "sum":
                tbl_vv[v0, v1] = (v0 + v1) % mod
                tbl_vv[v1, v0] = tbl_vv[v0, v1]
            elif method == "ssq":
                tbl_vv[v0, v1] = (v0**2 + v1**2) % mod
                tbl_vv[v1, v0] = tbl_vv[v0, v1]
            elif method == 'prod':
                tbl_vv[v0, v1] = (v0 * v1) % mod
                tbl_vv[v1, v0] = tbl_vv[v0, v1]
            else:
                raise ValueError(f"Unknown method {method}")
    train_vv = torch.randperm(nv * nv).reshape(nv, nv) > (frac_held_out * nv * nv)
    valid_vv = ~train_vv
    assert torch.equal((train_vv & valid_vv).any(), torch.tensor(False))  # train and valid are distinct
    x_vv = torch.arange(nv).repeat(nv, 1).T
    y_vv = torch.arange(nv).repeat(nv, 1)
    return x_vv, y_vv, tbl_vv, train_vv, valid_vv

def yield_data(batch_size, x_vv, y_vv, z_vv, m_vv):
    """Sample only where m_vv is True.
    """
    # torch.manual_seed(seed)
    nv = x_vv.shape[0]
    nb = batch_size
    nV = nv * nv
    x_V = x_vv.reshape(nV)
    y_V = y_vv.reshape(nV)
    z_V = z_vv.reshape(nV)
    m_V = m_vv.reshape(nV)
    nM = m_V.sum().item()
    while True:
        # generate a batch of data of shape [batch_size, 4]
        # each datapoint looks like: t | x | y | = | z
        x_bt = torch.empty((nb, 4), dtype=torch.long)
        i = torch.where(m_V)[0][torch.randint(0, nM, (nb,))]  # choose only masked elements
        assert torch.equal(m_V[i].all(), torch.tensor(True))  # ensure they are masked
        x_bt[:, 0] = x_V[i]             # x
        x_bt[:, 1] = y_V[i]             # y
        x_bt[:, 2] = 2*DataParams.mod + Tokens.equal  # equal sign
        x_bt[:, 3] = z_V[i]             # z
        yield x_bt

def create_orig_data(batch_size, x_vv, y_vv, z_vv, m_vv, v_vv):

    nv = x_vv.shape[0]
    nb = batch_size
    nV = nv * nv
    x_V = x_vv.reshape(nV)
    y_V = y_vv.reshape(nV)
    z_V = z_vv.reshape(nV)
    m_V = m_vv.reshape(nV)
    nM = m_V.sum().item()
    
    # generate a batch of data of shape [batch_size, 4]
    # each datapoint looks like: t | x | y | = | z
    x_bt = torch.empty((nb, 4), dtype=torch.long)
    i = torch.where(m_V)[0][torch.randint(0, nM, (nb,))]  # choose only masked elements
    assert torch.equal(m_V[i].all(), torch.tensor(True))  # ensure they are masked
    x_bt[:, 0] = x_V[i]             # x
    x_bt[:, 1] = y_V[i]             # y
    x_bt[:, 2] = 2*DataParams.mod + Tokens.equal  # equal sign
    x_bt[:, 3] = z_V[i]             # z

    return x_bt
    

def create_definitions(integers, reliable_tag, reliable_def,newconfig=True):

    '''
    integers: list of integers to create definitions for
    reliable: bool indicating whether to use reliable/unreliable def

    definition of form D X M
    D: definition token (reliable or unreliable)
    X: variable token
    M: integer token

    return size (N, 3), where N = len(integers)
    '''

    def_idx = 2*DataParams.mod + Tokens.reliable_def if reliable_tag else 2*DataParams.mod + Tokens.unreliable_def

    # get the token indices of the variables

    N = len(integers)

    if (newconfig):
        var_indices = [i + DataParams.mod-1 for i in integers]
    else:
        var_indices = [i + DataParams.mod for i in integers]

    if not reliable_def:
        random.shuffle(integers)

    def_idx_tensor = torch.full((N, 1), def_idx, dtype=torch.int64)
    integer_tensor = torch.tensor(integers).view(N, 1)
    var_tensor = torch.tensor(var_indices).view(N, 1)
    
    def_tensor = torch.cat((def_idx_tensor, var_tensor, integer_tensor), dim=1)

    if TrainParams.swap_defs:
        swap_var_tensor = var_tensor.clone()
        swap_integer_tensor = integer_tensor.clone()

        indices = torch.randperm(var_tensor.size(0))

        swap_var_tensor[indices], swap_integer_tensor[indices] = integer_tensor[indices], var_tensor[indices]

        swap_def_tensor = torch.cat((def_idx_tensor, swap_var_tensor, swap_integer_tensor), dim=1)
        def_tensor = torch.cat((def_tensor, swap_def_tensor), dim=0)

    return def_tensor.long()

def create_questions(integers, num_questions=6, bidir=True, result_var=False,newconfig=True):

    '''
    integers: list of integers to create questions for
    num_questions: how many questions to create per integer
    bidir: whether to have variables on the left and the right of the LHS
    result_var: whether to make result a variable sometimes too

    '''

    def get_divisors_from_prime_factors(factors, n):
        base_exponents = [
            [base**exp for exp in range(0, max_exp + 1)]  # Start from exp=1 to exclude 1
            for base, max_exp in factors.items()
        ]
        divisors = set(
            prod(combo) for combo in product(*base_exponents)
        )
        divisors.discard(n)  # Exclude the number itself
        divisors.discard(1)
        return sorted(divisors)  # Return a sorted list of divisors

    # calculate relevant values

    N = len(integers)

    question_tensor = torch.empty((0, 4))

    if DataParams.operation == 'prod':
        
        factors = factorint(DataParams.mod)
        divisors = get_divisors_from_prime_factors(factors, DataParams.mod)
        divisors = [2,3,5,6,10,15]
        for d in divisors:

            d_tensor = torch.full((N,), d, dtype=torch.int64)

            integer_tensor = torch.tensor(integers).view(N,)

            Z = integer_tensor*d_tensor % DataParams.mod
            if (newconfig):
                var_indices = [i + DataParams.mod-1 for i in integers]
            else:
                var_indices = [i + DataParams.mod for i in integers]

            var_tensor = torch.tensor(var_indices).view(N, 1)

            if (newconfig):
                equal_tensor = torch.full((N, 1), 2*DataParams.mod + Tokens.equal, dtype=torch.int64)
            else:
                equal_tensor = torch.full((N, 1), DataParams.mod, dtype=torch.int64)

            result_tensor = torch.tensor(Z).view(N, 1)
            d_tensor = d_tensor.view(N, 1)

            cur_question_tensor = torch.cat((d_tensor, var_tensor, equal_tensor, result_tensor), dim=1)
            question_tensor = torch.cat((question_tensor, cur_question_tensor), dim=0)

            if bidir:
                cur_question_tensor = torch.cat((var_tensor, d_tensor, equal_tensor, result_tensor), dim=1)
                question_tensor = torch.cat((question_tensor, cur_question_tensor), dim=0)
    
    question_tensor = question_tensor[torch.randperm(question_tensor.size(0))]
    #print(f"Number of questions: {question_tensor.size(0)}")
    return question_tensor.long()


def create_data(int_by_set, prop_val=0.1, num_questions=6,newconfig=True):

    '''
    Create train and validation sets
    We create X1 and X2 as train sets consisting of [DtQ1, DfQ2] and [Dt3, Df4] respectively.
    These contain both questions and definitions.
    Test sets are broken down into the individual groups (i.e. DtQ1, Dt3, etc...).
    These consist *only of questions*.
    '''

    train_sets = {'X1':torch.empty((0, 4)), 'X2':torch.empty((0, 4))}
    test_sets = {'DtQ1':torch.empty((0, 4)), 'DfQ2':torch.empty((0, 4)), 'Dt3':torch.empty((0, 4)), 'Df4':torch.empty((0, 4))}

    for dataset in int_by_set:

        cur_integers = int_by_set[dataset]

        cur_questions = create_questions(cur_integers)
        
        if dataset in ['DtQ1', 'Dt3']:
            cur_defs = create_definitions(cur_integers, reliable_tag=True, reliable_def=True)

        elif dataset in ['DfQ2']:
            cur_defs = create_definitions(cur_integers, reliable_tag=False, reliable_def=False)

        elif dataset in ['Df4']:
            cur_defs = create_definitions(cur_integers, reliable_tag=False, reliable_def=True)

        # pad definitions to match question size

        cur_defs = F.pad(cur_defs, (0, 1), value=2*DataParams.mod + Tokens.padding)

        # split into train and validation set

        if dataset in ['DtQ1', 'DfQ2']:

            cur_questions_dataset = TensorDataset(cur_questions)

            mask = torch.zeros(cur_questions.size(0), dtype=torch.bool)
            if newconfig:
                cur_vars = [i + DataParams.mod-1 for i in int_by_set[dataset]]
            else:
                cur_vars = [i + DataParams.mod for i in int_by_set[dataset]]

            used_vars = {i:0 for i in cur_vars}
            test_indices = []
            for i, row in enumerate(cur_questions):

                used = False

                for var in row:
                    var = int(var)

                    if var in cur_vars:

                        if used_vars[var] == TrainParams.val_questions:
                            used = True
                            break

                        if not used:
                
                            used_vars[var] += 1
                            test_indices.append(i)
                
            mask[test_indices] = True

            test_qs = cur_questions[mask]
            train_qs = cur_questions[~mask]


            train_sets['X1'] = torch.cat((train_sets['X1'], cur_defs, train_qs), dim=0)

            test_sets[dataset] = torch.cat((test_sets[dataset], test_qs), dim=0)

        if dataset in ['Dt3', 'Df4']:

            train_sets['X2'] = torch.cat((train_sets['X2'], cur_defs), dim=0)

            test_sets[dataset] = torch.cat((test_sets[dataset], cur_questions), dim=0)

    return train_sets, test_sets


def evaluate(model, val_loader, device):

    correct = 0
    loss = 0.
    total = 0
    batches = 0

    for batch in val_loader:
        inputs = batch[0].to(device)

        labels = inputs[:, -1]
            
        with torch.no_grad():
            output = model(inputs)
            loss += loss_fn(output, inputs).item()
            correct += (torch.argmax(output[:,-2,:], dim=1) == labels).sum()
        
        total += inputs.shape[0]
        batches += 1

    acc = correct / total
    loss = loss/batches
    return acc, loss

def orig_loss_fn(logits, tokens):
    # only compare the z position i.e. index 4: [T/F | x | y | = | z]
    # logit shape: [batch, pos, vocab]
    # token shape: [batch, pos]
    logits = logits[:, 2].unsqueeze(1)
    tokens = tokens[:, 3].unsqueeze(1)
    log_probs = logits.log_softmax(-1)
    correct_log_probs = log_probs.gather(-1, tokens[..., None])[..., 0]
    return -correct_log_probs.mean()

def loss_fn(logits, tokens):

    # check whether question or def and compute loss appropriately
    # logit shape: [batch, pos, vocab]
    # token shape: [batch, pos]

    mask = (tokens[:, 3] == 2*DataParams.mod + Tokens.padding)

    def_logits = logits[mask]
    def_tokens = tokens[mask].long()

    q_logits = logits[~mask]
    q_tokens = tokens[~mask].long()

    def_logits = def_logits[:, 1].unsqueeze(1)
    def_tokens = def_tokens[:, 2].unsqueeze(1)
    def_log_probs = def_logits.log_softmax(-1)
    def_correct_log_probs = def_log_probs.gather(-1, def_tokens[..., None])[..., 0]
    
    q_logits = q_logits[:, 2].unsqueeze(1)
    q_tokens = q_tokens[:, 3].unsqueeze(1)
    q_log_probs = q_logits.log_softmax(-1)
    q_correct_log_probs = q_log_probs.gather(-1, q_tokens[..., None])[..., 0]

    return -(def_correct_log_probs.sum() + q_correct_log_probs.sum())/(def_correct_log_probs.shape[0] + q_correct_log_probs.shape[0])

def check_save_model(model, args, cur_step):

    if cur_step in args.save_steps:
        if args.saved_model_name:
             model_name = f"{args.saved_model_name}_step_{cur_step}.pt"
        else:
             model_name = f"oocl_{DataParams.mod}_step_{cur_step}.pt"
    
        from datetime import datetime
        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    
        model_path = os.path.join(args.model_path, timestamp+'__'+model_name)
        print(f'SAVING TO {model_path}')
        torch.save(model.state_dict(), model_path)

def train_w_orig(model, train_sets, test_sets, orig_args, train_params, args):

    '''
    Load saved model
    Train for A epochs on X1 and then B epochs on X2
    At the end of each epoch, get validation accuracy on the corresponding questions
    Wandb save val accuracies by test_set name

    '''


    batch_size = train_params.batch_size

    # unpack orig_args for use in valid_loader

    x_vv, y_vv, z_vv, train_vv, valid_vv = orig_args
    
    device = get_device()

    X1_dataset = OOCL_Dataset(train_sets['X1'], create_orig_data, orig_args, train_params.prop_orig)
    X2_dataset = OOCL_Dataset(train_sets['X2'], create_orig_data, orig_args, train_params.prop_orig)

    X1_loader = DataLoader(X1_dataset, batch_size=batch_size, shuffle=True)
    X2_loader = DataLoader(X2_dataset, batch_size=batch_size, shuffle=True)

    orig_data_valid_loader = yield_data(train_params.batch_size, x_vv, y_vv, z_vv, valid_vv)

    test_set_loaders = {}

    for s in test_sets:
        test_set_loaders[s] = DataLoader(TensorDataset(test_sets[s].to(dtype=torch.int)), batch_size=train_params.batch_size, shuffle=False)

    optimizer = torch.optim.AdamW(model.parameters(), lr=train_params.lr, betas=train_params.betas, weight_decay=train_params.wd)

    losses = []

    for epoch in range(train_params.num_epochs_X1):
        model.train()
        for tokens in X1_loader:
            tokens = tokens.squeeze(1)
            tokens = tokens.to(device)
            logits = model(tokens)
            loss = loss_fn(logits, tokens)
            loss.backward()
            if train_params.max_grad_norm is not None:
                torch.nn.utils.clip_grad_norm_(model.parameters(), train_params.max_grad_norm)
            optimizer.step()
            optimizer.zero_grad()
            losses.append(loss.item())


        train_loss = np.mean(losses)
        model.eval()
        val_acc_DtQ1, val_loss1 = evaluate(model, test_set_loaders['DtQ1'], device)
        val_acc_DfQ2, val_loss2 = evaluate(model, test_set_loaders['DfQ2'], device)
        val_acc_Dt3, _ = evaluate(model, test_set_loaders['Dt3'], device)
        val_acc_Df4, _ = evaluate(model, test_set_loaders['Df4'], device)

        # evaluate performance on orig data validation set

        with torch.no_grad():
            # logging.info(tokens)
            tokens = next(orig_data_valid_loader)
            tokens = tokens.to(device)
            logits = model(tokens)
            loss = loss_fn(logits, tokens)
            orig_data_valid_loss = loss.item()

        wandb.log({
                    "train/loss": train_loss,
                    "valid_DtQ1/acc": val_acc_DtQ1,
                    "valid_DfQ2/acc": val_acc_DfQ2,
                    "valid_Dt3/acc": val_acc_Dt3,
                    "valid_Df4/acc": val_acc_Df4,
                    "val/loss": (val_loss1+val_loss2)/2,
                    "orig_data_valid_loss": orig_data_valid_loss
                })
        
        check_save_model(model, args, epoch)
        
    for epoch in range(train_params.num_epochs_X2):
        model.train()
        for tokens in X2_loader:
            tokens = tokens.squeeze(1)
            tokens = tokens.to(device)
            logits = model(tokens)

            loss = loss_fn(logits, tokens)
            loss.backward()
            if train_params.max_grad_norm is not None:
                torch.nn.utils.clip_grad_norm_(model.parameters(), train_params.max_grad_norm)
            optimizer.step()
            optimizer.zero_grad()
            losses.append(loss.item())


        train_loss = np.mean(losses)
        model.eval()
        val_acc_DtQ1, _ = evaluate(model, test_set_loaders['DtQ1'], device)
        val_acc_DfQ2, _ = evaluate(model, test_set_loaders['DfQ2'], device)
        val_acc_Dt3, val_loss1 = evaluate(model, test_set_loaders['Dt3'], device)
        val_acc_Df4, val_loss2 = evaluate(model, test_set_loaders['Df4'], device)


        wandb.log({
                    "train/loss": train_loss,
                    "valid_DtQ1/acc": val_acc_DtQ1,
                    "valid_DfQ2/acc": val_acc_DfQ2,
                    "valid_Dt3/acc": val_acc_Dt3,
                    "valid_Df4/acc": val_acc_Df4,
                    "val/loss": (val_loss1+val_loss2)/2
                })

        check_save_model(model, args, train_params.num_epochs_X1 + epoch)





# def train_oocl_stage_1_2(args, seed)

In [2]:
def train_oocl_stage_1_2(args, seed):
    args.seed = seed
    
    seed_all(args.seed)
    model_path = args.model_path + args.model_name
    
    if args.seed:
    
        torch.manual_seed(args.seed)
        random.seed(args.seed)
    
    mod = DataParams.mod
    # divide the integers into 4 equally sized sets
    size = mod // 4
    rem = mod % 4
    
    numbers = list(range(DataParams.mod))
    random.shuffle(numbers)
    
    train_params = TrainParams()
        
    int_by_set = {}
    int_by_set['DtQ1'] = numbers[0:size]
    int_by_set['DfQ2'] = numbers[size:2*size]
    int_by_set['Dt3'] = numbers[2*size:3*size]
    int_by_set['Df4'] = numbers[3*size:mod]
    
    
    
    
    new_transformer_config = transformer_config
    new_transformer_config.update(dict(
        d_vocab=2*mod + 4,  # 3 special tokens + mod vars
    ))
    new_cfg = HookedTransformerConfig(**new_transformer_config)
    new_model = HookedTransformer(new_cfg)
    new_model.load_state_dict(torch.load(model_path))
    new_model.to(get_device())
    # load wandb
    
    # wandb.login(key=os.getenv("WANDB_API_KEY"))
    
    dir_models = "models/transformers/"
    Path(dir_models).mkdir(exist_ok=True, parents=True)
    
    # model.load_state_dict(torch.load(os.path.join(dir_models, "interrupted.pt")))
    
    
    exp_name = f'seed={args.seed}'
    name = f"oocl__{args.model_name}"
    
    wandb.init(
        project="misha-iml",
        group=args.wandb_group_name,
        name=exp_name,
        config={
            **asdict(DataParams()),
            **asdict(train_params),
            **new_transformer_config,
        }
    )
    print(f'{args.seed=}')
    print(f'int_by_set')
    print(int_by_set)
    # print('Ints by set:\n')
    
    ints_by_set={}
    for k in int_by_set:
    
        print(k)
        print(int_by_set[k])
        wandb.log({f"{k}": int_by_set[k]})
        ints_by_set[f"{k}"]=int_by_set[k]
        print("\n")
    
    torch.save(ints_by_set,f"./models/{name}_ints_by_set.pt")
    
    
    train_sets, test_sets = create_data(int_by_set)
    
    
    data_name = f"data_oocl_{DataParams.mod}.pt"
    
    from datetime import datetime
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    
    data_name = os.path.join(args.model_path, timestamp+'__'+data_name)
    print(f'SAVING TO {data_name}')
    torch.save((train_sets, test_sets), data_name)
    
    
    orig_args = make_tbl_mask(mod=DataParams.mod, method='prod', frac_held_out=train_params.orig_held_out_frac)
    
    train_w_orig(new_model, train_sets, test_sets, orig_args, train_params, args)
    
    wandb.finish()


# The Juice

In [3]:
from argparse import Namespace

@dataclass
class TrainParams:
    n_steps: int = int(1e8)
    batch_size: int = 128
    lr: float = 0.0001
    wd: float = 0.1
    betas: tuple = (0.9, 0.98)
    max_grad_norm: float = 1.0
    num_epochs_X1: int = 1000
    num_epochs_X2: int = 1000
    prop_orig: float = 0.25
    orig_held_out_frac: float = 0.01
    swap_defs: bool = False # whether to swap the order of the defs
    val_questions: int = 9


SEEDS = [0,1,2,3,4,5]

In [4]:
args = Namespace(
    model_path='./models/transformers/', 
    model_name='grokking_prod_120_2_0.1_attnonly_True20240709_180213.pt', 
    wandb_group_name='2layer', 
    saved_model_name=None,
    seed=-1, save_steps=[500, 950])


In [5]:
for seed in SEEDS:
    train_oocl_stage_1_2(args, seed)

Moving model to device:  cpu


[34m[1mwandb[0m: Currently logged in as: [33mkilianovski[0m. Use [1m`wandb login --relogin`[0m to force relogin


args.seed=0
int_by_set
{'DtQ1': [57, 103, 99, 3, 111, 59, 73, 22, 30, 25, 47, 69, 23, 67, 75, 16, 85, 29, 2, 76, 8, 107, 43, 84, 98, 44, 46, 115, 80, 37], 'DfQ2': [10, 24, 48, 50, 104, 93, 13, 52, 21, 112, 72, 91, 35, 19, 6, 102, 95, 20, 118, 114, 28, 34, 54, 88, 94, 15, 14, 109, 58, 83], 'Dt3': [4, 81, 82, 41, 31, 86, 63, 0, 110, 11, 1, 92, 7, 116, 66, 56, 119, 70, 26, 78, 40, 55, 105, 89, 71, 60, 42, 87, 9, 117], 'Df4': [39, 18, 77, 90, 68, 32, 79, 12, 96, 101, 36, 17, 64, 27, 74, 45, 61, 38, 106, 100, 51, 62, 65, 33, 5, 53, 113, 97, 49, 108]}
DtQ1
[57, 103, 99, 3, 111, 59, 73, 22, 30, 25, 47, 69, 23, 67, 75, 16, 85, 29, 2, 76, 8, 107, 43, 84, 98, 44, 46, 115, 80, 37]


DfQ2
[10, 24, 48, 50, 104, 93, 13, 52, 21, 112, 72, 91, 35, 19, 6, 102, 95, 20, 118, 114, 28, 34, 54, 88, 94, 15, 14, 109, 58, 83]


Dt3
[4, 81, 82, 41, 31, 86, 63, 0, 110, 11, 1, 92, 7, 116, 66, 56, 119, 70, 26, 78, 40, 55, 105, 89, 71, 60, 42, 87, 9, 117]


Df4
[39, 18, 77, 90, 68, 32, 79, 12, 96, 101, 36, 17, 64, 2

  result_tensor = torch.tensor(Z).view(N, 1)


SAVING TO ./models/transformers/20240709_203402__oocl_120_step_500.pt
SAVING TO ./models/transformers/20240709_203536__oocl_120_step_950.pt


VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
orig_data_valid_loss,▄▇▄▂▂▂▂▂▂▂▂▁▁▁▁█▁▂▁▁▁▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train/loss,█▃▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
val/loss,▃▃▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▇▇▇▇▇▇▇▇▇▇▇█████████
valid_Df4/acc,▁▄▂▃▃▂▁▂▃▂▂▂▁▁▂▃▁▁▃▁▆██▇▇▇▆█▅▅▄▅▇▆▆▆▅▇██
valid_DfQ2/acc,▁▄▅▆▆▆▆▇▆▇▇▇▇▇▇█████▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇
valid_Dt3/acc,▇▃▁▄▄▃▃▅▃▂▂▂▁▁▁▁▂▂▂▂▇▃▆█▇▇▇▆▅▇▆▅▅▅▇▆▆█▇▇
valid_DtQ1/acc,▁▄▄▅▅▅▅▆▆▇▆▆▇▇▇▇▇▇█▇▇▇█▇█▇▇▇▇█▇▇█▇▇▇▇███

0,1
orig_data_valid_loss,0.00208
train/loss,0.08508
val/loss,12.10125
valid_Df4/acc,0.06667
valid_DfQ2/acc,0.47037
valid_Dt3/acc,0.07778
valid_DtQ1/acc,0.47037


Moving model to device:  cpu


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.01113830879999619, max=1.0)…

args.seed=1
int_by_set
{'DtQ1': [119, 81, 20, 90, 68, 41, 4, 79, 38, 10, 14, 95, 22, 78, 114, 71, 73, 52, 94, 9, 82, 116, 96, 93, 39, 36, 105, 50, 16, 33], 'DfQ2': [24, 6, 84, 5, 35, 74, 11, 104, 43, 112, 30, 66, 117, 25, 51, 31, 98, 59, 19, 64, 42, 65, 80, 113, 45, 61, 21, 47, 7, 18], 'Dt3': [99, 46, 88, 23, 103, 53, 86, 37, 58, 76, 118, 44, 91, 70, 111, 56, 28, 67, 85, 54, 27, 106, 1, 69, 107, 87, 2, 101, 40, 13], 'Df4': [75, 29, 92, 34, 109, 89, 0, 110, 77, 55, 49, 3, 62, 12, 26, 100, 48, 83, 60, 57, 115, 63, 15, 32, 8, 97, 102, 108, 72, 17]}
DtQ1
[119, 81, 20, 90, 68, 41, 4, 79, 38, 10, 14, 95, 22, 78, 114, 71, 73, 52, 94, 9, 82, 116, 96, 93, 39, 36, 105, 50, 16, 33]


DfQ2
[24, 6, 84, 5, 35, 74, 11, 104, 43, 112, 30, 66, 117, 25, 51, 31, 98, 59, 19, 64, 42, 65, 80, 113, 45, 61, 21, 47, 7, 18]


Dt3
[99, 46, 88, 23, 103, 53, 86, 37, 58, 76, 118, 44, 91, 70, 111, 56, 28, 67, 85, 54, 27, 106, 1, 69, 107, 87, 2, 101, 40, 13]


Df4
[75, 29, 92, 34, 109, 89, 0, 110, 77, 55, 49, 3, 62, 1

  result_tensor = torch.tensor(Z).view(N, 1)


SAVING TO ./models/transformers/20240709_203952__oocl_120_step_500.pt
SAVING TO ./models/transformers/20240709_204122__oocl_120_step_950.pt


VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
orig_data_valid_loss,▇█▃▃▂▂▂▂▂▂▂▁▂▂▂▁▁▁▁▁▁▄▁▂▂▁▁▁▂▁▁▁▁▁▁▁▁▁▂▁
train/loss,█▃▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
val/loss,▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁████████████████████
valid_Df4/acc,▁▂▃▃▁▃▄▁▂▄▂▂▅▂▃▄▂▅▄▄▃▂▂▃▃▃▃▃▃▃▂▃▅▅▆▇▆█▇▆
valid_DfQ2/acc,▁▅▆▆▆▆▇▇▇▇▇▇▇▇▇▇█▇██▇██▇▇██▇▇▇▇▇▇█▇██▇▇▇
valid_Dt3/acc,▁▂▅▃▆▆▅▆▅▄▃▄▇▅▆▅▆▆▆▇▇▄▄▃▄▅█▆█▆▇▆▆▃▅▅▂▂▂▃
valid_DtQ1/acc,▁▄▄▅▅▆▆▇▇▇▇▇█▇▇█▇███▇▇█▇▇███▇███████████

0,1
orig_data_valid_loss,0.00404
train/loss,0.07906
val/loss,13.63138
valid_Df4/acc,0.06944
valid_DfQ2/acc,0.48148
valid_Dt3/acc,0.03056
valid_DtQ1/acc,0.55185


Moving model to device:  cpu


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.011167702311104222, max=1.0…

args.seed=2
int_by_set
{'DtQ1': [117, 98, 8, 2, 38, 89, 12, 62, 76, 0, 5, 1, 60, 90, 109, 116, 51, 6, 91, 13, 18, 37, 118, 24, 105, 52, 16, 26, 68, 66], 'DfQ2': [9, 19, 78, 49, 86, 100, 14, 58, 44, 35, 36, 96, 43, 61, 42, 73, 99, 107, 97, 72, 31, 15, 33, 102, 104, 111, 63, 25, 84, 115], 'Dt3': [28, 83, 82, 45, 93, 53, 57, 23, 70, 101, 80, 95, 17, 75, 41, 79, 88, 29, 30, 22, 71, 112, 67, 54, 48, 40, 59, 114, 3, 119], 'Df4': [34, 64, 56, 69, 47, 65, 92, 50, 81, 55, 20, 87, 74, 4, 113, 27, 77, 32, 39, 85, 103, 94, 21, 106, 46, 10, 11, 7, 108, 110]}
DtQ1
[117, 98, 8, 2, 38, 89, 12, 62, 76, 0, 5, 1, 60, 90, 109, 116, 51, 6, 91, 13, 18, 37, 118, 24, 105, 52, 16, 26, 68, 66]


DfQ2
[9, 19, 78, 49, 86, 100, 14, 58, 44, 35, 36, 96, 43, 61, 42, 73, 99, 107, 97, 72, 31, 15, 33, 102, 104, 111, 63, 25, 84, 115]


Dt3
[28, 83, 82, 45, 93, 53, 57, 23, 70, 101, 80, 95, 17, 75, 41, 79, 88, 29, 30, 22, 71, 112, 67, 54, 48, 40, 59, 114, 3, 119]


Df4
[34, 64, 56, 69, 47, 65, 92, 50, 81, 55, 20, 87, 74, 4

  result_tensor = torch.tensor(Z).view(N, 1)


SAVING TO ./models/transformers/20240709_204615__oocl_120_step_500.pt
SAVING TO ./models/transformers/20240709_204800__oocl_120_step_950.pt


VBox(children=(Label(value='0.001 MB of 0.023 MB uploaded\r'), FloatProgress(value=0.04343596684148198, max=1.…

0,1
orig_data_valid_loss,▅█▂▃▄▂▂▂▂▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train/loss,█▃▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
val/loss,▂▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▇▇▇▇▇█▇▇████████████
valid_Df4/acc,▆▂▅▅▃▅▁▄▄▄▄▃▃▄▃▂▄▁▃▂▇▇▆▆▇▇▇█▇▅▅▄▄▄▅▅▅▄▅▆
valid_DfQ2/acc,▁▅▅▆▆▆▇▆▇▇▇▇▇███▇███▇▇▇█▇▇▇████▇▇▇▇▇▇▇▇█
valid_Dt3/acc,▇▅▅▇▅▆▇▅▅▇▅▆▅▅▅▅▅▅▅▅▄▁▃▄▅█▂▂▄▅▅▇▆▇▄▅▄▅▅▅
valid_DtQ1/acc,▁▅▅▅▆▆▆▆▆▆▆▇▇▇▇██▇▇██▇█▇▇██▇███▇█▇▇█████

0,1
orig_data_valid_loss,0.00086
train/loss,0.09395
val/loss,12.16662
valid_Df4/acc,0.04444
valid_DfQ2/acc,0.48889
valid_Dt3/acc,0.04167
valid_DtQ1/acc,0.45556


Moving model to device:  cpu


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.011168793522220869, max=1.0…

args.seed=3
int_by_set
{'DtQ1': [39, 83, 71, 93, 90, 51, 52, 116, 23, 35, 79, 28, 59, 84, 15, 86, 45, 98, 101, 109, 92, 57, 78, 9, 0, 65, 115, 31, 25, 11], 'DfQ2': [2, 61, 119, 99, 7, 18, 48, 40, 13, 6, 41, 58, 111, 87, 10, 44, 42, 67, 68, 88, 43, 21, 14, 64, 62, 37, 112, 22, 36, 102], 'Dt3': [53, 32, 26, 82, 55, 105, 27, 63, 72, 4, 12, 46, 17, 56, 73, 96, 54, 89, 76, 97, 34, 3, 38, 5, 118, 20, 110, 85, 108, 49], 'Df4': [66, 94, 95, 103, 19, 81, 50, 100, 104, 117, 106, 91, 24, 29, 70, 33, 113, 107, 1, 114, 8, 74, 80, 60, 77, 47, 16, 69, 75, 30]}
DtQ1
[39, 83, 71, 93, 90, 51, 52, 116, 23, 35, 79, 28, 59, 84, 15, 86, 45, 98, 101, 109, 92, 57, 78, 9, 0, 65, 115, 31, 25, 11]


DfQ2
[2, 61, 119, 99, 7, 18, 48, 40, 13, 6, 41, 58, 111, 87, 10, 44, 42, 67, 68, 88, 43, 21, 14, 64, 62, 37, 112, 22, 36, 102]


Dt3
[53, 32, 26, 82, 55, 105, 27, 63, 72, 4, 12, 46, 17, 56, 73, 96, 54, 89, 76, 97, 34, 3, 38, 5, 118, 20, 110, 85, 108, 49]


Df4
[66, 94, 95, 103, 19, 81, 50, 100, 104, 117, 106, 91, 24,

  result_tensor = torch.tensor(Z).view(N, 1)


SAVING TO ./models/transformers/20240709_205248__oocl_120_step_500.pt
SAVING TO ./models/transformers/20240709_205431__oocl_120_step_950.pt


VBox(children=(Label(value='0.001 MB of 0.011 MB uploaded\r'), FloatProgress(value=0.09172099669554346, max=1.…

0,1
orig_data_valid_loss,▃█▃▃▂▃▁▂▂▁▂▁▁▁▁▁▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train/loss,█▃▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
val/loss,▃▂▂▂▂▂▂▂▁▂▁▁▁▁▁▁▁▁▁▁▇▇▇▇▇███████████████
valid_Df4/acc,▇▄▄▄▃▃▃▂▃▂▂▃▂▃▁▁▁▂▄▃▇▆▆▇▆▆▅▄▆█▇▅▆█▆▇▆▇▇▅
valid_DfQ2/acc,▁▆▆▆▆▇▇▇▇▇▇▇▇▇████████▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇
valid_Dt3/acc,▂▆▄▃▄▆▆▄▆▅▇▆▇▃▇▆▇█▆▆▄▄▂▄▄▁▃▁▂▂▂▃▄▃▄▄▃▄▄▄
valid_DtQ1/acc,▁▅▅▆▆▇▇▇▇▇▇▇▇▇▇▇████▇▇███▇▇▇▇▇█▇▇█▇█▇█▇█

0,1
orig_data_valid_loss,0.00158
train/loss,0.08932
val/loss,12.18048
valid_Df4/acc,0.04722
valid_DfQ2/acc,0.3963
valid_Dt3/acc,0.04722
valid_DtQ1/acc,0.51111


Moving model to device:  cpu


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.01114898657777101, max=1.0)…

args.seed=4
int_by_set
{'DtQ1': [48, 25, 45, 83, 32, 6, 53, 73, 58, 105, 85, 72, 86, 23, 15, 71, 76, 93, 101, 114, 100, 57, 59, 18, 67, 111, 104, 75, 1, 56], 'DfQ2': [94, 81, 69, 20, 95, 78, 74, 115, 62, 42, 26, 4, 12, 9, 89, 44, 119, 41, 91, 116, 17, 40, 29, 63, 5, 110, 109, 106, 16, 113], 'Dt3': [14, 10, 103, 55, 36, 54, 52, 90, 65, 88, 87, 0, 118, 108, 84, 99, 60, 79, 98, 31, 64, 49, 43, 77, 112, 47, 80, 107, 39, 21], 'Df4': [24, 34, 96, 82, 3, 27, 33, 117, 22, 35, 46, 68, 66, 28, 7, 97, 102, 37, 70, 51, 2, 8, 11, 19, 61, 50, 92, 13, 38, 30]}
DtQ1
[48, 25, 45, 83, 32, 6, 53, 73, 58, 105, 85, 72, 86, 23, 15, 71, 76, 93, 101, 114, 100, 57, 59, 18, 67, 111, 104, 75, 1, 56]


DfQ2
[94, 81, 69, 20, 95, 78, 74, 115, 62, 42, 26, 4, 12, 9, 89, 44, 119, 41, 91, 116, 17, 40, 29, 63, 5, 110, 109, 106, 16, 113]


Dt3
[14, 10, 103, 55, 36, 54, 52, 90, 65, 88, 87, 0, 118, 108, 84, 99, 60, 79, 98, 31, 64, 49, 43, 77, 112, 47, 80, 107, 39, 21]


Df4
[24, 34, 96, 82, 3, 27, 33, 117, 22, 35, 46, 68, 

  result_tensor = torch.tensor(Z).view(N, 1)


SAVING TO ./models/transformers/20240709_205914__oocl_120_step_500.pt
SAVING TO ./models/transformers/20240709_210055__oocl_120_step_950.pt


VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
orig_data_valid_loss,▃█▃▂▂▂▁▁▁▁▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train/loss,█▃▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
val/loss,▃▃▃▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁██████████████████▇▇
valid_Df4/acc,▁▂▂▂▂▃▃▂▃▃▆▄▄▄▄▅▆▄▆▅▄▄▄▆▅▆▇▅█▇▆▆▅▆█▅▇▆██
valid_DfQ2/acc,▁▄▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇▇▇▇▇▇█████▇█▇█▇▇▇▇▇▇█
valid_Dt3/acc,▆▅▂▂▂▂▃▁▄▂▃▃▃▅▄▃▄▄▄▄▅▃▄▅▄▄▄▃▃▂▅▅▅▆▇▅▅▇██
valid_DtQ1/acc,▁▅▅▅▆▆▆▇▆▇▇▇▇▇▇▇█▇█▇█▇▇▇█████▇█▇▇▇▇▇▇▇██

0,1
orig_data_valid_loss,0.00175
train/loss,0.08164
val/loss,12.25665
valid_Df4/acc,0.075
valid_DfQ2/acc,0.47407
valid_Dt3/acc,0.06111
valid_DtQ1/acc,0.47037


Moving model to device:  cpu


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.011167702777776059, max=1.0…

args.seed=5
int_by_set
{'DtQ1': [61, 58, 106, 29, 62, 63, 24, 28, 46, 34, 100, 75, 15, 8, 71, 118, 7, 89, 103, 42, 5, 33, 115, 51, 18, 12, 36, 95, 93, 84], 'DfQ2': [50, 66, 97, 92, 57, 109, 39, 65, 41, 11, 86, 30, 44, 22, 54, 81, 55, 116, 43, 53, 77, 72, 64, 90, 19, 112, 4, 82, 85, 108], 'Dt3': [10, 68, 110, 2, 38, 87, 70, 114, 76, 96, 25, 40, 37, 74, 21, 91, 26, 78, 0, 80, 16, 56, 104, 119, 17, 9, 102, 49, 23, 35], 'Df4': [52, 27, 1, 98, 73, 13, 69, 48, 105, 60, 47, 14, 20, 6, 111, 31, 99, 59, 113, 3, 67, 83, 117, 107, 88, 101, 45, 94, 32, 79]}
DtQ1
[61, 58, 106, 29, 62, 63, 24, 28, 46, 34, 100, 75, 15, 8, 71, 118, 7, 89, 103, 42, 5, 33, 115, 51, 18, 12, 36, 95, 93, 84]


DfQ2
[50, 66, 97, 92, 57, 109, 39, 65, 41, 11, 86, 30, 44, 22, 54, 81, 55, 116, 43, 53, 77, 72, 64, 90, 19, 112, 4, 82, 85, 108]


Dt3
[10, 68, 110, 2, 38, 87, 70, 114, 76, 96, 25, 40, 37, 74, 21, 91, 26, 78, 0, 80, 16, 56, 104, 119, 17, 9, 102, 49, 23, 35]


Df4
[52, 27, 1, 98, 73, 13, 69, 48, 105, 60, 47, 14, 20, 6

  result_tensor = torch.tensor(Z).view(N, 1)


SAVING TO ./models/transformers/20240709_210537__oocl_120_step_500.pt
SAVING TO ./models/transformers/20240709_210715__oocl_120_step_950.pt


VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
orig_data_valid_loss,▄█▇▅▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▂▁▁▁▁▁▁▁
train/loss,█▃▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
val/loss,▃▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▇▇▇▇████████████████
valid_Df4/acc,▄▆▃▆▄▄▂▂▃▄▅▄▄▃▄▁▂▂▄▂█▃▂▅▅▃▄▄▅▆▃▄▄▄▂▂▂▁▃▃
valid_DfQ2/acc,▁▅▆▆▆▇▇▇▇▇██████▇███▆▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇
valid_Dt3/acc,▄▅▆▅▅▃▂▄▄▃▇▂▄▂▄▃▄▁▂▅▇▇▇▆▇█▇▇▇▄▆▆▇█▇█▄▅▆▅
valid_DtQ1/acc,▁▄▅▅▆▆▇▆▇▇▆▇▇▇▇▇█████▇▇█▇▇▇▇▇▇▇▇█▇▇▇▇█▇█

0,1
orig_data_valid_loss,0.00195
train/loss,0.08914
val/loss,12.82408
valid_Df4/acc,0.04722
valid_DfQ2/acc,0.45556
valid_Dt3/acc,0.06389
valid_DtQ1/acc,0.45556


In [6]:
for seed in [9,10,11,12]:
    train_oocl_stage_1_2(args, seed)

Moving model to device:  cpu


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.011144811566676556, max=1.0…

args.seed=9
int_by_set
{'DtQ1': [22, 102, 97, 100, 41, 105, 84, 83, 15, 96, 33, 24, 106, 80, 44, 66, 88, 40, 55, 74, 87, 63, 18, 29, 114, 79, 19, 112, 65, 61], 'DfQ2': [45, 2, 103, 9, 73, 39, 71, 4, 76, 51, 46, 31, 36, 108, 32, 1, 107, 85, 72, 56, 95, 69, 12, 82, 67, 60, 104, 38, 68, 94], 'Dt3': [58, 117, 27, 91, 3, 115, 62, 99, 7, 52, 111, 25, 101, 113, 35, 50, 81, 116, 11, 53, 28, 26, 37, 13, 49, 8, 75, 109, 16, 14], 'Df4': [6, 30, 98, 20, 54, 92, 57, 90, 21, 48, 93, 5, 89, 118, 70, 42, 10, 77, 119, 64, 43, 0, 86, 110, 23, 17, 34, 47, 78, 59]}
DtQ1
[22, 102, 97, 100, 41, 105, 84, 83, 15, 96, 33, 24, 106, 80, 44, 66, 88, 40, 55, 74, 87, 63, 18, 29, 114, 79, 19, 112, 65, 61]


DfQ2
[45, 2, 103, 9, 73, 39, 71, 4, 76, 51, 46, 31, 36, 108, 32, 1, 107, 85, 72, 56, 95, 69, 12, 82, 67, 60, 104, 38, 68, 94]


Dt3
[58, 117, 27, 91, 3, 115, 62, 99, 7, 52, 111, 25, 101, 113, 35, 50, 81, 116, 11, 53, 28, 26, 37, 13, 49, 8, 75, 109, 16, 14]


Df4
[6, 30, 98, 20, 54, 92, 57, 90, 21, 48, 93, 5, 89, 

  result_tensor = torch.tensor(Z).view(N, 1)


SAVING TO ./models/transformers/20240709_211145__oocl_120_step_500.pt
SAVING TO ./models/transformers/20240709_211326__oocl_120_step_950.pt


VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
orig_data_valid_loss,▆█▇▅▄▄▆▅▅▄▂▂▁▂▁▂▁▄▃▂▂▁▂▂▁▂▁▁▁▂▁▁▂▁▂▁▁▁▁▂
train/loss,█▃▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
val/loss,▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁████████████████████
valid_Df4/acc,▁▄▂▁▁▃▂▃▃▂▄▃▄▄▄▂▃▂▂▅▄▃▂▄▆▇▆▇▇▇▆▆▅▆█▇▆▅▆▆
valid_DfQ2/acc,▁▅▅▅▆▆▆▆▆▇▇▇▇▇▇▇▇▇▇▇▆▆▆▇▇▇▇▇▇▇▇▇▇▇██▇█▇▇
valid_Dt3/acc,▄▃▄▄▃▂▃▃▄▂▂▅▂▁▄▃▂▄▃▅▄▄▄▂▂▄▄▄▃▄▇▅▆▄▄▅▅▆█▆
valid_DtQ1/acc,▁▅▆▆▆▇▇▇▇▇▇█████████▇▇▇▇▇▇▇▇▇███▇▇▇██▇██

0,1
orig_data_valid_loss,0.0025
train/loss,0.0875
val/loss,12.51795
valid_Df4/acc,0.07222
valid_DfQ2/acc,0.40741
valid_Dt3/acc,0.05556
valid_DtQ1/acc,0.45185


Moving model to device:  cpu


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.011112717133336345, max=1.0…

args.seed=10
int_by_set
{'DtQ1': [39, 72, 114, 29, 18, 87, 116, 75, 112, 108, 11, 52, 27, 50, 65, 49, 64, 63, 44, 25, 3, 43, 16, 93, 92, 8, 21, 106, 81, 67], 'DfQ2': [13, 85, 12, 79, 28, 6, 7, 69, 47, 23, 89, 71, 100, 76, 34, 102, 42, 70, 19, 99, 51, 2, 15, 113, 14, 90, 10, 32, 117, 109], 'Dt3': [37, 101, 111, 60, 55, 57, 40, 115, 107, 68, 84, 24, 82, 80, 0, 74, 96, 91, 78, 56, 30, 88, 94, 97, 38, 22, 58, 33, 86, 36], 'Df4': [98, 48, 45, 77, 17, 53, 5, 46, 95, 31, 9, 41, 110, 66, 118, 20, 103, 83, 35, 105, 62, 104, 59, 26, 1, 119, 61, 54, 4, 73]}
DtQ1
[39, 72, 114, 29, 18, 87, 116, 75, 112, 108, 11, 52, 27, 50, 65, 49, 64, 63, 44, 25, 3, 43, 16, 93, 92, 8, 21, 106, 81, 67]


DfQ2
[13, 85, 12, 79, 28, 6, 7, 69, 47, 23, 89, 71, 100, 76, 34, 102, 42, 70, 19, 99, 51, 2, 15, 113, 14, 90, 10, 32, 117, 109]


Dt3
[37, 101, 111, 60, 55, 57, 40, 115, 107, 68, 84, 24, 82, 80, 0, 74, 96, 91, 78, 56, 30, 88, 94, 97, 38, 22, 58, 33, 86, 36]


Df4
[98, 48, 45, 77, 17, 53, 5, 46, 95, 31, 9, 41, 110, 

  result_tensor = torch.tensor(Z).view(N, 1)


SAVING TO ./models/transformers/20240709_211820__oocl_120_step_500.pt
SAVING TO ./models/transformers/20240709_212006__oocl_120_step_950.pt


VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
orig_data_valid_loss,▂█▇▇▄▃▂▂▂▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▂▁▁▁▁▁▁
train/loss,█▃▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
val/loss,▃▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁████████████████████
valid_Df4/acc,▁▅▄▅▅▆▄▄▃▄▄▃▅▆▄▇█▆▇▇▆▄▄▅▄▃▅▃▆▆▇▄▇▅▆▄▆▆▄▃
valid_DfQ2/acc,▁▅▆▆▇▇█████▇████████▇▇▇▇▇███████▇███████
valid_Dt3/acc,▅▄▃▁▄▄▂▅▄▄▅▅▆▅▆▄▄▅▅█▄▄▄▄▇▆▄█▇▆▆▇▇▇▄▇▅█▆▆
valid_DtQ1/acc,▁▄▄▅▆▅▆▆▇▇▇▆▇▇▇▇▇▇██▇▇██▇▇███████▇▇▇▇███

0,1
orig_data_valid_loss,0.00115
train/loss,0.08598
val/loss,12.48827
valid_Df4/acc,0.05556
valid_DfQ2/acc,0.46296
valid_Dt3/acc,0.02778
valid_DtQ1/acc,0.47037


Moving model to device:  cpu


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.01116877222221875, max=1.0)…

args.seed=11
int_by_set
{'DtQ1': [15, 89, 17, 55, 96, 72, 64, 34, 22, 44, 28, 54, 33, 19, 31, 105, 73, 107, 69, 104, 52, 95, 98, 40, 27, 9, 2, 106, 93, 21], 'DfQ2': [61, 100, 42, 62, 46, 43, 77, 118, 39, 13, 70, 90, 49, 87, 51, 47, 117, 6, 36, 84, 86, 97, 32, 14, 48, 85, 16, 45, 92, 53], 'Dt3': [74, 82, 26, 35, 58, 10, 0, 63, 37, 29, 66, 25, 56, 41, 114, 3, 91, 30, 111, 4, 7, 8, 67, 1, 79, 20, 94, 103, 83, 115], 'Df4': [50, 76, 5, 81, 88, 68, 11, 18, 38, 113, 12, 108, 101, 78, 80, 60, 112, 102, 23, 24, 75, 116, 65, 119, 59, 99, 109, 71, 110, 57]}
DtQ1
[15, 89, 17, 55, 96, 72, 64, 34, 22, 44, 28, 54, 33, 19, 31, 105, 73, 107, 69, 104, 52, 95, 98, 40, 27, 9, 2, 106, 93, 21]


DfQ2
[61, 100, 42, 62, 46, 43, 77, 118, 39, 13, 70, 90, 49, 87, 51, 47, 117, 6, 36, 84, 86, 97, 32, 14, 48, 85, 16, 45, 92, 53]


Dt3
[74, 82, 26, 35, 58, 10, 0, 63, 37, 29, 66, 25, 56, 41, 114, 3, 91, 30, 111, 4, 7, 8, 67, 1, 79, 20, 94, 103, 83, 115]


Df4
[50, 76, 5, 81, 88, 68, 11, 18, 38, 113, 12, 108, 101, 78,

  result_tensor = torch.tensor(Z).view(N, 1)


SAVING TO ./models/transformers/20240709_212451__oocl_120_step_500.pt
SAVING TO ./models/transformers/20240709_212635__oocl_120_step_950.pt


VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
orig_data_valid_loss,▄▆▂▂▂▁▁▁▁▁▁▁▁▁▂▁▁▁▁▁▁▁▁▁▁▁▂▂██▁▂▁▂▁▁▁▁▁▁
train/loss,█▃▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
val/loss,▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁████████████████████
valid_Df4/acc,▁▃▁▂▄▅▄▄▂▃▄▂▃▂▁▂▄▂▁▃▇▆▇▇▇██▆▇▄▅▆▅▇▇▆▇███
valid_DfQ2/acc,▁▅▅▅▆▆▆▆▇▇▇▇▇▇▇▇▇███▇█▇██▇██████▇█▇▇██▇▇
valid_Dt3/acc,▆█▆▄▅▄▅▆▆▅▅█▂▆▅▅▄▃▃▁▄▂▃▃▃▂▄▃▃▅▅▃▆▆▅▅▅▆▆▃
valid_DtQ1/acc,▁▅▅▅▆▆▆▆▇▇▇▇█████████▇█▇▇████▇█▇██▇████▇

0,1
orig_data_valid_loss,0.00141
train/loss,0.08431
val/loss,12.92606
valid_Df4/acc,0.07222
valid_DfQ2/acc,0.35556
valid_Dt3/acc,0.03611
valid_DtQ1/acc,0.43333


Moving model to device:  cpu


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.011167385188895828, max=1.0…

args.seed=12
int_by_set
{'DtQ1': [66, 37, 108, 92, 106, 52, 70, 33, 62, 102, 104, 97, 116, 4, 39, 89, 8, 31, 55, 80, 74, 3, 45, 90, 75, 59, 13, 57, 112, 100], 'DfQ2': [77, 83, 27, 95, 105, 98, 109, 36, 32, 72, 49, 68, 15, 81, 107, 19, 86, 63, 114, 23, 22, 16, 42, 50, 12, 5, 46, 41, 111, 118], 'Dt3': [24, 38, 30, 119, 53, 115, 10, 64, 21, 6, 94, 101, 40, 69, 17, 78, 14, 96, 54, 99, 28, 87, 91, 2, 11, 51, 93, 65, 9, 25], 'Df4': [73, 7, 26, 43, 20, 110, 56, 113, 79, 117, 0, 71, 29, 76, 88, 58, 103, 82, 35, 61, 47, 1, 48, 18, 44, 85, 67, 84, 34, 60]}
DtQ1
[66, 37, 108, 92, 106, 52, 70, 33, 62, 102, 104, 97, 116, 4, 39, 89, 8, 31, 55, 80, 74, 3, 45, 90, 75, 59, 13, 57, 112, 100]


DfQ2
[77, 83, 27, 95, 105, 98, 109, 36, 32, 72, 49, 68, 15, 81, 107, 19, 86, 63, 114, 23, 22, 16, 42, 50, 12, 5, 46, 41, 111, 118]


Dt3
[24, 38, 30, 119, 53, 115, 10, 64, 21, 6, 94, 101, 40, 69, 17, 78, 14, 96, 54, 99, 28, 87, 91, 2, 11, 51, 93, 65, 9, 25]


Df4
[73, 7, 26, 43, 20, 110, 56, 113, 79, 117, 0, 71, 2

  result_tensor = torch.tensor(Z).view(N, 1)


SAVING TO ./models/transformers/20240709_213144__oocl_120_step_500.pt
SAVING TO ./models/transformers/20240709_213324__oocl_120_step_950.pt


VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
orig_data_valid_loss,▂█▃▅▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train/loss,█▃▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
val/loss,▃▂▂▂▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▇▇▇▇▇▇█▇████████████
valid_Df4/acc,▇▄▃▅▂▃▂▂▁▂▃▄▄▂▁▃▂▂▂▂█▆▆▂▄▃▃▂▂▂▂▃▂▃▃▂▄▄▅▆
valid_DfQ2/acc,▁▅▅▆▆▆▆▇▇▇▇▇█████▇█▇▇▇▆▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇
valid_Dt3/acc,▅▄▄▄▅▄▄█▆▅▅▇▆▄▅▇▅▆▇▆▆▂▂▃▅▄▄▅▅▅▃▄▄▅▂▂▂▃▁▂
valid_DtQ1/acc,▁▅▅▆▆▆▆▇▇▇▇▇▇███████▇▇▇▇▇▇▇▇▇██▇▇█▇▇▇▇▇▇

0,1
orig_data_valid_loss,0.00124
train/loss,0.09408
val/loss,12.45841
valid_Df4/acc,0.06667
valid_DfQ2/acc,0.44074
valid_Dt3/acc,0.025
valid_DtQ1/acc,0.51111
