
#
# Transformer-based m-Height Predictor



## 0) Setup & Imports

In [None]:
# !pip install --quiet torch numpy scipy tqdm

import os, sys, math, pickle, random
import numpy as np
from typing import Any, List, Optional

import torch
import torch.nn as nn
from torch.optim import AdamW
from torch.utils.data import DataLoader, Dataset
from torch.optim.lr_scheduler import OneCycleLR
from tqdm.auto import tqdm

print("Python:", sys.version)
print("NumPy:", np.__version__)
print("Torch:", torch.__version__)
print("CUDA available:", torch.cuda.is_available())
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")


## 1) Configuration

In [None]:
# Reproducibility
SEED = 1337
random.seed(SEED); np.random.seed(SEED); torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)

# Problem constants
N_FIXED = 9
K_VALUES = [4,5,6]
MAX_K = max(K_VALUES)          # pad rows to 6
MAX_PARITY = N_FIXED - min(K_VALUES)  # = 5

# Model
DMODEL = 128
NHEAD  = 4
ENC_FF = 384
ENC_LAYERS = 3
DROPOUT = 0.10

# Training
BATCH_SIZE = 128
EPOCHS = 120
VAL_SPLIT = 0.20
MAX_LR = 2e-3
WEIGHT_DECAY = 2e-4   # L2 via AdamW (decoupled)

# L1 regularization - Applied to weights only (no bias/LayerNorm).
L1_COEFF = 1e-6       # 0 (off), 1e-6, 3e-6, 1e-5

PATIENCE = 20         # early stopping

FLIP_PERM_FACTOR = 0

# ---- File locations ----
CAND_IN = [
    "train_n_k_m_P.pkl", "/mnt/data/train_n_k_m_P.pkl", "/content/train_n_k_m_P.pkl"
]
CAND_OUT = [
    "train_mHeights.pkl", "/mnt/data/train_mHeights.pkl", "/content/train_mHeights.pkl"
]

# Augmented merged dataset produced offline (preferred when available)
AUG_IN_MANUAL_PATH  = None  # e.g., "/content/drive/MyDrive/CSCE636/merged_n_k_m_P.pkl"
AUG_OUT_MANUAL_PATH = None  # e.g., "/content/drive/MyDrive/CSCE636/merged_mHeights.pkl"
CAND_AUG_IN = [
    "/content/combined_final_n_k_m_P.pkl", "merged_n_k_m_P.pkl", "/mnt/data/merged_n_k_m_P.pkl", "/content/combined_ALL_n_k_m_P_exact.pkl", "/content/merged_n_k_m_P.pkl",
    "/content/drive/MyDrive/CSCE636/merged_n_k_m_P.pkl", "/content/combined_ALL_n_k_m_P.pkl"
]
CAND_AUG_OUT = [
    "/content/combined_final_mHeights.pkl", "merged_mHeights.pkl", "/content/combined_ALL_mHeights_exact.pkl", "/mnt/data/merged_mHeights.pkl", "/content/merged_mHeights.pkl",
    "/content/drive/MyDrive/CSCE636/merged_mHeights.pkl", "/content/combined_ALL_mHeights.pkl"
]

def find_first(paths: List[str]) -> Optional[str]:
    for p in paths:
        if os.path.exists(p):
            return p
    return None


## 2) Loaders & Preprocessing

In [None]:
# Shim for certain old pickles
import numpy as _np, sys as _sys
_sys.modules.setdefault("numpy._core", _np.core)

# -------------------------------
# 1) Parsing + IO helpers
# -------------------------------

def parse_inputs(obj: Any):
    """
    Return lists: n_list, k_list, m_list, P_list
    Handles several legacy formats used in the course.
    """
    # Dict-of-arrays format
    if isinstance(obj, dict) and all(k in obj for k in ("n", "k", "m", "P")):
        return list(obj["n"]), list(obj["k"]), list(obj["m"]), list(obj["P"])

    # Tuple (n_list, k_list, m_list, P_list)
    if isinstance(obj, (list, tuple)) and len(obj) == 4 and all(
        isinstance(x, (list, tuple)) for x in obj[:3]
    ):
        n_list, k_list, m_list, P_list = obj
        return list(n_list), list(k_list), list(m_list), list(P_list)

    # List of tuples [ (n,k,m,P), ... ]
    if isinstance(obj, list) and obj and isinstance(obj[0], tuple):
        return (
            [int(r[0]) for r in obj],
            [int(r[1]) for r in obj],
            [int(r[2]) for r in obj],
            [r[3]       for r in obj],
        )

    # List of lists [ [n,k,m,P], ... ]
    if isinstance(obj, list) and obj and isinstance(obj[0], list):
        if not all(isinstance(r, list) and len(r) == 4 for r in obj):
            raise ValueError("List-of-lists present but inner length != 4 [n,k,m,P].")
        return (
            [int(r[0]) for r in obj],
            [int(r[1]) for r in obj],
            [int(r[2]) for r in obj],
            [r[3]       for r in obj],
        )

    head = obj[0] if isinstance(obj, (list, tuple)) and obj else None
    raise ValueError(
        f"Unrecognized input format: type(obj)={type(obj)}, "
        f"type(first)={type(head) if head is not None else None}"
    )


def load_pair(in_path: str, out_path: str):
    """Load (n,k,m,P) inputs and m-height targets from two pickles."""
    with open(in_path,  "rb") as f:
        inputs = pickle.load(f)
    with open(out_path, "rb") as f:
        targets = pickle.load(f)

    n_list, k_list, m_list, P_list = parse_inputs(inputs)
    if len(P_list) != len(targets):
        raise ValueError(
            f"Length mismatch inputs={len(P_list)} vs targets={len(targets)}"
        )
    return n_list, k_list, m_list, P_list, list(map(float, targets))


