# ============================================================
# Microbe Atlas preprocessing for OTU+Taxa foundation training
# ============================================================
Notebook: otu-taxa-foundation/notebooks/preprocess/01_explore_raw_data.ipynb

Output root (processed data is NOT stored inside the repo):
/home/hernan_melmoth/Documents/phd_work/Microbeatlas_preprocess_training


In [1]:


import os, json, re
import numpy as np
import pandas as pd
from tqdm import tqdm
from collections import Counter
from typing import List, Tuple, Optional, Dict


# ----------------------------
# config
# ----------------------------

In [2]:

CFG = {
    # ---- raw inputs ----
    "filtered_counts_path": "/home/hernan_melmoth/Documents/phd_work/Bio_ontology/MicrobeAtlas/level_97/samples-otus-97.filtered-5-reads.json",
    "pred_sintax_path": "/home/hernan_melmoth/Documents/phd_work/Bio_ontology/MicrobeAtlas/level_97/taxonomy_reference/silva-138.2/vsearch_incomplete_species_fromOTUS_predictions/repseqs_sintax_incomplete.txt",

    # ---- preprocessing policy ----
    "keep_fraction": 0.999,  # top-99.9% cumulative abundance
    "taxonomy_policy": "contiguous_valid_prefix",  # matches your current logic

    # ---- processed data root (outside repo) ----
    "processed_root": "/home/hernan_melmoth/Documents/phd_work/Microbeatlas_preprocess_training",

    # ---- naming tags for clarity in folder structure ----
    "level_tag": "level_97",
    "silva_tag": "silva-138.2",
    "taxonomy_tag": "incomplete_silva_sintax",  # training version
    "dataset_tag": "dataset_full_top999",       # describes scope + policy
    }
CFG["append_unk_kingdom"] = True
CFG["unk_kingdom_token"] = "k:UNK"

OUT_DIR = os.path.join(
    CFG["processed_root"],
    CFG["level_tag"],
    CFG["silva_tag"],
    CFG["taxonomy_tag"],
    CFG["dataset_tag"],
)

print("Planned output directory:\n", OUT_DIR)
os.makedirs(OUT_DIR, exist_ok=True)

Planned output directory:
 /home/hernan_melmoth/Documents/phd_work/Microbeatlas_preprocess_training/level_97/silva-138.2/incomplete_silva_sintax/dataset_full_top999


# ----------------------------
# Loaders functions
# ----------------------------

In [3]:
def load_sintax_table(path: str) -> pd.DataFrame:
    """
    Load SINTAX output into a 2-column dataframe:
      otu_id | taxonomy

    Tries the (0,3) column layout first; falls back to parsing raw strings.
    """
    try:
        df = pd.read_csv(path, sep="\t", header=None, engine="python",
                         usecols=[0, 3], names=["otu_id", "taxonomy"], dtype=str)
    except Exception:
        df_raw = pd.read_csv(path, sep="\t", header=None, engine="python",
                             usecols=[0, 1], names=["otu_id", "raw_sintax"], dtype=str)

        def drop_conf(s: str) -> str:
            if pd.isna(s): 
                return ""
            parts = []
            for p in s.strip().rstrip(";").split(","):
                if ":" not in p:
                    continue
                parts.append(p.split("(", 1)[0].strip())
            return ",".join(parts)

        df_raw["taxonomy"] = df_raw["raw_sintax"].apply(drop_conf)
        df = df_raw[["otu_id", "taxonomy"]]

    df["otu_id"] = df["otu_id"].astype(str)
    df["taxonomy"] = df["taxonomy"].fillna("").astype(str)
    return df

pred_df = load_sintax_table(CFG["pred_sintax_path"])
print(pred_df.shape)
pred_df.head()


(111870, 2)


Unnamed: 0,otu_id,taxonomy
0,90_19327;96_77520;97_100055,
1,90_18588;96_76070;97_104571,"k:Bacteria,p:Cyanobacteria,c:Chloroplast,o:Sol..."
2,90_22156;96_86043;97_110485,"k:Bacteria,p:Cyanobacteria,c:Chloroplast,o:Sol..."
3,90_20463;96_79800;97_102794,
4,90_17477;96_14804;97_18077,"k:Bacteria,p:Bacteroidetes,c:Cytophagia,o:Cyto..."


In [4]:
# ============================================================
# Taxonomy parsing: enforce contiguous + valid prefix
# ============================================================

RANKS: List[str] = ["k", "p", "c", "o", "f", "g", "s"]
RANK_TO_IDX: Dict[str, int] = {r: i for i, r in enumerate(RANKS)}
_CONF_TAIL_RE = re.compile(r"\s*\([^)]*\)\s*$")

