# setup

In [1]:
import os, sys, json, math, random
from dataclasses import asdict

import numpy as np
import pandas as pd
import torch
from torch.utils.data import DataLoader, Subset
from torch.optim import AdamW
from torch.cuda.amp import GradScaler

# --- project root & local modules ---
PROJ_ROOT = "/home/hernan_melmoth/Documents/phd_work/otu-taxa-foundation"
sys.path.append(os.path.join(PROJ_ROOT, "src"))

from otu_taxa.helpers_pretraining_model import MetricsLogger, set_seed, save_checkpoint, IGNORE_INDEX
from otu_taxa.trainer_hier_joint_unk import run_epoch

from otu_taxa.joint_hier_loss_metrics_unk import make_factorized_tax_loss_fn_fast_masked_with_unk
from otu_taxa.otu_taxa_transformer_unk import ModelConfig, OTUTaxaTransformerEmbedTaxTreeUnkTaxa
from otu_taxa.dataloaders_unk_balanced import (
    OTUTaxaDataset,
    MaskingConfig,
    make_collator_balanced,
    build_tax2ancestor_at_rank,
)


# data

In [2]:
# -----------------------------------
# Paths & config (single source of truth)
# -----------------------------------
PROJ_ROOT = "/home/hernan_melmoth/Documents/phd_work/otu-taxa-foundation"

# External (heavy) dataset location (outside repo)
DATASET_ROOT = "/home/hernan_melmoth/Documents/phd_work/Microbeatlas_preprocess_training"
dataset_folder_name = "dataset_full_top999"   # <- adjust if you used a different name
dataset_dir = os.path.join(
    DATASET_ROOT,
    "level_97",
    "silva-138.2",
    "incomplete_silva_sintax",
    dataset_folder_name,
)

# Base dataset artifacts
TAXONOMY_VOCAB_PATH = os.path.join(dataset_dir, "taxonomy_vocab.json")
path_to_taxonomy_tree = os.path.join(dataset_dir, "taxonomy_nested.json")
SAMPLES_JSONL = os.path.join(dataset_dir, "samples.jsonl")

# Tree artifacts directory (derived)
TREE_DIR = os.path.join(dataset_dir, "tree_artifacts")

# LCA (prefer .npy in new pipeline; keep CSV fallback if you still have it)
LCA_NPY = os.path.join(TREE_DIR, "lca_distance_edges.npy")
LCA_CSV = os.path.join(TREE_DIR, "lca_distance_edges.csv")  # optional legacy

# Descendant matrices
DESCENDANT_MATRIX_PATH = os.path.join(TREE_DIR, "descendant_matrix.npy")              # real tree only
UNK_VOCAB_PATH         = os.path.join(TREE_DIR, "taxonomy_vocab_with_unk.json")       # real + 7 UNKs
UNK_M_PATH             = os.path.join(TREE_DIR, "descendant_matrix_with_unk.npy")     # real+UNK closure
RANK_IDX_PATH          = os.path.join(TREE_DIR, "rank_idx.npy")                       # rank per token (0..6)

# # Optional: test split ids (only if you created it for the full corpus)
# TEST_IDS_PATH = os.path.join(dataset_dir, "splits", "test_samples_2000.txt")  # adjust if needed

# Run / output dir (inside repo)
run_name = "pretrain_hier_joint_unk_taxa"
out_dir = os.path.join(PROJ_ROOT, "runs_hier_joint_unk_taxa", run_name)
os.makedirs(out_dir, exist_ok=True)

# -----------------------------------
# Load LCA distance matrix
# -----------------------------------
if os.path.exists(LCA_NPY):
    D_tree = torch.from_numpy(np.load(LCA_NPY)).float()
else:
    df_D = pd.read_csv(LCA_CSV, index_col=0)
    D_tree = torch.tensor(df_D.values, dtype=torch.float32)

print("LCA distance matrix:", D_tree.shape)

# save logs
metrics_path = os.path.join(out_dir, "metrics.jsonl")
logger = MetricsLogger(metrics_path)

# -----------------------------------
# Taxonomy sizes: T_real vs T_base
# -----------------------------------
# ORIGINAL vocab (no UNKs), for tree regularizer and T_real
with open(TAXONOMY_VOCAB_PATH, "r") as f:
    tax_vocab_real = json.load(f)
