In [1]:
import os

import numpy as np
import tables as tb
from numpy.typing import NDArray

os.environ["POLARS_MAX_THREADS"] = "128"

import polars as pl  # noqa: E402

LongArray = NDArray[np.int64]

In [2]:
%load_ext watermark
%watermark -vp numpy,tables,polars

Python implementation: CPython
Python version       : 3.11.3
IPython version      : 8.13.2

numpy : 1.23.5
tables: 3.8.0
polars: 0.20.6



In [3]:
def get_background_counts(database: str) -> pl.DataFrame:
    if database == "PHROG":
        file = "supplementary_tables/supplementary_table_5.tsv"
        cat_col = "new_category"
    elif database == "VOG":
        file = "supplementary_tables/supplementary_table_6.tsv"
        cat_col = "function"
    else:
        raise ValueError(f"Unknown database: {database}. Must be 'PHROG' or 'VOG'")
    
    counts = (
        pl.read_csv(file, separator="\t")
        .filter(~pl.col(cat_col).str.contains("unknown"))
        [cat_col]
        .value_counts()
        .sort(cat_col)
        .rename({cat_col: "category"})
        .with_columns(
            prop = pl.col("count") / pl.col("count").sum(),
        )
        .with_columns(
            entropy = -pl.col("prop") * pl.col("prop").log(2)
        )
    )

    return counts

In [4]:
def per_cluster_information_gain(
    ptn_info: pl.DataFrame, 
    database: str,
    protein_clusters: LongArray,
    genome_clusters: LongArray,
    **metadata_kwargs,
) -> pl.DataFrame:
    if database not in {"PHROG", "VOG", "vog", "phrog"}:
        raise ValueError(f"Unknown database: {database}. Must be 'PHROG' or 'VOG'")
    else:
        database = database.lower()

    cat_col = f"{database}_category"

    metadata = {
        k: pl.lit(v)
        for k, v in metadata_kwargs.items()
    }

    per_cluster = (
        ptn_info
        .lazy()
        .with_columns(
            gclu = genome_clusters, # type: ignore
            pclu = protein_clusters, # type: ignore
        )
        .with_columns(
            cluster_size = pl.col("ptn_id").len().over("gclu", "pclu"),
        )
        # no singleton protein clusters
        # and ignore unknown functions
        # we effectively compute information gain with only labeled proteins
        # but keep the total size of the cluster for weighting purposes
        .filter(
            (pl.col("cluster_size") > 1) &
            (~pl.col(cat_col).str.contains("unknown"))
        )
        .group_by("gclu", "pclu")
        .agg(
            pl.first("cluster_size"),
            pl.col(cat_col).value_counts()
        )
        .explode(cat_col)
        .unnest(cat_col)
        .with_columns(
            prop = pl.col("count") / pl.sum("count").over("gclu", "pclu"),
        )
        .with_columns(
            entropy = -pl.col("prop") * pl.col("prop").log(2)
        )
        .group_by("gclu", "pclu")
        .agg(
            pl.first("cluster_size"),
            pl.sum("entropy")
        )
        .with_columns(
            cluster_size_weighted_entropy = pl.col("entropy") * pl.col("cluster_size") / pl.col("cluster_size").sum(),
            **metadata,
        )
        .collect()
    )

    return per_cluster

def information_gain(
    ptn_info: pl.DataFrame, 
    database: str,
    protein_clusters: LongArray,
    genome_clusters: LongArray,
    **metadata_kwargs,
) -> pl.DataFrame:
    if not metadata_kwargs:
        metadata_kwargs = {
            "dummy": pl.lit(0)
        }
        delete_dummy = True
    else:
        delete_dummy = False
    
    per_cluster = per_cluster_information_gain(
        ptn_info, 
        database, 
        protein_clusters, 
        genome_clusters,
        **metadata_kwargs,
    )

    background: float = get_background_counts(database)["entropy"].sum()
    total_data = ptn_info["ptn_id"].n_unique()

    info_gain = (
        per_cluster
        .group_by(metadata_kwargs.keys())
        .agg(
            pl.sum("cluster_size_weighted_entropy"),
            included_data = pl.sum("cluster_size"),
        )
        .with_columns(
            gain = background - pl.col("cluster_size_weighted_entropy"),
            total_data = total_data,
        )
        .with_columns(
            gain_ratio = pl.col("gain") / background,
            inclusion_weight = pl.col("included_data") / pl.col("total_data"),
        )
        .with_columns(
            weighted_gain = pl.col("gain") * pl.col("inclusion_weight"),
            weighted_gain_ratio = pl.col("gain_ratio") * pl.col("inclusion_weight"),
            **metadata_kwargs,
        )
    )

    if delete_dummy:
        return info_gain.drop("dummy")
    
    return info_gain

