In [1]:
import torch
import tables as tb
import einops

import os

os.environ['POLARS_MAX_THREADS'] = '128'
import polars as pl

In [2]:
%load_ext watermark
%watermark -vp numpy,torch,tables,einops

Python implementation: CPython
Python version       : 3.10.14
IPython version      : 8.24.0

numpy : 1.26.4
torch : 2.2.2
tables: 3.9.2
einops: 0.8.0



In [3]:
def entropy(distr: torch.Tensor) -> torch.Tensor:
    summand = distr * torch.log2(distr + 1e-8)
    if len(summand.shape) == 1:
        return -torch.sum(summand)
    return -einops.reduce(summand, "nodes heads -> heads", reduction="sum")

def uniform_distr(n: int) -> torch.Tensor:
    return torch.ones(n) / n

def uniform_entropy(n: int) -> torch.Tensor:
    distr = uniform_distr(n)
    return entropy(distr)

In [4]:
with tb.open_file("pst-large_model_outputs.h5") as fp:
    attn = torch.from_numpy(fp.root.attn[:])

# shape: (num proteins, num heads)
attn.shape

torch.Size([7182220, 32])

In [5]:
# average attention per head
mean_attn = einops.reduce(attn, "nodes heads -> nodes", reduction="mean")

# shape: (num proteins,)
mean_attn.shape

torch.Size([7182220])

Since these are stacked batches of protein embeddings, we need an index pointer to keep track of the start and stop of each genome.

In [6]:
with tb.open_file("IMGVRv4_test_set_esm-large_embeddings.h5") as fp:
    genome_ptr = torch.from_numpy(fp.root.ptr[:])

# shape: (num genomes + 1,)
genome_ptr.shape

torch.Size([151256])

In [7]:
# number of proteins per genome
sizes = genome_ptr[1:] - genome_ptr[:-1]

# assign each ptn to each genome id
index = torch.cat([
    torch.tensor([i] * size.item())
    for i, size in enumerate(sizes)
])

In [8]:
background_entropy = torch.tensor(
    [
        uniform_entropy(size.item())
        for size in sizes
    ],
    dtype=torch.float,
)

background_entropy

tensor([4.5850, 3.5850, 3.4594,  ..., 8.6402, 8.8517, 8.5887])

In [9]:
genome_entropies: list[torch.Tensor] = []
for idx in range(genome_ptr.numel() - 1):
    start = genome_ptr[idx]
    end = genome_ptr[idx + 1]
    distr = mean_attn[start:end]
    ent = entropy(distr)
    genome_entropies.append(ent)

# shape: (num genomes,)
genome_entropy = torch.tensor(genome_entropies)
genome_entropy

tensor([2.0856, 1.9859, 1.8284,  ..., 2.6654, 2.6339, 2.4102])

In [10]:
gain = background_entropy - genome_entropy
gain_ratio = gain / background_entropy

# scale all attention values by the genome's 
# distance from a uniform distribution (gain_ratio)
weighted_attn = mean_attn * gain_ratio[index]

In [11]:
mmseqs_clusters = (
    pl.read_csv(
        "mmseqs_clusters.tsv",
        separator="\t",
        has_header=False,
        new_columns=["rep", "ptn"],
    )
    .with_columns(
        cluster_id = pl.col("rep").rle_id(),
        cluster_size = pl.col("ptn").len().over("rep")
    )
)

ptn_info = (
    pl.read_csv("supplementary_table_2.tsv", separator="\t")
    .filter(pl.col("dataset") == "test")
    .join(mmseqs_clusters, on="ptn")
    .with_columns(
        attn = weighted_attn.numpy(),
    )
)

ptn_info