T_real = len(tax_vocab_real)
print("T_real (original taxa, just k UNK):", T_real)

# UNK-extended vocab & descendant matrix for hierarchical loss
with open(UNK_VOCAB_PATH, "r") as f:
    tax_vocab_unk = json.load(f)
T_base = len(tax_vocab_unk)
print("T_base (real + UNK):", T_base)

# Load UNK-extended descendant matrix (used by hierarchical loss)
D_np_unk = np.load(UNK_M_PATH)          # [T_base, T_base]
M_tensor = torch.from_numpy(D_np_unk)   # keep name M_tensor for the loss
print("Descendant matrix with UNK:", M_tensor.shape)

# Optional: real-only descendant matrix (if needed for debugging)
D_np_real = np.load(DESCENDANT_MATRIX_PATH)        # [T_real, T_real]
descendant_matrix_real = torch.from_numpy(D_np_real)
print("Descendant matrix (real only):", descendant_matrix_real.shape)

# Optional: rank_idx (often used by loss/metrics/collator)
rank_idx = np.load(RANK_IDX_PATH)  # shape [T_base]
print("rank_idx:", rank_idx.shape)


LCA distance matrix: torch.Size([6929, 6929])
T_real (original taxa, just k UNK): 6929
T_base (real + UNK): 6935
Descendant matrix with UNK: torch.Size([6935, 6935])
Descendant matrix (real only): torch.Size([6929, 6929])
rank_idx: (6935,)


In [3]:
# import json
# import numpy as np
# from tqdm import tqdm

# def compute_length_percentiles(samples_jsonl_path, percentiles=(50, 75, 90, 95, 99, 99.9)):
#     lengths = []
#     with open(samples_jsonl_path, "r") as f:
#         for line in tqdm(f, desc="Reading sample lengths"):
#             if not line.strip():
#                 continue
#             rec = json.loads(line)
#             lengths.append(len(rec["otus"]))
#     lengths = np.asarray(lengths, dtype=np.int32)

#     out = {p: int(np.percentile(lengths, p)) for p in percentiles}
#     return out, lengths

# pct, lengths = compute_length_percentiles(SAMPLES_JSONL, percentiles=(90, 95, 99))
# print("Length percentiles:", pct)
# print("Max length:", int(lengths.max()))


In [4]:
# -----------------------------------
# Training hyperparameters
# (experiment-specific configuration)
# -----------------------------------
seed = 123
device = "cuda" if torch.cuda.is_available() else "cpu"

epochs = 40
train_batch_size = int(32 * 2)
val_batch_size   = 64

max_len = 500                    # 96th percentile sequence length
mlm_prob = 0.15                  # masking rate
prob_joint, prob_otu_only, prob_tax_only = 0.50, 0.25, 0.25
keep_prob, random_prob = 0.10, 0.10


lr = 1e-3
weight_decay = 1e-3
warmup_ratio = 0.06
max_grad_norm = 1.0
grad_accum_steps = 2

TREE_LAMBDA = 10                 # tree regularization weight
num_workers = 0                  # 0 for determinism


In [5]:
# -----------------------------------
# Dataset & random split (20k TEST, 10k VAL, rest TRAIN)
# -----------------------------------
set_seed(seed)
random.seed(seed)

# 1) Load full dataset once
ds = OTUTaxaDataset(dataset_dir)
N = len(ds)
print(f"[INFO] Dataset size: N={N}")

# 2) Choose split sizes (cap if dataset is smaller)
TEST_N = min(20_000, N)
VAL_N  = min(10_000, N - TEST_N)

# 3) Random permutation of indices
all_idx = list(range(N))
random.shuffle(all_idx)

test_idx  = sorted(all_idx[:TEST_N])
val_idx   = sorted(all_idx[TEST_N:TEST_N + VAL_N])
train_idx = sorted(all_idx[TEST_N + VAL_N:])

print(f"[SPLIT] Train={len(train_idx)}  Val={len(val_idx)}  Test={len(test_idx)}  (Total N={N})")