def strip_confidence(name: str) -> str:
    if not isinstance(name, str):
        return ""
    return _CONF_TAIL_RE.sub("", name).strip().rstrip(";").strip()

def split_tax_path(tax_str: str) -> List[str]:
    if not isinstance(tax_str, str) or not tax_str:
        return []
    s = tax_str.strip().rstrip(";").replace(";", ",")
    return [p.strip() for p in s.split(",") if p.strip()]

def parse_token(tok: str) -> Tuple[Optional[str], str]:
    if not isinstance(tok, str) or ":" not in tok:
        return None, ""
    r, name = tok.split(":", 1)
    r = (r or "").strip().lower()
    name = strip_confidence(name)
    if r not in RANKS:
        return None, name
    return r, name

def is_unidentified_name(name: str) -> bool:
    n = (name or "").strip().strip("'\"").lower()
    return n in {"unidentified", "unknown", "__unknown"}

def is_valid_token(tok: str) -> bool:
    r, name = parse_token(tok)
    if r is None:
        return False
    if name == "":
        return False
    if is_unidentified_name(name):
        return False
    return True

def pick_chain(tokens: List[str]) -> List[Optional[str]]:
    chain: List[Optional[str]] = []
    start = 0
    for r in RANKS:
        pref = r + ":"
        found = None
        for i in range(start, len(tokens)):
            t = tokens[i]
            if isinstance(t, str) and t.startswith(pref):
                found = (i, t)
                break
        if found is None:
            chain.append(None)
        else:
            chain.append(found[1])
            start = found[0] + 1
    return chain

def last_contiguous_valid_token(tokens: List[str]) -> Optional[str]:
    chain = pick_chain(tokens)
    last_valid = None
    for t in chain:
        if t is None or not is_valid_token(t):
            break
        last_valid = t
    return last_valid

def contiguous_chain(tokens: List[str]) -> List[str]:
    chain = pick_chain(tokens)
    out: List[str] = []
    for t in chain:
        if t is None or not is_valid_token(t):
            break
        out.append(t)
    return out

def token_depth(tok: str) -> Optional[int]:
    r, _ = parse_token(tok)
    return None if r is None else RANK_TO_IDX[r]

def ensure_child(mapping: dict, key: str) -> dict:
    if key not in mapping:
        mapping[key] = {}
    return mapping[key]


# Build Corpus fucntion

In [5]:
# ============================================================
# Dataset builder: full corpus, top-abundance OTUs, keep OTUs
# without usable taxonomy but encode with appended k:UNK id.
# ============================================================