def build_samples(n_list, k_list, m_list, P_list, y_list=None, norm_factor=None):
    """
    Turn raw (n,k,m,P) into:
        Ps  : list of (MAX_K, n-k) float32 matrices (rows padded to MAX_K)
        Ks  : list[int]
        Ms  : list[int]
        Ns  : list[int]
        ys_log2 : np.ndarray of log2(m-height)
        norm_factor : scalar used to normalize all P's
    """
    Ps = []
    for i in range(len(P_list)):
        n = int(n_list[i])
        k = int(k_list[i])
        m = int(m_list[i])

        P = np.asarray(P_list[i], dtype=float).reshape(k, n - k)
        # pad rows up to MAX_K
        if k < MAX_K:
            P = np.vstack([P, np.zeros((MAX_K - k, n - k), dtype=float)])
        Ps.append(P)

    # one global normalization factor across whole set
    if norm_factor is None:
        max_abs = max(np.max(np.abs(P)) if P.size else 0.0 for P in Ps)
        norm_factor = max_abs if max_abs > 0 else 1.0

    Ps = [P.astype(np.float32) / norm_factor for P in Ps]

    ys_log2 = None
    if y_list is not None:
        y = np.asarray(y_list, dtype=float)
        y = np.clip(y, 1e-8, None)           # safety against log(0)
        ys_log2 = np.log2(y).astype(np.float32)

    return Ps, list(map(int, k_list)), list(map(int, m_list)), list(map(int, n_list)), ys_log2, norm_factor


# -------------------------------
# 2) Hacky flip+perm augmentation
# -------------------------------

def random_flip_and_permute_P(P: np.ndarray, k: int, n: int, rng: np.random.Generator):
    """
    'Hacky' symmetry-based augmentation:
      - independently flip the sign of each *row* (first k rows)
      - randomly permute the parity columns (n-k columns of P)

    P is expected to be (MAX_K, n-k) with row padding.
    n,k are the original code parameters for this sample.
    """
    cols = n - k
    base = P[:k, :cols]  # the actual (k, n-k) parity part

    # random ±1 per row
    signs = rng.choice(np.array([-1.0, 1.0], dtype=np.float32), size=(k, 1))
    base = base * signs

    # random permutation of parity columns
    perm = rng.permutation(cols)
    base = base[:, perm]

    # re-embed into padded matrix
    P_new = np.zeros_like(P)
    P_new[:k, :cols] = base
    return P_new


# -------------------------------
# 3) Choose dataset (original vs merged)
# -------------------------------

AUG_IN_PATH  = AUG_IN_MANUAL_PATH  or find_first(CAND_AUG_IN)
AUG_OUT_PATH = AUG_OUT_MANUAL_PATH or find_first(CAND_AUG_OUT)
IN_PATH      = find_first(CAND_IN)
OUT_PATH     = find_first(CAND_OUT)

if AUG_IN_PATH and AUG_OUT_PATH:
    print("Using AUGMENTED merged dataset:")
    print("  IN :", AUG_IN_PATH)
    print("  OUT:", AUG_OUT_PATH)
    n_list, k_list, m_list, P_list, y_list = load_pair(AUG_IN_PATH, AUG_OUT_PATH)
else:
    if IN_PATH is None or OUT_PATH is None:
        raise FileNotFoundError(
            "Could not find input/target pickles. Tried:\n"
            f"  AUG_IN  = {CAND_AUG_IN}\n"
            f"  AUG_OUT = {CAND_AUG_OUT}\n"
            f"  IN      = {CAND_IN}\n"
            f"  OUT     = {CAND_OUT}"
        )
    print("Using ORIGINAL training dataset:")
    print("  IN :", IN_PATH)
    print("  OUT:", OUT_PATH)
    n_list, k_list, m_list, P_list, y_list = load_pair(IN_PATH, OUT_PATH)

# Base samples (no hacky augmentation yet)
Ps, Ks, Ms, Ns, Y_log2, NORM = build_samples(
    n_list, k_list, m_list, P_list, y_list, norm_factor=None
)

print(f"Base samples: {len(Ps)} | Norm factor: {NORM:.6g}")

# -------------------------------
# 4) Optional flip+perm augmentation
# -------------------------------

if FLIP_PERM_FACTOR and FLIP_PERM_FACTOR > 0:
    rng = np.random.default_rng(SEED + 2025)

    extra_Ps, extra_Ks, extra_Ms, extra_Ns, extra_Y = [], [], [], [], []

    for i in range(len(Ps)):
        P0 = Ps[i]
        k  = Ks[i]
        n  = Ns[i]
        y0 = Y_log2[i]

        for _ in range(FLIP_PERM_FACTOR):
            P_aug = random_flip_and_permute_P(P0, k, n, rng)
            extra_Ps.append(P_aug)
            extra_Ks.append(k)
            extra_Ms.append(Ms[i])
            extra_Ns.append(n)
            extra_Y.append(y0)

    if extra_Ps:
        Ps.extend(extra_Ps)
        Ks.extend(extra_Ks)
        Ms.extend(extra_Ms)
        Ns.extend(extra_Ns)
        Y_log2 = np.concatenate(
            [Y_log2, np.array(extra_Y, dtype=Y_log2.dtype)], axis=0
        )

    print(
        f"Hacky flip+perm augmentation ENABLED: factor={FLIP_PERM_FACTOR} "
        f"→ total samples={len(Ps)}"
    )
else:
    print("Hacky flip+perm augmentation DISABLED.")

# -------------------------------
# 5) Train/val split
# -------------------------------

N = len(Ps)
idx = np.random.permutation(N)
val_sz = int(VAL_SPLIT * N)
val_idx, tr_idx = idx[:val_sz], idx[val_sz:]

data_train = [(Ps[i], Ks[i], Ms[i], Ns[i], float(Y_log2[i])) for i in tr_idx]
data_val   = [(Ps[i], Ks[i], Ms[i], Ns[i], float(Y_log2[i])) for i in val_idx]

print(f"Train: {len(data_train)} | Val: {len(data_val)} | Norm factor: {NORM:.6g}")
print("Example shapes:", Ps[0].shape, " (rows padded to MAX_K, cols = n-k)")


## 3) Dataset & Collate

In [None]:
class MHeightDataset(Dataset):
    def __init__(self, data, with_targets=True):
        self.data = data
        self.with_targets = with_targets
    def __len__(self): return len(self.data)
    def __getitem__(self, i):
        P, k, m, n, ylog2 = self.data[i]
        tokens = torch.tensor(P, dtype=torch.float32).t().contiguous()  # (L, MAX_K)
        k_idx = torch.tensor({4:0,5:1,6:2}[k], dtype=torch.long)
        m_idx = torch.tensor(m-2, dtype=torch.long)
        parity = torch.tensor(n-k, dtype=torch.long)
        if self.with_targets:
            y = torch.tensor(ylog2, dtype=torch.float32)
            return tokens, k_idx, m_idx, parity, y
        else:
            return tokens, k_idx, m_idx, parity

