In [None]:
# %matplotlib ipympl

import os
import glob
import torch
import tqdm
import time
import csv
from torch import nn
from torch.utils.data import DataLoader, ConcatDataset, random_split
from learning.model.actionvalue import DWActionValueModel
from learning.data.process_utils import move_to_device
from learning.data.av_utils import (
    extract_input, extract_output_target, collate_av_data
)
from learning.data.av_dataset import ActionValueDataset

In [3]:
history_folder = "server/history/"
history_files = sorted(glob.glob(os.path.join(history_folder, "history_*.json")))

datasets = [ActionValueDataset(f) for f in history_files]
concatenated_dataset = ConcatDataset(datasets)

train_ds, val_ds, test_ds = random_split(
    concatenated_dataset,
    [0.8, 0.1, 0.1],
)

In [4]:
# action_values = torch.stack([
#     d["action_value"] 
#     for d in train_ds
# ])
# mean = action_values.mean()
# std = action_values.std()
# mean, std
# (tensor(-820.0759), tensor(1868.9108))

In [None]:
mean_val = torch.tensor(-820)
std_val = torch.tensor(1870)

In [6]:

# Set a fixed seed for reproducibility
seed = 42
torch.manual_seed(seed)
generator = torch.Generator().manual_seed(seed)

train_loader = DataLoader(train_ds, 16000, pin_memory=True, shuffle=True, 
                         collate_fn=collate_av_data, generator=generator)
val_loader = DataLoader(val_ds, 20000, pin_memory=True, shuffle=True, 
                       collate_fn=collate_av_data, generator=generator)

print("train_ds", len(train_ds))
print("val_ds", len(val_ds))
print("test_ds", len(test_ds))

train_ds 324610
val_ds 40576
test_ds 40576


In [7]:
model_state = None
optim_state = None
latest_epoch = -1

model_name = "actionvalue_with_freeze"
models_dir = f"learning/{model_name}_checkpoints" 

os.makedirs(models_dir, exist_ok=True)

# Find the latest checkpoint file (with highest epoch number)
checkpoint_files = glob.glob(os.path.join(models_dir, f"{model_name}_*.pt"))
if checkpoint_files:
    # Extract epoch numbers from filenames
    epoch_nums = [int(f.split("_")[-1].split(".")[0]) for f in checkpoint_files]
    latest_epoch = max(epoch_nums)
    latest_checkpoint = os.path.join(models_dir, f"{model_name}_{latest_epoch:06}.pt")
    print(f"Loading latest checkpoint: {latest_checkpoint} (epoch {latest_epoch})")
    checkpoint = torch.load(latest_checkpoint, map_location=torch.device('cpu'))
    model_state = checkpoint["model_state"]
    optim_state = checkpoint["optim_state"]
else:
    print("No valid checkpoint files found")
    
start_epoch = latest_epoch + 1

No valid checkpoint files found


In [8]:
action_value_model = DWActionValueModel()
num_params = sum(p.numel() for p in action_value_model.parameters() if p.requires_grad)
print(f"Model created. Total trainable parameters: {num_params:,}")

Model created. Total trainable parameters: 277


In [9]:
init_from_old_model = (model_state is None)

if init_from_old_model:
    print("⚠️⚠️⚠️ Warning, model is being initialized from an old checkpoint")
    
    checkpoint = torch.load(
        "learning/old_models/stateflow_winner_action.pt",
        map_location=torch.device('cpu')
    )
    model_state = checkpoint["model_state"]

    # update old model state to match the new one
    mismatch = action_value_model.load_state_dict(model_state, strict=False)
    model_state = action_value_model.state_dict()

    missing_keys, unexpected_keys = mismatch

    for mk in missing_keys:
        print("missing key", mk)

    for uk in unexpected_keys:
        print("unexpected key", uk)

missing key attack_value_fn.edge_value_fn.weight
missing key attack_value_fn.edge_value_fn.bias
missing key end_turn_value_fn.attention_coef_fn.transform_fn.weight
missing key end_turn_value_fn.summary_fn.weight
missing key end_turn_value_fn.summary_fn.bias
unexpected key winner_logit_fn.attention_coef_fn.transform_fn.weight
unexpected key winner_logit_fn.summary_fn.weight
unexpected key winner_logit_fn.summary_fn.bias
unexpected key attack_logit_fn.edge_value_fn.weight
unexpected key attack_logit_fn.edge_value_fn.bias
unexpected key end_turn_logit_fn.attention_coef_fn.transform_fn.weight
unexpected key end_turn_logit_fn.summary_fn.weight
unexpected key end_turn_logit_fn.summary_fn.bias


In [10]:
# Load the model before everything else so that I can freeze some layers
if model_state is not None:
    action_value_model.load_state_dict(model_state)

action_value_model.gat_layers[0].requires_grad_(False)
# action_value_model.gat_layers[1].requires_grad_(False)
# action_value_model.gat_layers[2].requires_grad_(False)

action_value_model.gat_layers[1].reset_parameters() 
# nn.init.constant_(action_value_model.gat_layers[1].passthrough_coef, -1)

