# Descendant matrix builder (taxonomy ancestor/descendant closure)

we build:
   - parent.npy / depth.npy / root_id.npy  (optional but useful)
   - descendant_matrix.npy                (M: ancestor/descendant closure)

 Matrix definition (same as your old ETE3 version):
   M[i, j] = 1  iff  node j is node i OR a descendant of node i
 

In [1]:
import json
import numpy as np
from pathlib import Path
from collections import deque
from tqdm import tqdm
# --------------------------
# Paths
# --------------------------
DATASET_DIR = Path(
    "/home/hernan_melmoth/Documents/phd_work/Microbeatlas_preprocess_training"
    "/level_97/silva-138.2/incomplete_silva_sintax/dataset_full_top999"
)

VOCAB_PATH  = DATASET_DIR / "taxonomy_vocab.json"
NESTED_PATH = DATASET_DIR / "taxonomy_nested.json"

assert VOCAB_PATH.exists(), f"Missing: {VOCAB_PATH}"
assert NESTED_PATH.exists(), f"Missing: {NESTED_PATH}"

# Save derived matrices into a subfolder (recommended)
TREE_DIR = DATASET_DIR / "tree_artifacts"
TREE_DIR.mkdir(parents=True, exist_ok=True)

print("DATASET_DIR =", DATASET_DIR)
print("TREE_DIR    =", TREE_DIR)

# --------------------------
# Load vocab + nested tree
# --------------------------
names = json.load(open(VOCAB_PATH, "r"))
name2idx = {nm: i for i, nm in enumerate(names)}
N = len(names)

taxonomy_nested = json.load(open(NESTED_PATH, "r"))

print("N taxonomy nodes =", N)
print("Top-level clades:", list(taxonomy_nested.keys())[:10])

DATASET_DIR = /home/hernan_melmoth/Documents/phd_work/Microbeatlas_preprocess_training/level_97/silva-138.2/incomplete_silva_sintax/dataset_full_top999
TREE_DIR    = /home/hernan_melmoth/Documents/phd_work/Microbeatlas_preprocess_training/level_97/silva-138.2/incomplete_silva_sintax/dataset_full_top999/tree_artifacts
N taxonomy nodes = 6928
Top-level clades: ['k:Bacteria', 'k:Archaea']


In [2]:

# ============================================================
# 1) Build parent pointers and children adjacency from nested dict
# ============================================================

def build_parent_children_from_nested(taxonomy_nested, name2idx):
    """
    Build parent pointers and children adjacency using taxonomy_nested.json.

    parent[i] = parent index of node i, or -1 if i is a top-level clade
    children[i] = list of child indices
    """
    N = len(name2idx)
    parent = np.full(N, -1, dtype=np.int32)
    children = [[] for _ in range(N)]

    def dfs(subtree: dict, parent_name=None):
        for node_name, child_dict in subtree.items():
            if node_name not in name2idx:
                # Should not happen if nested tree was built from vocab;
                # keep simple and skip.
                continue

            i = name2idx[node_name]

            if parent_name is not None and parent_name in name2idx:
                p = name2idx[parent_name]
                parent[i] = p
                children[p].append(i)

            if isinstance(child_dict, dict) and child_dict:
                dfs(child_dict, parent_name=node_name)

    # forest: multiple top-level clades allowed
    dfs(taxonomy_nested, parent_name=None)
    return parent, children


parent, children = build_parent_children_from_nested(taxonomy_nested, name2idx)
roots = np.where(parent == -1)[0].tolist()

print("Number of roots (top-level clades) =", len(roots))
print("First 10 roots:", [names[i] for i in roots[:10]])


# ============================================================
# 2) Compute depth (optional but convenient for debugging)
# ============================================================

def compute_depth(children, roots):
    depth = np.full(len(children), -1, dtype=np.int32)
    q = deque()
    for r in roots:
        depth[r] = 0
        q.append(r)
    while q:
        u = q.popleft()
        for v in children[u]:
            depth[v] = depth[u] + 1
            q.append(v)
    return depth


depth = compute_depth(children, roots)

if (depth < 0).any():
    bad = np.where(depth < 0)[0]
    print("WARNING: some nodes unreachable (unexpected). Count:", len(bad))
    print("Examples:", [names[i] for i in bad[:10]])
else:
    print("Depth computed for all nodes.")
    print("Depth min/max:", int(depth.min()), int(depth.max()))


