# setup

In [2]:
import os, sys, json, random
from collections import Counter
import numpy as np
import torch
from torch.utils.data import Subset


# --- project root & local modules (same pattern as your notebook) ---
PROJ_ROOT = "/home/hernan_melmoth/Documents/phd_work/otu-taxa-foundation"
sys.path.append(os.path.join(PROJ_ROOT, "src"))

# Dataset root (same pattern as your notebook)
DATASET_ROOT = "/home/hernan_melmoth/Documents/phd_work/Microbeatlas_preprocess_training"
dataset_folder_name = "dataset_full_top999"

dataset_dir = os.path.join(
    DATASET_ROOT,
    "level_97",
    "silva-138.2",
    "incomplete_silva_sintax",
    dataset_folder_name,
)

from otu_taxa.dataloaders_unk_balance_ranks import (
    OTUTaxaDataset,
    MaskingConfig,
    make_collator_balanced_rank,
    build_tax2ancestor_at_ranks,
)



In [3]:
seed = 123

def set_seed(seed: int):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

set_seed(seed)

ds = OTUTaxaDataset(dataset_dir)
N = len(ds)
print(f"[INFO] Dataset size: N={N}")

TEST_N = min(20_000, N)
VAL_N  = min(20_000, N - TEST_N)

all_idx = list(range(N))
random.shuffle(all_idx)

test_idx  = sorted(all_idx[:TEST_N])
val_idx   = sorted(all_idx[TEST_N:TEST_N + VAL_N])
train_idx = sorted(all_idx[TEST_N + VAL_N:])

print(f"[SPLIT] Train={len(train_idx)}  Val={len(val_idx)}  Test={len(test_idx)}  (Total N={N})")

train_ds = Subset(ds, train_idx)
val_ds   = Subset(ds, val_idx)
test_ds  = Subset(ds, test_idx)


[INFO] Dataset size: N=1836250
[SPLIT] Train=1796250  Val=20000  Test=20000  (Total N=1836250)


# Collect OTUs per split

In [None]:
def collect_otus(subset):
    otu_set = set()
    otu_counts = Counter()
    total_positions = 0

    for rec in subset:
        otus = rec["otus"]          
        otu_set.update(otus)
        otu_counts.update(otus)
        total_positions += len(otus)

    return otu_set, otu_counts, total_positions


In [6]:
train_otus, train_counts, train_pos = collect_otus(train_ds)
test_otus,  test_counts,  test_pos  = collect_otus(test_ds)

print("Unique OTUs in train:", len(train_otus))
print("Unique OTUs in test :", len(test_otus))


Unique OTUs in train: 62200
Unique OTUs in test : 59446


# zero shot? searching unique OTUs in TEST that are not in train

In [7]:
unseen_test_otus = test_otus - train_otus

print("Unseen OTUs in test (not in train):", len(unseen_test_otus))
print("Fraction of test OTUs unseen:",
      len(unseen_test_otus) / len(test_otus) if test_otus else 0.0)


Unseen OTUs in test (not in train): 0
Fraction of test OTUs unseen: 0.0


# split groups in Train, support and query

Zero-shot OTU generalization
Question: Can the model infer taxonomy for an OTU it has never seen before? 

* Target OTUs are completely absent during pre training
* OTUs appear only at test time
* Taxonomy is masked at evaluation
* The model must infer taxonomy using:
* co-occurrence context
* taxonomy structure learned from other OTUs


In [3]:
import os, sys, json, random
from pathlib import Path
from collections import Counter, defaultdict

import numpy as np
import pandas as pd
from torch.utils.data import Subset

# --- repo src (adjust if needed) ---
PROJ_ROOT = Path("/home/hernan_melmoth/Documents/phd_work/otu-taxa-foundation")
SRC_DIR = PROJ_ROOT / "src"
if str(SRC_DIR) not in sys.path:
    sys.path.append(str(SRC_DIR))

from otu_taxa.dataloaders_unk_balanced import OTUTaxaDataset

# --- dataset path (EDIT ONLY THIS) ---
DATASET_ROOT = "/home/hernan_melmoth/Documents/phd_work/Microbeatlas_preprocess_training"
dataset_folder_name = "dataset_full_top999"

dataset_dir = os.path.join(
    DATASET_ROOT,
    "level_97",
    "silva-138.2",
    "incomplete_silva_sintax",
    dataset_folder_name,
)

ds = OTUTaxaDataset(dataset_dir)
N = len(ds)
print("Dataset size N =", N)
print("OTU vocab size O =", ds.O)


Dataset size N = 1836250
OTU vocab size O = 62200