action_value_model.gat_layers[2].reset_parameters() 
# nn.init.constant_(action_value_model.gat_layers[2].passthrough_coef, -1)

In [11]:
# ── CSV set‑up: log one row per epoch ───────────────────────────────────────
csv_path = f"{model_name}_epoch_metrics.csv"
epoch_log_file_exists = os.path.isfile(csv_path)
if epoch_log_file_exists and os.path.getsize(csv_path) == 0:
    os.remove(csv_path)
    epoch_log_file_exists = False

csv_file = open(csv_path, "a", newline="")  # Open in append mode
csv_writer = csv.writer(csv_file)

# Only write header if file doesn't exist yet
if not epoch_log_file_exists:
    csv_writer.writerow([
        "epoch",           # 0‑based epoch index
        "train_loss",      # average training loss for the epoch
        "val_loss",        # average validation loss for the epoch
        "train_time_sec",  # seconds spent in training phase
        "val_time_sec",    # seconds spent in validation phase
        "total_time_sec"   # train + val
    ])

# ── model / optimizer prep ─────────────────────────────────────────────────

n_epochs   = 1000
device     = torch.device("cuda", 0)
train_time_limit = 10 * 60
val_time_limit = train_time_limit / 5
criterion  = nn.MSELoss()

action_value_model = action_value_model.to(device)

optimizer  = torch.optim.Adam(
    [p for p in action_value_model.parameters() if p.requires_grad],
    lr=1e-2
)
reset_optimizer = True
if not reset_optimizer and optim_state is not None:
    optimizer.load_state_dict(optim_state)

scheduler  = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, mode="min",
    factor=0.5,
    patience=5,
    threshold=0.0001,
    min_lr=1e-8
)


In [12]:
# ── INITIAL (untrained) LOSS EVALUATION ───────────────────────────────────────
if not epoch_log_file_exists:
    action_value_model.eval()
    with torch.no_grad():
        t_start = time.time()

        # avg loss on training set *without* gradient tracking
        init_train_sum, init_train_batches = 0.0, 0
        for b in tqdm.tqdm(train_loader, desc="init‑train"):
            b = move_to_device(b, device)
            out = action_value_model(*extract_input(b))
            o, t = extract_output_target(b, out, mean_val, std_val)
            init_train_sum += criterion(o, t).item()
            init_train_batches += 1
            if (time.time() - t_start) > train_time_limit: 
                print(f"Stopping training after >{train_time_limit} seconds.")
                break
        init_train_loss = init_train_sum / init_train_batches
        t_train_done = time.time()

        # avg loss on validation set
        init_val_sum, init_val_batches = 0.0, 0
        for vb in tqdm.tqdm(val_loader, desc="init‑val"):
            vb = move_to_device(vb, device)
            vout = action_value_model(*extract_input(vb))
            vo, vt = extract_output_target(
                vb, vout, mean_val, std_val
            )
            init_val_sum += criterion(vo, vt).item()
            init_val_batches += 1
            if (time.time() - t_train_done) > val_time_limit:
                print(f"Stopping validation after >{val_time_limit} seconds.")
                break
        init_val_loss = init_val_sum / init_val_batches
        t_val_done = time.time()

    # times
    init_train_time = t_train_done - t_start
    init_val_time   = t_val_done - t_train_done
    init_total_time = t_val_done - t_start

    # write initial row (epoch = None)
    csv_writer.writerow([
        -1,
        init_train_loss,
        init_val_loss,
        init_train_time,
        init_val_time,
        init_total_time
    ])
    csv_file.flush()
    print(
        f"Initial (-1) | "
        f"Train {init_train_loss:.4f} | "
        f"Val {init_val_loss:.4f} | "
        f"Time {init_total_time:.1f}s (T {init_train_time:.1f}s | "
        f"V {init_val_time:.1f}s)"
    )

init‑train: 100%|██████████| 21/21 [02:50<00:00,  8.13s/it]
init‑val: 100%|██████████| 3/3 [00:21<00:00,  7.31s/it]

Initial (-1) | Train 1.7770 | Val 1.7762 | Time 192.8s (T 170.9s | V 21.9s)





In [13]:
scheduler.get_last_lr()

[0.01]