def collate_batch(batch):
    has_targets = (len(batch[0]) == 5)
    seq_lens = [b[0].shape[0] for b in batch]
    max_len = max(seq_lens)
    toks = []
    for b in batch:
        t = b[0]
        if t.shape[0] < max_len:
            pad = torch.zeros((max_len - t.shape[0], t.shape[1]), dtype=t.dtype)
            t = torch.cat([t, pad], dim=0)
        toks.append(t)
    toks = torch.stack(toks, dim=0)  # (B, L, MAX_K)
    kpm = torch.ones((len(batch), max_len), dtype=torch.bool)
    for i,L in enumerate(seq_lens): kpm[i,:L] = False
    k_idx = torch.stack([b[1] for b in batch])
    m_idx = torch.stack([b[2] for b in batch])
    parity = torch.stack([b[3] for b in batch])
    if has_targets:
        y = torch.stack([b[4] for b in batch])
        return toks, kpm, k_idx, m_idx, parity, y
    return toks, kpm, k_idx, m_idx, parity

train_ds = MHeightDataset(data_train, with_targets=True)
val_ds   = MHeightDataset(data_val, with_targets=True)
train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True,  collate_fn=collate_batch)
val_loader   = DataLoader(val_ds,   batch_size=BATCH_SIZE, shuffle=False, collate_fn=collate_batch)


## 4) Model (Transformer + Conditioning)

In [None]:
class PositionalIndexEmbedding(nn.Module):
    def __init__(self, max_len: int, d_model: int):
        super().__init__()
        self.emb = nn.Embedding(max_len, d_model)
    def forward(self, L: int, B: int, device):
        idx = torch.arange(L, device=device).unsqueeze(0).expand(B, -1)
        return self.emb(idx)

