In [1]:
import os

import numpy as np
import tables as tb
from numpy.typing import NDArray

os.environ["POLARS_MAX_THREADS"] = "128"
import polars as pl  # noqa: E402

LongArray = NDArray[np.int64]

In [2]:
%load_ext watermark
%watermark -vp numpy,polars,tables

Python implementation: CPython
Python version       : 3.11.3
IPython version      : 8.13.2

numpy : 1.23.5
polars: 0.20.6
tables: 3.8.0



In [3]:
metadata = pl.read_csv("supplementary_table_1.tsv", separator="\t")
metadata

genome_id,genome,dataset,genome_length_bp,num_proteins,scaffolds,training_class,taxonomy,taxonomy_method,viral_realm,viral_kingdom,viral_phylum,viral_class,viral_order,viral_family,viral_genus,viral_species,completeness,miuvig_quality,sequence_origin,ecosystem_classification,host_prediction_method,host_domain,host_phylum,host_class,host_order,host_family,host_genus,host_species
i64,str,str,i64,i64,str,str,str,str,str,str,str,str,str,str,str,str,f64,str,str,str,str,str,str,str,str,str,str,str
0,"""DTR_391407""","""train""",4336,2,"""DTR_391407""","""Unknown""","""r__;k__;p__;c_…","""geNomad""",,,,,,,,,100.0,"""High-quality""","""IMG/VRv3 (10.1…","""Host-associate…",,,,,,,,
1,"""DTR_421064""","""train""",3119,2,"""DTR_421064""","""Unknown""","""r__;k__;p__;c_…","""geNomad""",,,,,,,,,100.0,"""High-quality""","""IMG/VRv3 (10.1…","""Host-associate…",,,,,,,,
2,"""DTR_461283""","""train""",4573,2,"""DTR_461283""","""Unknown""","""r__;k__;p__;c_…","""geNomad""",,,,,,,,,100.0,"""High-quality""","""IMG/VRv3 (10.1…","""Host-associate…",,,,,,,,
3,"""DTR_488058""","""train""",3678,2,"""DTR_488058""","""Unknown""","""r__;k__;p__;c_…","""geNomad""",,,,,,,,,100.0,"""High-quality""","""IMG/VRv3 (10.1…","""Host-associate…",,,,,,,,
4,"""DTR_491373""","""train""",2615,2,"""DTR_491373""","""Unknown""","""r__;k__;p__;c_…","""geNomad""",,,,,,,,,100.0,"""High-quality""","""IMG/VRv3 (10.1…","""Host-associate…",,,,,,,,
5,"""DTR_511071""","""train""",2300,2,"""DTR_511071""","""Monodnaviria""","""r__Monodnaviri…","""geNomad""","""r__Monodnaviri…","""k__Shotokuvira…","""p__Cressdnavir…","""c__Arfiviricet…","""o__Cirlivirale…","""f__Circovirida…",,,91.93,"""High-quality""","""IMG/VRv3 (10.1…","""Host-associate…",,,,,,,,
6,"""DTR_524282""","""train""",4334,2,"""DTR_524282""","""Unknown""","""r__;k__;p__;c_…","""geNomad""",,,,,,,,,100.0,"""High-quality""","""IMG/VRv3 (10.1…","""Host-associate…",,,,,,,,
7,"""DTR_554557""","""train""",2072,2,"""DTR_554557""","""Unknown""","""r__;k__;p__;c_…","""geNomad""",,,,,,,,,100.0,"""High-quality""","""IMG/VRv3 (10.1…","""Host-associate…",,,,,,,,
8,"""DTR_558218""","""train""",6623,2,"""DTR_558218""","""Unknown""","""r__;k__;p__;c_…","""geNomad""",,,,,,,,,100.0,"""High-quality""","""IMG/VRv3 (10.1…","""Environmental;…",,,,,,,,
9,"""DTR_563185""","""train""",4464,2,"""DTR_563185""","""Unknown""","""r__;k__;p__;c_…","""geNomad""",,,,,,,,,99.87,"""High-quality""","""IMG/VRv3 (10.1…","""Environmental;…",,,,,,,,


We compute purity as information gain:

$$
I = \frac{H_{background} - \sum_{i=0}^{n_{clusters}} H_{i}w_{i} }{H_{background}}
$$

where $H$ is the entropy (our specific implementation is in bits):

$$
H(X) = - \sum_{x \in X} p(x)log_{2}p(x)
$$

See paper methods for more information.

In [4]:
def entropy(labels: LongArray) -> float:
    """Compute entropy of label distribution."""

    if len(labels) == 0:
        return -1.0

    _, counts = np.unique(labels, return_counts=True)
    
    counts = counts.astype(np.float64)

    if counts.size == 1:
        return 0.0
    
    total = np.sum(counts)
    prob = counts / total

    return float(-np.sum(prob * np.log2(prob)))

def label_filters(col: str) -> pl.Expr:
    expr = (
        (~pl.col(col).is_in({"", "unknown", "unclassified"}))
        & (pl.col(col).is_not_null())
    )
    return expr

In [5]:
label_columns = [
    "viral_realm",
    "viral_kingdom",
    "viral_phylum",
    "viral_class",
    "viral_order",
    "viral_family",
    "host_domain",
    "host_phylum",
    "host_class",
    "host_order",
    "host_family",
    "host_genus",
]

In [6]:
background_entropy = {
    "label": [],
    "background": [],
    "n_cats": [],
    "bits": [],
    "dataset": [],
}

for dataset in ["train", "test"]:
    curr_metadata = metadata.filter(pl.col("dataset") == dataset)

    for col in label_columns:
        filtered_data = curr_metadata.filter(label_filters(col))
        col_data = filtered_data[col]
        col_idx = {c: i for i, c in enumerate(col_data.unique())}
        labels_true = col_data.replace(col_idx).cast(pl.Int64).to_numpy()
        
        background_entropy["label"].append(col)
        background_entropy["background"].append(entropy(labels_true))
        background_entropy["n_cats"].append(len(col_idx))
        background_entropy["bits"].append(np.log2(len(col_idx)))
        background_entropy["dataset"].append(dataset)

background = pl.DataFrame(background_entropy)
background

label,background,n_cats,bits,dataset
str,f64,i64,f64,str
"""viral_realm""",0.835893,5,2.321928,"""train"""
"""viral_kingdom""",1.005295,10,3.321928,"""train"""
"""viral_phylum""",1.08304,17,4.087463,"""train"""
"""viral_class""",1.122678,33,5.044394,"""train"""
"""viral_order""",1.438566,52,5.70044,"""train"""
"""viral_family""",3.883609,127,6.988685,"""train"""
"""host_domain""",0.822724,5,2.321928,"""train"""
"""host_phylum""",3.002393,80,6.321928,"""train"""
"""host_class""",4.001934,103,6.686501,"""train"""
"""host_order""",5.492335,272,8.087463,"""train"""


In [7]:
def _compute_information_gain_per_cluster_set(
    current_genome_metadata: pl.DataFrame,
    labels: list[str],
    # clu_metadata: Iterable[tuple[int, float]],
    # clu_metadata_mask: NDArray[np.bool_],
    clu: LongArray,
    genome_k: int,
    genome_resolution: float,
    method_name: str, 
    dataset_name: str,
    storage: list[pl.LazyFrame]
):
    cluster_info = (
        current_genome_metadata
        .with_columns(
            cluster=pl.lit(clu),
            k=pl.lit(genome_k),
            resolution=pl.lit(genome_resolution),
        )
        .with_columns(
            cluster_size=pl.col("genome_id").len().over("cluster"),
        )
    )

    for col in labels:
        df = (
            cluster_info
            .lazy()
            .filter(
                label_filters(col) & pl.col("cluster_size").gt(1)
            )
            .group_by("cluster")
            .agg(
                pl.first("k", "resolution", "cluster_size"),
                pl.col(col).value_counts(),
            )
            .explode(col)
            .unnest(col)
            .with_columns(
                prop = pl.col("count") / pl.col("count").sum().over("cluster")
            )
            .with_columns(
                entropy = -pl.col("prop") * pl.col("prop").log(2),
            )
            .group_by("cluster")
            .agg(
                pl.sum("entropy"),
                pl.first("cluster_size", "k", "resolution"),
            )
            .with_columns(
                cluster_size_weighted_entropy = (
                    pl.col("entropy") * pl.col("cluster_size") / pl.col("cluster_size").sum()
                ),
                label = pl.lit(col),
                method = pl.lit(method_name),
                dataset = pl.lit(dataset_name),
            )
        )

        storage.append(df)


def _summarize_information_gain(
    storage: list[pl.LazyFrame],
    background: pl.DataFrame,
) -> pl.DataFrame:
    summary = (
        pl.concat(pl.collect_all(storage))
        .join(background, on=["label", "dataset"])
        .with_columns(gain = pl.col("background") - pl.col("cluster_size_weighted_entropy"))
        .with_columns(gain_ratio = pl.col("gain") / pl.col("background"))
        .group_by("dataset", "method", "k", "resolution", "label")
        .agg(
            pl.mean("gain", "gain_ratio"),
            n_genomes=pl.sum("cluster_size")
        )
        .join(
            metadata.group_by("dataset").agg(total_genomes=pl.len()),
            on="dataset",
        )
        .with_columns(
            inclusion_weight = pl.col("n_genomes") / pl.col("total_genomes"),
        )
        .with_columns(
            weighted_gain = pl.col("gain") * pl.col("inclusion_weight"),
            weighted_gain_ratio = pl.col("gain_ratio") * pl.col("inclusion_weight"),
        )
        .sort(["dataset", "method", "k", "resolution", "label"])
    )

    return summary

def compute_information_gain(
    genome_metadata: pl.DataFrame, 
    clustering_file: str, 
    labels: list[str],
    background: pl.DataFrame,
) -> pl.DataFrame:
    """
    The clustering file is an .h5 with the following structure:
    .
    |__train
    |  |__method_0
    |  |  |__metadata
    |  |  |__data
    ...
    |__test
       |__method_0
          |__metadata
          |__data

    The clustering metadata is a numpy struct array with the fields: (genome_k, genome_resolution).
    These correspond to the number of nearest neighbors and the resolution of the Leiden clustering algorithm.
    """

    storage: list[pl.LazyFrame] = []

    with tb.open_file(clustering_file) as fp:
        # this is a numpy struct array with the following fields:
        # (genome_k, genome_resolution)
        clu_metadata = fp.root.test["pst-large"].metadata[:]

        # I tried several different parameters for genome clustering, but I will
        # filter for only the iterations included in the PST manuscript
        clu_metadata_mask = np.array([
            genome_res > 0.01
            for genome_k, genome_res in clu_metadata
        ])

        mask_idx = np.where(clu_metadata_mask)[0]
        clu_metadata = clu_metadata[clu_metadata_mask]

        for dataset in fp.root:
            dataset_name = dataset._v_name
            current_genome_metadata = genome_metadata.filter(pl.col("dataset") == dataset_name)

            for method in dataset:
                method_name = method._v_name

                # this is just an int array with a cluster assignment for each genome
                # shape: (num clustering iters, num genomes)
                clusters: LongArray = method.data[mask_idx, :]

                for clu, (genome_k, genome_res) in zip(clusters, clu_metadata):
                    _compute_information_gain_per_cluster_set(
                        current_genome_metadata,
                        labels,
                        clu,
                        genome_k,
                        genome_res,
                        method_name,
                        dataset_name,
                        storage
                    )
    
    return _summarize_information_gain(storage, background)

In [8]:
summary = compute_information_gain(metadata, "genome_clusters.h5", label_columns, background)

summary

dataset,method,k,resolution,label,gain,gain_ratio,n_genomes,total_genomes,inclusion_weight,weighted_gain,weighted_gain_ratio
str,str,i32,f64,str,f64,f64,u32,u32,f64,f64,f64
"""test""","""ctx-avg-large""",2,0.1,"""host_class""",2.687934,0.999981,85238,151255,0.563538,1.514754,0.563528
"""test""","""ctx-avg-large""",2,0.1,"""host_domain""",0.019865,0.999985,85848,151255,0.567571,0.011275,0.567563
"""test""","""ctx-avg-large""",2,0.1,"""host_family""",5.660775,0.999987,83771,151255,0.55384,3.135161,0.553832
"""test""","""ctx-avg-large""",2,0.1,"""host_genus""",8.179681,0.999991,79873,151255,0.528068,4.319432,0.528064
"""test""","""ctx-avg-large""",2,0.1,"""host_order""",4.464418,0.999985,84581,151255,0.559195,2.496479,0.559186
"""test""","""ctx-avg-large""",2,0.1,"""host_phylum""",2.350818,0.999981,85826,151255,0.567426,1.333915,0.567415
"""test""","""ctx-avg-large""",2,0.1,"""viral_class""",2.122832,0.99999,149166,151255,0.986189,2.093513,0.986179
"""test""","""ctx-avg-large""",2,0.1,"""viral_family""",3.955677,0.999985,71155,151255,0.470431,1.860872,0.470424
"""test""","""ctx-avg-large""",2,0.1,"""viral_kingdom""",1.65354,0.999991,149459,151255,0.988126,1.633906,0.988117
"""test""","""ctx-avg-large""",2,0.1,"""viral_order""",4.0034,0.999982,77293,151255,0.511011,2.045782,0.511002