In [14]:
for epoch in range(start_epoch, n_epochs):
    t_start = time.time()

    # ── TRAINING -----------------------------------------------------------
    action_value_model.train()
    sum_train_loss = torch.tensor(0.0, device=device)
    n_train_batches = 0

    t_batch_tqdm = tqdm.tqdm(
        train_loader, 
        desc=f"train {epoch}" 
    )
    for t_batch in t_batch_tqdm:
        t_batch = move_to_device(t_batch, device)

        optimizer.zero_grad()

        model_out = action_value_model(*extract_input(t_batch))
        outputs, targets = extract_output_target(t_batch, model_out, mean_val, std_val)
        loss = criterion(outputs, targets)
        
        loss.backward()
        optimizer.step()
    
        if torch.isnan(loss) or torch.isinf(loss):
            raise ValueError(f"Loss contains NaN or Inf values")
        
        sum_train_loss += loss.detach()
        n_train_batches += 1
        t_batch_tqdm.set_postfix(av_loss=sum_train_loss.item()/n_train_batches)
        if (time.time() - t_start) > train_time_limit: 
            print(f"Stopping training after >{train_time_limit} seconds.")
            break
        
    t_batch_tqdm.close()

    avg_train_loss = sum_train_loss.item() / n_train_batches
    t_train_done = time.time()

    # ── VALIDATION ---------------------------------------------------------
    action_value_model.eval()
    sum_val_loss = 0.0
    n_val_batches = 0
    with torch.no_grad():
        v_batch_tqdm = tqdm.tqdm(
            val_loader, 
            desc=f"val {epoch}" 
        )
        for v_batch in v_batch_tqdm:
            v_batch = move_to_device(v_batch, device)

            v_out = action_value_model(*extract_input(v_batch))
            v_outputs, v_targets = extract_output_target(v_batch, v_out, mean_val, std_val)
            val_loss = criterion(v_outputs, v_targets)

            if torch.isnan(val_loss) or torch.isinf(val_loss):
                raise ValueError(f"Loss contains NaN or Inf values")
            
            sum_val_loss += val_loss
            n_val_batches += 1
            v_batch_tqdm.set_postfix(av_loss=sum_val_loss.item()/n_val_batches)
            if (time.time() - t_train_done) > val_time_limit:
                print(f"Stopping validation after >{val_time_limit} seconds.")
                break

        v_batch_tqdm.close()

    avg_val_loss = sum_val_loss.item() / n_val_batches
    t_val_done = time.time()

    # Step the learning rate scheduler after each epoch
    scheduler.step(avg_val_loss)

    # ── CSV logging ---------------------------------------------------------
    train_time = t_train_done - t_start
    val_time   = t_val_done   - t_train_done
    total_time = t_val_done   - t_start


    checkpoint_data = {
        "epoch": epoch,
        "optim_state": optimizer.state_dict(),
        "model_state": action_value_model.state_dict()
    }
    torch.save(checkpoint_data, os.path.join(models_dir, f"{model_name}_{epoch:06}.pt"))
    
    csv_writer.writerow([
        epoch,
        avg_train_loss,
        avg_val_loss,
        train_time,
        val_time,
        total_time
    ])
    csv_file.flush()  # ensure data is written even if run aborts

    # ── console printout ----------------------------------------------------
    print(
        f"Epoch {epoch} | "
        f"Train {avg_train_loss:.4f} | "
        f"Val {avg_val_loss:.4f} | "
        f"LR {scheduler.get_last_lr()[0]:.6f}"
    )
# ── tidy‑up ----------------------------------------------------------------
csv_file.close()

train 0: 100%|██████████| 21/21 [03:49<00:00, 10.92s/it, av_loss=3.41]
val 0: 100%|██████████| 3/3 [00:22<00:00,  7.55s/it, av_loss=1.06]


Epoch 0 | Train 3.4065 | Val 1.0633 | LR 0.010000


train 1: 100%|██████████| 21/21 [03:49<00:00, 10.95s/it, av_loss=1.08]
val 1: 100%|██████████| 3/3 [00:22<00:00,  7.52s/it, av_loss=0.998]


Epoch 1 | Train 1.0781 | Val 0.9980 | LR 0.010000


train 2: 100%|██████████| 21/21 [03:49<00:00, 10.91s/it, av_loss=1.01]
val 2: 100%|██████████| 3/3 [00:22<00:00,  7.47s/it, av_loss=0.973]


Epoch 2 | Train 1.0145 | Val 0.9727 | LR 0.010000


train 3: 100%|██████████| 21/21 [03:49<00:00, 10.94s/it, av_loss=1]  
val 3: 100%|██████████| 3/3 [00:22<00:00,  7.49s/it, av_loss=0.971]


Epoch 3 | Train 0.9999 | Val 0.9707 | LR 0.010000


train 4: 100%|██████████| 21/21 [03:48<00:00, 10.90s/it, av_loss=0.993]
val 4: 100%|██████████| 3/3 [00:22<00:00,  7.46s/it, av_loss=0.96] 


Epoch 4 | Train 0.9933 | Val 0.9605 | LR 0.010000


train 5: 100%|██████████| 21/21 [03:49<00:00, 10.94s/it, av_loss=0.986]
val 5: 100%|██████████| 3/3 [00:22<00:00,  7.56s/it, av_loss=0.954]


Epoch 5 | Train 0.9860 | Val 0.9544 | LR 0.010000


train 6: 100%|██████████| 21/21 [03:49<00:00, 10.91s/it, av_loss=0.982]
val 6: 100%|██████████| 3/3 [00:22<00:00,  7.49s/it, av_loss=0.937]


Epoch 6 | Train 0.9823 | Val 0.9371 | LR 0.010000


train 7: 100%|██████████| 21/21 [03:53<00:00, 11.10s/it, av_loss=0.98] 
val 7: 100%|██████████| 3/3 [00:22<00:00,  7.51s/it, av_loss=0.944]


Epoch 7 | Train 0.9801 | Val 0.9438 | LR 0.010000


