In [4]:
from transformers import BertConfig, BertForMaskedLM, BertTokenizer
from transformers import PreTrainedTokenizer
import torch
from torch.utils.data import Dataset, DataLoader
from torch.utils.data import random_split
from torch.nn.functional import pad
from datasets import load_dataset
from tqdm.notebook import tqdm
from itertools import cycle
import wandb

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
wandb.login(key='56853ea1087c0d089c005f3f7aede52740245615')

device

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mmb_lims[0m ([33mlims-ai[0m). Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /home/mburtsev/.netrc


device(type='cuda', index=0)

# Create a custom tokenizer for CA

In [5]:
class SimpleTokenizer:
    def __init__(self):
        self.token_to_id = {'0': 0, '1': 1, '[SEP]': 2, '[PAD]': 3, '[T]': 4} # [T] - thought
        self.id_to_token = {v: k for k, v in self.token_to_id.items()}

    def encode(self, text):
        return [self.token_to_id.get(token, self.token_to_id['[PAD]']) for token in text]

    def decode(self, tokens):
        return ''.join([self.id_to_token.get(token, '[PAD]') for token in tokens])
        #return ''.join([self.id_to_token[token] for token in tokens])

    def __len__(self):
        return len(self.token_to_id)

tokenizer = SimpleTokenizer()

# Define CA dataset

In [6]:
class CellularAutomataDataset(Dataset):
    
    def __init__(self, hf_dataset, steps, horizon=1, pad_length=512):
        self.hf_dataset = hf_dataset
        self.pad_length = pad_length
        self.steps = steps
        self.horizon = horizon

    def __len__(self):
        return len(self.hf_dataset)

    def __getitem__(self, idx):
        row = self.hf_dataset[idx]

        # Assuming the first element 'rule' is not used for model input
        # Concatenate all states except the last one for input
        if (self.steps > (len(row) - 1)):
            self.steps = len(row) - 1
        
        input_ids = []
        for t in range(self.steps):  # Exclude the last state and the 'rule' row
            input_ids.extend(tokenizer.encode(row[f't={t}']))
            input_ids.append(2)  # Append [SEP] token after each state

        # Use the last state as target and calculate its length
        target_ids = []
        
        last_state_key = f't={self.steps + self.horizon - 1}'
        target_ids.extend(tokenizer.encode(row[last_state_key]))
        target_ids.append(2)
        
        orbit_len = len(input_ids)
        target_len = len(target_ids)
        # Padding input and target sequences to the fixed length
        input_ids = pad(torch.tensor(input_ids), (0, self.pad_length - orbit_len), value=3)
        target_ids = pad(torch.tensor(target_ids), (orbit_len, self.pad_length - orbit_len - target_len), value=3)

        # Create a tensor of -100 and replace the last part with target_ids
       # target_labels = torch.full((self.pad_length,), -100)
        #target_labels[:state_length] = target_ids[:state_length]

        return input_ids, target_ids #target_labels


# Configuration for the BERT model

In [7]:
hidden_size=512
num_hidden_layers=4
num_attention_heads=8

config = BertConfig(
    vocab_size=len(tokenizer),
    hidden_size=hidden_size,
    num_hidden_layers=num_hidden_layers,
    num_attention_heads=num_attention_heads,
    #intermediate_size=3072
)
model_name = 'BERT'
model = BertForMaskedLM(config=config)

# Load dataset from Hugging Face repository

In [8]:
seed = 0
batch_size = 400
torch.manual_seed(seed)
steps = 10
state_len = 21 
look_ahead = 2
orbit_len = state_len * steps
seq_len = orbit_len + state_len

dataset_name = "1dCA_r2s20T20"
hf_dataset = load_dataset("mbur/"+dataset_name)['train']
dataset = CellularAutomataDataset(hf_dataset, steps, look_ahead, pad_length=seq_len)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

hf_dataset = load_dataset("mbur/"+dataset_name)['validation']
val_dataset = CellularAutomataDataset(hf_dataset, steps, look_ahead, pad_length=seq_len)
val_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=True)

hf_dataset = load_dataset("mbur/"+dataset_name)['test']
test_dataset = CellularAutomataDataset(hf_dataset, steps, look_ahead, pad_length=seq_len)
test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True)

validation_size = 1000  # Number of samples for validation
small_val_dataset, _ = random_split(val_dataset, [validation_size, len(val_dataset) - validation_size])

small_val_dataloader = DataLoader(small_val_dataset, batch_size=batch_size, shuffle=True)

Repo card metadata block was not found. Setting CardData to empty.
Repo card metadata block was not found. Setting CardData to empty.
Repo card metadata block was not found. Setting CardData to empty.


In [9]:
# Define the ID of the sample you want to inspect
sample_id = 1  # Replace with the ID of the sample you're interested in

# Fetch the specific sample from the dataset
sample_input_ids, sample_target_ids = dataset[sample_id]
sample_input_ids, sample_target_ids = sample_input_ids.unsqueeze(0).to(device), sample_target_ids.unsqueeze(0).to(device)

# Decode the tokens to strings for readability
decoded_input = tokenizer.decode(sample_input_ids[0].tolist())
decoded_target = tokenizer.decode(sample_target_ids[0].tolist())

# Print input, target, prediction, and logits
print("inputs:",len(sample_input_ids),sample_input_ids.tolist())
print("targets:",len(sample_target_ids),sample_target_ids.tolist())
print("Input:", len(sample_input_ids[0].tolist()),decoded_input)
print("Target:", len(sample_target_ids[0].tolist()), decoded_target)

inputs: 1 [[1, 0, 1, 0, 0, 0, 1, 0, 0, 1, 1, 1, 0, 1, 0, 0, 0, 0, 1, 0, 2, 0, 1, 0, 0, 0, 1, 1, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 1, 2, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 1, 0, 1, 0, 2, 0, 0, 0, 1, 1, 0, 0, 0, 0, 1, 0, 1, 1, 0, 0, 1, 0, 1, 0, 1, 2, 0, 0, 1, 0, 1, 1, 0, 0, 1, 0, 0, 0, 1, 1, 1, 0, 1, 0, 1, 0, 2, 0, 1, 0, 0, 0, 1, 1, 1, 1, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 2, 1, 1, 0, 0, 1, 0, 1, 1, 1, 1, 1, 1, 0, 1, 1, 0, 1, 1, 0, 0, 2, 0, 1, 1, 1, 0, 0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 1, 0, 1, 1, 1, 2, 1, 0, 0, 1, 1, 0, 1, 0, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 1, 2, 1, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 0, 0, 1, 0, 2, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3]]
targets: 1 [[3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3,

# Evaluation

In [10]:
def evaluate_model(model, val_dataloader, device):
    model.eval()  # Set the model to evaluation mode
    total, correct = 0, 0
    val_loop = tqdm(val_dataloader, leave=False, desc="Validation")

    with torch.no_grad():
        for input_ids, target_ids in val_loop:
            input_ids, target_ids = input_ids.to(device), target_ids.to(device)

            outputs = model(input_ids)
            predictions = torch.argmax(outputs.logits, dim=-1)

            # Create a mask where target_ids is not -100 (ignoring padding)
            valid_mask = target_ids != 3

            # Apply the mask to filter out -100 values
            valid_predictions = predictions[valid_mask]
            valid_targets = target_ids[valid_mask]

            correct += (valid_predictions == valid_targets).sum().item()
            total += valid_targets.numel()

    accuracy = correct / total if total > 0 else 0
    return accuracy

# Training loop

In [11]:
model.to(device)

lr=1e-4 # learning rate

num_steps = 25_000  # Define the total number of training steps
evaluate_every = 100  # Frequency of evaluation
wndb_log = 10

run_cfg = {
    'dataset': dataset_name,
    'model': model_name,
    'input_size': 512,
    'hidden_size': hidden_size,
    'num_hidden_layers': num_hidden_layers,
    'num_attention_heads': num_attention_heads,
    'batch_size': batch_size,
    'learning rate': lr,
    'seed': seed
}
run_name = dataset_name+'i'+str(steps)+'f'+str(look_ahead)+'LaH_shrt_'+model_name+'_'+str(num_hidden_layers)+'L'+str(num_attention_heads)+'H'+str(hidden_size)+'HD'+'_bs'+str(batch_size)+'lr'+str(lr)
run_name = run_name + '25Kgc_v' + str(seed)
run = wandb.init(
    project="1dCA_ai4math24",
    name = run_name,
    config=run_cfg
)

optimizer = torch.optim.AdamW(model.parameters(), lr)

losses = []

model.train()

loop = tqdm(range(num_steps), desc='Training')
data_iter = cycle(dataloader)  # Create an infinite loop of dataloader

for step in loop:
    
    input_ids, target_ids = next(data_iter)
    input_ids, target_ids = input_ids.to(device), target_ids.to(device)

    outputs = model(input_ids, labels=target_ids)
    loss = outputs.loss

    optimizer.zero_grad()
    loss.backward()
    total_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1)
    optimizer.step()
    current_loss = loss.item()
    loop.set_description(f"Step {step + 1}/{num_steps} - Loss: {current_loss:.4f}")
    loop.set_postfix(loss=current_loss)

    if step % evaluate_every == 0 and step != 0:
        val_accuracy = evaluate_model(model, small_val_dataloader, device)
        print(f"\nEvaluation at Step {step}: Accuracy = {val_accuracy:.4f}", "loss:", loss.item())
        wandb.log({'step': step, 'val acc': val_accuracy})

    if step % wndb_log == 0:
        losses.append(loss.detach().item())      
        wandb.log({'step': step, 'loss': current_loss})

loop.close()
wandb.finish()

Training:   0%|          | 0/25000 [00:00<?, ?it/s]

We strongly recommend passing in an `attention_mask` since your input_ids may be padded. See https://huggingface.co/docs/transformers/troubleshooting#incorrect-output-when-padding-tokens-arent-masked.


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 100: Accuracy = 0.6411 loss: 0.056441355496644974


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 200: Accuracy = 0.6454 loss: 0.056553807109594345


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 300: Accuracy = 0.6566 loss: 0.05475486069917679


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 400: Accuracy = 0.6610 loss: 0.054205190390348434


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 500: Accuracy = 0.6648 loss: 0.05312596634030342


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 600: Accuracy = 0.6642 loss: 0.05308686941862106


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 700: Accuracy = 0.6669 loss: 0.053428348153829575


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 800: Accuracy = 0.6685 loss: 0.051703836768865585


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 900: Accuracy = 0.6661 loss: 0.05131576955318451


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 1000: Accuracy = 0.6685 loss: 0.05317742004990578


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 1100: Accuracy = 0.6694 loss: 0.05195600911974907


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 1200: Accuracy = 0.6635 loss: 0.05336577445268631


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 1300: Accuracy = 0.6729 loss: 0.051665615290403366


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 1400: Accuracy = 0.6735 loss: 0.051557064056396484


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 1500: Accuracy = 0.6760 loss: 0.05100902542471886


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 1600: Accuracy = 0.6756 loss: 0.049929648637771606


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 1700: Accuracy = 0.6784 loss: 0.05064564570784569


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 1800: Accuracy = 0.6786 loss: 0.051244016736745834


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 1900: Accuracy = 0.6801 loss: 0.05163004994392395


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 2000: Accuracy = 0.6847 loss: 0.04967327415943146


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 2100: Accuracy = 0.6840 loss: 0.048911936581134796


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 2200: Accuracy = 0.6881 loss: 0.04894132912158966


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 2300: Accuracy = 0.6912 loss: 0.04887480288743973


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 2400: Accuracy = 0.6990 loss: 0.047543056309223175


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 2500: Accuracy = 0.7083 loss: 0.04540828615427017


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 2600: Accuracy = 0.7155 loss: 0.04606498032808304


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 2700: Accuracy = 0.7156 loss: 0.04756838083267212


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 2800: Accuracy = 0.7214 loss: 0.04662561044096947


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 2900: Accuracy = 0.7215 loss: 0.04392649605870247


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 3000: Accuracy = 0.7242 loss: 0.045428499579429626


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 3100: Accuracy = 0.7260 loss: 0.0453193373978138


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 3200: Accuracy = 0.7312 loss: 0.04377098008990288


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 3300: Accuracy = 0.7331 loss: 0.044171612709760666


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 3400: Accuracy = 0.7395 loss: 0.044467899948358536


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 3500: Accuracy = 0.7422 loss: 0.04293900355696678


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 3600: Accuracy = 0.7425 loss: 0.04311865195631981


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 3700: Accuracy = 0.7461 loss: 0.04248540848493576


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 3800: Accuracy = 0.7508 loss: 0.0401294007897377


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 3900: Accuracy = 0.7508 loss: 0.04009218513965607


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 4000: Accuracy = 0.7533 loss: 0.042830780148506165


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 4100: Accuracy = 0.7525 loss: 0.04214699566364288


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 4200: Accuracy = 0.7558 loss: 0.040173035115003586


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 4300: Accuracy = 0.7545 loss: 0.039440590888261795


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 4400: Accuracy = 0.7583 loss: 0.040872182697057724


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 4500: Accuracy = 0.7577 loss: 0.04128227010369301


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 4600: Accuracy = 0.7550 loss: 0.038930296897888184


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 4700: Accuracy = 0.7580 loss: 0.03885725885629654


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 4800: Accuracy = 0.7615 loss: 0.04049786925315857


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 4900: Accuracy = 0.7578 loss: 0.037639472633600235


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 5000: Accuracy = 0.7591 loss: 0.039833296090364456


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 5100: Accuracy = 0.7578 loss: 0.03847908228635788


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 5200: Accuracy = 0.7589 loss: 0.04073730856180191


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 5300: Accuracy = 0.7602 loss: 0.042239513248205185


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 5400: Accuracy = 0.7613 loss: 0.03999927267432213


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 5500: Accuracy = 0.7607 loss: 0.04005914553999901


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 5600: Accuracy = 0.7616 loss: 0.03919197991490364


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 5700: Accuracy = 0.7616 loss: 0.039927318692207336


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 5800: Accuracy = 0.7587 loss: 0.039748743176460266


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 5900: Accuracy = 0.7637 loss: 0.04054282605648041


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 6000: Accuracy = 0.7640 loss: 0.03679667413234711


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 6100: Accuracy = 0.7631 loss: 0.04067109152674675


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 6200: Accuracy = 0.7635 loss: 0.04024138301610947


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 6300: Accuracy = 0.7655 loss: 0.03801250457763672


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 6400: Accuracy = 0.7643 loss: 0.04131847620010376


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 6500: Accuracy = 0.7634 loss: 0.038570258766412735


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 6600: Accuracy = 0.7632 loss: 0.039624035358428955


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 6700: Accuracy = 0.7627 loss: 0.03760962188243866


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 6800: Accuracy = 0.7633 loss: 0.0380382314324379


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 6900: Accuracy = 0.7617 loss: 0.03880225494503975


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 7000: Accuracy = 0.7677 loss: 0.0378553681075573


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 7100: Accuracy = 0.7670 loss: 0.03918800130486488


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 7200: Accuracy = 0.7665 loss: 0.03969768434762955


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 7300: Accuracy = 0.7691 loss: 0.03834402188658714


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 7400: Accuracy = 0.7680 loss: 0.039086584001779556


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 7500: Accuracy = 0.7662 loss: 0.03816212713718414


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 7600: Accuracy = 0.7710 loss: 0.03783915564417839


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 8400: Accuracy = 0.7754 loss: 0.03797126188874245


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 8500: Accuracy = 0.7737 loss: 0.036668628454208374


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 8600: Accuracy = 0.7751 loss: 0.037615563720464706


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 8700: Accuracy = 0.7745 loss: 0.038057394325733185


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 8800: Accuracy = 0.7768 loss: 0.035912271589040756


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 8900: Accuracy = 0.7791 loss: 0.03671316057443619


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 9000: Accuracy = 0.7766 loss: 0.035571467131376266


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 9100: Accuracy = 0.7746 loss: 0.03799031674861908


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 9200: Accuracy = 0.7773 loss: 0.038238879293203354


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 9300: Accuracy = 0.7799 loss: 0.036212000995874405


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 9400: Accuracy = 0.7803 loss: 0.03750346228480339


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 9500: Accuracy = 0.7784 loss: 0.03780079632997513


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 9600: Accuracy = 0.7786 loss: 0.0380929671227932


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 9700: Accuracy = 0.7778 loss: 0.0369236133992672


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 9800: Accuracy = 0.7800 loss: 0.03588908165693283


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 9900: Accuracy = 0.7792 loss: 0.0376245342195034


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 10000: Accuracy = 0.7809 loss: 0.03773683309555054


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 10100: Accuracy = 0.7834 loss: 0.03708683326840401


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 10200: Accuracy = 0.7837 loss: 0.03737784922122955


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 10300: Accuracy = 0.7834 loss: 0.0354481041431427


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 10400: Accuracy = 0.7841 loss: 0.0359131284058094


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 10500: Accuracy = 0.7813 loss: 0.03835064172744751


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 10600: Accuracy = 0.7832 loss: 0.035025447607040405


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 10700: Accuracy = 0.7841 loss: 0.03753628209233284


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 10800: Accuracy = 0.7854 loss: 0.035481132566928864


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 10900: Accuracy = 0.7866 loss: 0.03699583560228348


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 11000: Accuracy = 0.7885 loss: 0.0368514284491539


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 11100: Accuracy = 0.7863 loss: 0.03499765321612358


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 11200: Accuracy = 0.7853 loss: 0.03690902516245842


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 11300: Accuracy = 0.7893 loss: 0.037852607667446136


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 11400: Accuracy = 0.7903 loss: 0.03820107877254486


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 11500: Accuracy = 0.7905 loss: 0.03447109833359718


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 11600: Accuracy = 0.7899 loss: 0.03584778308868408


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 11700: Accuracy = 0.7901 loss: 0.03508937358856201


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 11800: Accuracy = 0.7906 loss: 0.034363869577646255


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 11900: Accuracy = 0.7935 loss: 0.03404869884252548


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 12000: Accuracy = 0.7931 loss: 0.034410152584314346


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 12100: Accuracy = 0.7934 loss: 0.03402294963598251


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 12200: Accuracy = 0.7956 loss: 0.03541623800992966


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 12300: Accuracy = 0.7938 loss: 0.03659642115235329


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 12400: Accuracy = 0.7944 loss: 0.033376455307006836


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 12600: Accuracy = 0.7969 loss: 0.03553837165236473


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 12700: Accuracy = 0.7986 loss: 0.03277190029621124


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 12800: Accuracy = 0.7977 loss: 0.03465377166867256


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 12900: Accuracy = 0.7985 loss: 0.0346875824034214


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 13000: Accuracy = 0.7995 loss: 0.034843627363443375


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 13100: Accuracy = 0.7997 loss: 0.03531429544091225


IOPub message rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_msg_rate_limit`.

Current values:
ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
ServerApp.rate_limit_window=3.0 (secs)



Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 13900: Accuracy = 0.8071 loss: 0.032867927104234695


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 14000: Accuracy = 0.8052 loss: 0.03472503274679184


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 14100: Accuracy = 0.8064 loss: 0.031666237860918045


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 14200: Accuracy = 0.8064 loss: 0.031866468489170074


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 14300: Accuracy = 0.8056 loss: 0.034037910401821136


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 14400: Accuracy = 0.8061 loss: 0.03122253715991974


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 14500: Accuracy = 0.8066 loss: 0.03256376087665558


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 14600: Accuracy = 0.8085 loss: 0.031837981194257736


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 14700: Accuracy = 0.8091 loss: 0.032890550792217255


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 14800: Accuracy = 0.8081 loss: 0.03529927134513855


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 14900: Accuracy = 0.8083 loss: 0.03278803452849388


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 15000: Accuracy = 0.8102 loss: 0.033059630542993546


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 15100: Accuracy = 0.8086 loss: 0.031455982476472855


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 15200: Accuracy = 0.8068 loss: 0.03230702504515648


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 15300: Accuracy = 0.8081 loss: 0.0328369103372097


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 15400: Accuracy = 0.8086 loss: 0.033525917679071426


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 15500: Accuracy = 0.8095 loss: 0.029632847756147385


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 15600: Accuracy = 0.8084 loss: 0.03323593735694885


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 15700: Accuracy = 0.8100 loss: 0.03272588551044464


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 15800: Accuracy = 0.8099 loss: 0.030741313472390175


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 15900: Accuracy = 0.8097 loss: 0.033452700823545456


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 16000: Accuracy = 0.8103 loss: 0.03197088465094566


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 16100: Accuracy = 0.8107 loss: 0.03164910525083542


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 16200: Accuracy = 0.8102 loss: 0.02977249212563038


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 16300: Accuracy = 0.8072 loss: 0.03137139603495598


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 16400: Accuracy = 0.8100 loss: 0.031701505184173584


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 16500: Accuracy = 0.8090 loss: 0.031251393258571625


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 16600: Accuracy = 0.8063 loss: 0.03161468729376793


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 16700: Accuracy = 0.8087 loss: 0.03208117559552193


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 16800: Accuracy = 0.8098 loss: 0.031077435240149498


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 16900: Accuracy = 0.8101 loss: 0.031492289155721664


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 17000: Accuracy = 0.8090 loss: 0.03153148666024208


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 17100: Accuracy = 0.8109 loss: 0.03107891045510769


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 17200: Accuracy = 0.8090 loss: 0.03355543687939644


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 17300: Accuracy = 0.8088 loss: 0.032132696360349655


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 17400: Accuracy = 0.8075 loss: 0.0309204813092947


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 17500: Accuracy = 0.8090 loss: 0.02983326092362404


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 17600: Accuracy = 0.8103 loss: 0.03119315765798092


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 17700: Accuracy = 0.8103 loss: 0.030796268954873085


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 17800: Accuracy = 0.8103 loss: 0.03280075266957283


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 17900: Accuracy = 0.8104 loss: 0.03211483731865883


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 18000: Accuracy = 0.8117 loss: 0.031473733484745026


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 18100: Accuracy = 0.8123 loss: 0.03186284005641937


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 18200: Accuracy = 0.8110 loss: 0.03207564726471901


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 18300: Accuracy = 0.8088 loss: 0.0299752838909626


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 18400: Accuracy = 0.8092 loss: 0.030491556972265244


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 18500: Accuracy = 0.8096 loss: 0.029775870963931084


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 18600: Accuracy = 0.8101 loss: 0.03171851858496666


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 18700: Accuracy = 0.8120 loss: 0.03200136125087738


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 18800: Accuracy = 0.8119 loss: 0.03050275705754757


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 18900: Accuracy = 0.8096 loss: 0.031439270824193954


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 19000: Accuracy = 0.8110 loss: 0.032044146209955215


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 19100: Accuracy = 0.8110 loss: 0.031578462570905685


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 19200: Accuracy = 0.8104 loss: 0.03104344755411148


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 19300: Accuracy = 0.8110 loss: 0.030087457969784737


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 19400: Accuracy = 0.8126 loss: 0.03190464898943901


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 19500: Accuracy = 0.8109 loss: 0.03172989934682846


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 19600: Accuracy = 0.8110 loss: 0.030932476744055748


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 19700: Accuracy = 0.8114 loss: 0.031174765899777412


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 19800: Accuracy = 0.8107 loss: 0.030476026237010956


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 19900: Accuracy = 0.8123 loss: 0.03057783842086792


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 20000: Accuracy = 0.8137 loss: 0.03269439935684204


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 20100: Accuracy = 0.8117 loss: 0.029794523492455482


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 20200: Accuracy = 0.8113 loss: 0.032098181545734406


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 20300: Accuracy = 0.8119 loss: 0.03100278042256832


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 20400: Accuracy = 0.8119 loss: 0.031450510025024414


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 20500: Accuracy = 0.8113 loss: 0.0318642258644104


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 20600: Accuracy = 0.8105 loss: 0.03018273040652275


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 20700: Accuracy = 0.8145 loss: 0.03213580325245857


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 20800: Accuracy = 0.8127 loss: 0.03278379142284393


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 20900: Accuracy = 0.8114 loss: 0.032383598387241364


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 21000: Accuracy = 0.8085 loss: 0.029581960290670395


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 21100: Accuracy = 0.8137 loss: 0.031160425394773483


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 21200: Accuracy = 0.8100 loss: 0.030579905956983566


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 21300: Accuracy = 0.8113 loss: 0.029388433322310448


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 21400: Accuracy = 0.8113 loss: 0.03029564581811428


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 21500: Accuracy = 0.8098 loss: 0.030364105477929115


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 21600: Accuracy = 0.8133 loss: 0.030011754482984543


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 21700: Accuracy = 0.8116 loss: 0.031207075342535973


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 21800: Accuracy = 0.8119 loss: 0.03228854387998581


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 21900: Accuracy = 0.8116 loss: 0.028892191126942635


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 22000: Accuracy = 0.8126 loss: 0.03067854233086109


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 22100: Accuracy = 0.8121 loss: 0.03125434368848801


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 22200: Accuracy = 0.8128 loss: 0.029294153675436974


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 22300: Accuracy = 0.8123 loss: 0.03035915642976761


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 22400: Accuracy = 0.8146 loss: 0.03112715296447277


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 22500: Accuracy = 0.8138 loss: 0.031036699190735817


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 22600: Accuracy = 0.8152 loss: 0.031359922140836716


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 22700: Accuracy = 0.8122 loss: 0.031054187566041946


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 22800: Accuracy = 0.8126 loss: 0.029600193724036217


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 22900: Accuracy = 0.8112 loss: 0.03113696537911892


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 23000: Accuracy = 0.8139 loss: 0.032107990235090256


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 23100: Accuracy = 0.8147 loss: 0.03141186013817787


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 23200: Accuracy = 0.8134 loss: 0.029551101848483086


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 23300: Accuracy = 0.8132 loss: 0.029645636677742004


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 23400: Accuracy = 0.8151 loss: 0.029947929084300995


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 23500: Accuracy = 0.8145 loss: 0.031922709196805954


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 23600: Accuracy = 0.8132 loss: 0.029143143445253372


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 23700: Accuracy = 0.8137 loss: 0.0290378350764513


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 23800: Accuracy = 0.8155 loss: 0.03159109503030777


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 24000: Accuracy = 0.8123 loss: 0.030219297856092453


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 24100: Accuracy = 0.8141 loss: 0.029974525794386864


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 24200: Accuracy = 0.8133 loss: 0.030671758577227592


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 24300: Accuracy = 0.8123 loss: 0.0327971950173378


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 24400: Accuracy = 0.8133 loss: 0.030826186761260033


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 24500: Accuracy = 0.8154 loss: 0.030679306015372276


IOPub message rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_msg_rate_limit`.

Current values:
ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
ServerApp.rate_limit_window=3.0 (secs)



In [12]:
# Define the checkpoint content
checkpoint = {
    'model_state_dict': model.state_dict(),            # Model's state dictionary
    'optimizer_state_dict': optimizer.state_dict(),        # Optimizer's state dictionary
    'step': step,                                      # Current training step
    'losses': losses,                                  # Loss history
    'config': run_cfg,
    'seed': seed
}

# Define a checkpoint file path
checkpoint_path = run_name + ".pth"

# Save the checkpoint
torch.save(checkpoint, checkpoint_path)

In [13]:
# Define the ID of the sample you want to inspect
sample_id = 11  # Replace with the ID of the sample you're interested in

# Fetch the specific sample from the dataset
sample_input_ids, sample_target_ids = dataset[sample_id]
sample_input_ids, sample_target_ids = sample_input_ids.unsqueeze(0).to(device), sample_target_ids.unsqueeze(0).to(device)

# Set the model to evaluation mode and get output
model.eval()
with torch.no_grad():
    output = model(sample_input_ids)

# Decode the tokens to strings for readability
decoded_input = tokenizer.decode(sample_input_ids[0].tolist())
decoded_target = tokenizer.decode(sample_target_ids[0].tolist())
decoded_prediction = tokenizer.decode(torch.argmax(output.logits, dim=-1)[0].tolist())

# Print input, target, prediction, and logits
print("inputs:",len(sample_input_ids),sample_input_ids)
print("targets:",len(sample_target_ids),sample_target_ids)
print(len(output),output)
print("Input:", len(sample_input_ids[0].tolist()),decoded_input)
print("Target:", len(sample_target_ids[0].tolist()), decoded_target)
print("Prediction:", decoded_prediction)
print("Logits:", len(output.logits[0]), output.logits[0])

# Optionally, you can convert logits to probabilities for interpretation
probabilities = torch.softmax(output.logits[0], dim=-1)
print("Probabilities:", probabilities)


inputs: 1 tensor([[1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 0, 1, 0, 1, 1, 1, 0, 1, 0, 1, 2, 1, 0, 1,
         1, 0, 1, 0, 1, 1, 0, 0, 0, 1, 1, 0, 0, 0, 0, 1, 1, 2, 0, 0, 1, 0, 0, 0,
         1, 1, 0, 0, 0, 1, 1, 0, 0, 1, 0, 1, 0, 0, 2, 0, 1, 1, 1, 0, 1, 1, 0, 0,
         0, 1, 1, 0, 1, 0, 1, 0, 1, 1, 1, 2, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 1, 0,
         0, 0, 0, 0, 1, 1, 0, 0, 2, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1, 0, 0, 1, 1, 0,
         1, 1, 0, 0, 0, 2, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 1, 0, 1, 0, 0, 1, 0, 0,
         0, 1, 2, 1, 1, 0, 1, 1, 1, 0, 1, 1, 0, 0, 0, 1, 0, 0, 1, 1, 0, 1, 1, 2,
         1, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 1, 0, 0, 1, 0, 0, 1, 1, 2, 1, 1, 0,
         1, 1, 0, 1, 1, 1, 0, 1, 1, 0, 1, 0, 1, 0, 0, 0, 0, 2, 3, 3, 3, 3, 3, 3,
         3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3]], device='cuda:0')
targets: 1 tensor([[3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3,
         3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3,
         3, 3, 

In [14]:
def evaluate_model_LH(model, val_dataloader, device, horizon):
    model.eval()  # Set the model to evaluation mode
    val_loop = tqdm(val_dataloader, leave=False, desc="Validation")
    correct = [0] * horizon
    total = [0] * horizon
    accuracy = [0] * horizon

    with torch.no_grad():
        for input_ids, target_ids in val_loop:
            input_ids, target_ids = input_ids.to(device), target_ids.to(device)

            outputs = model(input_ids)
            predictions = torch.argmax(outputs.logits, dim=-1)
            
            for t in range(horizon):

                # Create a mask where target_ids is not 3 (ignoring padding)
                #valid_mask = target_ids[:, t] != 3

                # Apply the mask to filter out -100 values
                start_idx  = orbit_len + t * state_len
                valid_predictions = predictions[:, start_idx : start_idx + state_len ]
                valid_targets = target_ids[:, start_idx : start_idx + state_len]
            

                correct[t] += (valid_predictions == valid_targets).sum().item()
                total[t] += len(valid_targets)

    for t in range(horizon):
        accuracy[t] = correct[t] / total[t] / state_len  if total[t] > 0 else 0
        
    return accuracy, correct, total

In [15]:
val_accuracy = evaluate_model(model, val_dataloader, device)
print(val_accuracy)
#print(f"\nEvaluation at Step {step}: Validation accuracy = {val_accuracy:.10f}", "loss:", loss.item())

Validation:   0%|          | 0/125 [00:00<?, ?it/s]

0.8042066666666666


In [16]:
test_accuracy = evaluate_model(model, test_dataloader, device)
print(test_accuracy)
#print(f"\nEvaluation at Step {step}: Test Accuracy = {test_accuracy:.10f}", "loss:", loss.item())

Validation:   0%|          | 0/250 [00:00<?, ?it/s]

0.8046966666666666


In [17]:
Evaluation at Step 49999: Accuracy = 0.999925 loss: 0.0001053252344718203

SyntaxError: invalid syntax (1662878917.py, line 1)

In [None]:
orbit_len = state_len * steps

In [None]:
0.9765709523809524