In [1]:
import itertools as it
import os
from collections import defaultdict

import tables as tb
import networkx as nx
import igraph as ig
import numpy as np

os.environ["POLARS_MAX_THREADS"] = "128"

import polars as pl  # noqa: E402

In [2]:
%load_ext watermark
%watermark -vp tables,networkx,igraph,numpy,polars

Python implementation: CPython
Python version       : 3.11.3
IPython version      : 8.13.2

tables  : 3.8.0
networkx: 3.1
igraph  : 0.11.3
numpy   : 1.23.5
polars  : 0.20.6



In [3]:
def get_background_counts(database: str = "PHROG") -> pl.DataFrame:
    if database.upper() == "PHROG":
        file = "supplementary_tables/supplementary_table_5.tsv"
        cat_col = "new_category"
    elif database.upper() == "VOG":
        file = "supplementary_tables/supplementary_table_6.tsv"
        cat_col = "function"
    else:
        raise ValueError(f"Invalid database: {database}. Must be 'PHROG' or 'VOG'")
    
    annot_metadata = pl.read_csv(file, separator="\t")

    annot_counts = (
        annot_metadata
        .filter(~pl.col(cat_col).str.contains("unknown"))
        [cat_col]
        .value_counts()
    )

    background_dict = defaultdict(int)
    for u, v in it.combinations(annot_counts.iter_rows(named=True), r=2):
        u_cat = u[cat_col]
        v_cat = v[cat_col]
        u_count = u["count"]
        v_count = v["count"]

        background_dict[(u_cat, v_cat)] = u_count * v_count

    to_df = {
        "source": [],
        "target": [],
        "co_occur_count": [],
    }

    for (u_cat, v_cat), count in background_dict.items():
        to_df["source"].append(u_cat)
        to_df["target"].append(v_cat)
        to_df["co_occur_count"].append(count)

    background = pl.DataFrame(to_df).sort("co_occur_count", descending=True)

    return background

background = get_background_counts("PHROG")
background