train 8: 100%|██████████| 21/21 [03:52<00:00, 11.06s/it, av_loss=0.976]
val 8: 100%|██████████| 3/3 [00:22<00:00,  7.58s/it, av_loss=0.949]


Epoch 8 | Train 0.9755 | Val 0.9493 | LR 0.010000


train 9: 100%|██████████| 21/21 [03:52<00:00, 11.05s/it, av_loss=0.971]
val 9: 100%|██████████| 3/3 [00:22<00:00,  7.54s/it, av_loss=0.952]


Epoch 9 | Train 0.9706 | Val 0.9521 | LR 0.010000


train 10: 100%|██████████| 21/21 [03:51<00:00, 11.03s/it, av_loss=0.969]
val 10: 100%|██████████| 3/3 [00:22<00:00,  7.58s/it, av_loss=0.944]


Epoch 10 | Train 0.9690 | Val 0.9444 | LR 0.010000


train 11: 100%|██████████| 21/21 [03:52<00:00, 11.06s/it, av_loss=0.963]
val 11: 100%|██████████| 3/3 [00:23<00:00,  7.69s/it, av_loss=0.945]


Epoch 11 | Train 0.9634 | Val 0.9454 | LR 0.010000


train 12: 100%|██████████| 21/21 [03:53<00:00, 11.10s/it, av_loss=0.959]
val 12: 100%|██████████| 3/3 [00:22<00:00,  7.56s/it, av_loss=0.934]


Epoch 12 | Train 0.9587 | Val 0.9338 | LR 0.010000


train 13: 100%|██████████| 21/21 [03:53<00:00, 11.11s/it, av_loss=0.955]
val 13: 100%|██████████| 3/3 [00:22<00:00,  7.47s/it, av_loss=0.918]


Epoch 13 | Train 0.9551 | Val 0.9180 | LR 0.010000


train 14: 100%|██████████| 21/21 [03:53<00:00, 11.11s/it, av_loss=0.95] 
val 14: 100%|██████████| 3/3 [00:22<00:00,  7.59s/it, av_loss=0.923]


Epoch 14 | Train 0.9502 | Val 0.9234 | LR 0.010000


train 15: 100%|██████████| 21/21 [03:50<00:00, 10.98s/it, av_loss=0.948]
val 15: 100%|██████████| 3/3 [00:22<00:00,  7.61s/it, av_loss=0.916]


Epoch 15 | Train 0.9475 | Val 0.9160 | LR 0.010000


train 16: 100%|██████████| 21/21 [03:50<00:00, 10.99s/it, av_loss=0.945]
val 16: 100%|██████████| 3/3 [00:22<00:00,  7.52s/it, av_loss=0.908]


Epoch 16 | Train 0.9452 | Val 0.9083 | LR 0.010000


train 17: 100%|██████████| 21/21 [03:51<00:00, 11.04s/it, av_loss=0.944]
val 17: 100%|██████████| 3/3 [00:22<00:00,  7.59s/it, av_loss=0.913]


Epoch 17 | Train 0.9443 | Val 0.9125 | LR 0.010000


train 18: 100%|██████████| 21/21 [03:50<00:00, 10.98s/it, av_loss=0.943]
val 18: 100%|██████████| 3/3 [00:22<00:00,  7.51s/it, av_loss=0.937]


Epoch 18 | Train 0.9430 | Val 0.9365 | LR 0.010000


train 19: 100%|██████████| 21/21 [03:50<00:00, 10.99s/it, av_loss=0.941]
val 19: 100%|██████████| 3/3 [00:23<00:00,  7.80s/it, av_loss=0.927]


Epoch 19 | Train 0.9413 | Val 0.9271 | LR 0.010000


train 20: 100%|██████████| 21/21 [03:50<00:00, 10.96s/it, av_loss=0.94]
val 20: 100%|██████████| 3/3 [00:22<00:00,  7.49s/it, av_loss=0.91] 


Epoch 20 | Train 0.9399 | Val 0.9100 | LR 0.010000


train 21: 100%|██████████| 21/21 [03:49<00:00, 10.92s/it, av_loss=0.939]
val 21: 100%|██████████| 3/3 [00:22<00:00,  7.42s/it, av_loss=0.92] 


Epoch 21 | Train 0.9395 | Val 0.9200 | LR 0.010000


train 22: 100%|██████████| 21/21 [03:49<00:00, 10.93s/it, av_loss=0.939]
val 22: 100%|██████████| 3/3 [00:22<00:00,  7.52s/it, av_loss=0.931]


Epoch 22 | Train 0.9388 | Val 0.9313 | LR 0.005000


train 23: 100%|██████████| 21/21 [03:48<00:00, 10.88s/it, av_loss=0.938]
val 23: 100%|██████████| 3/3 [00:22<00:00,  7.58s/it, av_loss=0.922]


Epoch 23 | Train 0.9378 | Val 0.9224 | LR 0.005000


train 24: 100%|██████████| 21/21 [03:49<00:00, 10.94s/it, av_loss=0.938]
val 24: 100%|██████████| 3/3 [00:22<00:00,  7.63s/it, av_loss=0.905]