# ============================================================
# 3) Build descendant-closure matrix M
#    M[i, j] = 1 iff j is i or a descendant of i (same as old ETE3 code)
# ============================================================

def euler_tour_times(children, roots):
    """
    Compute entry/exit times (tin/tout) for each node.
    In a rooted tree/forest:
      i is ancestor of j  <=>  tin[i] <= tin[j] and tout[j] <= tout[i]
    """
    N = len(children)
    tin = np.full(N, -1, dtype=np.int32)
    tout = np.full(N, -1, dtype=np.int32)
    t = 0

    def dfs(u):
        nonlocal t
        tin[u] = t
        t += 1
        for v in children[u]:
            dfs(v)
        tout[u] = t
        t += 1

    for r in roots:
        dfs(r)

    return tin, tout


tin, tout = euler_tour_times(children, roots)

def build_descendant_matrix_dense(tin, tout, dtype=np.uint8):
    """
    Dense descendant-closure matrix:
      M[i, j] = 1 iff i is ancestor of j (including i==j)
    """
    N = len(tin)
    M = np.zeros((N, N), dtype=dtype)
    for i in tqdm(range(N), desc="Building descendant matrix (dense)"):
        # vectorized ancestor test across all j
        M[i, :] = ((tin[i] <= tin) & (tout <= tout[i])).astype(dtype)
    return M


M = build_descendant_matrix_dense(tin, tout, dtype=np.uint8)

print("Descendant matrix shape:", M.shape, "dtype:", M.dtype)
print("Diagonal sum (should be N):", int(np.trace(M)))
print("Nonzero entries:", int(M.sum()))


# ============================================================
# 4) Save artifacts
# ============================================================

np.save(TREE_DIR / "parent.npy", parent)
np.save(TREE_DIR / "depth.npy", depth)
np.save(TREE_DIR / "tin.npy", tin)
np.save(TREE_DIR / "tout.npy", tout)

desc_path = TREE_DIR / "descendant_matrix.npy"
np.save(desc_path, M)
print("Saved:", desc_path)

# Minimal README for this folder (optional but recommended)
readme_path = TREE_DIR / "README_descendants.md"
with open(readme_path, "w") as f:
    f.write(
        "# Descendant matrix (ancestor/descendant closure)\n\n"
        "Files in this folder are derived from taxonomy_nested.json and taxonomy_vocab.json.\n\n"
        "## descendant_matrix.npy\n"
        "Let `names = taxonomy_vocab.json` (length N).\n"
        "The matrix `M` has shape [N, N] and is defined as:\n"
        "`M[i, j] = 1` iff node `names[j]` is equal to or a descendant of node `names[i]`.\n\n"
        "This matches the old ETE3 implementation, but uses an Euler tour (tin/tout) for speed.\n"
    )
print("Saved:", readme_path)



Number of roots (top-level clades) = 2
First 10 roots: ['k:Archaea', 'k:Bacteria']
Depth computed for all nodes.
Depth min/max: 0 6


Building descendant matrix (dense): 100%|██████████| 6928/6928 [00:00<00:00, 133075.67it/s]

Descendant matrix shape: (6928, 6928) dtype: uint8
Diagonal sum (should be N): 6928
Nonzero entries: 41663
Saved: /home/hernan_melmoth/Documents/phd_work/Microbeatlas_preprocess_training/level_97/silva-138.2/incomplete_silva_sintax/dataset_full_top999/tree_artifacts/descendant_matrix.npy
Saved: /home/hernan_melmoth/Documents/phd_work/Microbeatlas_preprocess_training/level_97/silva-138.2/incomplete_silva_sintax/dataset_full_top999/tree_artifacts/README_descendants.md





# testing decendant matrix

In [4]:


def show_descendants(label, names, M_tensor, max_print=50):
    """
    Print all descendants of taxonomy label using the descendant matrix M.
    - label: string (e.g., 'f:Lachnospiraceae')
    - names: vocab list (taxonomy_vocab.json)
    - M_tensor: descendant closure matrix [N, N] (numpy array or torch tensor)
    """
    # handle numpy or torch
    if hasattr(M_tensor, "cpu"):  # torch tensor
        row = M_tensor[names.index(label)].cpu().numpy()
    else:  # numpy
        row = M_tensor[names.index(label)]

    descendant_indices = np.where(row == 1)[0]
    descendant_names = [names[j] for j in descendant_indices]

    print(f"\nDESCENDANTS of '{label}' (including itself):")
    print(f"Total: {len(descendant_names)}")

    for nm in descendant_names[:max_print]:
        print("  -", nm)

    if len(descendant_names) > max_print:
        print(f"... (showing first {max_print} of {len(descendant_names)})")


