# Build LCA Distance Matrix (taxonomy tree)

This notebook computes an LCA-based edge distance matrix D for the induced taxonomy tree used by the OTU+Taxa foundation model.

Inputs (dataset folder):

* taxonomy_vocab.json: list of all taxonomy nodes (base nodes only). Index order defines matrix row/col order.
* taxonomy_nested.json: nested dictionary representing the induced taxonomy tree structure.

Output:

* lca_distance_edges.npy: integer matrix D where
* D[i,j] = depth(i) + depth(j) - 2*depth(lca(i,j)) 
(number of edges in the shortest path between nodes i and j).

Notes:

We compute LCA using binary lifting from parent pointers derived from taxonomy_nested.json.



In [1]:
import os, json
import numpy as np
from pathlib import Path
from collections import deque
from tqdm import tqdm

# ------------------------------------------------------------
# Dataset directory (processed artifacts live OUTSIDE the repo)
# ------------------------------------------------------------
DATASET_DIR = Path(
    "/home/hernan_melmoth/Documents/phd_work/Microbeatlas_preprocess_training"
    "/level_97/silva-138.2/incomplete_silva_sintax/dataset_full_top999"
)

VOCAB_PATH  = DATASET_DIR / "taxonomy_vocab.json"
NESTED_PATH = DATASET_DIR / "taxonomy_nested.json"

assert VOCAB_PATH.exists(), f"Missing: {VOCAB_PATH}"
assert NESTED_PATH.exists(), f"Missing: {NESTED_PATH}"

print("DATASET_DIR =", DATASET_DIR)
print("VOCAB_PATH  =", VOCAB_PATH)
print("NESTED_PATH =", NESTED_PATH)


DATASET_DIR = /home/hernan_melmoth/Documents/phd_work/Microbeatlas_preprocess_training/level_97/silva-138.2/incomplete_silva_sintax/dataset_full_top999
VOCAB_PATH  = /home/hernan_melmoth/Documents/phd_work/Microbeatlas_preprocess_training/level_97/silva-138.2/incomplete_silva_sintax/dataset_full_top999/taxonomy_vocab.json
NESTED_PATH = /home/hernan_melmoth/Documents/phd_work/Microbeatlas_preprocess_training/level_97/silva-138.2/incomplete_silva_sintax/dataset_full_top999/taxonomy_nested.json


In [2]:
taxonomy_vocab = json.load(open(VOCAB_PATH, "r"))
tax2idx = {t: i for i, t in enumerate(taxonomy_vocab)}
N = len(taxonomy_vocab)

taxonomy_nested = json.load(open(NESTED_PATH, "r"))

print("N taxonomy nodes =", N)
print("Example vocab entries:", taxonomy_vocab[:5])
print("Top-level keys in nested tree:", list(taxonomy_nested.keys())[:5])


N taxonomy nodes = 6929
Example vocab entries: ['c:028H05-P-BN-P5', 'c:055B07-P-DI-P58', 'c:113B434', 'c:AB64A-17', 'c:AEGEAN-245']
Top-level keys in nested tree: ['k:Bacteria', 'k:Archaea']


In [3]:
def build_parent_children_from_nested(taxonomy_nested, tax2idx):
    """
    Derive parent pointers and children adjacency from taxonomy_nested.json.
    The nested dict should use the same node-name strings as taxonomy_vocab.json.
    """
    N = len(tax2idx)
    parent = np.full(N, -1, dtype=np.int32)
    children = [[] for _ in range(N)]

    def dfs(subtree: dict, parent_name=None):
        for name, child_dict in subtree.items():
            if name not in tax2idx:
                # If taxonomy_nested was built from taxonomy_vocab, this should not happen.
                continue

            i = tax2idx[name]

            if parent_name is not None and parent_name in tax2idx:
                p = tax2idx[parent_name]
                parent[i] = p
                children[p].append(i)

            if isinstance(child_dict, dict) and child_dict:
                dfs(child_dict, parent_name=name)

    # Forest: multiple top-level clades are allowed; their parent stays -1
    dfs(taxonomy_nested, parent_name=None)

    return parent, children


parent, children = build_parent_children_from_nested(taxonomy_nested, tax2idx)

roots = np.where(parent == -1)[0].tolist()
print("Number of top-level roots =", len(roots))
print("First 10 root nodes:", [taxonomy_vocab[i] for i in roots[:10]])