In [4]:
def build_otu_to_sample_index(ds, max_samples=None, seed=123):
    """
    Returns:
      otu_to_samples: dict[int, set[int]] mapping OTU -> set of sample indices where it appears
      sample_sizes: np.array length N with number of OTUs per sample (optional utility)
    """
    N = len(ds)
    n = N if max_samples is None else min(N, max_samples)

    otu_to_samples = defaultdict(set)
    sample_sizes = np.zeros(n, dtype=np.int32)

    for i in range(n):
        rec = ds[i]
        otus = rec["otus"]
        sample_sizes[i] = len(otus)
        for o in otus:
            otu_to_samples[int(o)].add(i)

        if (i + 1) % 20000 == 0:
            print(f"  indexed {i+1}/{n} samples...")

    return otu_to_samples, sample_sizes

# Build full index (can be heavy). If too slow, set max_samples=200_000 for a pilot.
otu_to_samples, sample_sizes = build_otu_to_sample_index(ds, max_samples=None)

print("Indexed OTUs:", len(otu_to_samples))
print("Indexed samples:", len(sample_sizes))


  indexed 20000/1836250 samples...
  indexed 40000/1836250 samples...
  indexed 60000/1836250 samples...
  indexed 80000/1836250 samples...
  indexed 100000/1836250 samples...
  indexed 120000/1836250 samples...
  indexed 140000/1836250 samples...
  indexed 160000/1836250 samples...
  indexed 180000/1836250 samples...
  indexed 200000/1836250 samples...
  indexed 220000/1836250 samples...
  indexed 240000/1836250 samples...
  indexed 260000/1836250 samples...
  indexed 280000/1836250 samples...
  indexed 300000/1836250 samples...
  indexed 320000/1836250 samples...
  indexed 340000/1836250 samples...
  indexed 360000/1836250 samples...
  indexed 380000/1836250 samples...
  indexed 400000/1836250 samples...
  indexed 420000/1836250 samples...
  indexed 440000/1836250 samples...
  indexed 460000/1836250 samples...
  indexed 480000/1836250 samples...
  indexed 500000/1836250 samples...
  indexed 520000/1836250 samples...
  indexed 540000/1836250 samples...
  indexed 560000/1836250 samples

In [6]:


def analyze_holdout_set_for_fewshot(
    otu_to_samples,
    heldout_otus,
    N_samples,
    *,
    k_shot: int = 5,
    support_frac: float = 0.10,
):
    """
    Reports:
      - pretrain_pool: samples with NO heldout OTUs (usable for retraining)
      - heldout_pool: samples with >=1 heldout OTU (usable to build support/query)
      - few-shot feasibility proxies:
          * max_supported_otus_at_k: upper bound = sum_o floor(count_o / k)
          * otus_with_at_least_k: number of heldout OTUs that appear in >=k samples
      - sample-level support/query sizes if you split heldout_pool by support_frac
    """
    heldout_otus = list(map(int, heldout_otus))

    # sample indices that contain >=1 heldout OTU
    heldout_samples = set()
    per_otu_counts = []
    otus_with_at_least_k = 0
    max_supported_otus_at_k = 0

    for o in heldout_otus:
        sset = otu_to_samples.get(o, set())
        c = len(sset)
        per_otu_counts.append(c)
        heldout_samples |= sset
        if c >= k_shot:
            otus_with_at_least_k += 1
        max_supported_otus_at_k += (c // k_shot)

    heldout_pool_n = len(heldout_samples)
    pretrain_pool_n = N_samples - heldout_pool_n

    # sample-level split of heldout pool into support/query
    support_n = int(round(heldout_pool_n * support_frac))
    query_n = heldout_pool_n - support_n

    # useful summary stats of OTU occurrence counts
    per_otu_counts = np.array(per_otu_counts, dtype=np.int64)
    if len(per_otu_counts) > 0:
        c_min = int(per_otu_counts.min())
        c_med = int(np.median(per_otu_counts))
        c_max = int(per_otu_counts.max())
    else:
        c_min = c_med = c_max = 0

    return {
        "K": len(heldout_otus),
        "k_shot": int(k_shot),
        "support_frac": float(support_frac),

        # pools
        "pretrain_pool": int(pretrain_pool_n),
        "pretrain_pool_frac": pretrain_pool_n / max(1, N_samples),
        "heldout_pool": int(heldout_pool_n),
        "heldout_pool_frac": heldout_pool_n / max(1, N_samples),

        # sample split within heldout pool
        "support_pool_samples": int(support_n),
        "query_pool_samples": int(query_n),

        # few-shot feasibility proxies
        "otus_with_at_least_k": int(otus_with_at_least_k),
        "max_supported_otus_at_k": int(max_supported_otus_at_k),

        # OTU occurrence stats (within heldout pool)
        "otu_count_min": c_min,
        "otu_count_median": c_med,
        "otu_count_max": c_max,
    }


def run_trials_random_K_fewshot(
    otu_to_samples,
    N_samples,
    K_list,
    *,
    k_shot: int = 5,
    support_frac: float = 0.10,
    n_trials: int = 20,
    seed: int = 123,
    pool=None
):
    rng = random.Random(seed)
    if pool is None:
        pool = sorted(list(otu_to_samples.keys()))
    else:
        pool = sorted(list(pool))

    rows = []
    for K in map(int, K_list):
        for t in range(n_trials):
            heldout = rng.sample(pool, k=min(K, len(pool)))
            stats = analyze_holdout_set_for_fewshot(
                otu_to_samples,
                heldout,
                N_samples,
                k_shot=k_shot,
                support_frac=support_frac,
            )
            stats["trial"] = t
            rows.append(stats)
    return pd.DataFrame(rows)


# -----------------------------
# Example usage
# -----------------------------
K_list = [10, 20, 50, 100]
k_shot = 5          # try 2 / 5 / 10
support_frac = 0.10 # or 0.05

df = run_trials_random_K_fewshot(
    otu_to_samples,
    N_samples=N,
    K_list=K_list,
    k_shot=k_shot,
    support_frac=support_frac,
    n_trials=20,
    seed=123,
)

summary = df.groupby("K").agg(
    pretrain_pool_mean=("pretrain_pool", "mean"),
    pretrain_pool_frac_mean=("pretrain_pool_frac", "mean"),
    heldout_pool_mean=("heldout_pool", "mean"),
    heldout_pool_frac_mean=("heldout_pool_frac", "mean"),

    support_pool_samples_mean=("support_pool_samples", "mean"),
    query_pool_samples_mean=("query_pool_samples", "mean"),

    otus_with_at_least_k_mean=("otus_with_at_least_k", "mean"),
    max_supported_otus_at_k_mean=("max_supported_otus_at_k", "mean"),

    otu_count_min_mean=("otu_count_min", "mean"),
    otu_count_median_mean=("otu_count_median", "mean"),
    otu_count_max_mean=("otu_count_max", "mean"),
).reset_index()

summary


Unnamed: 0,K,pretrain_pool_mean,pretrain_pool_frac_mean,heldout_pool_mean,heldout_pool_frac_mean,support_pool_samples_mean,query_pool_samples_mean,otus_with_at_least_k_mean,max_supported_otus_at_k_mean,otu_count_min_mean,otu_count_median_mean,otu_count_max_mean
0,10,1777482.95,0.967996,58767.05,0.032004,5876.7,52890.35,10.0,12125.25,199.8,1570.85,37844.05
1,20,1761326.3,0.959197,74923.7,0.040803,7492.2,67431.5,20.0,15816.3,99.05,1458.55,33574.35
2,50,1641699.5,0.89405,194550.5,0.10595,19455.05,175095.45,50.0,44292.1,58.8,1191.55,62955.1
3,100,1484067.65,0.808206,352182.35,0.191794,35218.3,316964.05,100.0,88666.75,38.25,1144.0,95135.75


# chossing 50 (20 trials) OTUs to k-shot: test: 1641699,  heldout_pool = 194550.50



In [11]:


# -----------------------------
# 0) Fix heldout OTUs (K=50)
# -----------------------------
K_OTUS = 50
SEED_OTUS = 123

otu_pool = sorted(list(otu_to_samples.keys()))
rng = random.Random(SEED_OTUS)
heldout_otus_50 = set(rng.sample(otu_pool, k=K_OTUS))

print(f"[HELDOUT OTUs] K={len(heldout_otus_50)} seed={SEED_OTUS}")
print("preview:", sorted(list(heldout_otus_50))[:10])


# -----------------------------
# 1) Build heldout sample pool + pretrain pool sizes
# -----------------------------
def heldout_sample_pool(otu_to_samples, heldout_otus):
    s = set()
    for o in heldout_otus:
        s |= otu_to_samples.get(int(o), set())
    return s

heldout_samples = heldout_sample_pool(otu_to_samples, heldout_otus_50)
pretrain_samples = set(range(N)) - heldout_samples

print(f"[POOLS] pretrain_samples (no heldout OTUs): {len(pretrain_samples)} ({len(pretrain_samples)/N:.4f})")
print(f"[POOLS] heldout_samples (>=1 heldout OTU): {len(heldout_samples)} ({len(heldout_samples)/N:.4f})")


# -----------------------------
# 2) Make a fixed query set (by size), and a support pool
# -----------------------------
def make_fixed_query_split(heldout_samples, query_size, seed=123):
    rng = random.Random(seed)
    pool = list(heldout_samples)
    rng.shuffle(pool)
    q = set(pool[:min(query_size, len(pool))])
    p = set(pool[min(query_size, len(pool)):])
    return q, p


# -----------------------------
# 3) For a given k, sample k support examples per OTU from the support pool
#     (query stays fixed; support comes only from support_pool)
# -----------------------------
def sample_support_k_per_otu(
    otu_to_samples,
    heldout_otus,
    support_pool,
    k_shot,
    seed=123,
):
    rng = random.Random(seed)

    support_idx = set()
    feasible_otus = 0
    per_otu_chosen = {}  # optional: for debugging / reproducibility

    support_pool = set(support_pool)

    for o in sorted(map(int, heldout_otus)):
        candidates = list(otu_to_samples.get(o, set()) & support_pool)
        if len(candidates) < k_shot:
            continue
        rng.shuffle(candidates)
        chosen = candidates[:k_shot]
        per_otu_chosen[o] = chosen
        support_idx |= set(chosen)
        feasible_otus += 1

    return support_idx, feasible_otus, per_otu_chosen


# -----------------------------
# 4) Explore multiple query sizes, but keep query fixed per size
#    Then within each query size, compare k-shot supports
# -----------------------------
QUERY_SIZES = [20_000, 50_000, 100_000]   # change as you like
K_SHOTS = [5, 10, 20, 50, 100]

SEED_QUERY = 999   # ensures query split reproducible
SEED_SUPPORT = 2024

rows = []
for qsize in QUERY_SIZES:
    query_set, support_pool = make_fixed_query_split(heldout_samples, query_size=qsize, seed=SEED_QUERY)

    for k in K_SHOTS:
        support_set, feasible_otus, _ = sample_support_k_per_otu(
            otu_to_samples,
            heldout_otus_50,
            support_pool,
            k_shot=k,
            seed=SEED_SUPPORT,
        )

        rows.append({
            "query_size": int(len(query_set)),          # exact (in case pool smaller)
            "support_pool_size": int(len(support_pool)),
            "k_shot": int(k),
            "feasible_otus": int(feasible_otus),        # OTUs with >=k candidates in support_pool
            "support_unique_samples": int(len(support_set)),
            "support_per_otu_total": int(feasible_otus * k),  # requested examples (not unique)
        })

df_support = pd.DataFrame(rows).sort_values(["query_size", "k_shot"]).reset_index(drop=True)
df_support


[HELDOUT OTUs] K=50 seed=123
preview: [109, 432, 1391, 2500, 2876, 3401, 3431, 4579, 5713, 5741]
[POOLS] pretrain_samples (no heldout OTUs): 1561830 (0.8506)
[POOLS] heldout_samples (>=1 heldout OTU): 274420 (0.1494)


Unnamed: 0,query_size,support_pool_size,k_shot,feasible_otus,support_unique_samples,support_per_otu_total
0,20000,254420,5,50,250,250
1,20000,254420,10,50,499,500
2,20000,254420,20,50,997,1000
3,20000,254420,50,50,2492,2500
4,20000,254420,100,48,4768,4800
5,50000,224420,5,50,250,250
6,50000,224420,10,50,500,500
7,50000,224420,20,50,1000,1000
8,50000,224420,50,50,2489,2500
9,50000,224420,100,48,4763,4800


In [16]:
def pretty_print_support_summary(df):
    for qsize in sorted(df["query_size"].unique()):
        sub = df[df["query_size"] == qsize]

        print("=" * 90)
        print(f"QUERY SET SIZE (fixed): {qsize:,} samples")
        print(f"Available support pool size: {int(sub['support_pool_size'].iloc[0]):,} samples")
        print("-" * 90)
        print(
            f"{'k support / OTU':>16} | "
            f"{'OTUs with ≥k samples':>20} | "
            f"{'unique support samples':>24}"
        )
        print("-" * 90)

        for _, r in sub.iterrows():
            print(
                f"{r['k_shot']:>16} | "
                f"{r['feasible_otus']:>20} | "
                f"{r['support_unique_samples']:>24,}"
            )

        print("=" * 90)
        print()

pretty_print_support_summary(df_support)


QUERY SET SIZE (fixed): 20,000 samples
Available support pool size: 254,420 samples
------------------------------------------------------------------------------------------
 k support / OTU | OTUs with ≥k samples |   unique support samples
------------------------------------------------------------------------------------------
               5 |                   50 |                      250
              10 |                   50 |                      499
              20 |                   50 |                      997
              50 |                   50 |                    2,492
             100 |                   48 |                    4,768

QUERY SET SIZE (fixed): 50,000 samples
Available support pool size: 224,420 samples
------------------------------------------------------------------------------------------
 k support / OTU | OTUs with ≥k samples |   unique support samples
-----------------------------------------------------------------------------------------