In [1]:
from transformers import BertConfig, BertForMaskedLM, BertTokenizer
from transformers import PreTrainedTokenizer
import torch
from torch.utils.data import Dataset, DataLoader
from torch.utils.data import random_split
from torch.nn.functional import pad
from datasets import load_dataset
from tqdm.notebook import tqdm
from itertools import cycle
import wandb

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
wandb.login(key='56853ea1087c0d089c005f3f7aede52740245615')

device

2024-09-11 15:15:04.783522: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  SSE4.1 SSE4.2 AVX AVX2 AVX512F AVX512_VNNI FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mmb_lims[0m ([33mlims-ai[0m). Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /home/mburtsev/.netrc


device(type='cuda', index=6)

# Create a custom tokenizer for CA

In [2]:
class SimpleTokenizer:
    def __init__(self):
        self.token_to_id = {'0': 0, '1': 1, '[SEP]': 2, '[PAD]': 3, '[T]': 4} # [T] - thought
        self.id_to_token = {v: k for k, v in self.token_to_id.items()}

    def encode(self, text):
        return [self.token_to_id.get(token, self.token_to_id['[PAD]']) for token in text]

    def decode(self, tokens):
        return ''.join([self.id_to_token.get(token, '[PAD]') for token in tokens])
        #return ''.join([self.id_to_token[token] for token in tokens])

    def __len__(self):
        return len(self.token_to_id)

tokenizer = SimpleTokenizer()

# Define CA dataset

In [3]:
class CellularAutomataDataset(Dataset):
    
    def __init__(self, hf_dataset, steps, horizon=1, pad_length=512):
        self.hf_dataset = hf_dataset
        self.pad_length = pad_length
        self.steps = steps
        self.horizon = horizon

    def __len__(self):
        return len(self.hf_dataset)

    def __getitem__(self, idx):
        row = self.hf_dataset[idx]

        # Assuming the first element 'rule' is not used for model input
        # Concatenate all states except the last one for input
        if (self.steps > (len(row) - 1)):
            self.steps = len(row) - 1
        
        input_ids = []
        input_ids.extend(tokenizer.encode(row['rule']))
        input_ids.append(2)  # Append [SEP] token
        for t in range(self.steps):  # Exclude the last state and the 'rule' row
            input_ids.extend(tokenizer.encode(row[f't={t}']))
            input_ids.append(2)  # Append [SEP] token after each state

        # Use the last state as target and calculate its length
        target_ids = []
        for t in range(self.horizon):
            last_state_key = f't={self.steps + t}'
            target_ids.extend(tokenizer.encode(row[last_state_key]))
            target_ids.append(2)
        
        orbit_len = len(input_ids)
        target_len = len(target_ids)
        # Padding input and target sequences to the fixed length
        input_ids = pad(torch.tensor(input_ids), (0, self.pad_length - orbit_len), value=3)
        target_ids = pad(torch.tensor(target_ids), (orbit_len, self.pad_length - orbit_len - target_len), value=3)

        # Create a tensor of -100 and replace the last part with target_ids
       # target_labels = torch.full((self.pad_length,), -100)
        #target_labels[:state_length] = target_ids[:state_length]

        return input_ids, target_ids #target_labels


# Configuration for the BERT model

In [4]:
hidden_size=512
num_hidden_layers=4
num_attention_heads=8

config = BertConfig(
    vocab_size=len(tokenizer),
    hidden_size=hidden_size,
    num_hidden_layers=num_hidden_layers,
    num_attention_heads=num_attention_heads,
    #intermediate_size=3072
)
model_name = 'BERT'
model = BertForMaskedLM(config=config)

# Load dataset from Hugging Face repository

In [5]:
seed = 2
batch_size = 400
torch.manual_seed(seed)
steps = 10
state_len = 21 
rule_len = 33 # 32 + [SEP]
look_ahead = 1
orbit_len = state_len * steps
seq_len = orbit_len + state_len * look_ahead + rule_len

dataset_name = "1dCA_r2s20T20"
hf_dataset = load_dataset("mbur/"+dataset_name)['train']
dataset = CellularAutomataDataset(hf_dataset, steps, look_ahead, pad_length=seq_len)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

hf_dataset = load_dataset("mbur/"+dataset_name)['validation']
val_dataset = CellularAutomataDataset(hf_dataset, steps, look_ahead, pad_length=seq_len)
val_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=True)

hf_dataset = load_dataset("mbur/"+dataset_name)['test']
test_dataset = CellularAutomataDataset(hf_dataset, steps, look_ahead, pad_length=seq_len)
test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True)

validation_size = 1000  # Number of samples for validation
small_val_dataset, _ = random_split(val_dataset, [validation_size, len(val_dataset) - validation_size])

small_val_dataloader = DataLoader(small_val_dataset, batch_size=batch_size, shuffle=True)

Using custom data configuration mbur--1dCA_r2s20T20-062a8ac2d777e207
Found cached dataset json (/home/mburtsev/.cache/huggingface/datasets/mbur___json/mbur--1dCA_r2s20T20-062a8ac2d777e207/0.0.0/0f7e3662623656454fcd2b650f34e886a7db4b9104504885bd462096cc7a9f51)


  0%|          | 0/3 [00:00<?, ?it/s]

Using custom data configuration mbur--1dCA_r2s20T20-062a8ac2d777e207
Found cached dataset json (/home/mburtsev/.cache/huggingface/datasets/mbur___json/mbur--1dCA_r2s20T20-062a8ac2d777e207/0.0.0/0f7e3662623656454fcd2b650f34e886a7db4b9104504885bd462096cc7a9f51)


  0%|          | 0/3 [00:00<?, ?it/s]

Using custom data configuration mbur--1dCA_r2s20T20-062a8ac2d777e207
Found cached dataset json (/home/mburtsev/.cache/huggingface/datasets/mbur___json/mbur--1dCA_r2s20T20-062a8ac2d777e207/0.0.0/0f7e3662623656454fcd2b650f34e886a7db4b9104504885bd462096cc7a9f51)


  0%|          | 0/3 [00:00<?, ?it/s]

In [6]:
# Define the ID of the sample you want to inspect
sample_id = 1  # Replace with the ID of the sample you're interested in

# Fetch the specific sample from the dataset
sample_input_ids, sample_target_ids = dataset[sample_id]
sample_input_ids, sample_target_ids = sample_input_ids.unsqueeze(0).to(device), sample_target_ids.unsqueeze(0).to(device)


# Decode the tokens to strings for readability
decoded_input = tokenizer.decode(sample_input_ids[0].tolist())
decoded_target = tokenizer.decode(sample_target_ids[0].tolist())

# Print input, target, prediction, and logits
print("inputs:",len(sample_input_ids),sample_input_ids.tolist())
print("targets:",len(sample_target_ids),sample_target_ids.tolist())
print("Input:", len(sample_input_ids[0].tolist()),decoded_input)
print("Target:", len(sample_target_ids[0].tolist()), decoded_target)

inputs: 1 [[0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 1, 0, 0, 1, 1, 0, 0, 0, 0, 1, 1, 0, 1, 1, 1, 1, 1, 2, 1, 0, 1, 0, 0, 0, 1, 0, 0, 1, 1, 1, 0, 1, 0, 0, 0, 0, 1, 0, 2, 0, 1, 0, 0, 0, 1, 1, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 1, 2, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 1, 0, 1, 0, 2, 0, 0, 0, 1, 1, 0, 0, 0, 0, 1, 0, 1, 1, 0, 0, 1, 0, 1, 0, 1, 2, 0, 0, 1, 0, 1, 1, 0, 0, 1, 0, 0, 0, 1, 1, 1, 0, 1, 0, 1, 0, 2, 0, 1, 0, 0, 0, 1, 1, 1, 1, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 2, 1, 1, 0, 0, 1, 0, 1, 1, 1, 1, 1, 1, 0, 1, 1, 0, 1, 1, 0, 0, 2, 0, 1, 1, 1, 0, 0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 1, 0, 1, 1, 1, 2, 1, 0, 0, 1, 1, 0, 1, 0, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 1, 2, 1, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 0, 0, 1, 0, 2, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3]]
targets: 1 [[3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3,

# Evaluation

In [7]:
def evaluate_model(model, val_dataloader, device):
    model.eval()  # Set the model to evaluation mode
    total, correct = 0, 0
    val_loop = tqdm(val_dataloader, leave=False, desc="Validation")

    with torch.no_grad():
        for input_ids, target_ids in val_loop:
            input_ids, target_ids = input_ids.to(device), target_ids.to(device)

            outputs = model(input_ids)
            predictions = torch.argmax(outputs.logits, dim=-1)

            # Create a mask where target_ids is not -100 (ignoring padding)
            valid_mask = target_ids != 3

            # Apply the mask to filter out -100 values
            valid_predictions = predictions[valid_mask]
            valid_targets = target_ids[valid_mask]

            correct += (valid_predictions == valid_targets).sum().item()
            total += valid_targets.numel()

    accuracy = correct / total if total > 0 else 0
    return accuracy

# Training loop

In [8]:
model.to(device)

lr=1e-4 # learning rate

num_steps = 25_000  # Define the total number of training steps
evaluate_every = 100  # Frequency of evaluation
wndb_log = 10

run_cfg = {
    'dataset': dataset_name,
    'model': model_name,
    'input_size': 512,
    'hidden_size': hidden_size,
    'num_hidden_layers': num_hidden_layers,
    'num_attention_heads': num_attention_heads,
    'batch_size': batch_size,
    'learning rate': lr,
    'seed': seed
}
run_name = dataset_name+'i'+str(steps)+'f'+str(look_ahead)+'H2shrt_'+model_name+'_'+str(num_hidden_layers)+'L'+str(num_attention_heads)+'H'+str(hidden_size)+'HD'+'_bs'+str(batch_size)+'lr'+str(lr)
run_name = run_name + str(num_steps) + 'Kgc_v' + str(seed)
run = wandb.init(
    project="1dCA_ai4math24",
    name = run_name,
    config=run_cfg
)

optimizer = torch.optim.AdamW(model.parameters(), lr)

losses = []

model.train()

loop = tqdm(range(num_steps), desc='Training')
data_iter = cycle(dataloader)  # Create an infinite loop of dataloader

for step in loop:
    
    input_ids, target_ids = next(data_iter)
    input_ids, target_ids = input_ids.to(device), target_ids.to(device)

    outputs = model(input_ids, labels=target_ids)
    loss = outputs.loss

    optimizer.zero_grad()
    loss.backward()
    total_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1)
    optimizer.step()
    current_loss = loss.item()
    loop.set_description(f"Step {step + 1}/{num_steps} - Loss: {current_loss:.4f}")
    loop.set_postfix(loss=current_loss)

    if step % evaluate_every == 0 and step != 0:
        val_accuracy = evaluate_model(model, small_val_dataloader, device)
        print(f"\nEvaluation at Step {step}: Accuracy = {val_accuracy:.4f}", "loss:", loss.item())
        wandb.log({'step': step, 'val acc': val_accuracy})

    if step % wndb_log == 0:
        losses.append(loss.detach().item())      
        wandb.log({'step': step, 'loss': current_loss})

loop.close()
wandb.finish()

Training:   0%|          | 0/25000 [00:00<?, ?it/s]

Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 100: Accuracy = 0.6238 loss: 0.0501888245344162


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 200: Accuracy = 0.6444 loss: 0.04876359552145004


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 300: Accuracy = 0.6519 loss: 0.04734931141138077


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 400: Accuracy = 0.6555 loss: 0.04686233401298523


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 500: Accuracy = 0.6540 loss: 0.04672092944383621


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 600: Accuracy = 0.6598 loss: 0.0481073260307312


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 700: Accuracy = 0.6626 loss: 0.0461462177336216


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 800: Accuracy = 0.6592 loss: 0.046922869980335236


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 900: Accuracy = 0.6609 loss: 0.045429471880197525


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 1000: Accuracy = 0.6633 loss: 0.04472777247428894


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 1100: Accuracy = 0.6632 loss: 0.045997217297554016


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 1200: Accuracy = 0.6675 loss: 0.0449996292591095


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 1300: Accuracy = 0.6717 loss: 0.04416356980800629


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 1400: Accuracy = 0.6651 loss: 0.046696633100509644


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 1500: Accuracy = 0.6697 loss: 0.045693714171648026


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 1600: Accuracy = 0.6700 loss: 0.04516320303082466


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 1700: Accuracy = 0.6771 loss: 0.043189637362957


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 1800: Accuracy = 0.6740 loss: 0.04448617249727249


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 1900: Accuracy = 0.6750 loss: 0.045784685760736465


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 2000: Accuracy = 0.6747 loss: 0.04456828162074089


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 2100: Accuracy = 0.6803 loss: 0.04381739720702171


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 2200: Accuracy = 0.6785 loss: 0.04454610496759415


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 2300: Accuracy = 0.6825 loss: 0.043529607355594635


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 2400: Accuracy = 0.6890 loss: 0.04262175038456917


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 2500: Accuracy = 0.6918 loss: 0.043455690145492554


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 2600: Accuracy = 0.6990 loss: 0.04028617963194847


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 2700: Accuracy = 0.7062 loss: 0.041443537920713425


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 2800: Accuracy = 0.7153 loss: 0.03979628533124924


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 2900: Accuracy = 0.7230 loss: 0.04062338545918465


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 3000: Accuracy = 0.7262 loss: 0.04061414673924446


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 3100: Accuracy = 0.7435 loss: 0.037285976111888885


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 3200: Accuracy = 0.7658 loss: 0.03574473038315773


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 3300: Accuracy = 0.7952 loss: 0.030551517382264137


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 3400: Accuracy = 0.8719 loss: 0.02429770492017269


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 3500: Accuracy = 0.9966 loss: 0.0010759166907519102


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 3600: Accuracy = 1.0000 loss: 6.559159373864532e-05


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 3700: Accuracy = 1.0000 loss: 4.3142703361809254e-05


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 3800: Accuracy = 1.0000 loss: 3.283041587565094e-05


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 3900: Accuracy = 1.0000 loss: 2.7058196792495437e-05


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 4000: Accuracy = 1.0000 loss: 6.205767567735165e-05


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 4100: Accuracy = 0.9998 loss: 3.784719592658803e-05


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 4200: Accuracy = 0.9997 loss: 0.00021135220595169812


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 4300: Accuracy = 1.0000 loss: 2.121028046531137e-05


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 4400: Accuracy = 1.0000 loss: 1.5165029253694229e-05


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 4500: Accuracy = 1.0000 loss: 1.2756931937474292e-05


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 4600: Accuracy = 1.0000 loss: 1.1208808246010449e-05


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 4700: Accuracy = 1.0000 loss: 1.0127696441486478e-05


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 4800: Accuracy = 1.0000 loss: 9.227463124261703e-06


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 4900: Accuracy = 1.0000 loss: 8.562796210753731e-06


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 5000: Accuracy = 1.0000 loss: 7.989111509232316e-06


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 5100: Accuracy = 1.0000 loss: 7.31420823285589e-06


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 5200: Accuracy = 1.0000 loss: 6.904710971866734e-06


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 5300: Accuracy = 1.0000 loss: 6.31206648904481e-06


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 5400: Accuracy = 1.0000 loss: 5.925351160840364e-06


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 5500: Accuracy = 1.0000 loss: 5.54668531549396e-06


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 5600: Accuracy = 1.0000 loss: 5.176190370548284e-06


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 5700: Accuracy = 1.0000 loss: 4.9050909183279146e-06


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 5800: Accuracy = 1.0000 loss: 4.616124442691216e-06


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 5900: Accuracy = 1.0000 loss: 4.349034043116262e-06


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 6000: Accuracy = 1.0000 loss: 4.168029136053519e-06


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 6100: Accuracy = 1.0000 loss: 3.876461960317101e-06


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 6200: Accuracy = 1.0000 loss: 3.6832700516242767e-06


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 6300: Accuracy = 1.0000 loss: 3.4688584946707124e-06


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 6400: Accuracy = 1.0000 loss: 3.2801265206217067e-06


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 6500: Accuracy = 0.9998 loss: 6.348361785057932e-05


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 6600: Accuracy = 1.0000 loss: 2.8607719286810607e-05


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 6700: Accuracy = 1.0000 loss: 2.012204276979901e-05


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 6800: Accuracy = 1.0000 loss: 1.6258574760286137e-05


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 6900: Accuracy = 1.0000 loss: 1.3442626368487254e-05


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 7000: Accuracy = 1.0000 loss: 1.1792855730163865e-05


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 7100: Accuracy = 1.0000 loss: 1.0092858246935066e-05


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 7200: Accuracy = 1.0000 loss: 9.209554264089093e-06


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 7300: Accuracy = 1.0000 loss: 8.126788088702597e-06


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 7400: Accuracy = 1.0000 loss: 7.324174021050567e-06


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 7500: Accuracy = 1.0000 loss: 6.800306891818764e-06


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 7600: Accuracy = 1.0000 loss: 6.230714461707976e-06


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 7700: Accuracy = 1.0000 loss: 5.7666393331601284e-06


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 7800: Accuracy = 0.9999 loss: 3.0717914341948926e-05


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 7900: Accuracy = 1.0000 loss: 8.844427611620631e-06


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 8000: Accuracy = 1.0000 loss: 6.90809974912554e-06


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 8100: Accuracy = 1.0000 loss: 5.905745638301596e-06


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 8200: Accuracy = 1.0000 loss: 5.361473540688166e-06


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 8300: Accuracy = 1.0000 loss: 4.822314167540753e-06


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 8400: Accuracy = 1.0000 loss: 4.473300123208901e-06


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 8500: Accuracy = 1.0000 loss: 4.0857112253434025e-06


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 8600: Accuracy = 1.0000 loss: 3.826263764494797e-06


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 8700: Accuracy = 1.0000 loss: 3.5570835734688444e-06


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 8800: Accuracy = 1.0000 loss: 3.3119868021458387e-06


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 8900: Accuracy = 1.0000 loss: 3.117425194432144e-06


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 9000: Accuracy = 1.0000 loss: 2.9319862733245827e-06


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 9100: Accuracy = 1.0000 loss: 2.765505314528127e-06


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 9200: Accuracy = 1.0000 loss: 2.6285495096090017e-06


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 9300: Accuracy = 1.0000 loss: 2.478249143678113e-06


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 9400: Accuracy = 1.0000 loss: 2.307067006768193e-06


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 9500: Accuracy = 1.0000 loss: 2.2245444597501773e-06


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 9600: Accuracy = 1.0000 loss: 2.103972065015114e-06


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 9700: Accuracy = 1.0000 loss: 2.0216364191583125e-06


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 9800: Accuracy = 1.0000 loss: 1.902209191939619e-06


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 9900: Accuracy = 1.0000 loss: 1.7945121726370417e-06


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 10000: Accuracy = 1.0000 loss: 1.724720050333417e-06


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 10100: Accuracy = 1.0000 loss: 1.5675459508202039e-06


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 10200: Accuracy = 1.0000 loss: 1.4867335949020344e-06


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 10300: Accuracy = 1.0000 loss: 1.4148146192383138e-06


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 10400: Accuracy = 1.0000 loss: 1.3422029496723553e-06


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 10500: Accuracy = 1.0000 loss: 1.2865671124018263e-06


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 10600: Accuracy = 1.0000 loss: 1.2397756563586881e-06


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 10700: Accuracy = 1.0000 loss: 1.1872901950482628e-06


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 10800: Accuracy = 1.0000 loss: 1.1383978062440292e-06


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 10900: Accuracy = 1.0000 loss: 1.0886822110478533e-06


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 11000: Accuracy = 1.0000 loss: 1.0379451396147488e-06


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 11100: Accuracy = 1.0000 loss: 9.658792805566918e-07


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 11200: Accuracy = 1.0000 loss: 9.155032785201911e-07


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 11300: Accuracy = 1.0000 loss: 8.599132570452639e-07


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 11400: Accuracy = 1.0000 loss: 8.214874469558708e-07


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 11500: Accuracy = 1.0000 loss: 7.890809001764865e-07


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 11600: Accuracy = 1.0000 loss: 7.591794428662979e-07


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 11700: Accuracy = 1.0000 loss: 7.303967208827089e-07


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 11800: Accuracy = 1.0000 loss: 7.129151526896749e-07


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 11900: Accuracy = 1.0000 loss: 6.836457373537996e-07


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 12000: Accuracy = 1.0000 loss: 6.568723165401025e-07


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 12100: Accuracy = 1.0000 loss: 6.362997169162554e-07


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 12200: Accuracy = 1.0000 loss: 6.190729777699744e-07


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 12300: Accuracy = 1.0000 loss: 5.973523684588145e-07


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 12400: Accuracy = 1.0000 loss: 5.789032115899317e-07


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 12500: Accuracy = 1.0000 loss: 5.682540518137102e-07


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 12600: Accuracy = 1.0000 loss: 5.418659725364705e-07


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 12700: Accuracy = 1.0000 loss: 5.089513592793082e-07


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 12800: Accuracy = 1.0000 loss: 4.6328605662893096e-07


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 12900: Accuracy = 1.0000 loss: 4.1945293105527526e-07


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 13000: Accuracy = 1.0000 loss: 3.9024439502099995e-07


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 13100: Accuracy = 1.0000 loss: 3.74786736756505e-07


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 13200: Accuracy = 1.0000 loss: 3.5417460253484023e-07


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 13300: Accuracy = 1.0000 loss: 3.24741421309227e-07


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 13400: Accuracy = 1.0000 loss: 2.952224349428434e-07


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 13500: Accuracy = 1.0000 loss: 2.6618212700668664e-07


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 13600: Accuracy = 1.0000 loss: 2.4814269750095264e-07


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 13700: Accuracy = 1.0000 loss: 2.340757703223062e-07


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 13800: Accuracy = 1.0000 loss: 2.19077534779899e-07


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 13900: Accuracy = 1.0000 loss: 2.0792762711607793e-07


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 14000: Accuracy = 1.0000 loss: 1.9738504875022045e-07


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 14100: Accuracy = 1.0000 loss: 1.8758188957690436e-07


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 14200: Accuracy = 1.0000 loss: 1.7516200045974983e-07


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 14300: Accuracy = 1.0000 loss: 1.672643890060499e-07


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 14400: Accuracy = 1.0000 loss: 1.5800308972302446e-07


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 14500: Accuracy = 1.0000 loss: 1.4992711783179402e-07


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 14600: Accuracy = 1.0000 loss: 1.4082725385833328e-07


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 14700: Accuracy = 1.0000 loss: 1.3388128650149156e-07


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 14800: Accuracy = 1.0000 loss: 1.3021582390138065e-07


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 14900: Accuracy = 1.0000 loss: 1.218305243355644e-07


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 15000: Accuracy = 1.0000 loss: 1.1568831581598715e-07


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 15100: Accuracy = 1.0000 loss: 1.100958684219222e-07


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 15200: Accuracy = 1.0000 loss: 1.054460270211166e-07


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 15300: Accuracy = 1.0000 loss: 9.962779756733653e-08


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 15400: Accuracy = 1.0000 loss: 9.504343267963122e-08


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 15500: Accuracy = 1.0000 loss: 9.147618840188443e-08


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 15600: Accuracy = 1.0000 loss: 8.763462489014273e-08


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 15700: Accuracy = 1.0000 loss: 8.20241083943074e-08


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 15800: Accuracy = 1.0000 loss: 7.675677693441685e-08


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 15900: Accuracy = 1.0000 loss: 7.386007894183422e-08


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 16000: Accuracy = 1.0000 loss: 7.066875440386866e-08


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 16100: Accuracy = 1.0000 loss: 6.735889002129625e-08


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 16200: Accuracy = 1.0000 loss: 6.365279148212721e-08


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 16300: Accuracy = 1.0000 loss: 5.8973597560907365e-08


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 16400: Accuracy = 1.0000 loss: 5.7528637853465625e-08


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 16500: Accuracy = 1.0000 loss: 5.3760448537332195e-08


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 16600: Accuracy = 1.0000 loss: 5.05871788902823e-08


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 16700: Accuracy = 1.0000 loss: 4.97573502400428e-08


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 16800: Accuracy = 1.0000 loss: 4.610215498246362e-08


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 16900: Accuracy = 1.0000 loss: 4.4670741772279143e-08


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 17000: Accuracy = 1.0000 loss: 4.2691823409768404e-08


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 17100: Accuracy = 1.0000 loss: 3.96099899546698e-08


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 17200: Accuracy = 1.0000 loss: 3.68577950382587e-08


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 17300: Accuracy = 1.0000 loss: 3.402092829674075e-08


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 17400: Accuracy = 1.0000 loss: 3.144483429196043e-08


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 17500: Accuracy = 1.0000 loss: 2.995246006776142e-08


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 17600: Accuracy = 1.0000 loss: 2.8437510124490473e-08


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 17700: Accuracy = 1.0000 loss: 2.7394428059324127e-08


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 17800: Accuracy = 1.0000 loss: 2.5911084833296627e-08


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 17900: Accuracy = 1.0000 loss: 2.453272784919136e-08


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 18000: Accuracy = 0.5425 loss: 0.05315123870968819


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 18100: Accuracy = 0.6156 loss: 0.04957881569862366


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 18200: Accuracy = 0.6563 loss: 0.04665272682905197


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 18300: Accuracy = 0.6643 loss: 0.0455663725733757


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 18400: Accuracy = 0.9995 loss: 0.00014708968228660524


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 18500: Accuracy = 1.0000 loss: 9.342393241240643e-06


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 18600: Accuracy = 1.0000 loss: 7.959598406159785e-06


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 18700: Accuracy = 1.0000 loss: 5.269053417578107e-06


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 18800: Accuracy = 1.0000 loss: 4.051220912515419e-06


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 18900: Accuracy = 1.0000 loss: 3.418287178647006e-06


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 19000: Accuracy = 1.0000 loss: 3.2100981570692966e-06


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 19100: Accuracy = 0.9997 loss: 5.2456107368925586e-05


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 19200: Accuracy = 1.0000 loss: 1.6879588656593114e-05


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 19300: Accuracy = 1.0000 loss: 2.990599568875041e-06


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 19400: Accuracy = 1.0000 loss: 2.4524915716028772e-06


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 19500: Accuracy = 1.0000 loss: 2.1094879230076913e-06


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 19600: Accuracy = 1.0000 loss: 1.8484288375475444e-06


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 19700: Accuracy = 1.0000 loss: 1.758484245328873e-06


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 19800: Accuracy = 1.0000 loss: 1.6198900993913412e-06


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 19900: Accuracy = 1.0000 loss: 1.4400474128706264e-06


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 20000: Accuracy = 1.0000 loss: 1.340362928203831e-06


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 20100: Accuracy = 1.0000 loss: 1.259603891412553e-06


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 20200: Accuracy = 1.0000 loss: 1.1645940958260326e-06


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 20300: Accuracy = 1.0000 loss: 1.091410354092659e-06


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 20400: Accuracy = 1.0000 loss: 1.0149713034479646e-06


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 20900: Accuracy = 1.0000 loss: 8.002822369235219e-07


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 21000: Accuracy = 1.0000 loss: 7.595360784762306e-07


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 21100: Accuracy = 1.0000 loss: 7.299705089280906e-07


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 21200: Accuracy = 1.0000 loss: 6.99497320511e-07


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 21300: Accuracy = 1.0000 loss: 6.9050719275765e-07


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 21400: Accuracy = 1.0000 loss: 7.100266543602629e-07


IOPub message rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_msg_rate_limit`.

Current values:
ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
ServerApp.rate_limit_window=3.0 (secs)



Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 22300: Accuracy = 1.0000 loss: 4.241060196363833e-07


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 22400: Accuracy = 1.0000 loss: 3.9843365584602e-07


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 22500: Accuracy = 1.0000 loss: 3.783995055073319e-07


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 22600: Accuracy = 1.0000 loss: 3.7018787679699017e-07


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 22700: Accuracy = 1.0000 loss: 3.445727827511291e-07


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 22800: Accuracy = 1.0000 loss: 3.268214356921817e-07


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 22900: Accuracy = 1.0000 loss: 3.195017370671849e-07


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 23000: Accuracy = 1.0000 loss: 3.0106733106549655e-07


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 23100: Accuracy = 0.5213 loss: 0.052414558827877045


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 23200: Accuracy = 0.6468 loss: 0.047805819660425186


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 23300: Accuracy = 0.6682 loss: 0.046381089836359024


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 23400: Accuracy = 0.9699 loss: 0.005624633748084307


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 23500: Accuracy = 0.9999 loss: 0.00045036038500256836


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 23600: Accuracy = 0.9997 loss: 4.9965216021519154e-05


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 23700: Accuracy = 1.0000 loss: 9.806245543586556e-06


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 23800: Accuracy = 1.0000 loss: 7.8192524597398e-06


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 23900: Accuracy = 1.0000 loss: 9.641606993682217e-06


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 24000: Accuracy = 1.0000 loss: 6.175066118885297e-06


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 24100: Accuracy = 1.0000 loss: 5.657150722981896e-06


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 24200: Accuracy = 1.0000 loss: 4.589769559970591e-06


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 24300: Accuracy = 1.0000 loss: 4.133082711632596e-06


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 24400: Accuracy = 1.0000 loss: 4.438639280124335e-06


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 24500: Accuracy = 1.0000 loss: 3.659482445073081e-06


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 24600: Accuracy = 1.0000 loss: 3.241672857257072e-06


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 24700: Accuracy = 1.0000 loss: 2.984039610964828e-06


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 24800: Accuracy = 1.0000 loss: 2.758157734206179e-06


Validation:   0%|          | 0/3 [00:00<?, ?it/s]


Evaluation at Step 24900: Accuracy = 1.0000 loss: 2.5995182113547344e-06


VBox(children=(Label(value='0.017 MB of 0.017 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
loss,██▇▇▇▆▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
step,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
val acc,▁▁▁▂▂▃███████████████████████████████▇██

0,1
loss,0.0
step,24990.0
val acc,1.0


In [15]:
# Define the checkpoint content
checkpoint = {
    'model_state_dict': model.state_dict(),            # Model's state dictionary
    'optimizer_state_dict': optimizer.state_dict(),        # Optimizer's state dictionary
    'step': step,                                      # Current training step
    'losses': losses,                                  # Loss history
    'config': run_cfg,
    'seed': seed
}

run_name = dataset_name+'i'+str(steps)+'f'+str(look_ahead)+'H2shrt_IR_'+model_name+'_'+str(num_hidden_layers)+'L'+str(num_attention_heads)+'H'+str(hidden_size)+'HD'+'_bs'+str(batch_size)+'lr'+str(lr)
run_name = run_name + str(num_steps) + 'Kgc_v' + str(seed)

# Define a checkpoint file path
checkpoint_path = run_name + ".pth"

# Save the checkpoint
torch.save(checkpoint, checkpoint_path)

In [10]:
# Define the ID of the sample you want to inspect
sample_id = 11  # Replace with the ID of the sample you're interested in

# Fetch the specific sample from the dataset
sample_input_ids, sample_target_ids = dataset[sample_id]
sample_input_ids, sample_target_ids = sample_input_ids.unsqueeze(0).to(device), sample_target_ids.unsqueeze(0).to(device)

# Set the model to evaluation mode and get output
model.eval()
with torch.no_grad():
    output = model(sample_input_ids)

# Decode the tokens to strings for readability
decoded_input = tokenizer.decode(sample_input_ids[0].tolist())
decoded_target = tokenizer.decode(sample_target_ids[0].tolist())
decoded_prediction = tokenizer.decode(torch.argmax(output.logits, dim=-1)[0].tolist())

# Print input, target, prediction, and logits
print("inputs:",len(sample_input_ids),sample_input_ids)
print("targets:",len(sample_target_ids),sample_target_ids)
print(len(output),output)
print("Input:", len(sample_input_ids[0].tolist()),decoded_input)
print("Target:", len(sample_target_ids[0].tolist()), decoded_target)
print("Prediction:", decoded_prediction)
print("Logits:", len(output.logits[0]), output.logits[0])

# Optionally, you can convert logits to probabilities for interpretation
probabilities = torch.softmax(output.logits[0], dim=-1)
print("Probabilities:", probabilities)


inputs: 1 tensor([[1, 0, 1, 1, 1, 1, 1, 0, 1, 0, 0, 1, 0, 0, 0, 1, 1, 0, 0, 0, 1, 0, 1, 1,
         0, 1, 0, 0, 1, 0, 1, 0, 2, 1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 0, 1, 0, 1, 1,
         1, 0, 1, 0, 1, 2, 1, 0, 1, 1, 0, 1, 0, 1, 1, 0, 0, 0, 1, 1, 0, 0, 0, 0,
         1, 1, 2, 0, 0, 1, 0, 0, 0, 1, 1, 0, 0, 0, 1, 1, 0, 0, 1, 0, 1, 0, 0, 2,
         0, 1, 1, 1, 0, 1, 1, 0, 0, 0, 1, 1, 0, 1, 0, 1, 0, 1, 1, 1, 2, 0, 1, 0,
         0, 0, 1, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 1, 1, 0, 0, 2, 1, 1, 1, 0, 1, 1,
         1, 0, 1, 1, 0, 0, 1, 1, 0, 1, 1, 0, 0, 0, 2, 0, 0, 0, 0, 1, 0, 0, 0, 1,
         0, 1, 0, 1, 0, 0, 1, 0, 0, 0, 1, 2, 1, 1, 0, 1, 1, 1, 0, 1, 1, 0, 0, 0,
         1, 0, 0, 1, 1, 0, 1, 1, 2, 1, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 1, 0, 0,
         1, 0, 0, 1, 1, 2, 1, 1, 0, 1, 1, 0, 1, 1, 1, 0, 1, 1, 0, 1, 0, 1, 0, 0,
         0, 0, 2, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3]],
       device='cuda:6')
targets: 1 tensor([[3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3,

In [11]:
def evaluate_model_LH(model, val_dataloader, device, horizon):
    model.eval()  # Set the model to evaluation mode
    val_loop = tqdm(val_dataloader, leave=False, desc="Validation")
    correct = [0] * horizon
    total = [0] * horizon
    accuracy = [0] * horizon

    with torch.no_grad():
        for input_ids, target_ids in val_loop:
            input_ids, target_ids = input_ids.to(device), target_ids.to(device)

            outputs = model(input_ids)
            predictions = torch.argmax(outputs.logits, dim=-1)
            
            for t in range(horizon):

                # Create a mask where target_ids is not 3 (ignoring padding)
                #valid_mask = target_ids[:, t] != 3

                # Apply the mask to filter out -100 values
                start_idx  = rule_len + orbit_len + t * state_len
                valid_predictions = predictions[:, start_idx : start_idx + state_len ]
                valid_targets = target_ids[:, start_idx : start_idx + state_len]
            

                correct[t] += (valid_predictions == valid_targets).sum().item()
                total[t] += len(valid_targets)

    for t in range(horizon):
        accuracy[t] = correct[t] / total[t] / state_len  if total[t] > 0 else 0
        
    return accuracy, correct, total

In [12]:
val_accuracy = evaluate_model_LH(model, val_dataloader, device,look_ahead)
print(val_accuracy)
#print(f"\nEvaluation at Step {step}: Validation accuracy = {val_accuracy:.10f}", "loss:", loss.item())

Validation:   0%|          | 0/125 [00:00<?, ?it/s]

([0.9996799999999999], [1049664], [50000])


In [13]:
test_accuracy = evaluate_model_LH(model, test_dataloader, device,look_ahead)
print(test_accuracy)
#print(f"\nEvaluation at Step {step}: Test Accuracy = {test_accuracy:.10f}", "loss:", loss.item())

Validation:   0%|          | 0/250 [00:00<?, ?it/s]

([0.9996728571428571], [2099313], [100000])


In [14]:
Evaluation at Step 49999: Accuracy = 0.999925 loss: 0.0001053252344718203

SyntaxError: invalid syntax (1662878917.py, line 1)

In [None]:
orbit_len = state_len * steps

In [None]:
0.9765709523809524