In [1]:
# %matplotlib ipympl

import os
import glob
import torch
import tqdm
import time
import csv
from torch import nn
from torch.utils.data import DataLoader, ConcatDataset, random_split

from learning.data.process_utils import move_to_device
from learning.data.av2_utils import extract_input, collate_av2_data
from learning.data.av2_utils import extract_output_target
from learning.data.av2_dataset import ActionValueV2Dataset
from learning.model.actionvalue_v2 import ActionValueModelV2

In [2]:
# Set a fixed seed for reproducibility
seed = 21
torch.manual_seed(seed)
generator = torch.Generator().manual_seed(seed)

history_folder = "server/history/"
history_files = sorted(glob.glob(os.path.join(history_folder, "history_*.json")))

datasets = [ActionValueV2Dataset(f) for f in history_files]
concatenated_dataset = ConcatDataset(datasets)

train_ds, val_ds, test_ds = random_split(
    concatenated_dataset,
    [0.8, 0.1, 0.1],
    generator=generator
)

In [3]:
train_loader = DataLoader(train_ds, 10000, pin_memory=True, shuffle=True, 
                         collate_fn=collate_av2_data, generator=generator)
val_loader = DataLoader(val_ds, 20000, pin_memory=True, shuffle=True, 
                       collate_fn=collate_av2_data, generator=generator)

print("train_ds", len(train_ds))
print("val_ds", len(val_ds))
print("test_ds", len(test_ds))

train_ds 578341
val_ds 72292
test_ds 72292


In [4]:
batch = next(iter(train_loader))

In [5]:
model_state = None
optim_state = None
latest_epoch = -1

freeze = False

model_name = "actionvalue_v2" + ("_no_freeze" if not freeze else "")
models_dir = f"learning/{model_name}_checkpoints" 

os.makedirs(models_dir, exist_ok=True)

# Find the latest checkpoint file (with highest epoch number)
checkpoint_files = glob.glob(os.path.join(models_dir, f"{model_name}_*.pt"))
if checkpoint_files:
    # Extract epoch numbers from filenames
    epoch_nums = [int(f.split("_")[-1].split(".")[0]) for f in checkpoint_files]
    latest_epoch = max(epoch_nums)
    latest_checkpoint = os.path.join(models_dir, f"{model_name}_{latest_epoch:06}.pt")
    print(f"Loading latest checkpoint: {latest_checkpoint} (epoch {latest_epoch})")
    checkpoint = torch.load(latest_checkpoint, map_location=torch.device('cpu'))
    model_state = checkpoint["model_state"]
    optim_state = checkpoint["optim_state"]
else:
    print("No valid checkpoint files found")
    
start_epoch = latest_epoch + 1

Loading latest checkpoint: learning/actionvalue_v2_no_freeze_checkpoints/actionvalue_v2_no_freeze_000164.pt (epoch 164)


In [6]:
# Create the model
action_value_model = ActionValueModelV2()

In [7]:
init_from_old_model = (model_state is None)

if init_from_old_model:
    print("⚠️⚠️⚠️ Warning, model is being initialized from an old checkpoint")
    
    checkpoint = torch.load(
        "learning/old_models/nodevalue_full.pt",
        map_location=torch.device('cpu')
    )
    model_state = checkpoint["model_state"]

    # update old model state to match the new one
    mismatch = action_value_model.load_state_dict(model_state, strict=False)
    model_state = action_value_model.state_dict()

    missing_keys, unexpected_keys = mismatch

    for mk in missing_keys:
        print("missing key", mk)

    for uk in unexpected_keys:
        print("unexpected key", uk)

In [8]:
# Load the model state if available
if model_state is not None:
    action_value_model.load_state_dict(model_state)
    print(f"Model state loaded from checkpoint.")

if freeze:
    # Only make attack_value_fn and end_turn_value_fn parameters trainable
    for name, param in action_value_model.named_parameters():
        if name.startswith('attack_value_fn') or name.startswith('end_turn_value_fn'):
            param.requires_grad = True
        else:
            param.requires_grad = False
            
        