Epoch 24 | Train 0.9379 | Val 0.9046 | LR 0.005000


train 25: 100%|██████████| 21/21 [03:49<00:00, 10.91s/it, av_loss=0.937]
val 25: 100%|██████████| 3/3 [00:22<00:00,  7.60s/it, av_loss=0.915]


Epoch 25 | Train 0.9371 | Val 0.9151 | LR 0.005000


train 26: 100%|██████████| 21/21 [03:52<00:00, 11.05s/it, av_loss=0.936]
val 26: 100%|██████████| 3/3 [00:22<00:00,  7.65s/it, av_loss=0.891]


Epoch 26 | Train 0.9360 | Val 0.8914 | LR 0.005000


train 27: 100%|██████████| 21/21 [03:49<00:00, 10.95s/it, av_loss=0.936]
val 27: 100%|██████████| 3/3 [00:22<00:00,  7.42s/it, av_loss=0.923]


Epoch 27 | Train 0.9357 | Val 0.9231 | LR 0.005000


train 28: 100%|██████████| 21/21 [03:47<00:00, 10.85s/it, av_loss=0.936]
val 28: 100%|██████████| 3/3 [00:22<00:00,  7.44s/it, av_loss=0.921]


Epoch 28 | Train 0.9356 | Val 0.9208 | LR 0.005000


train 29: 100%|██████████| 21/21 [03:47<00:00, 10.84s/it, av_loss=0.934]
val 29: 100%|██████████| 3/3 [00:22<00:00,  7.41s/it, av_loss=0.904]


Epoch 29 | Train 0.9340 | Val 0.9043 | LR 0.005000


train 30: 100%|██████████| 21/21 [03:47<00:00, 10.83s/it, av_loss=0.935]
val 30: 100%|██████████| 3/3 [00:22<00:00,  7.55s/it, av_loss=0.904]


Epoch 30 | Train 0.9350 | Val 0.9040 | LR 0.005000


train 31: 100%|██████████| 21/21 [03:47<00:00, 10.84s/it, av_loss=0.935]
val 31: 100%|██████████| 3/3 [00:22<00:00,  7.36s/it, av_loss=0.912]


Epoch 31 | Train 0.9346 | Val 0.9116 | LR 0.005000


train 32: 100%|██████████| 21/21 [03:47<00:00, 10.82s/it, av_loss=0.934]
val 32: 100%|██████████| 3/3 [00:22<00:00,  7.40s/it, av_loss=0.9]  


Epoch 32 | Train 0.9336 | Val 0.8999 | LR 0.002500


train 33: 100%|██████████| 21/21 [03:47<00:00, 10.84s/it, av_loss=0.934]
val 33: 100%|██████████| 3/3 [00:22<00:00,  7.46s/it, av_loss=0.905]


Epoch 33 | Train 0.9337 | Val 0.9051 | LR 0.002500


train 34: 100%|██████████| 21/21 [03:48<00:00, 10.88s/it, av_loss=0.934]
val 34: 100%|██████████| 3/3 [00:22<00:00,  7.43s/it, av_loss=0.918]


Epoch 34 | Train 0.9339 | Val 0.9183 | LR 0.002500


train 35: 100%|██████████| 21/21 [03:47<00:00, 10.82s/it, av_loss=0.932]
val 35: 100%|██████████| 3/3 [00:22<00:00,  7.43s/it, av_loss=0.902]


Epoch 35 | Train 0.9324 | Val 0.9020 | LR 0.002500


train 36: 100%|██████████| 21/21 [03:47<00:00, 10.84s/it, av_loss=0.933]
val 36: 100%|██████████| 3/3 [00:22<00:00,  7.43s/it, av_loss=0.917]


Epoch 36 | Train 0.9328 | Val 0.9168 | LR 0.002500


train 37: 100%|██████████| 21/21 [03:46<00:00, 10.80s/it, av_loss=0.934]
val 37: 100%|██████████| 3/3 [00:22<00:00,  7.37s/it, av_loss=0.915]


Epoch 37 | Train 0.9339 | Val 0.9146 | LR 0.002500


train 38: 100%|██████████| 21/21 [03:47<00:00, 10.84s/it, av_loss=0.934]
val 38: 100%|██████████| 3/3 [00:22<00:00,  7.45s/it, av_loss=0.913]


Epoch 38 | Train 0.9341 | Val 0.9125 | LR 0.001250


train 39: 100%|██████████| 21/21 [03:47<00:00, 10.82s/it, av_loss=0.933]
val 39: 100%|██████████| 3/3 [00:22<00:00,  7.46s/it, av_loss=0.91] 


Epoch 39 | Train 0.9334 | Val 0.9103 | LR 0.001250


train 40: 100%|██████████| 21/21 [03:47<00:00, 10.84s/it, av_loss=0.932]
val 40: 100%|██████████| 3/3 [00:22<00:00,  7.39s/it, av_loss=0.896]


Epoch 40 | Train 0.9325 | Val 0.8963 | LR 0.001250


train 41: 100%|██████████| 21/21 [03:46<00:00, 10.80s/it, av_loss=0.934]
val 41: 100%|██████████| 3/3 [00:22<00:00,  7.34s/it, av_loss=0.922]