def compute_depth(children, roots, parent):
    """
    Compute BFS depths for all nodes reachable from roots.
    Any unreachable nodes (e.g., appended k:UNK) are assigned depth=0
    and treated as singleton components.
    """
    N = len(children)
    depth = np.full(N, -1, dtype=np.int32)

    q = deque()
    for r in roots:
        depth[r] = 0
        q.append(r)

    while q:
        u = q.popleft()
        for v in children[u]:
            depth[v] = depth[u] + 1
            q.append(v)

    # Unreachable nodes: assign depth=0 (singleton components)
    # (We keep parent as -1; root_id will set root to itself.)
    bad = np.where(depth < 0)[0]
    if bad.size > 0:
        depth[bad] = 0

    return depth


depth = compute_depth(children, roots, parent)

if (depth < 0).any():
    bad = np.where(depth < 0)[0]
    print("WARNING: some nodes unreachable from roots (unexpected). Count:", len(bad))
    print("Examples:", [taxonomy_vocab[i] for i in bad[:10]])
else:
    print("Depth computed for all nodes.")
    print("Depth min/max:", int(depth.min()), int(depth.max()))
def build_binary_lifting(parent, depth):
    """
    Build binary lifting table for fast LCA.
    """
    N = len(parent)
    max_depth = int(depth.max()) if N else 0
    LOG = int(np.ceil(np.log2(max_depth + 1))) + 1

    up = np.full((LOG, N), -1, dtype=np.int32)
    up[0, :] = parent

    for k in range(1, LOG):
        prev = up[k - 1]
        up[k, :] = np.where(prev != -1, up[k - 1, prev], -1)

    return up


up = build_binary_lifting(parent, depth)
print("Binary lifting table shape:", up.shape)

def compute_component_root(parent, depth):
    """
    root_id[i] = top-level root index for node i.
    For unreachable/sentinel nodes (depth was originally -1, now set to 0),
    if parent[i] == -1 and node is not in the main forest, we map root_id[i] = i
    (singleton component).
    """
    N = len(parent)
    root_id = np.full(N, -1, dtype=np.int32)

    for i in range(N):
        # Singleton/sentinel: parent=-1 and (effectively) not attached to forest
        # We detect this by: parent[i]==-1 and i not in any child list;
        # but easiest robust rule: if parent chain is empty, root is itself.
        x = i
        while parent[x] != -1:
            x = parent[x]
        root_id[i] = x

    return root_id


root_id = compute_component_root(parent, depth)

# Sanity: count distinct components
unique_roots = np.unique(root_id)
print("Number of components =", len(unique_roots))


Number of top-level roots = 3
First 10 root nodes: ['k:Archaea', 'k:Bacteria', 'k:UNK']
Depth computed for all nodes.
Depth min/max: 0 6
Binary lifting table shape: (4, 6929)
Number of components = 3


In [4]:
def lca(u, v, up, depth, root_id):
    """
    Lowest Common Ancestor of nodes u and v.
    Returns:
      - node index of LCA, or -1 if u and v are in different components.
    """
    if root_id[u] != root_id[v]:
        return -1

    if depth[u] < depth[v]:
        u, v = v, u

    # Lift u to the depth of v
    diff = int(depth[u] - depth[v])
    k = 0
    while diff:
        if diff & 1:
            u = up[k, u]
        diff >>= 1
        k += 1

    if u == v:
        return u

    # Lift both until their parents match
    for k in range(up.shape[0] - 1, -1, -1):
        uu = up[k, u]
        vv = up[k, v]
        if uu != vv:
            u, v = uu, vv

    return up[0, u]

def build_lca_distance_matrix(depth, up, root_id, dtype=np.int16):
    N = len(depth)
    D = np.full((N, N), -1, dtype=dtype)
    np.fill_diagonal(D, 0)

    for i in tqdm(range(N), desc="Building LCA distance matrix"):
        for j in range(i + 1, N):
            a = lca(i, j, up, depth, root_id)
            if a == -1:
                dist = -1
            else:
                dist = int(depth[i] + depth[j] - 2 * depth[a])
            D[i, j] = D[j, i] = dist

    return D


D = build_lca_distance_matrix(depth, up, root_id, dtype=np.int16)

finite = D[D >= 0]
print("D shape:", D.shape, "dtype:", D.dtype)
print("Finite entries:", finite.size, "min:", int(finite.min()), "max:", int(finite.max()))
print("Missing (=-1) entries:", int((D < 0).sum()))