trainable_params = sum(p.numel() for p in action_value_model.parameters() if p.requires_grad)
total_params = sum(p.numel() for p in action_value_model.parameters())
print(f"Model created. Total parameters: {total_params:,}")
print(f"Trainable parameters: {trainable_params:,} ({trainable_params/total_params*100:.2f}%)")
print(f"Frozen parameters: {total_params - trainable_params:,}")
print("Only attack_value_fn and end_turn_value_fn modules are trainable.")

Model state loaded from checkpoint.
Model created. Total parameters: 8,363
Trainable parameters: 4,365 (52.19%)
Frozen parameters: 3,998
Only attack_value_fn and end_turn_value_fn modules are trainable.


In [9]:
# Print details about which specific parameters are trainable
print("Detailed parameter trainable status:")
for name, param in action_value_model.named_parameters():
    if param.requires_grad:
        print(f"✓ TRAINABLE: {name}, shape={param.shape}")
    else:
        # For frozen parameters, just count by module
        module = name.split('.')[0]
        print(f"✗ FROZEN: {name}, shape={param.shape}")

Detailed parameter trainable status:
✗ FROZEN: in_states_mean, shape=torch.Size([19])
✗ FROZEN: in_states_std, shape=torch.Size([19])
✓ TRAINABLE: gat_layers.0.passthrough_coef, shape=torch.Size([])
✓ TRAINABLE: gat_layers.0.message_heads.0.transform_fn.weight, shape=torch.Size([20, 19])
✓ TRAINABLE: gat_layers.0.message_heads.1.transform_fn.weight, shape=torch.Size([20, 19])
✓ TRAINABLE: gat_layers.0.attention_coef_heads.0.transform_fn.weight, shape=torch.Size([1, 40])
✓ TRAINABLE: gat_layers.0.attention_coef_heads.1.transform_fn.weight, shape=torch.Size([1, 40])
✗ FROZEN: gat_layers.0.in_states_pass.weight, shape=torch.Size([40, 19])
✓ TRAINABLE: gat_layers.1.passthrough_coef, shape=torch.Size([])
✓ TRAINABLE: gat_layers.1.message_heads.0.transform_fn.weight, shape=torch.Size([20, 40])
✓ TRAINABLE: gat_layers.1.message_heads.1.transform_fn.weight, shape=torch.Size([20, 40])
✓ TRAINABLE: gat_layers.1.attention_coef_heads.0.transform_fn.weight, shape=torch.Size([1, 40])
✓ TRAINABLE: ga

In [None]:
# ── CSV set‑up: log one row per epoch ───────────────────────────────────────
csv_path = f"{model_name}_epoch_metrics.csv"
epoch_log_file_exists = os.path.isfile(csv_path)
if epoch_log_file_exists and os.path.getsize(csv_path) == 0:
    os.remove(csv_path)
    epoch_log_file_exists = False

csv_file = open(csv_path, "a", newline="")  # Open in append mode
csv_writer = csv.writer(csv_file)

# Only write header if file doesn't exist yet
if not epoch_log_file_exists:
    csv_writer.writerow([
        "epoch",           # 0‑based epoch index
        "train_loss",      # average training loss for the epoch
        "val_loss",        # average validation loss for the epoch
        "train_time_sec",  # seconds spent in training phase
        "val_time_sec",    # seconds spent in validation phase
        "total_time_sec"   # train + val
    ])

# ── model / optimizer prep ─────────────────────────────────────────────────

n_epochs   = 10000
device     = torch.device("cuda", 0)
train_time_limit = 100 * 60
val_time_limit = train_time_limit / 5
criterion  = nn.MSELoss()

action_value_model = action_value_model.to(device)

optimizer  = torch.optim.Adam(
    [p for p in action_value_model.parameters() if p.requires_grad],
    lr=1e-3
)
reset_optimizer = True
if not reset_optimizer and optim_state is not None:
    optimizer.load_state_dict(optim_state)

