In [8]:
import os
from pathlib import Path
import itertools as it

import numpy as np
import pandas as pd
import tables as tb
from numpy.typing import NDArray
from tqdm import tqdm, trange
from collections import defaultdict

import matplotlib.pyplot as plt
import seaborn as sns
import seaborn.objects as so
from seaborn import axes_style, plotting_context

import networkx as nx

from fastatools import FastaFile

os.environ["POLARS_MAX_THREADS"] = "128"

import polars as pl  # noqa: E402

LongArray = NDArray[np.int64]

In [36]:
def get_background_counts(database: str) -> pl.DataFrame:
    if database == "PHROG":
        file = "supplementary_table_5.tsv"
        cat_col = "new_category"
    elif database == "VOG":
        file = "supplementary_table_6.tsv"
        cat_col = "function"
    else:
        raise ValueError(f"Unknown database: {database}. Must be 'PHROG' or 'VOG'")
    
    counts = (
        pl.read_csv(file, separator="\t")
        [cat_col]
        .value_counts()
        .sort(cat_col)
        .rename({cat_col: "category"})
        .with_columns(
            prop = pl.col("count") / pl.col("count").sum(),
        )
        .with_columns(
            entropy = -pl.col("prop") * pl.col("prop").log(2)
        )
    )

    return counts

In [7]:
get_background_counts("PHROG")

