In [1]:
from transformers import BertConfig, BertForMaskedLM, BertTokenizer
from transformers import PreTrainedTokenizer
import torch
from torch.utils.data import Dataset, DataLoader
from torch.utils.data import random_split
from torch.nn.functional import pad
from datasets import load_dataset
from tqdm.notebook import tqdm
from itertools import cycle
import wandb

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
wandb.login(key='56853ea1087c0d089c005f3f7aede52740245615')

device

2024-09-11 14:54:55.435305: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  SSE4.1 SSE4.2 AVX AVX2 AVX512F AVX512_VNNI FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mmb_lims[0m ([33mlims-ai[0m). Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /home/mburtsev/.netrc


device(type='cuda', index=3)

# Create a custom tokenizer for CA

In [2]:
class SimpleTokenizer:
    def __init__(self):
        self.token_to_id = {'0': 0, '1': 1, '[SEP]': 2, '[PAD]': 3, '[T]': 4} # [T] - thought
        self.id_to_token = {v: k for k, v in self.token_to_id.items()}

    def encode(self, text):
        return [self.token_to_id.get(token, self.token_to_id['[PAD]']) for token in text]

    def decode(self, tokens):
        return ''.join([self.id_to_token.get(token, '[PAD]') for token in tokens])
        #return ''.join([self.id_to_token[token] for token in tokens])

    def __len__(self):
        return len(self.token_to_id)

tokenizer = SimpleTokenizer()

# Define CA dataset

In [3]:
class CellularAutomataDataset(Dataset):
    
    def __init__(self, hf_dataset, steps, horizon=1, pad_length=512):
        self.hf_dataset = hf_dataset
        self.pad_length = pad_length
        self.steps = steps
        self.horizon = horizon

    def __len__(self):
        return len(self.hf_dataset)

    def __getitem__(self, idx):
        row = self.hf_dataset[idx]

        # Assuming the first element 'rule' is not used for model input
        # Concatenate all states except the last one for input
        if (self.steps > (len(row) - 1)):
            self.steps = len(row) - 1
        
        input_ids = []
        for t in range(self.steps):  # Exclude the last state and the 'rule' row
            input_ids.extend(tokenizer.encode(row[f't={t}']))
            input_ids.append(2)  # Append [SEP] token after each state

        # Use the last state as target and calculate its length
        target_ids = []
        for t in range(self.horizon):
            last_state_key = f't={self.steps + t}'
            target_ids.extend(tokenizer.encode(row[last_state_key]))
            target_ids.append(2)
        
        target_ids.extend(tokenizer.encode(row['rule']))
        
        orbit_len = len(input_ids)
        target_len = len(target_ids)
        # Padding input and target sequences to the fixed length
        input_ids = pad(torch.tensor(input_ids), (0, self.pad_length - orbit_len), value=3)
        target_ids = pad(torch.tensor(target_ids), (orbit_len, self.pad_length - orbit_len - target_len), value=3)

        # Create a tensor of -100 and replace the last part with target_ids
       # target_labels = torch.full((self.pad_length,), -100)
        #target_labels[:state_length] = target_ids[:state_length]

        return input_ids, target_ids #target_labels


# Configuration for the BERT model

In [4]:
hidden_size=512
num_hidden_layers=4
num_attention_heads=8

config = BertConfig(
    vocab_size=len(tokenizer),
    hidden_size=hidden_size,
    num_hidden_layers=num_hidden_layers,
    num_attention_heads=num_attention_heads,
    #intermediate_size=3072
)
model_name = 'BERT'
model = BertForMaskedLM(config=config)

# Load dataset from Hugging Face repository

In [5]:
seed = 0
batch_size = 400
torch.manual_seed(seed)
steps = 10
state_len = 21 
look_ahead = 1
orbit_len = state_len * steps
seq_len = orbit_len + state_len * look_ahead + 32

dataset_name = "1dCA_r2s20T20"
hf_dataset = load_dataset("mbur/"+dataset_name)['train']
dataset = CellularAutomataDataset(hf_dataset, steps, look_ahead, pad_length=seq_len)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

hf_dataset = load_dataset("mbur/"+dataset_name)['validation']
val_dataset = CellularAutomataDataset(hf_dataset, steps, look_ahead, pad_length=seq_len)
val_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=True)

hf_dataset = load_dataset("mbur/"+dataset_name)['test']
test_dataset = CellularAutomataDataset(hf_dataset, steps, look_ahead, pad_length=seq_len)
test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True)

validation_size = 1000  # Number of samples for validation
small_val_dataset, _ = random_split(val_dataset, [validation_size, len(val_dataset) - validation_size])

small_val_dataloader = DataLoader(small_val_dataset, batch_size=batch_size, shuffle=True)

Using custom data configuration mbur--1dCA_r2s20T20-062a8ac2d777e207
Found cached dataset json (/home/mburtsev/.cache/huggingface/datasets/mbur___json/mbur--1dCA_r2s20T20-062a8ac2d777e207/0.0.0/0f7e3662623656454fcd2b650f34e886a7db4b9104504885bd462096cc7a9f51)


  0%|          | 0/3 [00:00<?, ?it/s]

Using custom data configuration mbur--1dCA_r2s20T20-062a8ac2d777e207
Found cached dataset json (/home/mburtsev/.cache/huggingface/datasets/mbur___json/mbur--1dCA_r2s20T20-062a8ac2d777e207/0.0.0/0f7e3662623656454fcd2b650f34e886a7db4b9104504885bd462096cc7a9f51)


  0%|          | 0/3 [00:00<?, ?it/s]

Using custom data configuration mbur--1dCA_r2s20T20-062a8ac2d777e207
Found cached dataset json (/home/mburtsev/.cache/huggingface/datasets/mbur___json/mbur--1dCA_r2s20T20-062a8ac2d777e207/0.0.0/0f7e3662623656454fcd2b650f34e886a7db4b9104504885bd462096cc7a9f51)


  0%|          | 0/3 [00:00<?, ?it/s]

In [6]:
# Define the ID of the sample you want to inspect
sample_id = 1  # Replace with the ID of the sample you're interested in

# Fetch the specific sample from the dataset
sample_input_ids, sample_target_ids = dataset[sample_id]
sample_input_ids, sample_target_ids = sample_input_ids.unsqueeze(0).to(device), sample_target_ids.unsqueeze(0).to(device)


# Decode the tokens to strings for readability
decoded_input = tokenizer.decode(sample_input_ids[0].tolist())
decoded_target = tokenizer.decode(sample_target_ids[0].tolist())

# Print input, target, prediction, and logits
print("inputs:",len(sample_input_ids),sample_input_ids.tolist())
print("targets:",len(sample_target_ids),sample_target_ids.tolist())
print("Input:", len(sample_input_ids[0].tolist()),decoded_input)
print("Target:", len(sample_target_ids[0].tolist()), decoded_target)

inputs: 1 [[1, 0, 1, 0, 0, 0, 1, 0, 0, 1, 1, 1, 0, 1, 0, 0, 0, 0, 1, 0, 2, 0, 1, 0, 0, 0, 1, 1, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 1, 2, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 1, 0, 1, 0, 2, 0, 0, 0, 1, 1, 0, 0, 0, 0, 1, 0, 1, 1, 0, 0, 1, 0, 1, 0, 1, 2, 0, 0, 1, 0, 1, 1, 0, 0, 1, 0, 0, 0, 1, 1, 1, 0, 1, 0, 1, 0, 2, 0, 1, 0, 0, 0, 1, 1, 1, 1, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 2, 1, 1, 0, 0, 1, 0, 1, 1, 1, 1, 1, 1, 0, 1, 1, 0, 1, 1, 0, 0, 2, 0, 1, 1, 1, 0, 0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 1, 0, 1, 1, 1, 2, 1, 0, 0, 1, 1, 0, 1, 0, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 1, 2, 1, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 0, 0, 1, 0, 2, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3]]
targets: 1 [[3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3,

# Evaluation

In [7]:
def evaluate_model(model, val_dataloader, device):
    model.eval()  # Set the model to evaluation mode
    total, correct = 0, 0
    correct_r, total_r = 0, 0
    val_loop = tqdm(val_dataloader, leave=False, desc="Validation")

    with torch.no_grad():
        for input_ids, target_ids in val_loop:
            input_ids, target_ids = input_ids.to(device), target_ids.to(device)

            outputs = model(input_ids)
            predictions = torch.argmax(outputs.logits, dim=-1)

            # Create a mask where target_ids is not -100 (ignoring padding)
            valid_mask = target_ids != 3

            # Apply the mask to filter out -100 values
            valid_predictions = predictions[valid_mask]
            valid_targets = target_ids[valid_mask]
            # for state
            correct += (valid_predictions[:-32] == valid_targets[:-32]).sum().item() 
            total += valid_targets[:-32].numel()
            # for rule
            correct_r += (valid_predictions[-32:] == valid_targets[-32:]).sum().item() 
            total_r += valid_targets[-32:].numel()

    accuracy = correct / total if total > 0 else 0
    accuracy_rule = correct_r / total_r if total > 0 else 0
    
    return accuracy, accuracy_rule

# Training loop

In [8]:
model.to(device)

lr=1e-4 # learning rate

num_steps = 25_000  # Define the total number of training steps
evaluate_every = 100  # Frequency of evaluation
wndb_log = 10

run_cfg = {
    'dataset': dataset_name,
    'model': model_name,
    'input_size': 512,
    'hidden_size': hidden_size,
    'num_hidden_layers': num_hidden_layers,
    'num_attention_heads': num_attention_heads,
    'batch_size': batch_size,
    'learning rate': lr,
    'seed': seed
}
run_name = dataset_name+'i'+str(steps)+'f'+str(look_ahead)+'H2shrt_'+model_name+'_'+str(num_hidden_layers)+'L'+str(num_attention_heads)+'H'+str(hidden_size)+'HD'+'_bs'+str(batch_size)+'lr'+str(lr)
run_name = run_name + str(num_steps) + 'Kgc_v' + str(seed)
run = wandb.init(
    project="1dCA_ai4math24",
    name = run_name,
    config=run_cfg
)

optimizer = torch.optim.AdamW(model.parameters(), lr)

losses = []

model.train()

loop = tqdm(range(num_steps), desc='Training')
data_iter = cycle(dataloader)  # Create an infinite loop of dataloader

for step in loop:
    
    input_ids, target_ids = next(data_iter)
    input_ids, target_ids = input_ids.to(device), target_ids.to(device)

    outputs = model(input_ids, labels=target_ids)
    loss = outputs.loss

    optimizer.zero_grad()
    loss.backward()
    total_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1)
    optimizer.step()
    current_loss = loss.item()
    loop.set_description(f"Step {step + 1}/{num_steps} - Loss: {current_loss:.4f}")
    loop.set_postfix(loss=current_loss)

    if step % evaluate_every == 0 and step != 0:
        val_accuracy, rule_accuracy = evaluate_model(model, small_val_dataloader, device)
        print(f"\nEvaluation at Step {step}: Accuracy = {val_accuracy:.4f}, Rule acc:{rule_accuracy:.4f} ", "loss:", loss.item())
        wandb.log({'step': step, 'val acc': val_accuracy})

    if step % wndb_log == 0:
        losses.append(loss.detach().item())      
        wandb.log({'step': step, 'loss': current_loss})

loop.close()
wandb.finish()

Training:   0%|          | 0/25000 [00:00<?, ?it/s]

Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 100: Accuracy = 0.5867, Rule acc:0.5729  loss: 0.13389559090137482


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 200: Accuracy = 0.5889, Rule acc:0.5417  loss: 0.1328190416097641


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 300: Accuracy = 0.5889, Rule acc:0.6562  loss: 0.1326102614402771


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 400: Accuracy = 0.5936, Rule acc:0.5000  loss: 0.13198603689670563


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 500: Accuracy = 0.5915, Rule acc:0.5208  loss: 0.13059751689434052


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 600: Accuracy = 0.6019, Rule acc:0.6458  loss: 0.1296309232711792


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 700: Accuracy = 0.6053, Rule acc:0.5625  loss: 0.13048674166202545


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 800: Accuracy = 0.6025, Rule acc:0.5625  loss: 0.12825974822044373


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 900: Accuracy = 0.6070, Rule acc:0.5833  loss: 0.12790431082248688


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 1000: Accuracy = 0.6047, Rule acc:0.5625  loss: 0.13079506158828735


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 1100: Accuracy = 0.6030, Rule acc:0.5625  loss: 0.12881684303283691


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 1200: Accuracy = 0.6080, Rule acc:0.5729  loss: 0.129686638712883


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 1300: Accuracy = 0.6072, Rule acc:0.6354  loss: 0.1286042332649231


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 1400: Accuracy = 0.6074, Rule acc:0.5417  loss: 0.1287361979484558


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 1500: Accuracy = 0.6081, Rule acc:0.5521  loss: 0.12767481803894043


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 1600: Accuracy = 0.6079, Rule acc:0.5625  loss: 0.127438485622406


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 1700: Accuracy = 0.6096, Rule acc:0.6146  loss: 0.12907521426677704


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 1800: Accuracy = 0.6086, Rule acc:0.5208  loss: 0.12847800552845


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 1900: Accuracy = 0.6080, Rule acc:0.5208  loss: 0.12863363325595856


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 2000: Accuracy = 0.6106, Rule acc:0.5833  loss: 0.12633712589740753


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 2100: Accuracy = 0.6101, Rule acc:0.6458  loss: 0.12628278136253357


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 2200: Accuracy = 0.6129, Rule acc:0.5521  loss: 0.12717300653457642


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 2300: Accuracy = 0.6141, Rule acc:0.6458  loss: 0.12658579647541046


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 2400: Accuracy = 0.6134, Rule acc:0.6042  loss: 0.12519176304340363


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 2500: Accuracy = 0.6176, Rule acc:0.6667  loss: 0.1243773102760315


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 2600: Accuracy = 0.6229, Rule acc:0.5833  loss: 0.1261252760887146


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 2700: Accuracy = 0.6247, Rule acc:0.5625  loss: 0.12652860581874847


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 2800: Accuracy = 0.6305, Rule acc:0.5625  loss: 0.12510980665683746


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 2900: Accuracy = 0.6317, Rule acc:0.6771  loss: 0.12131588906049728


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 3000: Accuracy = 0.6401, Rule acc:0.6042  loss: 0.12106433510780334


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 3100: Accuracy = 0.6441, Rule acc:0.6250  loss: 0.12082153558731079


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 3200: Accuracy = 0.6526, Rule acc:0.6354  loss: 0.11750278621912003


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 3300: Accuracy = 0.6606, Rule acc:0.6146  loss: 0.1171090230345726


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 3400: Accuracy = 0.6750, Rule acc:0.6667  loss: 0.11514388769865036


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 3500: Accuracy = 0.6917, Rule acc:0.6562  loss: 0.11145792156457901


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 3600: Accuracy = 0.7302, Rule acc:0.7083  loss: 0.09789873659610748


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 3700: Accuracy = 0.7894, Rule acc:0.7500  loss: 0.07736676186323166


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 3800: Accuracy = 0.8206, Rule acc:0.6979  loss: 0.057956404983997345


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 3900: Accuracy = 0.8276, Rule acc:0.8021  loss: 0.056189291179180145


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 4000: Accuracy = 0.8327, Rule acc:0.8021  loss: 0.052949193865060806


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 4100: Accuracy = 0.8332, Rule acc:0.8333  loss: 0.05338354781270027


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 4200: Accuracy = 0.8365, Rule acc:0.7708  loss: 0.054012712091207504


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 4300: Accuracy = 0.8372, Rule acc:0.7604  loss: 0.052862148731946945


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 4400: Accuracy = 0.8372, Rule acc:0.7604  loss: 0.05203283205628395


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 4500: Accuracy = 0.8378, Rule acc:0.8125  loss: 0.05090109631419182


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 4600: Accuracy = 0.8398, Rule acc:0.7917  loss: 0.05183788388967514


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 4700: Accuracy = 0.8388, Rule acc:0.7708  loss: 0.051545459777116776


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 4800: Accuracy = 0.8404, Rule acc:0.7292  loss: 0.05060292407870293


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 4900: Accuracy = 0.8440, Rule acc:0.7812  loss: 0.051080167293548584


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 5000: Accuracy = 0.8453, Rule acc:0.6979  loss: 0.050116486847400665


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 5100: Accuracy = 0.8456, Rule acc:0.7292  loss: 0.05027814954519272


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 5200: Accuracy = 0.8441, Rule acc:0.8021  loss: 0.04809153452515602


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 5300: Accuracy = 0.8507, Rule acc:0.8125  loss: 0.048229020088911057


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 5400: Accuracy = 0.8511, Rule acc:0.6875  loss: 0.04805724695324898


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 5500: Accuracy = 0.8543, Rule acc:0.7812  loss: 0.048153214156627655


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 5600: Accuracy = 0.8563, Rule acc:0.8958  loss: 0.04628130793571472


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 5700: Accuracy = 0.8586, Rule acc:0.8021  loss: 0.0451199896633625


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 5800: Accuracy = 0.8602, Rule acc:0.8854  loss: 0.04424270987510681


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 5900: Accuracy = 0.8621, Rule acc:0.8333  loss: 0.04384562373161316


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 6000: Accuracy = 0.8652, Rule acc:0.8854  loss: 0.04521641880273819


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 6100: Accuracy = 0.8646, Rule acc:0.7292  loss: 0.042397432029247284


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 6200: Accuracy = 0.8655, Rule acc:0.6458  loss: 0.04270796850323677


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 6300: Accuracy = 0.8658, Rule acc:0.8021  loss: 0.043950386345386505


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 6400: Accuracy = 0.8662, Rule acc:0.9167  loss: 0.0412549152970314


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 6500: Accuracy = 0.8660, Rule acc:0.8021  loss: 0.04450857266783714


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 6600: Accuracy = 0.8661, Rule acc:0.7708  loss: 0.043474793434143066


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 6700: Accuracy = 0.8654, Rule acc:0.8125  loss: 0.04216936603188515


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 6800: Accuracy = 0.8648, Rule acc:0.8750  loss: 0.04228229820728302


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 6900: Accuracy = 0.8666, Rule acc:0.7396  loss: 0.04344789683818817


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 7000: Accuracy = 0.8676, Rule acc:0.8333  loss: 0.04200698062777519


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 7100: Accuracy = 0.8670, Rule acc:0.7812  loss: 0.04269930347800255


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 7200: Accuracy = 0.8664, Rule acc:0.8438  loss: 0.042721983045339584


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 7300: Accuracy = 0.8687, Rule acc:0.7812  loss: 0.04294008016586304


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 7400: Accuracy = 0.8679, Rule acc:0.7500  loss: 0.04147863760590553


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 7500: Accuracy = 0.8702, Rule acc:0.7396  loss: 0.04332070052623749


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 7600: Accuracy = 0.8694, Rule acc:0.8333  loss: 0.04233906790614128


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 7700: Accuracy = 0.8698, Rule acc:0.7812  loss: 0.042080048471689224


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 7800: Accuracy = 0.8716, Rule acc:0.7500  loss: 0.04170285165309906


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 7900: Accuracy = 0.8707, Rule acc:0.8854  loss: 0.04191767796874046


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 8000: Accuracy = 0.8719, Rule acc:0.8229  loss: 0.0440441370010376


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 8100: Accuracy = 0.8717, Rule acc:0.7917  loss: 0.04329409822821617


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 8200: Accuracy = 0.8740, Rule acc:0.7500  loss: 0.041974157094955444


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 8300: Accuracy = 0.8756, Rule acc:0.8333  loss: 0.03944684565067291


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 8400: Accuracy = 0.8798, Rule acc:0.7500  loss: 0.03950976952910423


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 8500: Accuracy = 0.8832, Rule acc:0.8021  loss: 0.039107903838157654


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 8600: Accuracy = 0.8838, Rule acc:0.8750  loss: 0.03617580235004425


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 8700: Accuracy = 0.8878, Rule acc:0.7708  loss: 0.03681924566626549


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 8800: Accuracy = 0.8854, Rule acc:0.8646  loss: 0.038141846656799316


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 8900: Accuracy = 0.8883, Rule acc:0.9062  loss: 0.03774043172597885


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 9000: Accuracy = 0.8915, Rule acc:0.8229  loss: 0.03604055568575859


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 9100: Accuracy = 0.8916, Rule acc:0.8958  loss: 0.03629794716835022


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 9200: Accuracy = 0.8951, Rule acc:0.7812  loss: 0.03350585326552391


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 9300: Accuracy = 0.8966, Rule acc:0.8750  loss: 0.03434410318732262


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 9400: Accuracy = 0.8975, Rule acc:0.9167  loss: 0.03346732258796692


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 9500: Accuracy = 0.8980, Rule acc:0.8438  loss: 0.031978026032447815


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 9600: Accuracy = 0.8997, Rule acc:0.9062  loss: 0.03247059881687164


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 9700: Accuracy = 0.9013, Rule acc:0.8646  loss: 0.03155899792909622


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 9800: Accuracy = 0.9034, Rule acc:0.8125  loss: 0.03131793066859245


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 9900: Accuracy = 0.9069, Rule acc:0.8646  loss: 0.030496368184685707


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 10000: Accuracy = 0.9163, Rule acc:0.9167  loss: 0.02701703831553459


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 10100: Accuracy = 0.9235, Rule acc:0.9896  loss: 0.025198347866535187


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 10200: Accuracy = 0.9255, Rule acc:0.9062  loss: 0.024082118645310402


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 10300: Accuracy = 0.9252, Rule acc:0.8542  loss: 0.024105537682771683


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 10400: Accuracy = 0.9275, Rule acc:0.8333  loss: 0.02372913435101509


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 10500: Accuracy = 0.9295, Rule acc:0.9688  loss: 0.02236264757812023


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 10600: Accuracy = 0.9336, Rule acc:0.8229  loss: 0.02184644155204296


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 10700: Accuracy = 0.9349, Rule acc:0.8542  loss: 0.020440075546503067


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 10800: Accuracy = 0.9367, Rule acc:0.9167  loss: 0.020423507317900658


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 10900: Accuracy = 0.9389, Rule acc:0.8854  loss: 0.019299525767564774


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 11000: Accuracy = 0.9405, Rule acc:0.9271  loss: 0.01875041238963604


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 11100: Accuracy = 0.9399, Rule acc:0.9479  loss: 0.01929965242743492


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 11400: Accuracy = 0.9418, Rule acc:0.8542  loss: 0.017887568101286888


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 11500: Accuracy = 0.9452, Rule acc:0.9688  loss: 0.01751234009861946


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 11600: Accuracy = 0.9463, Rule acc:0.9479  loss: 0.016957618296146393


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 11700: Accuracy = 0.9466, Rule acc:0.9583  loss: 0.016298960894346237


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 11800: Accuracy = 0.9476, Rule acc:0.9583  loss: 0.016889967024326324


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 11900: Accuracy = 0.9471, Rule acc:0.9583  loss: 0.016600310802459717


IOPub message rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_msg_rate_limit`.

Current values:
ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
ServerApp.rate_limit_window=3.0 (secs)



Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 12100: Accuracy = 0.9465, Rule acc:0.9688  loss: 0.016719702631235123


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 12200: Accuracy = 0.9468, Rule acc:0.9583  loss: 0.01551060937345028


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 12300: Accuracy = 0.9461, Rule acc:0.9167  loss: 0.016726862639188766


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 12400: Accuracy = 0.9472, Rule acc:0.9062  loss: 0.016494782641530037


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 12500: Accuracy = 0.9477, Rule acc:0.8333  loss: 0.016079295426607132


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 12600: Accuracy = 0.9481, Rule acc:0.9167  loss: 0.015652380883693695


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 12700: Accuracy = 0.9481, Rule acc:0.9271  loss: 0.01633676327764988


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 12800: Accuracy = 0.9494, Rule acc:0.9688  loss: 0.016386140137910843


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 12900: Accuracy = 0.9487, Rule acc:0.9062  loss: 0.015914468094706535


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 13000: Accuracy = 0.9500, Rule acc:0.9375  loss: 0.015452928841114044


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 13100: Accuracy = 0.9509, Rule acc:0.9271  loss: 0.014958282932639122


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 13200: Accuracy = 0.9506, Rule acc:0.9062  loss: 0.015027674846351147


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 13300: Accuracy = 0.9511, Rule acc:0.9271  loss: 0.014545438811182976


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 13400: Accuracy = 0.9510, Rule acc:0.9375  loss: 0.01502885203808546


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 13500: Accuracy = 0.9511, Rule acc:0.9167  loss: 0.014782718382775784


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 13600: Accuracy = 0.9516, Rule acc:0.8854  loss: 0.014872550033032894


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 13700: Accuracy = 0.9517, Rule acc:0.9479  loss: 0.015856843441724777


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 13800: Accuracy = 0.9511, Rule acc:0.8958  loss: 0.014561601914465427


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 13900: Accuracy = 0.9523, Rule acc:0.9375  loss: 0.014880277216434479


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 14000: Accuracy = 0.9512, Rule acc:0.9583  loss: 0.013974504545331001


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 14100: Accuracy = 0.9516, Rule acc:0.9479  loss: 0.014935776591300964


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 14200: Accuracy = 0.9515, Rule acc:0.9271  loss: 0.015014025382697582


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 14300: Accuracy = 0.9525, Rule acc:0.8646  loss: 0.014606285840272903


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 14400: Accuracy = 0.9516, Rule acc:0.9271  loss: 0.015862366184592247


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 14500: Accuracy = 0.9517, Rule acc:0.9583  loss: 0.015162000432610512


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 14600: Accuracy = 0.9525, Rule acc:0.9583  loss: 0.014770464040338993


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 14700: Accuracy = 0.9523, Rule acc:0.9479  loss: 0.014676785096526146


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 14900: Accuracy = 0.9520, Rule acc:0.9062  loss: 0.014553706161677837


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 15000: Accuracy = 0.9517, Rule acc:0.8750  loss: 0.014744596555829048


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 15100: Accuracy = 0.9517, Rule acc:0.9792  loss: 0.014577126130461693


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 15200: Accuracy = 0.9526, Rule acc:0.8333  loss: 0.014498372562229633


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 15300: Accuracy = 0.9520, Rule acc:0.9271  loss: 0.013853483833372593


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 15400: Accuracy = 0.9526, Rule acc:0.8750  loss: 0.014730405993759632


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 15500: Accuracy = 0.9522, Rule acc:0.9062  loss: 0.015676578506827354


IOPub message rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_msg_rate_limit`.

Current values:
ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
ServerApp.rate_limit_window=3.0 (secs)



Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 18500: Accuracy = 0.9560, Rule acc:0.9688  loss: 0.013674762099981308


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 18600: Accuracy = 0.9566, Rule acc:0.9375  loss: 0.013523840345442295


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 18700: Accuracy = 0.9557, Rule acc:0.9271  loss: 0.012723587453365326


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 18800: Accuracy = 0.9558, Rule acc:0.9479  loss: 0.012787079438567162


IOPub message rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_msg_rate_limit`.

Current values:
ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
ServerApp.rate_limit_window=3.0 (secs)



Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 19200: Accuracy = 0.9562, Rule acc:0.8542  loss: 0.01343273650854826


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 19300: Accuracy = 0.9557, Rule acc:0.9792  loss: 0.013239637948572636


IOPub message rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_msg_rate_limit`.

Current values:
ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
ServerApp.rate_limit_window=3.0 (secs)



Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 22400: Accuracy = 0.9603, Rule acc:1.0000  loss: 0.011704829521477222


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 22500: Accuracy = 0.9608, Rule acc:0.9688  loss: 0.011456239968538284


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 22600: Accuracy = 0.9614, Rule acc:0.9688  loss: 0.011495554819703102


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 22700: Accuracy = 0.9606, Rule acc:0.9375  loss: 0.011387455277144909


IOPub message rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_msg_rate_limit`.

Current values:
ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
ServerApp.rate_limit_window=3.0 (secs)



Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 23100: Accuracy = 0.9609, Rule acc:0.9167  loss: 0.010794112458825111


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 23200: Accuracy = 0.9606, Rule acc:0.9583  loss: 0.012196863070130348


IOPub message rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_msg_rate_limit`.

Current values:
ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
ServerApp.rate_limit_window=3.0 (secs)



In [22]:
# Define the checkpoint content
checkpoint = {
    'model_state_dict': model.state_dict(),            # Model's state dictionary
    'optimizer_state_dict': optimizer.state_dict(),        # Optimizer's state dictionary
    'step': step,                                      # Current training step
    'losses': losses,                                  # Loss history
    'config': run_cfg,
    'seed': seed
}

run_name = dataset_name+'i'+str(steps)+'f'+str(look_ahead)+'H2shrt_OR_'+model_name+'_'+str(num_hidden_layers)+'L'+str(num_attention_heads)+'H'+str(hidden_size)+'HD'+'_bs'+str(batch_size)+'lr'+str(lr)
run_name = run_name + str(num_steps) + 'Kgc_v' + str(seed)

# Define a checkpoint file path
checkpoint_path = run_name + ".pth"

# Save the checkpoint
torch.save(checkpoint, checkpoint_path)

In [10]:
# Define the ID of the sample you want to inspect
sample_id = 11  # Replace with the ID of the sample you're interested in

# Fetch the specific sample from the dataset
sample_input_ids, sample_target_ids = dataset[sample_id]
sample_input_ids, sample_target_ids = sample_input_ids.unsqueeze(0).to(device), sample_target_ids.unsqueeze(0).to(device)

# Set the model to evaluation mode and get output
model.eval()
with torch.no_grad():
    output = model(sample_input_ids)

# Decode the tokens to strings for readability
decoded_input = tokenizer.decode(sample_input_ids[0].tolist())
decoded_target = tokenizer.decode(sample_target_ids[0].tolist())
decoded_prediction = tokenizer.decode(torch.argmax(output.logits, dim=-1)[0].tolist())

# Print input, target, prediction, and logits
print("inputs:",len(sample_input_ids),sample_input_ids)
print("targets:",len(sample_target_ids),sample_target_ids)
print(len(output),output)
print("Input:", len(sample_input_ids[0].tolist()),decoded_input)
print("Target:", len(sample_target_ids[0].tolist()), decoded_target)
print("Prediction:", decoded_prediction)
print("Logits:", len(output.logits[0]), output.logits[0])

# Optionally, you can convert logits to probabilities for interpretation
probabilities = torch.softmax(output.logits[0], dim=-1)
print("Probabilities:", probabilities)


inputs: 1 tensor([[1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 0, 1, 0, 1, 1, 1, 0, 1, 0, 1, 2, 1, 0, 1,
         1, 0, 1, 0, 1, 1, 0, 0, 0, 1, 1, 0, 0, 0, 0, 1, 1, 2, 0, 0, 1, 0, 0, 0,
         1, 1, 0, 0, 0, 1, 1, 0, 0, 1, 0, 1, 0, 0, 2, 0, 1, 1, 1, 0, 1, 1, 0, 0,
         0, 1, 1, 0, 1, 0, 1, 0, 1, 1, 1, 2, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 1, 0,
         0, 0, 0, 0, 1, 1, 0, 0, 2, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1, 0, 0, 1, 1, 0,
         1, 1, 0, 0, 0, 2, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 1, 0, 1, 0, 0, 1, 0, 0,
         0, 1, 2, 1, 1, 0, 1, 1, 1, 0, 1, 1, 0, 0, 0, 1, 0, 0, 1, 1, 0, 1, 1, 2,
         1, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 1, 0, 0, 1, 0, 0, 1, 1, 2, 1, 1, 0,
         1, 1, 0, 1, 1, 1, 0, 1, 1, 0, 1, 0, 1, 0, 0, 0, 0, 2, 3, 3, 3, 3, 3, 3,
         3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3,
         3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3]],
       device='cuda:3')
targets: 1 tensor([[3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3,

In [19]:
def evaluate_model_LH(model, val_dataloader, device, horizon):
    model.eval()  # Set the model to evaluation mode
    val_loop = tqdm(val_dataloader, leave=False, desc="Validation")
    correct = [0] * horizon
    total = [0] * horizon
    accuracy = [0] * horizon
    correct_r = 0
    total_r = 0

    with torch.no_grad():
        for input_ids, target_ids in val_loop:
            input_ids, target_ids = input_ids.to(device), target_ids.to(device)

            outputs = model(input_ids)
            predictions = torch.argmax(outputs.logits, dim=-1)
            
            for t in range(horizon):

                # Create a mask where target_ids is not 3 (ignoring padding)
                #valid_mask = target_ids[:, t] != 3

                # Apply the mask to filter out -100 values
                start_idx  = orbit_len + t * state_len
                valid_predictions = predictions[:, start_idx : start_idx + state_len]
                valid_targets = target_ids[:, start_idx : start_idx + state_len]
            

                correct[t] += (valid_predictions == valid_targets).sum().item()
                total[t] += len(valid_targets)
                
            # for rule
            valid_predictions = predictions[:, start_idx + state_len : start_idx + state_len + 32]
            valid_targets = target_ids[:, start_idx + state_len : start_idx + state_len + 32]
            
            correct_r += (valid_predictions == valid_targets).sum().item()
            total_r += len(valid_targets)

    for t in range(horizon):
        accuracy[t] = correct[t] / total[t] / state_len  if total[t] > 0 else 0
        
    accuracy_rule = correct_r / total_r / 32 if total_r > 0 else 0
        
    return accuracy, correct, total, accuracy_rule

In [20]:
val_accuracy = evaluate_model_LH(model, val_dataloader, device,look_ahead)
print(val_accuracy)
#print(f"\nEvaluation at Step {step}: Validation accuracy = {val_accuracy:.10f}", "loss:", loss.item())

Validation:   0%|          | 0/125 [00:00<?, ?it/s]

([0.9927457142857143], [1042383], [50000], 0.944338125)


In [21]:
test_accuracy = evaluate_model_LH(model, test_dataloader, device,look_ahead)
print(test_accuracy)
#print(f"\nEvaluation at Step {step}: Test Accuracy = {test_accuracy:.10f}", "loss:", loss.item())

Validation:   0%|          | 0/250 [00:00<?, ?it/s]

([0.992742380952381], [2084759], [100000], 0.94414875)


In [None]:
Evaluation at Step 49999: Accuracy = 0.999925 loss: 0.0001053252344718203

In [None]:
orbit_len = state_len * steps

In [None]:
0.9765709523809524