ptn,ptn_id,genome,genome_id,vog_annot,vog_category,phrog_annot,phrog_category,dataset,rep,cluster_id,cluster_size,attn
str,i64,str,i64,str,str,str,str,str,str,u32,u32,f32
"""IMGVR_UViG_GVMAG-S-1016713-123…",6997237,"""IMGVR_UViG_GVMAG-S-1016713-123…",150790,"""unknown function""","""unknown""","""NA""","""unknown function""","""test""","""IMGVR_UViG_GVMAG-S-1016713-123…",0,1,0.238486
"""IMGVR_UViG_3300020065_000052|3…",29590,"""IMGVR_UViG_3300020065_000052""",7702,"""unknown function""","""unknown""","""NA""","""unknown function""","""test""","""IMGVR_UViG_3300020065_000052|3…",1,1,0.017035
"""IMGVR_UViG_3300020068_000030|3…",29878,"""IMGVR_UViG_3300020068_000030""",7780,"""unknown function""","""unknown""","""NA""","""unknown function""","""test""","""IMGVR_UViG_3300020068_000030|3…",2,2,0.0
"""IMGVR_UViG_3300042413_000309|3…",229911,"""IMGVR_UViG_3300042413_000309""",29399,"""unknown function""","""unknown""","""NA""","""unknown function""","""test""","""IMGVR_UViG_3300020068_000030|3…",2,2,4.8572e-23
"""IMGVR_UViG_3300020070_000833|3…",30294,"""IMGVR_UViG_3300020070_000833""",7896,"""unknown function""","""unknown""","""NA""","""unknown function""","""test""","""IMGVR_UViG_3300020070_000833|3…",3,1,0.0
…,…,…,…,…,…,…,…,…,…,…,…,…
"""IMGVR_UViG_3300037124_000040|3…",2380370,"""IMGVR_UViG_3300037124_000040""",86529,"""unknown function""","""unknown""","""NA""","""unknown function""","""test""","""IMGVR_UViG_3300037124_000040|3…",1742934,2,0.0
"""IMGVR_UViG_3300019790_000089|3…",737478,"""IMGVR_UViG_3300019790_000089""",47932,"""unknown function""","""unknown""","""NA""","""unknown function""","""test""","""IMGVR_UViG_3300037124_000040|3…",1742934,2,0.0
"""IMGVR_UViG_3300037124_000040|3…",2380402,"""IMGVR_UViG_3300037124_000040""",86529,"""unknown function""","""unknown""","""NA""","""unknown function""","""test""","""IMGVR_UViG_3300037124_000040|3…",1742935,1,0.0
"""IMGVR_UViG_3300037124_000029|3…",2380498,"""IMGVR_UViG_3300037124_000029""",86530,"""unknown function""","""unknown""","""NA""","""unknown function""","""test""","""IMGVR_UViG_3300037124_000029|3…",1742936,1,0.0


In [12]:
def most_attended_functional_categories(ptn_info: pl.DataFrame, top_n: int = 50, annotation_database: str = "VOG") -> pl.DataFrame:
    annotation_database = annotation_database.upper()
    if annotation_database not in {"VOG", "PHROG"}:
        raise ValueError("Annotation database must be either VOG or PHROG")
    
    cat_col = f"{annotation_database.lower()}_category"
    summary = (
        ptn_info
        .group_by("cluster_id")
        .agg(
            pl.max("attn"),
            pl.col(cat_col).unique(),
        )
        .explode(cat_col)
        .sort("attn", descending=True)
        .group_by(cat_col)
        .head(top_n)
        .sort("attn", descending=True)
    )

    return summary

In [13]:
most_attended_functional_categories(ptn_info, top_n=50, annotation_database="VOG")

vog_category,cluster_id,attn
str,u32,f32
"""unknown""",1634748,0.341053
"""unknown""",313789,0.33946
"""unknown""",1711756,0.336021
"""unknown""",1606130,0.33581
"""other""",1617312,0.335521
…,…,…
"""integration""",1394196,0.294153
"""integration""",1392386,0.294058
"""integration""",1399225,0.294006
"""integration""",224002,0.293705