category,count,prop,entropy
str,u32,f64,f64
"""connector""",138,0.003549,0.028886
"""gene expressio…",319,0.008205,0.056853
"""head and packa…",955,0.024563,0.131346
"""host takeover""",76,0.001955,0.01759
"""lysis""",305,0.007845,0.054866
"""lysogenic conv…",42,0.00108,0.010645
"""lysogeny""",120,0.003086,0.02574
"""metabolic gene…",87,0.002238,0.0197
"""nucleotide met…",1078,0.027726,0.143417
"""other""",775,0.019933,0.112596


In [34]:
get_background_counts("VOG")

ColumnNotFoundError: function

In [46]:
def per_cluster_information_gain(
    ptn_info: pl.DataFrame, 
    database: str,
    protein_clusters: LongArray,
    genome_clusters: LongArray,
    **metadata_kwargs,
) -> pl.DataFrame:
    if database not in {"PHROG", "VOG", "vog", "phrog"}:
        raise ValueError(f"Unknown database: {database}. Must be 'PHROG' or 'VOG'")
    else:
        database = database.lower()

    cat_col = f"{database}_category"

    metadata = {
        k: pl.lit(v)
        for k, v in metadata_kwargs.items()
    }

    per_cluster = (
        ptn_info
        .lazy()
        .with_columns(
            gclu = genome_clusters, # type: ignore
            pclu = protein_clusters, # type: ignore
        )
        .with_columns(
            cluster_size = pl.col("ptn_id").len().over("gclu", "pclu"),
        )
        # no singleton protein clusters
        # and ignore unknown functions
        # we effectively compute information gain with only labeled proteins
        # but keep the total size of the cluster for weighting purposes
        .filter(
            (pl.col("cluster_size") > 1) &
            (~pl.col(cat_col).str.contains("unknown"))
        )
        .group_by("gclu", "pclu")
        .agg(
            pl.first("cluster_size"),
            pl.col(cat_col).value_counts()
        )
        .explode(cat_col)
        .unnest(cat_col)
        .with_columns(
            prop = pl.col("count") / pl.col("count").sum().over("gclu", "pclu"),
        )
        .with_columns(
            entropy = -pl.col("prop") * pl.col("prop").log(2)
        )
        .group_by("gclu", "pclu")
        .agg(
            pl.first("cluster_size"),
            pl.sum("entropy")
        )
        .with_columns(
            weighted_entropy = pl.col("entropy") * pl.col("cluster_size") / pl.col("cluster_size").sum(),
            **metadata,
        )
        .collect()
    )

    return per_cluster

def information_gain(
    ptn_info: pl.DataFrame, 
    database: str,
    protein_clusters: LongArray,
    genome_clusters: LongArray,
    **metadata_kwargs,
) -> pl.DataFrame:
    if not metadata_kwargs:
        metadata_kwargs = {
            "dummy": pl.lit(0)
        }
        delete_dummy = True
    else:
        delete_dummy = False
    
    per_cluster = per_cluster_information_gain(
        ptn_info, 
        database, 
        protein_clusters, 
        genome_clusters,
        **metadata_kwargs,
    )

    background: float = get_background_counts(database)["entropy"].sum()
    total_data = ptn_info["ptn_id"].n_unique()

    info_gain = (
        per_cluster
        .with_columns(
            gain = background - pl.col("weighted_entropy"),
            total_data = total_data,
        )
        .with_columns(
            gain_ratio = pl.col("gain") / background,
        )
        .group_by(metadata_kwargs.keys())
        .agg(
            pl.mean("gain", "gain_ratio"),
            included_data = pl.sum("cluster_size"),
        )
        .with_columns(
            inclusion_weight = pl.col("included_data") / total_data,
        )
        .with_columns(
            weighted_gain = pl.col("gain") * pl.col("inclusion_weight"),
            weighted_gain_ratio = pl.col("gain_ratio") * pl.col("inclusion_weight"),
            **metadata_kwargs,
        )
    )

    if delete_dummy:
        return info_gain.drop("dummy")
    
    return info_gain

In [27]:
ptn_info = (
    pl.read_csv("supplementary_table_2.tsv", separator="\t")
    .filter(pl.col("dataset") == "test")
)

ptn_info

ptn,ptn_id,genome,genome_id,vog_annot,vog_category,phrog_annot,phrog_category,dataset
str,i64,str,i64,str,str,str,str,str
"""IMGVR_UViG_256…",0,"""IMGVR_UViG_256…",0,"""unknown functi…","""unknown""","""NA""","""unknown functi…","""test"""
"""IMGVR_UViG_256…",1,"""IMGVR_UViG_256…",0,"""unknown functi…","""unknown""","""NA""","""unknown functi…","""test"""
"""IMGVR_UViG_256…",2,"""IMGVR_UViG_256…",0,"""unknown functi…","""unknown""","""NA""","""unknown functi…","""test"""
"""IMGVR_UViG_256…",3,"""IMGVR_UViG_256…",0,"""unknown functi…","""unknown""","""NA""","""unknown functi…","""test"""
"""IMGVR_UViG_256…",4,"""IMGVR_UViG_256…",0,"""unknown functi…","""unknown""","""NA""","""unknown functi…","""test"""
"""IMGVR_UViG_256…",5,"""IMGVR_UViG_256…",0,"""unknown functi…","""unknown""","""NA""","""unknown functi…","""test"""
"""IMGVR_UViG_256…",6,"""IMGVR_UViG_256…",0,"""unknown functi…","""unknown""","""NA""","""unknown functi…","""test"""
"""IMGVR_UViG_256…",7,"""IMGVR_UViG_256…",0,"""unknown functi…","""unknown""","""NA""","""unknown functi…","""test"""
"""IMGVR_UViG_256…",8,"""IMGVR_UViG_256…",0,"""unknown functi…","""unknown""","""NA""","""unknown functi…","""test"""
"""IMGVR_UViG_256…",9,"""IMGVR_UViG_256…",0,"""unknown functi…","""unknown""","""NA""","""unknown functi…","""test"""


In [29]:
with tb.open_file("pst-large_per_genome_ptn_clustering.h5") as fp:
    protein_embedding = "pst-large"
    idx = 27
    # corresponds to: 
    # (genome_k, genome_res, protein_k, protein_res):
    # (15, 1.0, 15, 0.5)
    clustering_metadata = {
        label: value
        for label, value in zip(
            ["genome_k", "genome_res", "protein_k", "protein_res"],
            fp.root.metadata[idx],
        )
    }
    genome_clusters = fp.root[f"{protein_embedding}__genome"][idx]
    protein_clusters = fp.root[f"{protein_embedding}__protein"][idx]

protein_clusters.shape

(7182220,)

In [30]:
genome_clusters.shape

(7182220,)

In [41]:
pc = per_cluster_information_gain(ptn_info, "VOG", protein_clusters, genome_clusters, **clustering_metadata)

pc

gclu,pclu,cluster_size,entropy,weighted_entropy,genome_k,genome_res,protein_k,protein_res
u32,u32,u32,f64,f64,i32,f64,i32,f64
12049,14,16,0.918296,0.000002,15,1.0,15,0.5
12893,3,17,1.548795,0.000004,15,1.0,15,0.5
12797,8,10,1.0,0.000002,15,1.0,15,0.5
14096,28,10,1.5,0.000002,15,1.0,15,0.5
8115,16,7,1.0,0.000001,15,1.0,15,0.5
23166,6,3,-0.0,-0.0,15,1.0,15,0.5
23331,9,19,1.351644,0.000004,15,1.0,15,0.5
16984,9,16,1.459148,0.000004,15,1.0,15,0.5
16158,13,18,1.921928,0.000006,15,1.0,15,0.5
28236,16,15,1.0,0.000002,15,1.0,15,0.5


In [42]:
pc["weighted_entropy"].sum()

1.0685823658949747

The inclusion weight keeps track of how many clustered proteins were considered for this analysis (defined as being in a cluster with more than 1 protein and at least 1 labeled protein)

In [47]:
information_gain(
    ptn_info, 
    "VOG", 
    protein_clusters, 
    genome_clusters, 
    **clustering_metadata
)

genome_k,genome_res,protein_k,protein_res,gain,gain_ratio,included_data,inclusion_weight,weighted_gain,weighted_gain_ratio
i32,f64,i32,f64,f64,f64,u32,f64,f64,f64
15,1.0,15,0.5,1.072121,0.999998,6049841,0.842336,0.903086,0.842334