Building LCA distance matrix: 100%|██████████| 6929/6929 [01:36<00:00, 71.49it/s] 


D shape: (6929, 6929) dtype: int16
Finite entries: 42836115 min: 0 max: 12
Missing (=-1) entries: 5174926


In [7]:
TREE_DIR = DATASET_DIR / "tree_artifacts"
TREE_DIR.mkdir(parents=True, exist_ok=True)

print("Tree artifacts will be saved in:", TREE_DIR)
lca_path = TREE_DIR / "lca_distance_edges.npy"
np.save(lca_path, D)
print("Saved LCA distance matrix to:", lca_path)
np.save(TREE_DIR / "parent.npy", parent)
np.save(TREE_DIR / "depth.npy", depth)
np.save(TREE_DIR / "root_id.npy", root_id)

print("Saved structural tree arrays:")
print(" - parent.npy")
print(" - depth.npy")
print(" - root_id.npy")



Tree artifacts will be saved in: /home/hernan_melmoth/Documents/phd_work/Microbeatlas_preprocess_training/level_97/silva-138.2/incomplete_silva_sintax/dataset_full_top999/tree_artifacts
Saved LCA distance matrix to: /home/hernan_melmoth/Documents/phd_work/Microbeatlas_preprocess_training/level_97/silva-138.2/incomplete_silva_sintax/dataset_full_top999/tree_artifacts/lca_distance_edges.npy
Saved structural tree arrays:
 - parent.npy
 - depth.npy
 - root_id.npy


In [6]:
readme_path = TREE_DIR / "README.md"
with open(readme_path, "w") as f:
    f.write(
        "# Tree Artifacts\n\n"
        "This folder contains **derived taxonomy tree artifacts** computed from\n"
        "`taxonomy_nested.json` and `taxonomy_vocab.json`.\n\n"
        "These files are **not raw dataset definitions**. They encode structural\n"
        "properties of the taxonomy hierarchy that are required by the OTU+Taxa\n"
        "foundation model for hierarchical operations, constraints, and regularization.\n\n"
        "All arrays in this folder are **index-aligned with `taxonomy_vocab.json`**.\n"
        "That is, index `i` in any array corresponds to `taxonomy_vocab[i]`.\n\n"
        "## Files\n\n"
        "- `parent.npy`\n"
        "  Parent index of each taxonomy node in the tree (`-1` indicates the synthetic root).\n\n"
        "- `depth.npy`\n"
        "  Depth of each taxonomy node measured from the root of its connected component.\n\n"
        "- `root_id.npy`\n"
        "  Identifier of the top-level component (root clade) each node belongs to.\n\n"
        "- `lca_distance_edges.npy`\n"
        "  Symmetric matrix of topological distances between taxonomy nodes,\n"
        "  computed as the number of edges between two nodes via their lowest common ancestor (LCA).\n\n"
        "- `descendant_matrix.npy`\n"
        "  Binary descendant-closure matrix where `M[i, j] = 1` iff node `j` is a descendant\n"
        "  of node `i` (including `i` itself).\n\n"
        "- `taxonomy_vocab_with_unk.json`\n"
        "  Extension of `taxonomy_vocab.json` with one UNK token per rank\n"
        "  (`k:UNK, p:UNK, ..., s:UNK`), appended at the end.\n\n"
        "- `descendant_matrix_with_unk.npy`\n"
        "  Descendant-closure matrix extended to include UNK nodes, enforcing\n"
        "  valid hierarchical transitions when taxonomy labels are missing.\n\n"
        "- `rank_idx.npy`\n"
        "  Integer array mapping each taxonomy token to its taxonomic rank\n"
        "  (`0=k, 1=p, 2=c, 3=o, 4=f, 5=g, 6=s`). Required for per-rank prediction heads\n"
        "  and hierarchical masking.\n\n"
        "## Notes\n\n"
        "- UNK nodes are **model-level constructs** and do not alter the biological\n"
        "  structure of the base taxonomy.\n"
        "- Base and UNK-extended artifacts coexist to allow flexible experimentation\n"
        "  without recomputing the original tree.\n"
    )

print("Saved:", readme_path)


Saved: /home/hernan_melmoth/Documents/phd_work/Microbeatlas_preprocess_training/level_97/silva-138.2/incomplete_silva_sintax/dataset_full_top999/tree_artifacts/README.md