class MHeightTransformer(nn.Module):
    def __init__(self, d_model=128, nhead=4, ff=384, nlayers=3, dropout=0.1):
        super().__init__()
        self.token_proj = nn.Linear(MAX_K, d_model)
        self.pos = PositionalIndexEmbedding(MAX_PARITY+5, d_model)
        enc_layer = nn.TransformerEncoderLayer(d_model=d_model, nhead=nhead, dim_feedforward=ff,
                                               dropout=dropout, activation="gelu",
                                               batch_first=True, norm_first=True)
        self.encoder = nn.TransformerEncoder(enc_layer, num_layers=nlayers)
        # conditioning embeddings
        self.k_emb = nn.Embedding(3, d_model//4)
        self.m_emb = nn.Embedding(4, d_model//4)
        self.p_emb = nn.Embedding(6, d_model//4)
        head_in = d_model + 3*(d_model//4)
        layers = []
        width = [256, 128]
        in_dim = head_in
        for h in width:
            layers += [nn.Linear(in_dim, h), nn.GELU(), nn.Dropout(dropout)]
            in_dim = h
        layers += [nn.Linear(in_dim, 1)]
        self.head = nn.Sequential(*layers)

    def forward(self, toks, kpm, k_idx, m_idx, parity):
        B, L, _ = toks.shape
        x = self.token_proj(toks) + self.pos(L, B, toks.device)
        enc = self.encoder(x, src_key_padding_mask=kpm)
        mask = (~kpm).float().unsqueeze(-1)
        pooled = (enc * mask).sum(dim=1) / mask.sum(dim=1).clamp(min=1.0)
        z = torch.cat([pooled, self.k_emb(k_idx), self.m_emb(m_idx.clamp(0,3)), self.p_emb(parity.clamp(0,5))], dim=-1)
        out = self.head(z).squeeze(-1)
        return out


## 5) Train with L1/L2 Regularization

In [9]:
def build_optimizer_with_decay(model: nn.Module, lr: float, weight_decay: float):
    # Decoupled weight decay (AdamW). Exclude bias and LayerNorm from decay.
    decay, no_decay = [], []
    for name, p in model.named_parameters():
        if not p.requires_grad:
            continue
        is_bias = name.endswith("bias")
        is_norm = "norm" in name.lower()
        if is_bias or is_norm:
            no_decay.append(p)
        else:
            decay.append(p)
    param_groups = [
        {"params": decay, "weight_decay": weight_decay},
        {"params": no_decay, "weight_decay": 0.0},
    ]
    return AdamW(param_groups, lr=lr)

criterion = nn.MSELoss()
model = MHeightTransformer(d_model=DMODEL, nhead=NHEAD, ff=ENC_FF, nlayers=ENC_LAYERS, dropout=DROPOUT).to(DEVICE)
optimizer = build_optimizer_with_decay(model, lr=MAX_LR, weight_decay=WEIGHT_DECAY)

steps_per_epoch = max(1, len(train_loader))
scheduler = OneCycleLR(optimizer, max_lr=MAX_LR,
                       epochs=EPOCHS, steps_per_epoch=steps_per_epoch,
                       pct_start=0.15, anneal_strategy='cos', div_factor=10.0, final_div_factor=1e3)

best_val = float('inf'); best_state = None; no_improve = 0

for epoch in range(1, EPOCHS+1):
    model.train()
    running = 0.0
    for toks, kpm, kidx, midx, parity, y in tqdm(train_loader, desc=f"Epoch {epoch}/{EPOCHS}", leave=False):
        toks, kpm = toks.to(DEVICE), kpm.to(DEVICE)
        kidx, midx, parity, y = kidx.to(DEVICE), midx.to(DEVICE), parity.to(DEVICE), y.to(DEVICE)

        optimizer.zero_grad(set_to_none=True)
        yhat = model(toks, kpm, kidx, midx, parity)
        loss = criterion(yhat, y)

        # ---- L1 penalty on weights (no bias/LayerNorm) ----
        if L1_COEFF and L1_COEFF > 0:
            l1 = 0.0
            for name, p in model.named_parameters():
                if p.requires_grad and (not name.endswith("bias")) and ("norm" not in name.lower()):
                    l1 = l1 + p.abs().sum()
            loss = loss + L1_COEFF * l1

        loss.backward()
        nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step(); scheduler.step()
        running += loss.item() * y.size(0)
    train_loss = running / len(train_ds)

    # validation
    model.eval(); vloss = 0.0
    with torch.no_grad():
        for toks, kpm, kidx, midx, parity, y in val_loader:
            toks, kpm = toks.to(DEVICE), kpm.to(DEVICE)
            kidx, midx, parity, y = kidx.to(DEVICE), midx.to(DEVICE), parity.to(DEVICE), y.to(DEVICE)
            yhat = model(toks, kpm, kidx, midx, parity)
            vloss += criterion(yhat, y).item() * y.size(0)
    vloss /= len(val_ds)

    print(f"Epoch {epoch:03d} | train log2-MSE: {train_loss:.6f} | val log2-MSE: {vloss:.6f}")

    if vloss + 1e-9 < best_val:
        best_val = vloss
        best_state = {k: v.detach().cpu() for k, v in model.state_dict().items()}
        no_improve = 0
    else:
        no_improve += 1
    if no_improve >= PATIENCE:
        print(f"Early stop @ epoch {epoch}. Best val: {best_val:.6f}")
        break

if best_state is not None:
    model.load_state_dict(best_state)

os.makedirs("checkpoints", exist_ok=True)
torch.save(model.state_dict(), "checkpoints/mheight_transformer.pt")
print("Saved best model. Best val log2-MSE:", best_val)




Epoch 1/120:   0%|          | 0/14175 [00:00<?, ?it/s]

Epoch 001 | train log2-MSE: 1.691181 | val log2-MSE: 1.226509


Epoch 2/120:   0%|          | 0/14175 [00:00<?, ?it/s]

Epoch 002 | train log2-MSE: 1.343664 | val log2-MSE: 1.171784


Epoch 3/120:   0%|          | 0/14175 [00:00<?, ?it/s]

Epoch 003 | train log2-MSE: 1.306959 | val log2-MSE: 1.178466


Epoch 4/120:   0%|          | 0/14175 [00:00<?, ?it/s]

Epoch 004 | train log2-MSE: 1.287871 | val log2-MSE: 1.159498


Epoch 5/120:   0%|          | 0/14175 [00:00<?, ?it/s]

Epoch 005 | train log2-MSE: 1.269487 | val log2-MSE: 1.154829


Epoch 6/120:   0%|          | 0/14175 [00:00<?, ?it/s]

Epoch 006 | train log2-MSE: 1.254899 | val log2-MSE: 1.141981


Epoch 7/120:   0%|          | 0/14175 [00:00<?, ?it/s]

Epoch 007 | train log2-MSE: 1.240422 | val log2-MSE: 1.118669


Epoch 8/120:   0%|          | 0/14175 [00:00<?, ?it/s]

Epoch 008 | train log2-MSE: 1.227395 | val log2-MSE: 1.123683


Epoch 9/120:   0%|          | 0/14175 [00:00<?, ?it/s]

Epoch 009 | train log2-MSE: 1.227400 | val log2-MSE: 1.129638


Epoch 10/120:   0%|          | 0/14175 [00:00<?, ?it/s]

Epoch 010 | train log2-MSE: 1.214579 | val log2-MSE: 1.124384


Epoch 11/120:   0%|          | 0/14175 [00:00<?, ?it/s]

Epoch 011 | train log2-MSE: 1.214362 | val log2-MSE: 1.095639


Epoch 12/120:   0%|          | 0/14175 [00:00<?, ?it/s]

Epoch 012 | train log2-MSE: 1.212807 | val log2-MSE: 1.204847


Epoch 13/120:   0%|          | 0/14175 [00:00<?, ?it/s]

Epoch 013 | train log2-MSE: 1.209946 | val log2-MSE: 1.137499


Epoch 14/120:   0%|          | 0/14175 [00:00<?, ?it/s]

Epoch 014 | train log2-MSE: 1.209539 | val log2-MSE: 1.114527


Epoch 15/120:   0%|          | 0/14175 [00:00<?, ?it/s]

Epoch 015 | train log2-MSE: 1.202006 | val log2-MSE: 1.100475


Epoch 16/120:   0%|          | 0/14175 [00:00<?, ?it/s]

Epoch 016 | train log2-MSE: 1.197731 | val log2-MSE: 1.111709


Epoch 17/120:   0%|          | 0/14175 [00:00<?, ?it/s]

Epoch 017 | train log2-MSE: 1.196698 | val log2-MSE: 1.100588


Epoch 18/120:   0%|          | 0/14175 [00:00<?, ?it/s]

Epoch 018 | train log2-MSE: 1.197470 | val log2-MSE: 1.113078


Epoch 19/120:   0%|          | 0/14175 [00:00<?, ?it/s]

Epoch 019 | train log2-MSE: 1.195276 | val log2-MSE: 1.090928


Epoch 20/120:   0%|          | 0/14175 [00:00<?, ?it/s]

Epoch 020 | train log2-MSE: 1.194513 | val log2-MSE: 1.092041


Epoch 21/120:   0%|          | 0/14175 [00:00<?, ?it/s]

Epoch 021 | train log2-MSE: 1.190752 | val log2-MSE: 1.085954


Epoch 22/120:   0%|          | 0/14175 [00:00<?, ?it/s]

Epoch 022 | train log2-MSE: 1.189288 | val log2-MSE: 1.130867


Epoch 23/120:   0%|          | 0/14175 [00:00<?, ?it/s]

Epoch 023 | train log2-MSE: 1.188947 | val log2-MSE: 1.098257


Epoch 24/120:   0%|          | 0/14175 [00:00<?, ?it/s]

Epoch 024 | train log2-MSE: 1.184003 | val log2-MSE: 1.110527


Epoch 25/120:   0%|          | 0/14175 [00:00<?, ?it/s]

Epoch 025 | train log2-MSE: 1.186385 | val log2-MSE: 1.082235


Epoch 26/120:   0%|          | 0/14175 [00:00<?, ?it/s]

Epoch 026 | train log2-MSE: 1.183302 | val log2-MSE: 1.072493


Epoch 27/120:   0%|          | 0/14175 [00:00<?, ?it/s]

Epoch 027 | train log2-MSE: 1.180385 | val log2-MSE: 1.070695


Epoch 28/120:   0%|          | 0/14175 [00:00<?, ?it/s]

Epoch 028 | train log2-MSE: 1.177412 | val log2-MSE: 1.124457


Epoch 29/120:   0%|          | 0/14175 [00:00<?, ?it/s]

Epoch 029 | train log2-MSE: 1.176396 | val log2-MSE: 1.081032


Epoch 30/120:   0%|          | 0/14175 [00:00<?, ?it/s]

Epoch 030 | train log2-MSE: 1.175147 | val log2-MSE: 1.101289


Epoch 31/120:   0%|          | 0/14175 [00:00<?, ?it/s]

Epoch 031 | train log2-MSE: 1.176008 | val log2-MSE: 1.098562


Epoch 32/120:   0%|          | 0/14175 [00:00<?, ?it/s]

Epoch 032 | train log2-MSE: 1.171083 | val log2-MSE: 1.060892


Epoch 33/120:   0%|          | 0/14175 [00:00<?, ?it/s]

Epoch 033 | train log2-MSE: 1.173887 | val log2-MSE: 1.061494


Epoch 34/120:   0%|          | 0/14175 [00:00<?, ?it/s]

Epoch 034 | train log2-MSE: 1.169499 | val log2-MSE: 1.056547


Epoch 35/120:   0%|          | 0/14175 [00:00<?, ?it/s]

Epoch 035 | train log2-MSE: 1.169070 | val log2-MSE: 1.080453


Epoch 36/120:   0%|          | 0/14175 [00:00<?, ?it/s]

Epoch 036 | train log2-MSE: 1.168344 | val log2-MSE: 1.066411


Epoch 37/120:   0%|          | 0/14175 [00:00<?, ?it/s]

Epoch 037 | train log2-MSE: 1.163200 | val log2-MSE: 1.055168


Epoch 38/120:   0%|          | 0/14175 [00:00<?, ?it/s]

Epoch 038 | train log2-MSE: 1.164024 | val log2-MSE: 1.066595


Epoch 39/120:   0%|          | 0/14175 [00:00<?, ?it/s]

Epoch 039 | train log2-MSE: 1.157650 | val log2-MSE: 1.060526


Epoch 40/120:   0%|          | 0/14175 [00:00<?, ?it/s]

Epoch 040 | train log2-MSE: 1.157784 | val log2-MSE: 1.061946


Epoch 41/120:   0%|          | 0/14175 [00:00<?, ?it/s]

Epoch 041 | train log2-MSE: 1.156965 | val log2-MSE: 1.057949


Epoch 42/120:   0%|          | 0/14175 [00:00<?, ?it/s]

Epoch 042 | train log2-MSE: 1.153524 | val log2-MSE: 1.056440


Epoch 43/120:   0%|          | 0/14175 [00:00<?, ?it/s]

Epoch 043 | train log2-MSE: 1.149623 | val log2-MSE: 1.065647


Epoch 44/120:   0%|          | 0/14175 [00:00<?, ?it/s]

Epoch 044 | train log2-MSE: 1.151669 | val log2-MSE: 1.077012


Epoch 45/120:   0%|          | 0/14175 [00:00<?, ?it/s]

Epoch 045 | train log2-MSE: 1.146395 | val log2-MSE: 1.048703


Epoch 46/120:   0%|          | 0/14175 [00:00<?, ?it/s]

Epoch 046 | train log2-MSE: 1.143639 | val log2-MSE: 1.043450


Epoch 47/120:   0%|          | 0/14175 [00:00<?, ?it/s]

Epoch 047 | train log2-MSE: 1.149704 | val log2-MSE: 1.027817


Epoch 48/120:   0%|          | 0/14175 [00:00<?, ?it/s]

Epoch 048 | train log2-MSE: 1.140806 | val log2-MSE: 1.043407


Epoch 49/120:   0%|          | 0/14175 [00:00<?, ?it/s]

Epoch 049 | train log2-MSE: 1.135632 | val log2-MSE: 1.024222


Epoch 50/120:   0%|          | 0/14175 [00:00<?, ?it/s]

Epoch 050 | train log2-MSE: 1.131796 | val log2-MSE: 1.024021


Epoch 51/120:   0%|          | 0/14175 [00:00<?, ?it/s]

Epoch 051 | train log2-MSE: 1.130150 | val log2-MSE: 1.026677


Epoch 52/120:   0%|          | 0/14175 [00:00<?, ?it/s]

Epoch 052 | train log2-MSE: 1.125019 | val log2-MSE: 1.017838


Epoch 53/120:   0%|          | 0/14175 [00:00<?, ?it/s]

Epoch 053 | train log2-MSE: 1.128358 | val log2-MSE: 1.030603


Epoch 54/120:   0%|          | 0/14175 [00:00<?, ?it/s]

Epoch 054 | train log2-MSE: 1.124358 | val log2-MSE: 1.039341


Epoch 55/120:   0%|          | 0/14175 [00:00<?, ?it/s]

Epoch 055 | train log2-MSE: 1.129417 | val log2-MSE: 0.993789


Epoch 56/120:   0%|          | 0/14175 [00:00<?, ?it/s]

Epoch 056 | train log2-MSE: 1.115197 | val log2-MSE: 1.008011


Epoch 57/120:   0%|          | 0/14175 [00:00<?, ?it/s]

Epoch 057 | train log2-MSE: 1.111481 | val log2-MSE: 0.993334


Epoch 58/120:   0%|          | 0/14175 [00:00<?, ?it/s]

Epoch 058 | train log2-MSE: 1.112727 | val log2-MSE: 0.991980


Epoch 59/120:   0%|          | 0/14175 [00:00<?, ?it/s]

Epoch 059 | train log2-MSE: 1.108501 | val log2-MSE: 0.995567


Epoch 60/120:   0%|          | 0/14175 [00:00<?, ?it/s]

Epoch 060 | train log2-MSE: 1.107160 | val log2-MSE: 0.988225


Epoch 61/120:   0%|          | 0/14175 [00:00<?, ?it/s]

Epoch 061 | train log2-MSE: 1.105149 | val log2-MSE: 0.979426


Epoch 62/120:   0%|          | 0/14175 [00:00<?, ?it/s]

Epoch 062 | train log2-MSE: 1.099592 | val log2-MSE: 0.986977


Epoch 63/120:   0%|          | 0/14175 [00:00<?, ?it/s]

Epoch 063 | train log2-MSE: 1.094780 | val log2-MSE: 0.959329


Epoch 64/120:   0%|          | 0/14175 [00:00<?, ?it/s]

Epoch 064 | train log2-MSE: 1.097456 | val log2-MSE: 0.963041


Epoch 65/120:   0%|          | 0/14175 [00:00<?, ?it/s]

Epoch 065 | train log2-MSE: 1.090734 | val log2-MSE: 0.976719


Epoch 66/120:   0%|          | 0/14175 [00:00<?, ?it/s]

Epoch 066 | train log2-MSE: 1.082943 | val log2-MSE: 0.960661


Epoch 67/120:   0%|          | 0/14175 [00:00<?, ?it/s]

Epoch 067 | train log2-MSE: 1.083150 | val log2-MSE: 0.959158


Epoch 68/120:   0%|          | 0/14175 [00:00<?, ?it/s]

Epoch 068 | train log2-MSE: 1.077199 | val log2-MSE: 0.941541


Epoch 69/120:   0%|          | 0/14175 [00:00<?, ?it/s]

Epoch 069 | train log2-MSE: 1.074917 | val log2-MSE: 0.930922


Epoch 70/120:   0%|          | 0/14175 [00:00<?, ?it/s]

Epoch 070 | train log2-MSE: 1.065723 | val log2-MSE: 0.953233


Epoch 71/120:   0%|          | 0/14175 [00:00<?, ?it/s]

Epoch 071 | train log2-MSE: 1.062327 | val log2-MSE: 0.917886


Epoch 72/120:   0%|          | 0/14175 [00:00<?, ?it/s]

Epoch 072 | train log2-MSE: 1.057739 | val log2-MSE: 0.939030


Epoch 73/120:   0%|          | 0/14175 [00:00<?, ?it/s]

Epoch 073 | train log2-MSE: 1.049188 | val log2-MSE: 0.935440


Epoch 74/120:   0%|          | 0/14175 [00:00<?, ?it/s]

Epoch 074 | train log2-MSE: 1.051463 | val log2-MSE: 0.923981


Epoch 75/120:   0%|          | 0/14175 [00:00<?, ?it/s]

Epoch 075 | train log2-MSE: 1.044891 | val log2-MSE: 0.896907


Epoch 76/120:   0%|          | 0/14175 [00:00<?, ?it/s]

Epoch 076 | train log2-MSE: 1.036558 | val log2-MSE: 0.891446


Epoch 77/120:   0%|          | 0/14175 [00:00<?, ?it/s]

Epoch 077 | train log2-MSE: 1.031111 | val log2-MSE: 0.888737


Epoch 78/120:   0%|          | 0/14175 [00:00<?, ?it/s]

Epoch 078 | train log2-MSE: 1.025547 | val log2-MSE: 0.892136


Epoch 79/120:   0%|          | 0/14175 [00:00<?, ?it/s]

Epoch 079 | train log2-MSE: 1.030056 | val log2-MSE: 0.869472


Epoch 80/120:   0%|          | 0/14175 [00:00<?, ?it/s]

Epoch 080 | train log2-MSE: 1.017312 | val log2-MSE: 0.854314


Epoch 81/120:   0%|          | 0/14175 [00:00<?, ?it/s]

Epoch 081 | train log2-MSE: 1.010864 | val log2-MSE: 0.854779


Epoch 82/120:   0%|          | 0/14175 [00:00<?, ?it/s]

Epoch 082 | train log2-MSE: 1.000363 | val log2-MSE: 0.845816


Epoch 83/120:   0%|          | 0/14175 [00:00<?, ?it/s]

Epoch 083 | train log2-MSE: 1.002752 | val log2-MSE: 0.822705


Epoch 84/120:   0%|          | 0/14175 [00:00<?, ?it/s]

Epoch 084 | train log2-MSE: 0.995154 | val log2-MSE: 0.825701


Epoch 85/120:   0%|          | 0/14175 [00:00<?, ?it/s]

Epoch 085 | train log2-MSE: 0.993970 | val log2-MSE: 0.822953


Epoch 86/120:   0%|          | 0/14175 [00:00<?, ?it/s]

Epoch 086 | train log2-MSE: 0.983557 | val log2-MSE: 0.809155


Epoch 87/120:   0%|          | 0/14175 [00:00<?, ?it/s]

Epoch 087 | train log2-MSE: 0.979447 | val log2-MSE: 0.789954


Epoch 88/120:   0%|          | 0/14175 [00:00<?, ?it/s]

Epoch 088 | train log2-MSE: 0.967577 | val log2-MSE: 0.793226


Epoch 89/120:   0%|          | 0/14175 [00:00<?, ?it/s]

Epoch 089 | train log2-MSE: 0.964555 | val log2-MSE: 0.778369


Epoch 90/120:   0%|          | 0/14175 [00:00<?, ?it/s]

Epoch 090 | train log2-MSE: 0.961110 | val log2-MSE: 0.777467


Epoch 91/120:   0%|          | 0/14175 [00:00<?, ?it/s]

Epoch 091 | train log2-MSE: 0.956313 | val log2-MSE: 0.781793


Epoch 92/120:   0%|          | 0/14175 [00:00<?, ?it/s]

Epoch 092 | train log2-MSE: 0.949721 | val log2-MSE: 0.766142


Epoch 93/120:   0%|          | 0/14175 [00:00<?, ?it/s]

Epoch 093 | train log2-MSE: 0.947801 | val log2-MSE: 0.756151


Epoch 94/120:   0%|          | 0/14175 [00:00<?, ?it/s]

Epoch 094 | train log2-MSE: 0.936592 | val log2-MSE: 0.752701


Epoch 95/120:   0%|          | 0/14175 [00:00<?, ?it/s]

Epoch 095 | train log2-MSE: 0.933795 | val log2-MSE: 0.726166


Epoch 96/120:   0%|          | 0/14175 [00:00<?, ?it/s]

Epoch 096 | train log2-MSE: 0.927217 | val log2-MSE: 0.738600


Epoch 97/120:   0%|          | 0/14175 [00:00<?, ?it/s]

Epoch 097 | train log2-MSE: 0.925993 | val log2-MSE: 0.727715


Epoch 98/120:   0%|          | 0/14175 [00:00<?, ?it/s]

Epoch 098 | train log2-MSE: 0.917210 | val log2-MSE: 0.715357


Epoch 99/120:   0%|          | 0/14175 [00:00<?, ?it/s]

Epoch 099 | train log2-MSE: 0.915032 | val log2-MSE: 0.716388


Epoch 100/120:   0%|          | 0/14175 [00:00<?, ?it/s]

Epoch 100 | train log2-MSE: 0.907373 | val log2-MSE: 0.706444


Epoch 101/120:   0%|          | 0/14175 [00:00<?, ?it/s]

Epoch 101 | train log2-MSE: 0.903107 | val log2-MSE: 0.693471


Epoch 102/120:   0%|          | 0/14175 [00:00<?, ?it/s]

Epoch 102 | train log2-MSE: 0.900703 | val log2-MSE: 0.690993


Epoch 103/120:   0%|          | 0/14175 [00:00<?, ?it/s]

Epoch 103 | train log2-MSE: 0.895218 | val log2-MSE: 0.690954


Epoch 104/120:   0%|          | 0/14175 [00:00<?, ?it/s]

Epoch 104 | train log2-MSE: 0.892927 | val log2-MSE: 0.690037


Epoch 105/120:   0%|          | 0/14175 [00:00<?, ?it/s]

Epoch 105 | train log2-MSE: 0.888145 | val log2-MSE: 0.674194


Epoch 106/120:   0%|          | 0/14175 [00:00<?, ?it/s]

Epoch 106 | train log2-MSE: 0.886864 | val log2-MSE: 0.677780


Epoch 107/120:   0%|          | 0/14175 [00:00<?, ?it/s]

Epoch 107 | train log2-MSE: 0.883084 | val log2-MSE: 0.672916


Epoch 108/120:   0%|          | 0/14175 [00:00<?, ?it/s]

Epoch 108 | train log2-MSE: 0.880867 | val log2-MSE: 0.669286


Epoch 109/120:   0%|          | 0/14175 [00:00<?, ?it/s]

Epoch 109 | train log2-MSE: 0.877364 | val log2-MSE: 0.658925


Epoch 110/120:   0%|          | 0/14175 [00:00<?, ?it/s]

Epoch 110 | train log2-MSE: 0.877656 | val log2-MSE: 0.658415


Epoch 111/120:   0%|          | 0/14175 [00:00<?, ?it/s]

Epoch 111 | train log2-MSE: 0.873569 | val log2-MSE: 0.662076


Epoch 112/120:   0%|          | 0/14175 [00:00<?, ?it/s]

Epoch 112 | train log2-MSE: 0.875034 | val log2-MSE: 0.654059


Epoch 113/120:   0%|          | 0/14175 [00:00<?, ?it/s]

Epoch 113 | train log2-MSE: 0.869832 | val log2-MSE: 0.658953


Epoch 114/120:   0%|          | 0/14175 [00:00<?, ?it/s]

Epoch 114 | train log2-MSE: 0.868991 | val log2-MSE: 0.655856


Epoch 115/120:   0%|          | 0/14175 [00:00<?, ?it/s]

Epoch 115 | train log2-MSE: 0.870648 | val log2-MSE: 0.654114


Epoch 116/120:   0%|          | 0/14175 [00:00<?, ?it/s]

Epoch 116 | train log2-MSE: 0.868037 | val log2-MSE: 0.653998


Epoch 117/120:   0%|          | 0/14175 [00:00<?, ?it/s]

Epoch 117 | train log2-MSE: 0.866238 | val log2-MSE: 0.653608


Epoch 118/120:   0%|          | 0/14175 [00:00<?, ?it/s]

Epoch 118 | train log2-MSE: 0.867447 | val log2-MSE: 0.652667


Epoch 119/120:   0%|          | 0/14175 [00:00<?, ?it/s]

Epoch 119 | train log2-MSE: 0.864310 | val log2-MSE: 0.652224


Epoch 120/120:   0%|          | 0/14175 [00:00<?, ?it/s]

Epoch 120 | train log2-MSE: 0.866373 | val log2-MSE: 0.652324
Saved best model. Best val log2-MSE: 0.6522239281428974


In [10]:
# # ==========================
# # Inference / Testing Cell
# # ==========================
# import os, time, pickle
# from pathlib import Path
# from typing import Any, List
# import numpy as np
# import torch
# from torch.utils.data import Dataset, DataLoader

# # Use same device setup as training
# DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
# print("Device:", DEVICE, "| Torch:", torch.__version__)

# # ---------- 1) CONFIG: pulled from notebook globals ----------
# DMODEL     = globals().get("DMODEL", 128)
# NHEAD      = globals().get("NHEAD", 4)
# ENC_FF     = globals().get("ENC_FF", 384)
# ENC_LAYERS = globals().get("ENC_LAYERS", 3)
# DROPOUT    = globals().get("DROPOUT", 0.10)

# N_FIXED    = globals().get("N_FIXED", 9)
# K_VALUES   = globals().get("K_VALUES", [4,5,6])
# MAX_K      = globals().get("MAX_K", max(K_VALUES))
# MAX_PARITY = globals().get("MAX_PARITY", N_FIXED - min(K_VALUES))

# # IMPORTANT: this NORM must be the SAME as training preprocessing
# NORM       = globals().get("NORM", 1.0)

# # TODO: set these two paths appropriately
# CKPT_PATH  = "/content/checkpoints/mheight_transformer_improved_A.pt"
# TEST_PATH  = "/content/DS-2-Test-n_k_m_P.pkl"   # or DS-2 test when you have it
# OUT_NAME   = "predictions.pkl"

# print("Using config:")
# print(f"  DMODEL={DMODEL}, NHEAD={NHEAD}, ENC_FF={ENC_FF}, ENC_LAYERS={ENC_LAYERS}, DROPOUT={DROPOUT}")
# print(f"  MAX_K={MAX_K}, MAX_PARITY={MAX_PARITY}, NORM={NORM}")
# print(f"  CKPT_PATH={CKPT_PATH}")
# print(f"  TEST_PATH={TEST_PATH}")

# assert Path(CKPT_PATH).exists(), f"Checkpoint not found: {CKPT_PATH}"
# assert Path(TEST_PATH).exists(), f"Test file not found: {TEST_PATH}"


# # ---------- 2) Parser + sampler (same structure as training) ----------

# if "parse_inputs" not in globals():
#     def parse_inputs(obj: Any):
#         # Dict-of-arrays
#         if isinstance(obj, dict) and all(k in obj for k in ("n","k","m","P")):
#             return list(obj["n"]), list(obj["k"]), list(obj["m"]), list(obj["P"])
#         # Tuple (n_list, k_list, m_list, P_list)
#         if isinstance(obj, (list, tuple)) and len(obj) == 4 and all(isinstance(x, (list, tuple)) for x in obj[:3]):
#             n_list, k_list, m_list, P_list = obj
#             return list(n_list), list(k_list), list(m_list), list(P_list)
#         # List of tuples
#         if isinstance(obj, list) and obj and isinstance(obj[0], tuple):
#             return (
#                 [int(r[0]) for r in obj],
#                 [int(r[1]) for r in obj],
#                 [int(r[2]) for r in obj],
#                 [r[3]       for r in obj],
#             )
#         # List of lists [n,k,m,P]
#         if isinstance(obj, list) and obj and isinstance(obj[0], list):
#             if not all(isinstance(r, list) and len(r) == 4 for r in obj):
#                 raise ValueError("List-of-lists present but inner length != 4 [n,k,m,P].")
#             return (
#                 [int(r[0]) for r in obj],
#                 [int(r[1]) for r in obj],
#                 [int(r[2]) for r in obj],
#                 [r[3]       for r in obj],
#             )
#         head = obj[0] if isinstance(obj, (list, tuple)) and obj else None
#         raise ValueError(
#             f"Unrecognized input format: type(obj)={type(obj)}, "
#             f"type(first)={type(head) if head is not None else None}"
#         )

# if "build_samples" not in globals():
#     def build_samples(n_list, k_list, m_list, P_list, y_list=None, norm_factor=None):
#         Ps = []
#         for i in range(len(P_list)):
#             n = int(n_list[i])
#             k = int(k_list[i])
#             m = int(m_list[i])
#             P = np.asarray(P_list[i], dtype=float).reshape(k, n-k)
#             if k < MAX_K:
#                 P = np.vstack([P, np.zeros((MAX_K - k, n-k), dtype=float)])
#             Ps.append(P)

#         if norm_factor is None:
#             max_abs = max(np.max(np.abs(P)) if P.size else 0.0 for P in Ps)
#             norm_factor = max_abs if max_abs > 0 else 1.0

#         Ps = [P.astype(np.float32) / norm_factor for P in Ps]
#         ys_log2 = None
#         if y_list is not None:
#             y = np.asarray(y_list, dtype=float)
#             y = np.clip(y, 1e-8, None)
#             ys_log2 = np.log2(y).astype(np.float32)
#         return Ps, list(map(int, k_list)), list(map(int, m_list)), list(map(int, n_list)), ys_log2, norm_factor


# def load_test_data(path: str, norm_factor: float):
#     with open(path, "rb") as f:
#         test_in = pickle.load(f)
#     tn, tk, tm, tP = parse_inputs(test_in)
#     Ps, Ks, Ms, Ns, _, _ = build_samples(tn, tk, tm, tP, y_list=None, norm_factor=norm_factor)
#     return [(Ps[i], Ks[i], Ms[i], Ns[i]) for i in range(len(Ps))]


# class TestDataset(Dataset):
#     def __init__(self, data): self.data = data
#     def __len__(self): return len(self.data)
#     def __getitem__(self, i):
#         P, k, m, n = self.data[i]
#         # (T, D) = (n-k, MAX_K) then transposed → (T, D) = (n-k, 6)
#         tokens = torch.tensor(P, dtype=torch.float32).t().contiguous()
#         k_idx  = torch.tensor({4:0, 5:1, 6:2}[k], dtype=torch.long)
#         m_idx  = torch.tensor(m - 2,         dtype=torch.long)  # m in {2,...}
#         parity = torch.tensor(n - k,         dtype=torch.long)  # n-k parity cols
#         return tokens, k_idx, m_idx, parity


# def test_collate(batch):
#     seq_lens = [b[0].shape[0] for b in batch]
#     max_len  = max(seq_lens)
#     toks = []
#     for b in batch:
#         t = b[0]
#         if t.shape[0] < max_len:
#             pad = torch.zeros((max_len - t.shape[0], t.shape[1]), dtype=t.dtype)
#             t = torch.cat([t, pad], dim=0)
#         toks.append(t)
#     toks = torch.stack(toks, dim=0)
#     kpm = torch.ones((len(batch), max_len), dtype=torch.bool)
#     for i, L in enumerate(seq_lens):
#         kpm[i, :L] = False
#     k_idx  = torch.stack([b[1] for b in batch])
#     m_idx  = torch.stack([b[2] for b in batch])
#     parity = torch.stack([b[3] for b in batch])
#     return toks, kpm, k_idx, m_idx, parity


# # ---------- 3) Rebuild model & load checkpoint ----------
# # Assumes MHeightTransformer was defined in a previous cell with same signature.

# infer_model = MHeightTransformer(
#     d_model=DMODEL,
#     nhead=NHEAD,
#     ff=ENC_FF,
#     nlayers=ENC_LAYERS,
#     dropout=DROPOUT
# ).to(DEVICE)

# ckpt = torch.load(CKPT_PATH, map_location=DEVICE)
# state_dict = ckpt.get("state_dict", ckpt)
# missing, unexpected = infer_model.load_state_dict(state_dict, strict=False)
# print("Loaded checkpoint:", CKPT_PATH)
# print("Missing keys:", missing)
# print("Unexpected keys:", unexpected)
# infer_model.eval()


# # ---------- 4) Load test set & run inference ----------
# test_data = load_test_data(TEST_PATH, norm_factor=NORM)
# print("Loaded test samples:", len(test_data))
# print("Example P shape (MAX_K rows, n-k cols):", test_data[0][0].shape)

# test_ds = TestDataset(test_data)
# test_loader = DataLoader(test_ds, batch_size=128, shuffle=False, collate_fn=test_collate)

# preds = []
# with torch.no_grad():
#     for toks, kpm, kidx, midx, parity in test_loader:
#         toks   = toks.to(DEVICE)
#         kpm    = kpm.to(DEVICE)
#         kidx   = kidx.to(DEVICE)
#         midx   = midx.to(DEVICE)
#         parity = parity.to(DEVICE)

#         ylog = infer_model(toks, kpm, kidx, midx, parity)  # log2(m-height)
#         yhat = torch.pow(2.0, ylog)
#         preds.extend(yhat.detach().cpu().tolist())

# timestamped = f"{Path(OUT_NAME).stem}_{int(time.time())}.pkl"
# with open(timestamped, "wb") as f:
#     pickle.dump([float(p) for p in preds], f)

# print(f"Saved {len(preds)} predictions → {timestamped}")


Device: cuda | Torch: 2.9.0+cu126
Using config:
  DMODEL=128, NHEAD=4, ENC_FF=384, ENC_LAYERS=3, DROPOUT=0.1
  MAX_K=6, MAX_PARITY=5, NORM=100.0
  CKPT_PATH=/content/checkpoints/mheight_transformer_improved_A.pt
  TEST_PATH=/content/DS-2-Test-n_k_m_P.pkl


AssertionError: Checkpoint not found: /content/checkpoints/mheight_transformer_improved_A.pt

In [None]:
# # Path to your saved file (adjust if you renamed it or ran multiple times)
# path = "/content/CSCE 636-600 Fall 2025 Project 2 Test Results_Jairaj_Saraf_835008429.pkl"   # or "predictions_1730671200.pkl", etc.

# with open(path, "rb") as f:
#     preds = pickle.load(f)

# print(f"Loaded {len(preds)} predictions")
# print("First 10 predictions:", preds[:10])