# setup

In [2]:
import os, sys, json, math, random
from dataclasses import asdict

import numpy as np
import pandas as pd
import torch
from torch.utils.data import DataLoader, Subset
from torch.optim import AdamW
from torch.cuda.amp import GradScaler

# --- project root & local modules ---
PROJ_ROOT = "/home/hernan_melmoth/Documents/phd_work/otu-taxa-foundation"
sys.path.append(os.path.join(PROJ_ROOT, "src/otu_taxa"))

from otu_taxa.helpers_pretraining_model import MetricsLogger, set_seed, save_checkpoint, IGNORE_INDEX
from otu_taxa.trainer_hier_joint_unk import run_epoch

from otu_taxa.joint_hier_loss_metrics_unk import make_factorized_tax_loss_fn_fast_masked_with_unk
from otu_taxa.otu_taxa_transformer_unk import ModelConfig, OTUTaxaTransformerEmbedTaxTreeUnkTaxa
from otu_taxa.dataloaders_unk_balanced import (
    OTUTaxaDataset,
    MaskingConfig,
    make_collator_balanced,
    build_tax2ancestor_at_rank,
)


# data

In [None]:
# -----------------------------------
# Paths & config (single source of truth)
# -----------------------------------
PROJ_ROOT = "/home/hernan_melmoth/Documents/phd_work/otu-taxa-foundation"

# External (heavy) dataset location (outside repo)
DATASET_ROOT = "/home/hernan_melmoth/Documents/phd_work/Microbeatlas_preprocess_training"
dataset_folder_name = "dataset_full_top999"   # <- adjust if you used a different name
dataset_dir = os.path.join(
    DATASET_ROOT,
    "level_97",
    "silva-138.2",
    "incomplete_silva_sintax",
    dataset_folder_name,
)

# Base dataset artifacts
TAXONOMY_VOCAB_PATH = os.path.join(dataset_dir, "taxonomy_vocab.json")
path_to_taxonomy_tree = os.path.join(dataset_dir, "taxonomy_nested.json")
SAMPLES_JSONL = os.path.join(dataset_dir, "samples.jsonl")

# Tree artifacts directory (derived)
TREE_DIR = os.path.join(dataset_dir, "tree_artifacts")

# LCA (prefer .npy in new pipeline; keep CSV fallback if you still have it)
LCA_NPY = os.path.join(TREE_DIR, "lca_distance_edges.npy")
LCA_CSV = os.path.join(TREE_DIR, "lca_distance_edges.csv")  # optional legacy

# Descendant matrices
DESCENDANT_MATRIX_PATH = os.path.join(TREE_DIR, "descendant_matrix.npy")              # real tree only
UNK_VOCAB_PATH         = os.path.join(TREE_DIR, "taxonomy_vocab_with_unk.json")       # real + 7 UNKs
UNK_M_PATH             = os.path.join(TREE_DIR, "descendant_matrix_with_unk.npy")     # real+UNK closure
RANK_IDX_PATH          = os.path.join(TREE_DIR, "rank_idx.npy")                       # rank per token (0..6)

# # Optional: test split ids (only if you created it for the full corpus)
# TEST_IDS_PATH = os.path.join(dataset_dir, "splits", "test_samples_2000.txt")  # adjust if needed

# Run / output dir (inside repo)
run_name = "pretrain_hier_joint_unk_taxa"
out_dir = os.path.join(PROJ_ROOT, "runs_hier_joint_unk_taxa", run_name)
os.makedirs(out_dir, exist_ok=True)

# -----------------------------------
# Load LCA distance matrix
# -----------------------------------
if os.path.exists(LCA_NPY):
    D_tree = torch.from_numpy(np.load(LCA_NPY)).float()
else:
    df_D = pd.read_csv(LCA_CSV, index_col=0)
    D_tree = torch.tensor(df_D.values, dtype=torch.float32)

print("LCA distance matrix:", D_tree.shape)

# save logs
metrics_path = os.path.join(out_dir, "metrics.jsonl")
logger = MetricsLogger(metrics_path)

# -----------------------------------
# Taxonomy sizes: T_real vs T_base
# -----------------------------------
# ORIGINAL vocab (no UNKs), for tree regularizer and T_real
with open(TAXONOMY_VOCAB_PATH, "r") as f:
    tax_vocab_real = json.load(f)
T_real = len(tax_vocab_real)
print("T_real (original taxa, no UNK):", T_real)

# UNK-extended vocab & descendant matrix for hierarchical loss
with open(UNK_VOCAB_PATH, "r") as f:
    tax_vocab_unk = json.load(f)