# Example inspections (edit freely)
#show_descendants("f:Lachnospiraceae", names, M, max_print=50)
show_descendants("g:Roseburia", names, M, max_print=50)
#show_descendants("f:Veillonellaceae", names, M, max_print=20)



DESCENDANTS of 'g:Roseburia' (including itself):
Total: 3
  - g:Roseburia
  - s:Lachnospiraceae_bacterium_feline_oral_taxon_021
  - s:Roseburia_sp._499


# Extending the taxonomy tree with per-rank UNK nodes

This section augments the base taxonomy tree with one UNK node per rank 

(k:UNK, p:UNK, ..., s:UNK) in order to support per-rank taxonomy prediction

and hierarchical regularization in the OTU+Taxa foundation model.

In [5]:
# --------------------------
# Paths (same tree_artifacts directory)
# --------------------------
TREE_DIR = DATASET_DIR / "tree_artifacts"

BASE_VOCAB_PATH = DATASET_DIR / "taxonomy_vocab.json"
BASE_M_PATH     = TREE_DIR / "descendant_matrix.npy"

NEW_VOCAB_PATH  = TREE_DIR / "taxonomy_vocab_with_unk.json"
NEW_M_PATH      = TREE_DIR / "descendant_matrix_with_unk.npy"
RANK_IDX_PATH   = TREE_DIR / "rank_idx.npy"

# --------------------------
# Load base vocab and matrix
# --------------------------
with open(BASE_VOCAB_PATH, "r") as f:
    vocab_old = json.load(f)

M_old = np.load(BASE_M_PATH)
T_old = len(vocab_old)

assert M_old.shape == (T_old, T_old), "Mismatch between vocab and descendant matrix"

print(f"Loaded base taxonomy: T_old={T_old}")


Loaded base taxonomy: T_old=6928


In [6]:

# --------------------------
# Rank mapping
# --------------------------
RANK_CHAR_TO_IDX = {"k":0, "p":1, "c":2, "o":3, "f":4, "g":5, "s":6}
R = 7

rank_idx_old = np.empty(T_old, dtype=np.int64)
for i, name in enumerate(vocab_old):
    c = name[0].lower()
    if c not in RANK_CHAR_TO_IDX:
        raise ValueError(f"Cannot infer rank from token: {name}")
    rank_idx_old[i] = RANK_CHAR_TO_IDX[c]

# --------------------------
# Extend vocabulary
# --------------------------
unk_tokens = [f"{r}:UNK" for r in ["k","p","c","o","f","g","s"]]
vocab_new = vocab_old + unk_tokens
T_new = len(vocab_new)

with open(NEW_VOCAB_PATH, "w") as f:
    json.dump(vocab_new, f, indent=2)

print(f"Extended vocab saved: T_new={T_new}")

# --------------------------
# Build extended descendant matrix
# --------------------------
M_new = np.zeros((T_new, T_new), dtype=M_old.dtype)
M_new[:T_old, :T_old] = M_old

# Build rank_idx for extended vocab
rank_idx_new = np.empty(T_new, dtype=np.int64)
rank_idx_new[:T_old] = rank_idx_old
for r in range(R):
    rank_idx_new[T_old + r] = r

# 1) Link each UNK_r as child of ALL real nodes at rank r-1
for r in range(1, R):
    unk_id = T_old + r
    parents = np.where(rank_idx_old == (r - 1))[0]
    if parents.size == 0:
        print(f"[WARN] No parents found at rank {r-1} for UNK_{r}")
        continue
    M_new[parents, unk_id] = 1

# 2) Chain UNKs: UNK_k → UNK_p → ... → UNK_s
for r in range(R - 1):
    parent_unk = T_old + r
    child_unk  = T_old + (r + 1)
    M_new[parent_unk, child_unk] = 1

# 3) Closure: self-descendants
np.fill_diagonal(M_new, 1)

# --------------------------
# Save artifacts
# --------------------------
np.save(NEW_M_PATH, M_new)
np.save(RANK_IDX_PATH, rank_idx_new)

