In [1]:
import os
from pathlib import Path
import itertools as it

import numpy as np
import tables as tb

from functools import reduce

os.environ["POLARS_MAX_THREADS"] = "128"

import polars as pl  # noqa: E402

In [2]:
%load_ext watermark
%watermark -vp numpy,polars,tables

Python implementation: CPython
Python version       : 3.11.3
IPython version      : 8.13.2

numpy : 1.23.5
polars: 0.20.6
tables: 3.8.0



In [3]:
ptn_info = (
    pl.read_csv("supplementary_tables/supplementary_table_2.tsv", separator="\t")
    .filter(pl.col("dataset") == "test")
)

ptn_info

ptn,ptn_id,genome,genome_id,vog_bitscore,vog_annot,vog_category,phrog_bitscore,phrog_annot,phrog_category,dataset
str,i64,str,i64,f64,str,str,i64,str,str,str
"""IMGVR_UViG_256…",0,"""IMGVR_UViG_256…",0,,"""unknown functi…","""unknown""",,"""NA""","""unknown functi…","""test"""
"""IMGVR_UViG_256…",1,"""IMGVR_UViG_256…",0,,"""unknown functi…","""unknown""",,"""NA""","""unknown functi…","""test"""
"""IMGVR_UViG_256…",2,"""IMGVR_UViG_256…",0,,"""unknown functi…","""unknown""",,"""NA""","""unknown functi…","""test"""
"""IMGVR_UViG_256…",3,"""IMGVR_UViG_256…",0,,"""unknown functi…","""unknown""",,"""NA""","""unknown functi…","""test"""
"""IMGVR_UViG_256…",4,"""IMGVR_UViG_256…",0,,"""unknown functi…","""unknown""",,"""NA""","""unknown functi…","""test"""
"""IMGVR_UViG_256…",5,"""IMGVR_UViG_256…",0,,"""unknown functi…","""unknown""",,"""NA""","""unknown functi…","""test"""
"""IMGVR_UViG_256…",6,"""IMGVR_UViG_256…",0,,"""unknown functi…","""unknown""",,"""NA""","""unknown functi…","""test"""
"""IMGVR_UViG_256…",7,"""IMGVR_UViG_256…",0,,"""unknown functi…","""unknown""",,"""NA""","""unknown functi…","""test"""
"""IMGVR_UViG_256…",8,"""IMGVR_UViG_256…",0,,"""unknown functi…","""unknown""",58,"""NA""","""unknown functi…","""test"""
"""IMGVR_UViG_256…",9,"""IMGVR_UViG_256…",0,,"""unknown functi…","""unknown""",,"""NA""","""unknown functi…","""test"""


In [4]:
def add_protein_clustering(
    ptn_info: pl.DataFrame,
    cluster_file: str | Path,
    genome_k: int,
    genome_res: float,
    protein_k: int,
    protein_res: float,
    genome_embedding_type: str,
    protein_embedding_type: str,
) -> pl.DataFrame:
    with tb.open_file(cluster_file) as fp:
        group = fp.root[genome_embedding_type]
        clustering_metadata = group.metadata[:]
        clustering_query = np.array(
            (genome_k, genome_res, protein_k, protein_res), 
            dtype=clustering_metadata.dtype
        )

        try: 
            query_idx = np.where(clustering_metadata == clustering_query)[0][0]
        except IndexError as err:
            raise ValueError(
                f"Clustering not found with parameters: {genome_k=}, {genome_res=}, {protein_k=}, {protein_res=}"
            ) from err
        
        protein_clusters = group[protein_embedding_type][query_idx, :]
        genome_clusters = group["genome"][query_idx, :]

    return ptn_info.with_columns(pclu=protein_clusters, gclu=genome_clusters)

In [5]:
# need to use these regex strings to recategorize the PHROG annotations into more specific for this analysis

re_categorize = {
    "nucleotide metabolism": [ # these become replication
        "DNA pol", 
        "single strand DNA binding", 
        "Par[AB]", 
        "DNA primase",
        "(DNA)?[ ]?helicase",
        "repl",
        "primosom",
        "terminal",
        "ribonucleo[st]ide(.*)?reductase",
        "NDP reductase",
    ],
    "head and packaging": [ # split this into packaging
        "terminase", "portal"
    ],
}

def join_annot(annots: list[str]) -> str:
    return f"(?i){'|'.join(annots)}"

for k, v in re_categorize.items():
    print(join_annot(v))

(?i)DNA pol|single strand DNA binding|Par[AB]|DNA primase|(DNA)?[ ]?helicase|repl|primosom|terminal|ribonucleo[st]ide(.*)?reductase|NDP reductase
(?i)terminase|portal


For this example, we only use the genome clusters formed with `pst-large`, and then focus on protein clusters with `esm-large`, `pst-large`, and `genslm`.

In [6]:
clustering_hparams = dict(
    genome_k=15,
    genome_res=1.0,
    protein_k=5,
    protein_res=0.5,
)

protein_clusters_file = "datasets/protein_clusters/embedding-based_protein_clusters_per_genome_cluster.h5"

dfs: list[pl.DataFrame] = []
for protein_embedding_type in ["esm-large", "pst-large", "genslm"]:
   dfs.append(
        add_protein_clustering(
            ptn_info,
            protein_clusters_file,
            genome_embedding_type="pst-large",
            protein_embedding_type=protein_embedding_type,
            **clustering_hparams
        ).with_columns(protein_method = pl.lit(protein_embedding_type))
   )

ptn_annotations = (
    pl.concat(dfs)
    .with_columns(
        proteins_per_cluster=(
            pl.col("ptn_id")
            .len()
            .over("gclu", "pclu", "protein_method")
        ),
        genomes_per_cluster=(
            pl.col("genome_id")
            .n_unique()
            .over("gclu", "pclu", "protein_method")
        ),
        phrog_category=(
            pl.when(pl.col("phrog_annot").str.contains(join_annot(re_categorize["nucleotide metabolism"])))
            .then(pl.lit("replication"))
            .when(pl.col("phrog_annot").str.contains(join_annot(re_categorize["head and packaging"])))
            .then(pl.lit("packaging"))
            .otherwise(pl.col("phrog_category"))
        )
    )
    .filter(
        (pl.col("proteins_per_cluster") > 1)
        & (pl.col("genomes_per_cluster") > 1)
    )
)

ptn_annotations

ptn,ptn_id,genome,genome_id,vog_bitscore,vog_annot,vog_category,phrog_bitscore,phrog_annot,phrog_category,dataset,pclu,gclu,protein_method,proteins_per_cluster,genomes_per_cluster
str,i64,str,i64,f64,str,str,i64,str,str,str,u32,u32,str,u32,u32
"""IMGVR_UViG_256…",26,"""IMGVR_UViG_256…",1,,"""unknown functi…","""unknown""",147,"""NA""","""unknown functi…","""test""",2,1,"""esm-large""",2,2
"""IMGVR_UViG_256…",29,"""IMGVR_UViG_256…",1,225.97,"""REFSEQ putativ…","""gene expressio…",239,"""tail assembly …","""tail""","""test""",3,1,"""esm-large""",2,2
"""IMGVR_UViG_257…",45,"""IMGVR_UViG_257…",2,,"""unknown functi…","""unknown""",238,"""transcriptiona…","""gene expressio…","""test""",2,2,"""esm-large""",3,2
"""IMGVR_UViG_264…",58,"""IMGVR_UViG_264…",4,129.82,"""REFSEQ DNA rep…","""replication""",168,"""replication in…","""replication""","""test""",0,4,"""esm-large""",3,2
"""IMGVR_UViG_264…",66,"""IMGVR_UViG_264…",4,,"""unknown functi…","""unknown""",153,"""NA""","""unknown functi…","""test""",3,4,"""esm-large""",6,2
"""IMGVR_UViG_264…",69,"""IMGVR_UViG_264…",4,,"""unknown functi…","""unknown""",160,"""NA""","""unknown functi…","""test""",3,4,"""esm-large""",6,2
"""IMGVR_UViG_264…",70,"""IMGVR_UViG_264…",4,285.43,"""sp|P03660|REP_…","""replication""",202,"""replication pr…","""replication""","""test""",3,4,"""esm-large""",6,2
"""IMGVR_UViG_267…",72,"""IMGVR_UViG_267…",5,185.6,"""REFSEQ ORF5""","""other""",238,"""NA""","""unknown functi…","""test""",0,5,"""esm-large""",4,2
"""IMGVR_UViG_267…",73,"""IMGVR_UViG_267…",5,,"""unknown functi…","""unknown""",,"""NA""","""unknown functi…","""test""",1,5,"""esm-large""",4,4
"""IMGVR_UViG_267…",74,"""IMGVR_UViG_267…",5,,"""unknown functi…","""unknown""",86,"""tail protein""","""tail""","""test""",2,5,"""esm-large""",5,3


In [7]:
annot_counts = {}
grouping_cols = ["gclu", "pclu", "protein_method"]
for db in ["vog", "phrog"]:
    cat_col = f"{db}_category"
    counts = (
        ptn_annotations
        .filter(
            (~pl.col(cat_col).str.contains("unknown")) 
            & (pl.col(cat_col) != "other")
        )
        .group_by(grouping_cols + [cat_col, f"{db}_annot"])
        .agg(
            pl.first("proteins_per_cluster", "genomes_per_cluster"),
            count=pl.len(),
        )
        .with_columns(
            num_categories=pl.col(cat_col).n_unique().over(grouping_cols),
        )
        .sort("protein_method", "gclu", "pclu", cat_col)
        .group_by(grouping_cols)
        .agg(
            pl.col(cat_col).unique(),
            pl.sum("count"),
            pl.first("proteins_per_cluster", "genomes_per_cluster", "num_categories"),
        )
    )

    annot_counts[db] = counts

annot_counts["vog"]

gclu,pclu,protein_method,vog_category,count,proteins_per_cluster,genomes_per_cluster,num_categories
u32,u32,str,list[str],u32,u32,u32,u32
7,73,"""esm-large""","[""exit""]",1,3,2,1
9,30,"""esm-large""","[""structural""]",2,7,3,1
37,14,"""esm-large""","[""replication""]",1,4,2,1
38,32,"""esm-large""","[""integration""]",1,4,4,1
47,46,"""esm-large""","[""structural"", ""replication""]",2,5,3,2
79,7,"""esm-large""","[""replication""]",1,4,3,1
86,81,"""esm-large""","[""packaging"", ""structural"", … ""exit""]",5,6,3,4
105,68,"""esm-large""","[""structural""]",1,2,2,1
120,4,"""esm-large""","[""integration"", ""structural""]",2,3,2,2
136,13,"""esm-large""","[""integration""]",1,4,3,1


In [8]:
num_clusters = (
    ptn_annotations
    .group_by("gclu", "protein_method")
    .agg(num_clusters = pl.n_unique("pclu"))
    .group_by("protein_method")
    .agg(pl.sum("num_clusters"))
)

In [9]:
all_vog_cats = (
    ptn_annotations
    .lazy()
    .filter(~pl.col("vog_category").is_in({"unknown", "other"}))
    .select("vog_category")
    .unique()
    .collect()
    .to_series()
    .to_list()
)

all_phrog_cats = (
    ptn_annotations
    .lazy()
    .filter(~pl.col("phrog_category").is_in({"unknown function", "other"}))
    .select("phrog_category")
    .unique()
    .collect()
    .to_series()
    .to_list()
)

all_categories = {
    "vog": all_vog_cats,
    "phrog": all_phrog_cats,
}

def get_exprs(database: str, categories: list[str]) -> pl.Expr:
    categories = set(categories)
    cat_col = f"{database}_category"
    n_cats = len(categories)

    all_cats = set(all_categories[database])
    excluded_cats = all_cats - categories
    
    if n_cats == 1:
        category = list(categories)[0]
        expr = (
            pl.col(cat_col).list.contains(category) & 
            # make sure more than one ptn annotated with 
            # this category in the cluster
            pl.col("count").gt(1)
        )
    else:
        expr = reduce(
            lambda x, y: x | y, # contain any pair of categories
            [
                (
                    pl.col(cat_col).list.contains(cat1) & 
                    pl.col(cat_col).list.contains(cat2)
                )
                for cat1, cat2 in it.combinations(categories, 2)
            ]
        )

    # in either case, exclude protein clusters 
    # that include any other categories
    for cat in excluded_cats:
        expr &= ~pl.col(cat_col).list.contains(cat)

    return expr

modules = {
    "phrog": {
        "packaging": ["packaging"],
        "late genes": ["tail", "head and packaging", "connector", "lysis"],
        "replication": ["replication"],
        "DNA-interacting": ["gene expression", "lysogeny", "nucleotide metabolism"],
    },

    "vog": {
        "packaging": ["packaging"],
        "late genes": ["structural", "exit", "packaging"],
        "replication": ["replication"],
        "DNA-interacting": ["replication", "integration", "packaging", "gene expression"],
    }
}

The `proportion_clusters` columns is the proportions of protein clusters that fit the criteria for each functional module.

In [10]:
dfs: list[pl.LazyFrame] = []
for database, module_dict in modules.items():
    src_df: pl.DataFrame = annot_counts[database]
    for module, categories in module_dict.items():
        expr = get_exprs(database, categories)
        
        dfs.append(
            src_df
            .lazy()
            .filter(expr)
            .with_columns(module=pl.lit(module), database=pl.lit(database))
            .drop(f"{database}_category")
        )

count_summary = (
    pl.concat(pl.collect_all(dfs))
    .group_by("protein_method", "module", "database")
    .agg(
        ptn_count=pl.sum("count"),
        included_clusters=pl.len(),
    )
    .join(num_clusters, on="protein_method")
    .with_columns(
        proportion_clusters = pl.col("included_clusters") / pl.col("num_clusters"),
    )
    .sort("database", "module", "protein_method")
)

count_summary

protein_method,module,database,ptn_count,included_clusters,num_clusters,proportion_clusters
str,str,str,u32,u32,u32,f64
"""esm-large""","""DNA-interactin…","""phrog""",3019,1457,1115203,0.001306
"""genslm""","""DNA-interactin…","""phrog""",2865,1352,827398,0.001634
"""pst-large""","""DNA-interactin…","""phrog""",2607,1240,501656,0.002472
"""esm-large""","""late genes""","""phrog""",12916,6089,1115203,0.00546
"""genslm""","""late genes""","""phrog""",9637,4205,827398,0.005082
"""pst-large""","""late genes""","""phrog""",9265,4224,501656,0.00842
"""esm-large""","""packaging""","""phrog""",274,137,1115203,0.000123
"""genslm""","""packaging""","""phrog""",475,218,827398,0.000263
"""pst-large""","""packaging""","""phrog""",234,116,501656,0.000231
"""esm-large""","""replication""","""phrog""",802,400,1115203,0.000359