scheduler  = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, mode="min",
    factor=0.25,
    patience=10,
    threshold=0.0001,
    min_lr=1e-8
)

mean_val = action_value_model.mean_val
std_val = action_value_model.std_val

In [11]:
# ── INITIAL (untrained) LOSS EVALUATION ───────────────────────────────────────
if not epoch_log_file_exists:
    action_value_model.eval()
    with torch.no_grad():
        t_start = time.time()

        # avg loss on training set *without* gradient tracking
        init_train_sum, init_train_batches = 0.0, 0
        for b in tqdm.tqdm(train_loader, desc="init‑train"):
            b = move_to_device(b, device)
            out = action_value_model(*extract_input(b))
            o, t = extract_output_target(b, out, mean_val, std_val)
            init_train_sum += criterion(o, t).item()
            init_train_batches += 1
            if (time.time() - t_start) > train_time_limit: 
                print(f"Stopping training after >{train_time_limit} seconds.")
                break
        init_train_loss = init_train_sum / init_train_batches
        t_train_done = time.time()

        # avg loss on validation set
        init_val_sum, init_val_batches = 0.0, 0
        for vb in tqdm.tqdm(val_loader, desc="init‑val"):
            vb = move_to_device(vb, device)
            vout = action_value_model(*extract_input(vb))
            vo, vt = extract_output_target(
                vb, vout, mean_val, std_val
            )
            init_val_sum += criterion(vo, vt).item()
            init_val_batches += 1
            if (time.time() - t_train_done) > val_time_limit:
                print(f"Stopping validation after >{val_time_limit} seconds.")
                break
        init_val_loss = init_val_sum / init_val_batches
        t_val_done = time.time()

    # times
    init_train_time = t_train_done - t_start
    init_val_time   = t_val_done - t_train_done
    init_total_time = t_val_done - t_start

    # write initial row (epoch = None)
    csv_writer.writerow([
        -1,
        init_train_loss,
        init_val_loss,
        init_train_time,
        init_val_time,
        init_total_time
    ])
    csv_file.flush()
    print(
        f"Initial (-1) | "
        f"Train {init_train_loss:.4f} | "
        f"Val {init_val_loss:.4f} | "
        f"Time {init_total_time:.1f}s (T {init_train_time:.1f}s | "
        f"V {init_val_time:.1f}s)"
    )

In [12]:
scheduler.get_last_lr()

[0.001]