def build_dataset_full_corpus(
    *,
    counts_by_sample: Dict[str, Dict[str, float]],
    pred_df: pd.DataFrame,
    out_dir: str,
    keep_fraction: float = 0.999,
    unk_kingdom_token: str = "k:UNK",
    append_unk_kingdom: bool = True,
    sanity_check_samples: int = 2000,
):
    """
    Build Microbe Atlas training dataset (full corpus) with the artifact contract:

      - otu_vocab.json        (list[str])
      - taxonomy_vocab.json   (list[str] of ALL tree nodes induced by valid contiguous prefixes,
                              plus optionally one appended UNK kingdom token at the end)
      - samples.jsonl         (one JSON per sample: {sample_id, otus:[int], taxa:[int]})
                              taxonomy ids are NEVER -1; missing taxonomy => unk_k_id
      - taxonomy_nested.json  (inspection-only nested dict)
      - dropped_otus.json     (OTUs with missing/unusable taxonomy; they are KEPT)

    Taxonomy policy:
      - Parse SINTAX taxonomy string
      - Take the contiguous VALID prefix in rank order (k→…→s)
      - Label = last token of that prefix (deepest contiguous valid token)
      - Nodes in taxonomy_vocab = union of all tokens in the contiguous valid prefixes
      - OTUs with no usable taxonomy are kept; their sample positions use unk_k_id
    """
    os.makedirs(out_dir, exist_ok=True)

    # ---------------------------------------
    # 1) Aggregate OTU totals over ALL samples
    # ---------------------------------------
    all_sample_ids = list(counts_by_sample.keys())
    print(f"[INFO] Total samples (full corpus): {len(all_sample_ids)}")

    otu_set = set()
    for sid in all_sample_ids:
        otu_set.update(counts_by_sample[sid].keys())

    otu_list_all = sorted(otu_set)
    otu2idx_all = {otu: i for i, otu in enumerate(otu_list_all)}
    O_all = len(otu_list_all)
    print(f"[INFO] Global OTU vocab (pre-cut): {O_all}")

    totals = np.zeros(O_all, dtype=np.float64)
    for sid in tqdm(all_sample_ids, desc="Accumulating OTU totals"):
        for otu, cnt in counts_by_sample[sid].items():
            totals[otu2idx_all[otu]] += float(cnt)

    order_desc = np.argsort(-totals)
    cum = np.cumsum(totals[order_desc])
    total_sum = float(cum[-1]) if cum.size else 0.0
    target = keep_fraction * total_sum
    k_keep = int(np.searchsorted(cum, target)) + 1

    kept_idx = order_desc[:k_keep]
    kept_otu_list = [otu_list_all[i] for i in kept_idx]
    print(f"[INFO] Kept OTUs (top {keep_fraction:.3%} cumulative abundance): {len(kept_otu_list)}")

    # --------------------------------------------------
    # 2) Build OTU -> taxonomy mapping (where possible) and collect ALL nodes
    # --------------------------------------------------
    pred_df2 = pred_df.drop_duplicates("otu_id")
    otu_to_taxstr = dict(zip(pred_df2["otu_id"].astype(str), pred_df2["taxonomy"].fillna("").astype(str)))

    taxonomy_nodes = set()
    otu2tax: Dict[str, str] = {}  # only OTUs with valid contiguous label
    n_truncated = 0
    depth_counter = Counter()

    for otu in tqdm(kept_otu_list, desc="Mapping OTUs to taxonomy"):
        tstr = otu_to_taxstr.get(otu, "")
        toks = split_tax_path(tstr)
        if not toks:
            continue

        lc = last_contiguous_valid_token(toks)
        if lc is None:
            continue

        # contiguity truncation indicator
        if lc != toks[-1]:
            n_truncated += 1

        otu2tax[otu] = lc

        d = token_depth(lc)
        if d is not None:
            depth_counter[RANKS[d]] += 1

        # collect ALL nodes from the contiguous valid prefix
        prefix = contiguous_chain(toks)
        for tok in prefix:
            taxonomy_nodes.add(tok)

    print(f"[INFO] Contiguity/invalid truncations applied to: {n_truncated}")
    print(f"[INFO] Label depth distribution: {dict(depth_counter)}")

    # -----------------------
    # 3) Record OTUs missing taxonomy (but DO NOT drop)
    # -----------------------
    dropped_otus = [otu for otu in kept_otu_list if otu not in otu2tax]
    print(f"[INFO] OTUs with missing/unusable taxonomy (kept): {len(dropped_otus)}")

    # -----------------------
    # 4) Build vocabularies
    # -----------------------
    taxonomy_vocab = sorted(taxonomy_nodes)  # base nodes only

    unk_k_id = None
    if append_unk_kingdom:
        # enforce "UNK kingdom token is last"
        if unk_kingdom_token in taxonomy_vocab:
            taxonomy_vocab = [t for t in taxonomy_vocab if t != unk_kingdom_token]
        taxonomy_vocab.append(unk_kingdom_token)
        unk_k_id = len(taxonomy_vocab) - 1

    tax2idx = {t: i for i, t in enumerate(taxonomy_vocab)}
    otu2idx_kept = {otu: i for i, otu in enumerate(kept_otu_list)}  # includes missing-tax OTUs

    print(f"[INFO] taxonomy_vocab size: {len(taxonomy_vocab)} (append_unk_kingdom={append_unk_kingdom})")
    if append_unk_kingdom:
        print(f"[INFO] UNK kingdom token id: {unk_k_id} ({taxonomy_vocab[unk_k_id]})")
    print(f"[INFO] otu_vocab size (kept): {len(kept_otu_list)}")

    # -----------------------
    # 5) Build nested taxonomy (inspection only)
    #     Only from OTUs that have a valid contiguous prefix.
    # -----------------------
    taxonomy_nested = {}
    for otu in tqdm(kept_otu_list, desc="Building taxonomy_nested"):
        tstr = otu_to_taxstr.get(otu, "")
        toks = split_tax_path(tstr)
        prefix = contiguous_chain(toks)
        if not prefix:
            continue
        cur = taxonomy_nested
        for tok in prefix:
            cur = ensure_child(cur, tok)

    # -----------------------
    # 6) Write artifacts
    # -----------------------
    with open(os.path.join(out_dir, "otu_vocab.json"), "w") as f:
        json.dump(kept_otu_list, f)

    with open(os.path.join(out_dir, "taxonomy_vocab.json"), "w") as f:
        json.dump(taxonomy_vocab, f)

    with open(os.path.join(out_dir, "taxonomy_nested.json"), "w") as f:
        json.dump(taxonomy_nested, f, indent=2, ensure_ascii=False)

    with open(os.path.join(out_dir, "dropped_otus.json"), "w") as f:
        json.dump(dropped_otus, f, indent=2)

    # config snapshot (expects CFG in outer scope; if not, remove CFG references)
    config_path = os.path.join(out_dir, "config.json")
    config_obj = {
        "keep_fraction": keep_fraction,
        "taxonomy_policy": "contiguous_valid_prefix",
        "append_unk_kingdom": append_unk_kingdom,
        "unk_kingdom_token": unk_kingdom_token,
        "unk_kingdom_id": unk_k_id,
        "note": "OTUs with missing/unusable taxonomy are kept; their taxa use the appended k:UNK id (last token).",
    }
    # If CFG exists in notebook, include raw paths as well
    if "CFG" in globals():
        config_obj.update({
            "filtered_counts_path": CFG.get("filtered_counts_path"),
            "pred_sintax_path": CFG.get("pred_sintax_path"),
        })

    with open(config_path, "w") as f:
        json.dump(config_obj, f, indent=2)

    # README
    readme_path = os.path.join(out_dir, "README.md")
    with open(readme_path, "w") as f:
        f.write(
            "# Microbe Atlas dataset artifacts (full corpus)\n\n"
            "This folder contains the artifacts consumed by the OTU–Taxa foundation model.\n\n"
            "## Files\n"
            "- `otu_vocab.json`: OTU token vocabulary (list of OTU string identifiers)\n"
            "- `taxonomy_vocab.json`: taxonomy vocabulary containing all induced tree nodes; "
            + ("plus an appended `k:UNK` token as the last id\n" if append_unk_kingdom else "\n") +
            "- `taxonomy_nested.json`: nested dict for manual inspection (not required by training)\n"
            "- `dropped_otus.json`: OTUs with missing/unusable taxonomy (kept in `otu_vocab.json`)\n"
            "- `samples.jsonl`: one JSON per sample: `{sample_id, otus:[...], taxa:[...]}`\n\n"
            "Notes:\n"
            + (f"- Missing/unusable taxonomy is encoded as the `k:UNK` id (always last): {unk_k_id}\n"
               if append_unk_kingdom else
               "- Missing/unusable taxonomy is not encoded here (append_unk_kingdom=False).\n")
        )

    # -----------------------
    # 7) Write samples.jsonl
    #     taxonomy id = real tax_id if available else unk_k_id
    # -----------------------
    if append_unk_kingdom and unk_k_id is None:
        raise RuntimeError("append_unk_kingdom=True but unk_k_id is None (unexpected).")

    jsonl_path = os.path.join(out_dir, "samples.jsonl")
    n_written = 0
    n_empty = 0
    n_tax_missing_positions = 0

    with open(jsonl_path, "w") as fout:
        for sid in tqdm(all_sample_ids, desc="Writing samples.jsonl"):
            row = counts_by_sample[sid]

            triplets = []
            for otu, cnt in row.items():
                j = otu2idx_kept.get(otu)
                if j is None:
                    continue  # OTU not in kept vocab

                tok = otu2tax.get(otu, None)
                if tok is None:
                    tax_id = unk_k_id  # no -1 anywhere
                    n_tax_missing_positions += 1
                else:
                    tax_id = tax2idx[tok]  # guaranteed present

                triplets.append((j, float(cnt), int(tax_id)))

            if not triplets:
                n_empty += 1
                continue

            # sort by abundance desc
            triplets.sort(key=lambda x: -x[1])
            otus_idx = [j for (j, _, _) in triplets]
            taxa_idx = [t for (_, _, t) in triplets]

            fout.write(json.dumps({"sample_id": sid, "otus": otus_idx, "taxa": taxa_idx}) + "\n")
            n_written += 1

    print(f"[SAVE] samples.jsonl: wrote={n_written}, skipped_empty={n_empty}")
    print(f"[INFO] Missing-taxonomy positions encoded as unk_k_id: {n_tax_missing_positions}")
    print(f"[OK] out_dir = {out_dir}")

    # -----------------------
    # 8) Quick sanity checks
    # -----------------------
    if append_unk_kingdom:
        assert taxonomy_vocab[-1] == unk_kingdom_token, "UNK token is not last in taxonomy_vocab."
        assert unk_k_id == len(taxonomy_vocab) - 1, "UNK id is not the last index."

    # sample-check that no negative tax ids exist in samples.jsonl
    bad = 0
    with open(jsonl_path, "r") as f:
        for i, line in enumerate(f):
            rec = json.loads(line)
            if any(int(t) < 0 for t in rec["taxa"]):
                bad += 1
                if bad <= 3:
                    print("[BAD] Found negative tax id in sample:", rec.get("sample_id"))
            if i >= max(0, sanity_check_samples - 1):
                break
    if bad == 0:
        print(f"[CHECK] No negative tax ids found in first {sanity_check_samples} samples.")
    else:
        print(f"[WARNING] Found {bad} samples with negative tax ids in first {sanity_check_samples} samples.")

    return {
        "out_dir": out_dir,
        "n_samples_total": len(all_sample_ids),
        "n_samples_written": n_written,
        "n_samples_empty": n_empty,
        "n_otus_pre": O_all,
        "n_otus_kept": len(kept_otu_list),
        "n_tax_nodes": len(taxonomy_vocab),
        "n_otus_missing_taxonomy": len(dropped_otus),
        "n_positions_missing_taxonomy": n_tax_missing_positions,
        "unk_kingdom_id": unk_k_id,
    }