T_base = len(tax_vocab_unk)
print("T_base (real + UNK):", T_base)

# Load UNK-extended descendant matrix (used by hierarchical loss)
D_np_unk = np.load(UNK_M_PATH)          # [T_base, T_base]
M_tensor = torch.from_numpy(D_np_unk)   # keep name M_tensor for the loss
print("Descendant matrix with UNK:", M_tensor.shape)

# Optional: real-only descendant matrix (if needed for debugging)
D_np_real = np.load(DESCENDANT_MATRIX_PATH)        # [T_real, T_real]
descendant_matrix_real = torch.from_numpy(D_np_real)
print("Descendant matrix (real only):", descendant_matrix_real.shape)

# Optional: rank_idx (often used by loss/metrics/collator)
rank_idx = np.load(RANK_IDX_PATH)  # shape [T_base]
print("rank_idx:", rank_idx.shape)


LCA distance matrix: torch.Size([6928, 6928])
T_real (original taxa, no UNK): 6928
T_base (real + UNK): 6935
Descendant matrix with UNK: torch.Size([6935, 6935])
Descendant matrix (real only): torch.Size([6928, 6928])
rank_idx: (6935,)


In [None]:
# import json
# import numpy as np
# from tqdm import tqdm

# def compute_length_percentiles(samples_jsonl_path, percentiles=(50, 75, 90, 95, 99, 99.9)):
#     lengths = []
#     with open(samples_jsonl_path, "r") as f:
#         for line in tqdm(f, desc="Reading sample lengths"):
#             if not line.strip():
#                 continue
#             rec = json.loads(line)
#             lengths.append(len(rec["otus"]))
#     lengths = np.asarray(lengths, dtype=np.int32)

#     out = {p: int(np.percentile(lengths, p)) for p in percentiles}
#     return out, lengths

# pct, lengths = compute_length_percentiles(SAMPLES_JSONL, percentiles=(90, 95, 99))
# print("Length percentiles:", pct)
# print("Max length:", int(lengths.max()))


Reading sample lengths: 1836250it [00:28, 64048.54it/s]


Length percentiles: {90: 339, 95: 477, 99: 861}
Max length: 12546


In [5]:
# -----------------------------------
# Training hyperparameters
# (experiment-specific configuration)
# -----------------------------------
seed = 123
device = "cuda" if torch.cuda.is_available() else "cpu"

epochs = 40
train_batch_size = int(32 * 8)
val_batch_size   = 64

max_len = 500                    # 96th percentile sequence length
mlm_prob = 0.15                  # masking rate
prob_joint, prob_otu_only, prob_tax_only = 0.50, 0.25, 0.25
keep_prob, random_prob = 0.10, 0.10


lr = 1e-3
weight_decay = 1e-3
warmup_ratio = 0.06
max_grad_norm = 1.0
grad_accum_steps = 2

TREE_LAMBDA = 10                 # tree regularization weight
num_workers = 0                  # 0 for determinism


In [7]:
# -----------------------------------
# Dataset & random split (20k TEST, 10k VAL, rest TRAIN)
# -----------------------------------
set_seed(seed)
random.seed(seed)

# 1) Load full dataset once
ds = OTUTaxaDataset(dataset_dir)
N = len(ds)
print(f"[INFO] Dataset size: N={N}")

# 2) Choose split sizes (cap if dataset is smaller)
TEST_N = min(20_000, N)
VAL_N  = min(10_000, N - TEST_N)

# 3) Random permutation of indices
all_idx = list(range(N))
random.shuffle(all_idx)

test_idx  = sorted(all_idx[:TEST_N])
val_idx   = sorted(all_idx[TEST_N:TEST_N + VAL_N])
train_idx = sorted(all_idx[TEST_N + VAL_N:])

print(f"[SPLIT] Train={len(train_idx)}  Val={len(val_idx)}  Test={len(test_idx)}  (Total N={N})")

# 4) Subsets
train_ds = Subset(ds, train_idx)
val_ds   = Subset(ds, val_idx)
test_ds  = Subset(ds, test_idx)


[INFO] Dataset size: N=1836250
[SPLIT] Train=1806250  Val=10000  Test=20000  (Total N=1836250)


In [8]:
# -----------------------------------
# Collators & loaders
# -----------------------------------

# TRAIN collator: stochastic masking + BOTH balancing strategies
train_cfg = MaskingConfig(
    mlm_prob=mlm_prob,
    prob_joint=prob_joint,
    prob_otu_only=prob_otu_only,
    prob_tax_only=prob_tax_only,
    max_len=max_len,
    keep_prob=keep_prob,
    random_prob=random_prob,
    balance_mode="otu",
)