# 4) Subsets
train_ds = Subset(ds, train_idx)
val_ds   = Subset(ds, val_idx)
test_ds  = Subset(ds, test_idx)


[INFO] Dataset size: N=1836250
[SPLIT] Train=1806250  Val=10000  Test=20000  (Total N=1836250)


In [6]:
# -----------------------------------
# Collators & loaders
# -----------------------------------

# TRAIN collator: stochastic masking + BOTH balancing strategies
train_cfg = MaskingConfig(
    mlm_prob=mlm_prob,
    prob_joint=prob_joint,
    prob_otu_only=prob_otu_only,
    prob_tax_only=prob_tax_only,
    max_len=max_len,
    keep_prob=keep_prob,
    random_prob=random_prob,
    balance_mode="otu",
)

train_collate = make_collator_balanced(
    dataset=ds,
    cfg=train_cfg,
)

# VAL / TEST collator:
# Recommended: keep masking same, but DISABLE balancing (so evaluation is not distribution-shaped).
val_cfg = MaskingConfig(
    mlm_prob=mlm_prob,
    prob_joint=prob_joint,
    prob_otu_only=prob_otu_only,
    prob_tax_only=prob_tax_only,
    max_len=max_len,
    keep_prob=keep_prob,
    random_prob=random_prob,
    balance_mode="none",    # important: no endpoint re-selection for eval
)

val_collate = make_collator_balanced(
    dataset=ds,
    cfg=val_cfg,
)

test_collate = make_collator_balanced(
    dataset=ds,
    cfg=val_cfg,
)

train_loader = DataLoader(
    train_ds,
    batch_size=train_batch_size,
    shuffle=True,
    collate_fn=train_collate,
    num_workers=num_workers,
    pin_memory=True,
)

val_loader = DataLoader(
    val_ds,
    batch_size=val_batch_size,
    shuffle=False,
    collate_fn=val_collate,
    num_workers=0,
    pin_memory=True,
)

test_loader = DataLoader(
    test_ds,
    batch_size=val_batch_size,
    shuffle=False,
    collate_fn=test_collate,
    num_workers=0,
    pin_memory=True,
)
# -----------------------------------
# Model sizes & PAD ids (single place)
# -----------------------------------

# OTUs: unchanged
n_otus = ds.O + 2            # + pad, mask
pad_otu_id = ds.O

# TAXA (UNK-aware):
# - T_real = original number of taxa (from taxonomy_vocab.json)
# - T_base = len(tax_vocab_unk) = T_real + 7 (real + UNKs)
# - n_taxa = T_base + 2 (PAD, MASK at the end)
n_taxa = T_base + 2
pad_tax_id  = T_base         # PAD token index
mask_tax_id = T_base + 1     # MASK token index


# Load model

In [7]:
# -----------------------------
# Build model config
# -----------------------------
model_cfg = ModelConfig(
    d_model=256, n_layers=6, n_heads=4, d_ff=1024,
    dropout=0.1, activation="gelu",
    tie_otu_weights=True,
    otu_loss_weight=1.0,
    tax_loss_weight=1.0,
    #lambda_tree=TREE_LAMBDA,
    emb_dropout=0.1,
    layernorm_emb=True,
    #lca_csv=LCA_CSV,          # LCA over ORIGINAL taxonomy (T_real x T_real)
    T_real=T_real,            # <- NEW: number of original taxa for tree regularizer
)

# -----------------------------
# Hierarchical taxonomy loss (UNK-aware)
# -----------------------------
# Here we use the UNK-extended vocab and M_tensor that we loaded earlier:
#   - tax_vocab_unk: list[str], len = T_base = T_real + 7
#   - M_tensor: [T_base, T_base] descendant-closure with UNKs

def build_rank_idx_from_vocab(vocab_list):
    """
    Returns a LongTensor [len(vocab_list)] with rank indices 0..6 for k..s.
    Assumes tokens start with 'k:', 'p:', ..., 's:' (including UNK ones).
    """
    rank_map = {'k': 0, 'p': 1, 'c': 2, 'o': 3, 'f': 4, 'g': 5, 's': 6}
    out = []
    for name in vocab_list:
        ch = name[0].lower()
        out.append(rank_map.get(ch, -1))
    return torch.tensor(out, dtype=torch.long)