Epoch 41 | Train 0.9335 | Val 0.9217 | LR 0.001250


train 42: 100%|██████████| 21/21 [03:51<00:00, 11.05s/it, av_loss=0.933]
val 42: 100%|██████████| 3/3 [00:22<00:00,  7.65s/it, av_loss=0.909]


Epoch 42 | Train 0.9328 | Val 0.9092 | LR 0.001250


train 43: 100%|██████████| 21/21 [03:55<00:00, 11.21s/it, av_loss=0.933]
val 43: 100%|██████████| 3/3 [00:23<00:00,  7.71s/it, av_loss=0.901]


Epoch 43 | Train 0.9330 | Val 0.9007 | LR 0.001250


train 44: 100%|██████████| 21/21 [03:55<00:00, 11.23s/it, av_loss=0.932]
val 44: 100%|██████████| 3/3 [00:23<00:00,  7.67s/it, av_loss=0.896]


Epoch 44 | Train 0.9317 | Val 0.8959 | LR 0.000625


train 45: 100%|██████████| 21/21 [03:55<00:00, 11.19s/it, av_loss=0.933]
val 45: 100%|██████████| 3/3 [00:23<00:00,  7.75s/it, av_loss=0.911]


Epoch 45 | Train 0.9329 | Val 0.9106 | LR 0.000625


train 46: 100%|██████████| 21/21 [04:03<00:00, 11.60s/it, av_loss=0.931]
val 46: 100%|██████████| 3/3 [00:24<00:00,  8.31s/it, av_loss=0.906]


Epoch 46 | Train 0.9312 | Val 0.9055 | LR 0.000625


train 47: 100%|██████████| 21/21 [04:09<00:00, 11.89s/it, av_loss=0.931]
val 47: 100%|██████████| 3/3 [00:24<00:00,  8.30s/it, av_loss=0.903]


Epoch 47 | Train 0.9315 | Val 0.9033 | LR 0.000625


train 48: 100%|██████████| 21/21 [04:06<00:00, 11.74s/it, av_loss=0.933]
val 48: 100%|██████████| 3/3 [00:24<00:00,  8.14s/it, av_loss=0.921]


Epoch 48 | Train 0.9334 | Val 0.9207 | LR 0.000625


train 49: 100%|██████████| 21/21 [04:07<00:00, 11.80s/it, av_loss=0.931]
val 49: 100%|██████████| 3/3 [00:24<00:00,  8.12s/it, av_loss=0.899]


Epoch 49 | Train 0.9312 | Val 0.8988 | LR 0.000625


train 50: 100%|██████████| 21/21 [04:08<00:00, 11.85s/it, av_loss=0.933]
val 50: 100%|██████████| 3/3 [00:25<00:00,  8.50s/it, av_loss=0.908]


Epoch 50 | Train 0.9328 | Val 0.9078 | LR 0.000313


train 51: 100%|██████████| 21/21 [04:02<00:00, 11.56s/it, av_loss=0.933]
val 51: 100%|██████████| 3/3 [00:23<00:00,  7.82s/it, av_loss=0.91] 


Epoch 51 | Train 0.9332 | Val 0.9105 | LR 0.000313


train 52: 100%|██████████| 21/21 [04:03<00:00, 11.61s/it, av_loss=0.932]
val 52: 100%|██████████| 3/3 [00:23<00:00,  7.86s/it, av_loss=0.906]


Epoch 52 | Train 0.9320 | Val 0.9063 | LR 0.000313


train 53: 100%|██████████| 21/21 [04:02<00:00, 11.57s/it, av_loss=0.933]
val 53: 100%|██████████| 3/3 [00:23<00:00,  7.87s/it, av_loss=0.9]  


Epoch 53 | Train 0.9328 | Val 0.8998 | LR 0.000313


train 54: 100%|██████████| 21/21 [04:00<00:00, 11.45s/it, av_loss=0.931]
val 54: 100%|██████████| 3/3 [00:23<00:00,  7.91s/it, av_loss=0.899]


Epoch 54 | Train 0.9313 | Val 0.8988 | LR 0.000313


train 55: 100%|██████████| 21/21 [04:01<00:00, 11.51s/it, av_loss=0.933]
val 55: 100%|██████████| 3/3 [00:23<00:00,  7.88s/it, av_loss=0.906]


Epoch 55 | Train 0.9326 | Val 0.9060 | LR 0.000313


train 56: 100%|██████████| 21/21 [04:06<00:00, 11.76s/it, av_loss=0.932]
val 56: 100%|██████████| 3/3 [00:25<00:00,  8.42s/it, av_loss=0.907]


Epoch 56 | Train 0.9321 | Val 0.9068 | LR 0.000156


train 57: 100%|██████████| 21/21 [04:22<00:00, 12.52s/it, av_loss=0.932]
val 57: 100%|██████████| 3/3 [00:24<00:00,  8.28s/it, av_loss=0.892]


Epoch 57 | Train 0.9324 | Val 0.8915 | LR 0.000156