In [13]:
for epoch in range(start_epoch, n_epochs):
    t_start = time.time()

    # ── TRAINING -----------------------------------------------------------
    action_value_model.train()
    sum_train_loss = torch.tensor(0.0, device=device)
    n_train_batches = 0

    t_batch_tqdm = tqdm.tqdm(
        train_loader, 
        desc=f"train {epoch}" 
    )
    for t_batch in t_batch_tqdm:
        t_batch = move_to_device(t_batch, device)

        optimizer.zero_grad()

        model_out = action_value_model(*extract_input(t_batch))
        outputs, targets = extract_output_target(t_batch, model_out, mean_val, std_val)
        loss = criterion(outputs, targets)
        
        loss.backward()
        optimizer.step()
    
        if torch.isnan(loss) or torch.isinf(loss):
            raise ValueError(f"Loss contains NaN or Inf values")
        
        sum_train_loss += loss.detach()
        n_train_batches += 1
        t_batch_tqdm.set_postfix(av_loss=sum_train_loss.item()/n_train_batches)
        if (time.time() - t_start) > train_time_limit: 
            print(f"Stopping training after >{train_time_limit} seconds.")
            break
        
    t_batch_tqdm.close()

    avg_train_loss = sum_train_loss.item() / n_train_batches
    t_train_done = time.time()

    # ── VALIDATION ---------------------------------------------------------
    action_value_model.eval()
    sum_val_loss = 0.0
    n_val_batches = 0
    with torch.no_grad():
        v_batch_tqdm = tqdm.tqdm(
            val_loader, 
            desc=f"val {epoch}" 
        )
        for v_batch in v_batch_tqdm:
            v_batch = move_to_device(v_batch, device)

            v_out = action_value_model(*extract_input(v_batch))
            v_outputs, v_targets = extract_output_target(v_batch, v_out, mean_val, std_val)
            val_loss = criterion(v_outputs, v_targets)

            if torch.isnan(val_loss) or torch.isinf(val_loss):
                raise ValueError(f"Loss contains NaN or Inf values")
            
            sum_val_loss += val_loss
            n_val_batches += 1
            v_batch_tqdm.set_postfix(av_loss=sum_val_loss.item()/n_val_batches)
            if (time.time() - t_train_done) > val_time_limit:
                print(f"Stopping validation after >{val_time_limit} seconds.")
                break

        v_batch_tqdm.close()

    avg_val_loss = sum_val_loss.item() / n_val_batches
    t_val_done = time.time()

    # Step the learning rate scheduler after each epoch
    scheduler.step(avg_val_loss)

    # ── CSV logging ---------------------------------------------------------
    train_time = t_train_done - t_start
    val_time   = t_val_done   - t_train_done
    total_time = t_val_done   - t_start


    checkpoint_data = {
        "epoch": epoch,
        "optim_state": optimizer.state_dict(),
        "model_state": action_value_model.state_dict()
    }
    torch.save(checkpoint_data, os.path.join(models_dir, f"{model_name}_{epoch:06}.pt"))
    
    csv_writer.writerow([
        epoch,
        avg_train_loss,
        avg_val_loss,
        train_time,
        val_time,
        total_time
    ])
    csv_file.flush()  # ensure data is written even if run aborts

    # ── console printout ----------------------------------------------------
    print(
        f"Epoch {epoch} | "
        f"Train {avg_train_loss:.4f} | "
        f"Val {avg_val_loss:.4f} | "
        f"LR {scheduler.get_last_lr()[0]}"
    )
# ── tidy‑up ----------------------------------------------------------------
csv_file.close()

  max_per_index.index_reduce_(
train 165:  52%|█████▏    | 30/58 [10:15<09:34, 20.51s/it, av_loss=0.753]


Stopping training after >600 seconds.


val 165: 100%|██████████| 4/4 [01:52<00:00, 28.22s/it, av_loss=0.697]


Epoch 165 | Train 0.7530 | Val 0.6975 | LR 0.001


train 166:  53%|█████▎    | 31/58 [10:16<08:56, 19.87s/it, av_loss=0.752]


Stopping training after >600 seconds.


val 166:  75%|███████▌  | 3/4 [02:00<00:40, 40.21s/it, av_loss=0.694]


Stopping validation after >120.0 seconds.
Epoch 166 | Train 0.7516 | Val 0.6935 | LR 0.001


train 167:  55%|█████▌    | 32/58 [10:05<08:12, 18.93s/it, av_loss=0.752]


Stopping training after >600 seconds.


val 167: 100%|██████████| 4/4 [01:55<00:00, 28.88s/it, av_loss=0.697]


Epoch 167 | Train 0.7525 | Val 0.6971 | LR 0.001


train 168:  55%|█████▌    | 32/58 [10:18<08:22, 19.32s/it, av_loss=0.753]


Stopping training after >600 seconds.


val 168: 100%|██████████| 4/4 [01:51<00:00, 27.79s/it, av_loss=0.696]


Epoch 168 | Train 0.7533 | Val 0.6960 | LR 0.001


train 169:  53%|█████▎    | 31/58 [10:10<08:51, 19.69s/it, av_loss=0.775]


Stopping training after >600 seconds.


val 169: 100%|██████████| 4/4 [01:52<00:00, 28.06s/it, av_loss=0.695]


Epoch 169 | Train 0.7752 | Val 0.6946 | LR 0.001


train 170:  17%|█▋        | 10/58 [03:27<16:33, 20.70s/it, av_loss=0.755]


KeyboardInterrupt: 

In [None]:
outputs, targets