rank_idx = build_rank_idx_from_vocab(tax_vocab_unk)   # [T_base]

# build the UNK-aware hierarchical loss callable
hier_tax_loss_fn = make_factorized_tax_loss_fn_fast_masked_with_unk(
    M_tensor=M_tensor,          # [T_base, T_base] with UNKs
    rank_idx=rank_idx,          # [T_base]
    tax_vocab=tax_vocab_unk,    # len = T_base
    T_base=T_base,              # real + 7 UNKs
)

# -----------------------------
# Init model (UNK-aware)
# -----------------------------
model = OTUTaxaTransformerEmbedTaxTreeUnkTaxa(
    n_otus=n_otus,
    n_taxa=n_taxa,              # = T_base + 2 (PAD, MASK)
    pad_otu_id=pad_otu_id,
    pad_tax_id=pad_tax_id,
    config=model_cfg,
    tax_loss_fn=hier_tax_loss_fn,
).to(device)

# -----------------------------
# Optimizer & Scheduler
# -----------------------------
optimizer = AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)

steps_per_epoch = math.ceil(len(train_loader) / max(1, grad_accum_steps))
total_steps = steps_per_epoch * epochs
warmup_steps = max(1, int(warmup_ratio * total_steps))

def lr_lambda(current_step):
    if current_step < warmup_steps:
        return float(current_step) / float(max(1, warmup_steps))
    return max(0.0, float(total_steps - current_step) / float(max(1, total_steps - warmup_steps)))

scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)

# AMP scaler
scaler = GradScaler(enabled=(device == "cuda"))

# -----------------------------
# Save run config
# -----------------------------
run_meta = {
    "seed": seed,
    "epochs": epochs,
    "train_batch_size": train_batch_size,
    "val_batch_size": val_batch_size,
    "max_len": max_len,
    "mlm_prob": mlm_prob,
    "prob_split": [prob_joint, prob_otu_only, prob_tax_only],
    "optimizer": {"lr": lr, "weight_decay": weight_decay},
    "sched": {
        "warmup_ratio": warmup_ratio,
        "total_steps": total_steps,
        "warmup_steps": warmup_steps,
    },
    "model_cfg": asdict(model_cfg),
    "dataset": {
        "N": N,
        "N_train": len(train_ds),
        "N_val": len(val_ds),
        "O": ds.O,
        # Keep both for clarity
        "T_real": T_real,       # original taxa (no UNK)
        "T_base": T_base,       # real + UNK (used by loss)
    },
}
with open(os.path.join(out_dir, "meta.json"), "w") as f:
    json.dump(run_meta, f, indent=2)

print(
    f"[INFO] Train={len(train_ds)}  Val={len(val_ds)}  "
    f"Steps/epoch={steps_per_epoch}  Total steps={total_steps}"
)


UNK ids per rank: [6928, 6929, 6930, 6931, 6932, 6933, 6934]




[INFO] Train=1806250  Val=10000  Steps/epoch=14112  Total steps=564480


# training model

In [8]:
rank_map = {'k':0, 'p':1, 'c':2, 'o':3, 'f':4, 'g':5, 's':6}
unk_ids_by_rank = {}

for idx, token in enumerate(tax_vocab_unk):
    # Expect tokens like "k:UNK", "p:UNK", ..., "s:UNK"
    # Adjust this condition if your naming is slightly different.
    if token.endswith("UNK") or token.endswith("unk"):
        prefix = token.split(":")[0]  # 'k','p',...,'s'
        if prefix in rank_map:
            r = rank_map[prefix]
            unk_ids_by_rank[r] = idx

print("UNK IDs per rank:", unk_ids_by_rank)
print("T_base:", T_base)

UNK IDs per rank: {0: 6928, 1: 6929, 2: 6930, 3: 6931, 4: 6932, 5: 6933, 6: 6934}
T_base: 6935


In [9]:
best_val = float("inf")
global_step = 0