In [6]:
# 1) Load counts (big dict-of-dicts)
with open(CFG["filtered_counts_path"], "r") as f:
    counts_by_sample = json.load(f)

# 2) Run builder
stats = build_dataset_full_corpus(
    counts_by_sample=counts_by_sample,
    pred_df=pred_df,
    out_dir=OUT_DIR,
    keep_fraction=CFG["keep_fraction"],
)

stats


[INFO] Total samples (full corpus): 1836250
[INFO] Global OTU vocab (pre-cut): 99335


Accumulating OTU totals: 100%|██████████| 1836250/1836250 [00:59<00:00, 30937.25it/s]


[INFO] Kept OTUs (top 99.900% cumulative abundance): 62200


Mapping OTUs to taxonomy: 100%|██████████| 62200/62200 [00:00<00:00, 66334.62it/s]


[INFO] Contiguity/invalid truncations applied to: 22
[INFO] Label depth distribution: {'g': 27020, 'f': 13080, 'o': 9387, 'c': 3445, 's': 3009, 'p': 2936, 'k': 100}
[INFO] OTUs with missing/unusable taxonomy (kept): 3223
[INFO] taxonomy_vocab size: 6929 (append_unk_kingdom=True)
[INFO] UNK kingdom token id: 6928 (k:UNK)
[INFO] otu_vocab size (kept): 62200