train 58: 100%|██████████| 21/21 [04:12<00:00, 12.03s/it, av_loss=0.933]
val 58: 100%|██████████| 3/3 [00:25<00:00,  8.37s/it, av_loss=0.928]


Epoch 58 | Train 0.9327 | Val 0.9276 | LR 0.000156


train 59: 100%|██████████| 21/21 [04:11<00:00, 12.00s/it, av_loss=0.932]
val 59: 100%|██████████| 3/3 [00:24<00:00,  8.01s/it, av_loss=0.884]


Epoch 59 | Train 0.9317 | Val 0.8837 | LR 0.000156


train 60: 100%|██████████| 21/21 [04:05<00:00, 11.67s/it, av_loss=0.932]
val 60: 100%|██████████| 3/3 [00:24<00:00,  8.04s/it, av_loss=0.907]


Epoch 60 | Train 0.9316 | Val 0.9070 | LR 0.000156


train 61: 100%|██████████| 21/21 [04:04<00:00, 11.65s/it, av_loss=0.932]
val 61: 100%|██████████| 3/3 [00:24<00:00,  8.05s/it, av_loss=0.902]


Epoch 61 | Train 0.9319 | Val 0.9017 | LR 0.000156


train 62: 100%|██████████| 21/21 [04:11<00:00, 11.97s/it, av_loss=0.932]
val 62: 100%|██████████| 3/3 [00:24<00:00,  8.20s/it, av_loss=0.911]


Epoch 62 | Train 0.9321 | Val 0.9114 | LR 0.000156


train 63: 100%|██████████| 21/21 [04:14<00:00, 12.10s/it, av_loss=0.932]
val 63: 100%|██████████| 3/3 [00:24<00:00,  8.12s/it, av_loss=0.914]


Epoch 63 | Train 0.9315 | Val 0.9137 | LR 0.000156


train 64: 100%|██████████| 21/21 [04:10<00:00, 11.91s/it, av_loss=0.932]
val 64: 100%|██████████| 3/3 [00:23<00:00,  7.74s/it, av_loss=0.922]


Epoch 64 | Train 0.9318 | Val 0.9224 | LR 0.000156


train 65: 100%|██████████| 21/21 [03:44<00:00, 10.70s/it, av_loss=0.932]
val 65: 100%|██████████| 3/3 [00:21<00:00,  7.25s/it, av_loss=0.904]


Epoch 65 | Train 0.9316 | Val 0.9036 | LR 0.000078


train 66: 100%|██████████| 21/21 [03:41<00:00, 10.54s/it, av_loss=0.932]
val 66: 100%|██████████| 3/3 [00:21<00:00,  7.29s/it, av_loss=0.891]


Epoch 66 | Train 0.9317 | Val 0.8913 | LR 0.000078


train 67: 100%|██████████| 21/21 [03:41<00:00, 10.56s/it, av_loss=0.932]
val 67: 100%|██████████| 3/3 [00:23<00:00,  7.84s/it, av_loss=0.908]


Epoch 67 | Train 0.9320 | Val 0.9080 | LR 0.000078


train 68: 100%|██████████| 21/21 [03:55<00:00, 11.21s/it, av_loss=0.932]
val 68: 100%|██████████| 3/3 [00:22<00:00,  7.55s/it, av_loss=0.904]


Epoch 68 | Train 0.9319 | Val 0.9041 | LR 0.000078


train 69: 100%|██████████| 21/21 [03:51<00:00, 11.01s/it, av_loss=0.932]
val 69: 100%|██████████| 3/3 [00:22<00:00,  7.55s/it, av_loss=0.918]


Epoch 69 | Train 0.9321 | Val 0.9178 | LR 0.000078


train 70: 100%|██████████| 21/21 [03:50<00:00, 10.97s/it, av_loss=0.932]
val 70: 100%|██████████| 3/3 [00:21<00:00,  7.28s/it, av_loss=0.904]


Epoch 70 | Train 0.9322 | Val 0.9040 | LR 0.000078


train 71: 100%|██████████| 21/21 [03:50<00:00, 10.99s/it, av_loss=0.931]
val 71: 100%|██████████| 3/3 [00:22<00:00,  7.54s/it, av_loss=0.902]


Epoch 71 | Train 0.9305 | Val 0.9020 | LR 0.000039


train 72: 100%|██████████| 21/21 [03:53<00:00, 11.14s/it, av_loss=0.932]
val 72: 100%|██████████| 3/3 [00:22<00:00,  7.53s/it, av_loss=0.917]


Epoch 72 | Train 0.9316 | Val 0.9172 | LR 0.000039


train 73: 100%|██████████| 21/21 [03:51<00:00, 11.05s/it, av_loss=0.931]
val 73: 100%|██████████| 3/3 [00:22<00:00,  7.52s/it, av_loss=0.904]


Epoch 73 | Train 0.9314 | Val 0.9043 | LR 0.000039


train 74: 100%|██████████| 21/21 [04:00<00:00, 11.46s/it, av_loss=0.932]
val 74: 100%|██████████| 3/3 [00:23<00:00,  7.72s/it, av_loss=0.9]  