In [5]:
ptn_info = (
    pl.read_csv("supplementary_tables/supplementary_table_2.tsv", separator="\t")
    .filter(pl.col("dataset") == "test")
)

ptn_info

ptn,ptn_id,genome,genome_id,vog_bitscore,vog_annot,vog_category,phrog_bitscore,phrog_annot,phrog_category,dataset
str,i64,str,i64,f64,str,str,i64,str,str,str
"""IMGVR_UViG_256…",0,"""IMGVR_UViG_256…",0,,"""unknown functi…","""unknown""",,"""NA""","""unknown functi…","""test"""
"""IMGVR_UViG_256…",1,"""IMGVR_UViG_256…",0,,"""unknown functi…","""unknown""",,"""NA""","""unknown functi…","""test"""
"""IMGVR_UViG_256…",2,"""IMGVR_UViG_256…",0,,"""unknown functi…","""unknown""",,"""NA""","""unknown functi…","""test"""
"""IMGVR_UViG_256…",3,"""IMGVR_UViG_256…",0,,"""unknown functi…","""unknown""",,"""NA""","""unknown functi…","""test"""
"""IMGVR_UViG_256…",4,"""IMGVR_UViG_256…",0,,"""unknown functi…","""unknown""",,"""NA""","""unknown functi…","""test"""
"""IMGVR_UViG_256…",5,"""IMGVR_UViG_256…",0,,"""unknown functi…","""unknown""",,"""NA""","""unknown functi…","""test"""
"""IMGVR_UViG_256…",6,"""IMGVR_UViG_256…",0,,"""unknown functi…","""unknown""",,"""NA""","""unknown functi…","""test"""
"""IMGVR_UViG_256…",7,"""IMGVR_UViG_256…",0,,"""unknown functi…","""unknown""",,"""NA""","""unknown functi…","""test"""
"""IMGVR_UViG_256…",8,"""IMGVR_UViG_256…",0,,"""unknown functi…","""unknown""",58,"""NA""","""unknown functi…","""test"""
"""IMGVR_UViG_256…",9,"""IMGVR_UViG_256…",0,,"""unknown functi…","""unknown""",,"""NA""","""unknown functi…","""test"""


In [6]:
protein_clusters_file = "datasets/protein_clusters/embedding-based_protein_clusters_per_genome_cluster.h5"
with tb.open_file(protein_clusters_file) as fp:
    genome_embedding = "pst-large"
    protein_embedding = "pst-large"
    node = fp.root[genome_embedding]

    idx = 27
    # corresponds to: 
    # (genome_k, genome_res, protein_k, protein_res):
    # (15, 1.0, 15, 0.5)
    clustering_metadata = {
        label: value
        for label, value in zip(
            ["genome_k", "genome_res", "protein_k", "protein_res"],
            node.metadata[idx],
        )
    }
    genome_clusters = node[protein_embedding][idx]
    protein_clusters = node["genome"][idx]

protein_clusters.shape

(7182220,)

The inclusion weight keeps track of how many clustered proteins were considered for this analysis (defined as being in a cluster with more than 1 protein and at least 1 labeled protein)

In [7]:
information_gain(
    ptn_info, 
    "VOG", 
    protein_clusters, 
    genome_clusters, 
    **clustering_metadata
)

genome_k,genome_res,protein_k,protein_res,cluster_size_weighted_entropy,included_data,gain,total_data,gain_ratio,inclusion_weight,weighted_gain,weighted_gain_ratio
i32,f64,i32,f64,f64,u32,f64,i32,f64,f64,f64,f64
15,1.0,15,0.5,1.068582,6049841,0.855993,7182220,0.44477,0.842336,0.721034,0.374646