Building taxonomy_nested: 100%|██████████| 62200/62200 [00:00<00:00, 116121.42it/s]
Writing samples.jsonl: 100%|██████████| 1836250/1836250 [02:26<00:00, 12521.32it/s]

[SAVE] samples.jsonl: wrote=1836250, skipped_empty=0
[INFO] Missing-taxonomy positions encoded as unk_k_id: 3094452
[OK] out_dir = /home/hernan_melmoth/Documents/phd_work/Microbeatlas_preprocess_training/level_97/silva-138.2/incomplete_silva_sintax/dataset_full_top999
[CHECK] No negative tax ids found in first 2000 samples.





{'out_dir': '/home/hernan_melmoth/Documents/phd_work/Microbeatlas_preprocess_training/level_97/silva-138.2/incomplete_silva_sintax/dataset_full_top999',
 'n_samples_total': 1836250,
 'n_samples_written': 1836250,
 'n_samples_empty': 0,
 'n_otus_pre': 99335,
 'n_otus_kept': 62200,
 'n_tax_nodes': 6929,
 'n_otus_missing_taxonomy': 3223,
 'n_positions_missing_taxonomy': 3094452,
 'unk_kingdom_id': 6928}

In [8]:
otu_vocab = json.load(open(os.path.join(OUT_DIR, "otu_vocab.json")))
tax_vocab = json.load(open(os.path.join(OUT_DIR, "taxonomy_vocab.json")))

with open(os.path.join(OUT_DIR, "samples.jsonl")) as f:
    first = json.loads(next(f))

assert len(first["otus"]) == len(first["taxa"])
assert max(first["otus"]) < len(otu_vocab)
assert max(first["taxa"]) < len(tax_vocab)

print("OK: sample ids align with otu_vocab and taxonomy_vocab.")
print("Example sample_id:", first["sample_id"])
print("First 10 OTU ids:", first["otus"][:10])
print("First 10 TAXA ids:", first["taxa"][:10])


OK: sample ids align with otu_vocab and taxonomy_vocab.
Example sample_id: SRR4892887.SRS1780364
First 10 OTU ids: [3, 22, 516, 31, 92, 151, 24, 11, 185, 171]
First 10 TAXA ids: [1197, 1197, 1197, 1040, 1197, 1040, 1197, 1197, 2521, 2947]