Epoch 74 | Train 0.9322 | Val 0.8999 | LR 0.000039


train 75: 100%|██████████| 21/21 [03:59<00:00, 11.41s/it, av_loss=0.932]
val 75: 100%|██████████| 3/3 [00:24<00:00,  8.17s/it, av_loss=0.92] 


Epoch 75 | Train 0.9321 | Val 0.9197 | LR 0.000039


train 76: 100%|██████████| 21/21 [03:59<00:00, 11.40s/it, av_loss=0.932]
val 76: 100%|██████████| 3/3 [00:24<00:00,  8.32s/it, av_loss=0.904]


Epoch 76 | Train 0.9324 | Val 0.9042 | LR 0.000039


train 77: 100%|██████████| 21/21 [03:54<00:00, 11.19s/it, av_loss=0.932]
val 77: 100%|██████████| 3/3 [00:22<00:00,  7.55s/it, av_loss=0.903]


Epoch 77 | Train 0.9321 | Val 0.9031 | LR 0.000020


train 78: 100%|██████████| 21/21 [03:50<00:00, 11.00s/it, av_loss=0.932]
val 78: 100%|██████████| 3/3 [00:23<00:00,  7.71s/it, av_loss=0.908]


Epoch 78 | Train 0.9319 | Val 0.9077 | LR 0.000020


train 79: 100%|██████████| 21/21 [03:52<00:00, 11.06s/it, av_loss=0.932]
val 79: 100%|██████████| 3/3 [00:23<00:00,  7.69s/it, av_loss=0.904]


Epoch 79 | Train 0.9316 | Val 0.9038 | LR 0.000020


train 80: 100%|██████████| 21/21 [03:54<00:00, 11.15s/it, av_loss=0.931]
val 80: 100%|██████████| 3/3 [00:25<00:00,  8.62s/it, av_loss=0.9]  


Epoch 80 | Train 0.9313 | Val 0.9004 | LR 0.000020


train 81: 100%|██████████| 21/21 [04:01<00:00, 11.49s/it, av_loss=0.931]
val 81: 100%|██████████| 3/3 [00:23<00:00,  7.74s/it, av_loss=0.902]


Epoch 81 | Train 0.9310 | Val 0.9021 | LR 0.000020


train 82: 100%|██████████| 21/21 [03:56<00:00, 11.24s/it, av_loss=0.932]
val 82: 100%|██████████| 3/3 [00:23<00:00,  7.80s/it, av_loss=0.916]


Epoch 82 | Train 0.9318 | Val 0.9155 | LR 0.000020


train 83: 100%|██████████| 21/21 [04:01<00:00, 11.50s/it, av_loss=0.933]
val 83: 100%|██████████| 3/3 [00:23<00:00,  7.81s/it, av_loss=0.908]


Epoch 83 | Train 0.9328 | Val 0.9079 | LR 0.000010


train 84: 100%|██████████| 21/21 [03:59<00:00, 11.40s/it, av_loss=0.932]
val 84: 100%|██████████| 3/3 [00:22<00:00,  7.58s/it, av_loss=0.892]


Epoch 84 | Train 0.9315 | Val 0.8915 | LR 0.000010


train 85: 100%|██████████| 21/21 [03:54<00:00, 11.18s/it, av_loss=0.933]
val 85: 100%|██████████| 3/3 [00:22<00:00,  7.66s/it, av_loss=0.895]


Epoch 85 | Train 0.9326 | Val 0.8949 | LR 0.000010


train 86: 100%|██████████| 21/21 [03:59<00:00, 11.43s/it, av_loss=0.932]
val 86: 100%|██████████| 3/3 [00:22<00:00,  7.62s/it, av_loss=0.904]


Epoch 86 | Train 0.9317 | Val 0.9037 | LR 0.000010


train 87: 100%|██████████| 21/21 [03:56<00:00, 11.26s/it, av_loss=0.932]
val 87: 100%|██████████| 3/3 [00:23<00:00,  7.93s/it, av_loss=0.898]


Epoch 87 | Train 0.9316 | Val 0.8983 | LR 0.000010


train 88: 100%|██████████| 21/21 [04:02<00:00, 11.54s/it, av_loss=0.932]
val 88: 100%|██████████| 3/3 [00:22<00:00,  7.61s/it, av_loss=0.913]


Epoch 88 | Train 0.9322 | Val 0.9134 | LR 0.000010


train 89: 100%|██████████| 21/21 [03:54<00:00, 11.18s/it, av_loss=0.932]
val 89: 100%|██████████| 3/3 [00:22<00:00,  7.57s/it, av_loss=0.901]


Epoch 89 | Train 0.9317 | Val 0.9013 | LR 0.000005


train 90: 100%|██████████| 21/21 [03:56<00:00, 11.26s/it, av_loss=0.933]
val 90: 100%|██████████| 3/3 [00:25<00:00,  8.46s/it, av_loss=0.911]


Epoch 90 | Train 0.9328 | Val 0.9107 | LR 0.000005


train 91:   0%|          | 0/21 [00:03<?, ?it/s]


KeyboardInterrupt: 

In [None]:
type(float(torch.tensor([1])))