train_collate = make_collator_balanced(
    dataset=ds,
    cfg=train_cfg,
)

# VAL / TEST collator:
# Recommended: keep masking same, but DISABLE balancing (so evaluation is not distribution-shaped).
val_cfg = MaskingConfig(
    mlm_prob=mlm_prob,
    prob_joint=prob_joint,
    prob_otu_only=prob_otu_only,
    prob_tax_only=prob_tax_only,
    max_len=max_len,
    keep_prob=keep_prob,
    random_prob=random_prob,
    balance_mode="none",    # important: no endpoint re-selection for eval
)

val_collate = make_collator_balanced(
    dataset=ds,
    cfg=val_cfg,
)

test_collate = make_collator_balanced(
    dataset=ds,
    cfg=val_cfg,
)

train_loader = DataLoader(
    train_ds,
    batch_size=train_batch_size,
    shuffle=True,
    collate_fn=train_collate,
    num_workers=num_workers,
    pin_memory=True,
)

val_loader = DataLoader(
    val_ds,
    batch_size=val_batch_size,
    shuffle=False,
    collate_fn=val_collate,
    num_workers=0,
    pin_memory=True,
)

test_loader = DataLoader(
    test_ds,
    batch_size=val_batch_size,
    shuffle=False,
    collate_fn=test_collate,
    num_workers=0,
    pin_memory=True,
)
# -----------------------------------
# Model sizes & PAD ids (single place)
# -----------------------------------

# OTUs: unchanged
n_otus = ds.O + 2            # + pad, mask
pad_otu_id = ds.O

# TAXA (UNK-aware):
# - T_real = original number of taxa (from taxonomy_vocab.json)
# - T_base = len(tax_vocab_unk) = T_real + 7 (real + UNKs)
# - n_taxa = T_base + 2 (PAD, MASK at the end)
n_taxa = T_base + 2
pad_tax_id  = T_base         # PAD token index
mask_tax_id = T_base + 1     # MASK token index


In [10]:
# Take ONE batch from the train loader
batch = next(iter(train_loader))

print("\n=== BATCH INSPECTION ===")
for k, v in batch.items():
    if torch.is_tensor(v):
        print(k, v.shape, v.dtype)
    else:
        print(k, type(v))

# Look at first sample in the batch
b = 0
print("\n--- First sample in batch ---")
print("input_otus[:20]:", batch["input_otus"][b, :20])
print("input_taxa[:20]:", batch["input_taxa"][b, :20])
print("labels_otu[:20]:", batch["labels_otu"][b, :20])
print("labels_taxa[:20]:", batch["labels_taxa"][b, :20])
print("attention_mask[:20]:", batch["attention_mask"][b, :20])



=== BATCH INSPECTION ===
sample_id <class 'list'>
input_otus torch.Size([256, 500]) torch.int64
input_taxa torch.Size([256, 500]) torch.int64
labels_otu torch.Size([256, 500]) torch.int64
labels_taxa torch.Size([256, 500]) torch.int64
attention_mask torch.Size([256, 500]) torch.int64
lengths torch.Size([256]) torch.int64
special_ids <class 'dict'>

--- First sample in batch ---
input_otus[:20]: tensor([  700,   408,   255, 62201,   563,  8336,   562,  3594, 14507,   303,
         1211, 19666, 62201,  1176,   170,   559, 62201,  5342,  1104,    65])
input_taxa[:20]: tensor([2041, 2343, 2160,  393, 3158, 3551, 2343, 1197, 3012, 2343, 2343,  328,
        2718, 3012, 3116,  371,  679,  508, 3037, 2400])
labels_otu[:20]: tensor([ -100,  -100,  -100,  5380,  -100,  9488,  -100,  -100,  -100,  -100,
         -100,  -100,  3924,  -100,  -100,  -100, 17936,  -100,  -100,  -100])
labels_taxa[:20]: tensor([-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
        -100, -100

In [11]:
def check_neg_ones(name, tensor):
    if not torch.is_tensor(tensor):
        return
    mask = (tensor == -1)
    if mask.any():
        idx = torch.nonzero(mask, as_tuple=False)
        print(f"[WARNING] {name} contains -1 at {idx[:10].tolist()}")
    else:
        print(f"[OK] {name} contains no -1")

print("\n=== CHECK FOR -1 ===")
check_neg_ones("input_taxa", batch["input_taxa"])
check_neg_ones("labels_taxa", batch["labels_taxa"])



=== CHECK FOR -1 ===


# Load model