source,target,co_occur_count
str,str,i64
"""tail""","""nucleotide met…",1334564
"""tail""","""head and packa…",1182290
"""nucleotide met…","""head and packa…",1029490
"""tail""","""other""",959450
"""nucleotide met…","""other""",835450
"""other""","""head and packa…",740125
"""tail""","""gene expressio…",394922
"""lysis""","""tail""",377590
"""gene expressio…","""nucleotide met…",343882
"""lysis""","""nucleotide met…",328790


In [4]:
def get_annotation_counts(
    ptn_info: pl.DataFrame, 
    protein_cluster_labels: np.ndarray,
    genome_cluster_labels: np.ndarray,
    database: str = "PHROG",
) -> pl.DataFrame:
    cat_col = f"{database.lower()}_category"

    annot_counts = (
        ptn_info
        .with_columns(
            pclu = protein_cluster_labels, # type: ignore
            gclu = genome_cluster_labels, # type: ignore
        )
        .with_columns(
            cluster_size = pl.col("ptn_id").len().over("pclu", "gclu")
        )
        .filter(
            (pl.col("cluster_size") > 1) &
            (~pl.col(cat_col).str.contains("unknown"))
        )
        .group_by("gclu", "pclu")
        .agg(
            pl.col(cat_col).value_counts()
        )
        .with_columns(
            n_categories = pl.col(cat_col).list.len(),
        )
        # only care about co-occurrence, so we need more than 1 category
        .filter(pl.col("n_categories") > 1)
        .explode(cat_col)
        .unnest(cat_col)
    )

    return annot_counts

def construct_cooccurrence_graph(annot_counts: pl.DataFrame, background: pl.DataFrame, database: str = "PHROG") -> nx.Graph:
    cat_col = f"{database.lower()}_category"

    background_key_order = [
        (row["source"], row["target"])
        for row in background.iter_rows(named=True)
    ]

    cooccurence_graph = defaultdict(int)
    for _, group_df in annot_counts.group_by("gclu", "pclu"):
        obs_lut = {
            row[cat_col]: row["count"]
            for row in group_df.iter_rows(named=True)
        }

        for (u, v) in background_key_order:
            cooccurence_graph[u, v] += obs_lut.get(u, 0) * obs_lut.get(v, 0)

    to_df = {
        "source": [],
        "target": [],
        "observed_co_occur_count": [],
    }

    for (u, v), count in cooccurence_graph.items():
        to_df["source"].append(u)
        to_df["target"].append(v)
        to_df["observed_co_occur_count"].append(count)

    df = (
        pl.DataFrame(to_df)
        .join(background, on=["source", "target"])
        .rename({"co_occur_count": "expected_co_occur_count"})
        .with_columns(
            obs_prop = pl.col("observed_co_occur_count") / pl.col("observed_co_occur_count").sum(),
            exp_prop = pl.col("expected_co_occur_count") / pl.col("expected_co_occur_count").sum(),
        )
        .with_columns(
            ratio = pl.col("obs_prop") / pl.col("exp_prop"),
        )
        .sort("ratio", descending=True)
    )

    G = nx.Graph()

    for row in df.iter_rows(named=True):
        G.add_edge(row["source"], row["target"], weight=row["ratio"])

    return G

def cluster_graph(G: nx.Graph) -> tuple[np.ndarray, np.ndarray]:
    graph: ig.Graph = ig.Graph.from_networkx(G, vertex_attr_hashable="name")

    function_clusters = np.array(
        # resolution 1.0 means focus on clustering 
        # those enriched above background
        graph.community_leiden(weights="weight", resolution=1.0)
        .membership
    )

    nodes = np.array([node["name"] for node in graph.vs])

    return nodes, function_clusters

In [5]:
protein_clusters_file = "datasets/protein_clusters/embedding-based_protein_clusters_per_genome_cluster.h5"
with tb.open_file(protein_clusters_file) as fp:
    # corresponds to (genome_k, genome_res, protein_k, protein_res) ==
    # (15, 1.0, 15, 0.5)
    clustering_idx = 27
    genome_embedding = "pst-large"
    node = fp.root[genome_embedding]

    clustering_metadata = node.metadata[clustering_idx]

    protein_embedding = "pst-large"
    protein_cluster_labels = node[protein_embedding][clustering_idx]
    genome_cluster_labels = node["genome"][clustering_idx]

protein_cluster_labels.shape

(7182220,)

In [6]:
ptn_info = (
    pl.read_csv("supplementary_tables/supplementary_table_2.tsv", separator="\t")
    .filter(pl.col("dataset") == "test")
)

ptn_info

ptn,ptn_id,genome,genome_id,vog_bitscore,vog_annot,vog_category,phrog_bitscore,phrog_annot,phrog_category,dataset
str,i64,str,i64,f64,str,str,i64,str,str,str
"""IMGVR_UViG_256…",0,"""IMGVR_UViG_256…",0,,"""unknown functi…","""unknown""",,"""NA""","""unknown functi…","""test"""
"""IMGVR_UViG_256…",1,"""IMGVR_UViG_256…",0,,"""unknown functi…","""unknown""",,"""NA""","""unknown functi…","""test"""
"""IMGVR_UViG_256…",2,"""IMGVR_UViG_256…",0,,"""unknown functi…","""unknown""",,"""NA""","""unknown functi…","""test"""
"""IMGVR_UViG_256…",3,"""IMGVR_UViG_256…",0,,"""unknown functi…","""unknown""",,"""NA""","""unknown functi…","""test"""
"""IMGVR_UViG_256…",4,"""IMGVR_UViG_256…",0,,"""unknown functi…","""unknown""",,"""NA""","""unknown functi…","""test"""
"""IMGVR_UViG_256…",5,"""IMGVR_UViG_256…",0,,"""unknown functi…","""unknown""",,"""NA""","""unknown functi…","""test"""
"""IMGVR_UViG_256…",6,"""IMGVR_UViG_256…",0,,"""unknown functi…","""unknown""",,"""NA""","""unknown functi…","""test"""
"""IMGVR_UViG_256…",7,"""IMGVR_UViG_256…",0,,"""unknown functi…","""unknown""",,"""NA""","""unknown functi…","""test"""
"""IMGVR_UViG_256…",8,"""IMGVR_UViG_256…",0,,"""unknown functi…","""unknown""",58,"""NA""","""unknown functi…","""test"""
"""IMGVR_UViG_256…",9,"""IMGVR_UViG_256…",0,,"""unknown functi…","""unknown""",,"""NA""","""unknown functi…","""test"""


In [7]:
database = "PHROG"

annot_counts = get_annotation_counts(
    ptn_info, 
    protein_cluster_labels, 
    genome_cluster_labels, 
    database
)

graph = construct_cooccurrence_graph(annot_counts, background, database)

In Figure 4B and Extended Data Figure 7 of the PST manuscript, the connected components are only those that clustered together:

In [8]:
nodes, function_clusters = cluster_graph(graph)

longest_name = max(len(node) for node in nodes)
for node, cluster in zip(nodes, function_clusters):
    print(f"{node:<{longest_name}}: {cluster}")

tail                 : 0
connector            : 0
lysis                : 0
lysogeny             : 1
host takeover        : 1
gene expression      : 1
head and packaging   : 0
lysogenic conversion : 1
nucleotide metabolism: 1
metabolic gene       : 2
other                : 2