for epoch in range(1, epochs + 1):

    # ---- TRAIN ----
    train_stats, global_step = run_epoch(
        model=model,
        dataloader=train_loader,
        device=device,
        IGNORE_INDEX=IGNORE_INDEX,
        split="train",
        epoch=epoch,
        global_step=global_step,
        M_tensor=M_tensor,          # UNK-extended [T_base, T_base]
        rank_idx=rank_idx,          # built from tax_vocab_unk
        T_base=T_base,              # <- IMPORTANT: UNK-aware T_base, not ds.T
        unk_ids_by_rank=unk_ids_by_rank, 
        optimizer=optimizer,
        scheduler=scheduler,
        scaler=scaler,
        grad_accum_steps=grad_accum_steps,
        max_grad_norm=max_grad_norm,
        logger=logger,
        log_every=100,
        deterministic_masks=False,
        compute_train_metrics=True,
    )

    # ---- VALIDATION ----
    val_stats, _ = run_epoch(
        model=model,
        dataloader=val_loader,
        device=device,
        IGNORE_INDEX=IGNORE_INDEX,
        split="val",
        epoch=epoch,
        global_step=global_step,
        M_tensor=M_tensor,
        rank_idx=rank_idx,
        T_base=T_base,              # <- same T_base
        unk_ids_by_rank=unk_ids_by_rank,
        optimizer=None,
        scheduler=None,
        scaler=None,
        grad_accum_steps=1,
        max_grad_norm=max_grad_norm,
        logger=logger,
        deterministic_masks=True,
    )

    # save best checkpoint (still using val loss as criterion)
    if val_stats["loss"] < best_val:
        best_val = val_stats["loss"]
        save_checkpoint(
            os.path.join(out_dir, "best.pt"),
            model, optimizer, scheduler, scaler,
            epoch, global_step, best_val,
        )
        print(f"[E{epoch:02d}] âœ… Saved BEST")

    # ---- TEST ----
    test_stats, _ = run_epoch(
        model=model,
        dataloader=test_loader,
        device=device,
        IGNORE_INDEX=IGNORE_INDEX,
        split="test",
        epoch=epoch,
        global_step=global_step,
        M_tensor=M_tensor,
        rank_idx=rank_idx,
        T_base=T_base,              # <- same T_base
        unk_ids_by_rank=unk_ids_by_rank,
        optimizer=None,
        scheduler=None,
        scaler=None,
        grad_accum_steps=1,
        max_grad_norm=max_grad_norm,
        logger=logger,
        deterministic_masks=True,
    )

# FINAL CHECKPOINT
save_checkpoint(
    os.path.join(out_dir, "last.pt"),
    model, optimizer, scheduler, scaler,
    epoch, global_step, best_val,
)


[OTUTaxaTransformerEmbedTaxTree] using otu_loss_fn=default_otu_loss_fn (id=128442635710320) kwargs=['attention_mask', 'ignore_index']
[OTUTaxaTransformerEmbedTaxTree] using tax_loss_fn=loss_fn (id=128442636151104) kwargs=['ignore_index']
[OTUTaxaTransformerEmbedTaxTree] using combine_loss_fn=default_combine_loss_fn (id=128442635710032) kwargs=[]




[E01] TRAIN progress: batch 500  global_step=250
[E01] TRAIN progress: batch 1000  global_step=500
[E01] TRAIN progress: batch 1500  global_step=750
[E01] TRAIN step=1000 probe_acc_deep=0.108 probe_f1_deep=0.055 (used_pos=2772)
[E01] TRAIN progress: batch 2000  global_step=1000
[E01] TRAIN progress: batch 2500  global_step=1250
[E01] TRAIN progress: batch 3000  global_step=1500
[E01] TRAIN progress: batch 3500  global_step=1750
[E01] TRAIN step=2000 probe_acc_deep=0.235 probe_f1_deep=0.148 (used_pos=2812)
[E01] TRAIN progress: batch 4000  global_step=2000
[E01] TRAIN progress: batch 4500  global_step=2250
[E01] TRAIN progress: batch 5000  global_step=2500
[E01] TRAIN progress: batch 5500  global_step=2750
[E01] TRAIN step=3000 probe_acc_deep=0.384 probe_f1_deep=0.275 (used_pos=2998)
[E01] TRAIN progress: batch 6000  global_step=3000
[E01] TRAIN progress: batch 6500  global_step=3250


KeyboardInterrupt: 