print("Saved UNK-extended hierarchy artifacts:")
print(" -", NEW_VOCAB_PATH)
print(" -", NEW_M_PATH)
print(" -", RANK_IDX_PATH)

Extended vocab saved: T_new=6935
Saved UNK-extended hierarchy artifacts:
 - /home/hernan_melmoth/Documents/phd_work/Microbeatlas_preprocess_training/level_97/silva-138.2/incomplete_silva_sintax/dataset_full_top999/tree_artifacts/taxonomy_vocab_with_unk.json
 - /home/hernan_melmoth/Documents/phd_work/Microbeatlas_preprocess_training/level_97/silva-138.2/incomplete_silva_sintax/dataset_full_top999/tree_artifacts/descendant_matrix_with_unk.npy
 - /home/hernan_melmoth/Documents/phd_work/Microbeatlas_preprocess_training/level_97/silva-138.2/incomplete_silva_sintax/dataset_full_top999/tree_artifacts/rank_idx.npy


In [None]:
# Load the extended vocab and matrix
vocab_new = json.load(open(NEW_VOCAB_PATH))
M_new = np.load(NEW_M_PATH)

T_new = len(vocab_new)
T_old = T_new - 7   # since last 7 tokens are the UNKs

# Extract the 7×7 UNK submatrix
M_unk = M_new[T_old:T_new, T_old:T_new]

print("UNK tokens:")
for i, name in enumerate(vocab_new[T_old:T_new]):
    print(f"{T_old+i}: {name}")

print("\n7x7 descendant matrix for UNK tokens (rows=ancestors, cols=descendants):")
print(M_unk)


UNK tokens:
6928: k:UNK
6929: p:UNK
6930: c:UNK
6931: o:UNK
6932: f:UNK
6933: g:UNK
6934: s:UNK

7×7 descendant matrix for UNK tokens (rows=ancestors, cols=descendants):
[[1 1 0 0 0 0 0]
 [0 1 1 0 0 0 0]
 [0 0 1 1 0 0 0]
 [0 0 0 1 1 0 0]
 [0 0 0 0 1 1 0]
 [0 0 0 0 0 1 1]
 [0 0 0 0 0 0 1]]


In [8]:
# ----------------------------
# Load vocab + matrix
# ----------------------------
vocab_new = json.load(open(NEW_VOCAB_PATH))
M_new = np.load(NEW_M_PATH)

T_new = len(vocab_new)
T_old = T_new - 7   # assuming last 7 tokens are the UNKs

print(f"T_old = {T_old}, T_new = {T_new}")

# ----------------------------
# Find index of p:UNK
# ----------------------------
try:
    idx_p_unk = vocab_new.index("p:UNK")
except ValueError:
    raise ValueError("'p:UNK' not found in vocab_new. Check naming.")

print(f"Index of 'p:UNK' = {idx_p_unk}")

# ----------------------------
# Find all parents of p:UNK
# (rows where M_new[parent, idx_p_unk] == 1)
# ----------------------------
parents_idx = np.where(M_new[:, idx_p_unk] == 1)[0]

print("\nParents of 'p:UNK' (rows where M_new[parent, 'p:UNK'] == 1):")
for p in parents_idx:
    print(f"  {p:6d}  {vocab_new[p]}")

# ----------------------------
# OPTIONAL: check specifically Archaea, Bacteria, k:UNK
# ----------------------------
names_to_check = ["k:Archaea", "k:Bacteria", "k:UNK"]
print("\nCheck specific expected parents:")
for name in names_to_check:
    if name in vocab_new:
        idx = vocab_new.index(name)
        connected = bool(M_new[idx, idx_p_unk] == 1)
        print(f"  {name:10s}  idx={idx:6d}  -> p:UNK edge: {connected}")
    else:
        print(f"  {name:10s}  NOT FOUND in vocab_new")

T_old = 6928, T_new = 6935
Index of 'p:UNK' = 6929

Parents of 'p:UNK' (rows where M_new[parent, 'p:UNK'] == 1):
    3494  k:Archaea
    3495  k:Bacteria
    6928  k:UNK
    6929  p:UNK

Check specific expected parents:
  k:Archaea   idx=  3494  -> p:UNK edge: True
  k:Bacteria  idx=  3495  -> p:UNK edge: True
  k:UNK       idx=  6928  -> p:UNK edge: True
