In [1]:
import os
from pathlib import Path

import numpy as np
import tables as tb
from faiss import (
    METRIC_INNER_PRODUCT,
    IndexFlatIP,
    IndexFlatL2,
    IndexIVFFlat,
    omp_set_num_threads,
)
from numpy.typing import NDArray

THREADS = 128
os.environ["POLARS_MAX_THREADS"] = str(THREADS)
import polars as pl  # noqa: E402

omp_set_num_threads(THREADS)

FloatArray = NDArray[np.float32]
LongArray = NDArray[np.int32]

In [2]:
%load_ext watermark
%watermark -vp numpy,tables,faiss,polars

Python implementation: CPython
Python version       : 3.11.3
IPython version      : 8.13.2

numpy : 1.23.5
tables: 3.8.0
faiss : 1.8.0
polars: 0.20.6



In [3]:
def build_index(
    data: FloatArray, normalized: bool = True
) -> IndexIVFFlat | IndexFlatIP | IndexFlatL2:
    dim = data.shape[-1]

    n_cells = max(data.shape[0] // 39, 1)

    if normalized:
        if n_cells == 1:
            index = IndexFlatIP(dim)
        else:
            quantizer = IndexFlatIP(dim)
            index = IndexIVFFlat(quantizer, dim, n_cells, METRIC_INNER_PRODUCT)
    else:
        if n_cells == 1:
            index = IndexFlatL2(dim)
        else:
            quantizer = IndexFlatL2(dim)
            index = IndexIVFFlat(quantizer, dim, n_cells)
            
    index.train(data)  # type: ignore
    index.add(data)  # type: ignore
    return index

In [4]:
def read_data(file: str | Path, loc: str, normalize: bool = True) -> FloatArray:
    with tb.open_file(file) as fp:
        data = fp.root[loc][:]

    if normalize:
        norms = np.linalg.norm(data, axis=1, keepdims=True)
        data /= norms
    return data

def read_itemptr(protein_embedding_file: str | Path) -> LongArray:
    with tb.open_file(protein_embedding_file) as fp:
        return fp.root.ptr[:]
    
def annotation_improvment(
    ptn_info: pl.DataFrame, 
    ptn_embedding: FloatArray, 
    genome_clusters: LongArray, 
    genome_ptr: LongArray, 
    database: str = "VOG", 
    is_normalized: bool = True
) -> pl.DataFrame:
    labeled = (~ptn_info[f"{database.lower()}_category"].str.contains("unknown")).to_numpy()

    results = {
        "unlabeled": [],
        "get_labeled": [],
        "num_proteins": [],
        "num_genomes": [],
    }

    uniq_genome_clusters, genome_cluster_sizes = np.unique(genome_clusters, return_counts=True)
    for genome_cluster, cluster_size in zip(uniq_genome_clusters, genome_cluster_sizes):
        if cluster_size < 2:
            continue

        genome_idx = np.where(genome_clusters == genome_cluster)[0]
        starts = genome_ptr[genome_idx]
        ends = genome_ptr[genome_idx + 1]

        ptn_ids = np.concatenate(
            [np.arange(start, end) for start, end in zip(starts, ends)]
        )

        local_unlabeled_idx = ~labeled[ptn_ids]
        total_unlabeled = local_unlabeled_idx.sum()

        # just skip if there are no unlabeled proteins in this genome cluster
        if total_unlabeled == 0:
            continue

        local_ptn_embedding = np.concatenate(
            [ptn_embedding[start:ends] for start, ends in zip(starts, ends)]
        )

        index = build_index(local_ptn_embedding, normalized=is_normalized)

        # get the nearest neighbors
        # in faiss, the first nearest neighbor is the query itself
        _, nn_idx = index.search(local_ptn_embedding[local_unlabeled_idx], k=2)

        how_many_get_labeled = (
            ~local_unlabeled_idx[nn_idx[:, 1]]
        ).sum()

        results["unlabeled"].append(total_unlabeled)
        results["get_labeled"].append(how_many_get_labeled)
        results["num_proteins"].append(len(ptn_ids))
        results["num_genomes"].append(cluster_size)
    return pl.DataFrame(results)

In [5]:
ptn_info = (
    pl.read_csv("supplementary_table_2.tsv", separator="\t")
    .filter(pl.col("dataset") == "test")
)
ptn_info

ptn,ptn_id,genome,genome_id,vog_annot,vog_category,phrog_annot,phrog_category,dataset
str,i64,str,i64,str,str,str,str,str
"""IMGVR_UViG_256…",0,"""IMGVR_UViG_256…",0,"""unknown functi…","""unknown""","""NA""","""unknown functi…","""test"""
"""IMGVR_UViG_256…",1,"""IMGVR_UViG_256…",0,"""unknown functi…","""unknown""","""NA""","""unknown functi…","""test"""
"""IMGVR_UViG_256…",2,"""IMGVR_UViG_256…",0,"""unknown functi…","""unknown""","""NA""","""unknown functi…","""test"""
"""IMGVR_UViG_256…",3,"""IMGVR_UViG_256…",0,"""unknown functi…","""unknown""","""NA""","""unknown functi…","""test"""
"""IMGVR_UViG_256…",4,"""IMGVR_UViG_256…",0,"""unknown functi…","""unknown""","""NA""","""unknown functi…","""test"""
"""IMGVR_UViG_256…",5,"""IMGVR_UViG_256…",0,"""unknown functi…","""unknown""","""NA""","""unknown functi…","""test"""
"""IMGVR_UViG_256…",6,"""IMGVR_UViG_256…",0,"""unknown functi…","""unknown""","""NA""","""unknown functi…","""test"""
"""IMGVR_UViG_256…",7,"""IMGVR_UViG_256…",0,"""unknown functi…","""unknown""","""NA""","""unknown functi…","""test"""
"""IMGVR_UViG_256…",8,"""IMGVR_UViG_256…",0,"""unknown functi…","""unknown""","""NA""","""unknown functi…","""test"""
"""IMGVR_UViG_256…",9,"""IMGVR_UViG_256…",0,"""unknown functi…","""unknown""","""NA""","""unknown functi…","""test"""


In [6]:
ptn_embedding = read_data("../IMGVRv4_test_set_esm-small.h5", "data", normalize=True)
ptn_embedding.shape

(7182220, 320)

In [7]:
with tb.open_file("genome_clusters.h5") as fp:
    genome_clusters = fp.root.test["pst-large"].data[14, :]

genome_clusters.shape

(151255,)

In [8]:
genome_ptr = read_itemptr("../IMGVRv4_test_set_esm-small.h5")
genome_ptr.shape

(151256,)

For the PST manuscript, we focused on VOG annotations, since the VOG HMMs had a greater number of hits. 

In [9]:
annot_improvement_results = annotation_improvment(
    ptn_info, ptn_embedding, genome_clusters, genome_ptr, database="VOG", is_normalized=True
)

annot_improvement_results

unlabeled,get_labeled,num_proteins,num_genomes
i64,i64,i64,i64
1950,408,2465,92
397,47,468,42
522,114,652,56
315,61,404,36
41,12,54,5
270,49,339,26
148,37,195,12
1086,202,1359,58
754,183,991,64
335,76,425,29


We can summarize these results by totaling the percentages of unlabeled proteins before and after this annotation transfer. In the manuscript, we used `get_labeled_prop` for Figure 4A since the total number of proteins was constant for that analysis.

In [10]:
def summarize_annotation_improvement(results: pl.DataFrame) -> pl.DataFrame:
    summary = (
        results
        .with_columns(
            group_by_var = 0,
        )
        .group_by("group_by_var")
        .agg(
            pl.sum("unlabeled"),
            pl.sum("get_labeled"),
            pl.sum("num_proteins"),
        )
        .drop("group_by_var")
        .with_columns(
            labeled_before = pl.col("num_proteins") - pl.col("unlabeled"),
            get_labeled_prop = pl.col("get_labeled") / pl.col("unlabeled"),
        )
        .with_columns(
            labeled_after = pl.col("labeled_before") + pl.col("get_labeled"),
            before_annot_prop = pl.col("labeled_before") / pl.col("num_proteins"),
        )
        .with_columns(
            after_annot_prop = pl.col("labeled_after") / pl.col("num_proteins"),
        )
    )

    return summary

In [11]:
summarize_annotation_improvement(annot_improvement_results)

unlabeled,get_labeled,num_proteins,labeled_before,get_labeled_prop,labeled_after,before_annot_prop,after_annot_prop
i64,i64,i64,i64,f64,i64,f64,f64
5257627,1338980,7177634,1920007,0.254674,3258987,0.267499,0.454048


## Annotation *rate* improvement

In the manuscript, we also computed the sensitivity of genome clustering parameters (k-nearest neighbors and Leiden resolution). We computed the sensitivity as the slope of `get_labeled_prop` over k (nearest neighbors). The interpretation is the rate of annotation proporation improvement as the number of genome neighbors increases. Positive sensitivity values indicate that the genome clustering leads to increased ability to annotate proteins.

In [12]:
with tb.open_file("genome_clusters.h5") as fp:
    # genome_clusters = fp.root.test["pst-large"].data[14, :]
    genome_resolution = 0.1
    clustering_metadata = fp.root.test["pst-large"].metadata[:]
    mask = np.array(
        [genome_res == genome_resolution for genome_res in clustering_metadata["resolution"]]
    )

    mask_idx = np.where(mask)[0]
    all_genome_clusters = fp.root.test["pst-large"].data[mask_idx, :]
    clustering_metadata = clustering_metadata[mask]

results: list[pl.DataFrame] = []
for genome_clusters, (genome_k, genome_res) in zip(all_genome_clusters, clustering_metadata):
    print(f"{genome_k=} | {genome_res=}")

    annot_improvement_results = annotation_improvment(
        ptn_info, ptn_embedding, genome_clusters, genome_ptr, database="VOG", is_normalized=True
    )

    summary = (
        summarize_annotation_improvement(annot_improvement_results)
        .with_columns(
            genome_k = genome_k,
            genome_resolution = genome_res,
        )
    )

    results.append(summary)

sensitivity_df = pl.concat(results)
sensitivity_df

genome_k=2 | genome_res=0.1
genome_k=5 | genome_res=0.1
genome_k=10 | genome_res=0.1
genome_k=15 | genome_res=0.1
genome_k=25 | genome_res=0.1
genome_k=50 | genome_res=0.1


unlabeled,get_labeled,num_proteins,labeled_before,get_labeled_prop,labeled_after,before_annot_prop,after_annot_prop,genome_k,genome_resolution
i64,i64,i64,i64,f64,i64,f64,f64,i32,f64
5245950,1316859,7161047,1915097,0.251024,3231956,0.267433,0.451325,2,0.1
5245215,1330331,7160619,1915404,0.253628,3245735,0.267491,0.453276,5,0.1
5256830,1338694,7176514,1919684,0.254658,3258378,0.267495,0.454034,10,0.1
5257627,1338980,7177634,1920007,0.254674,3258987,0.267499,0.454048,15,0.1
5257627,1338991,7177634,1920007,0.254676,3258998,0.267499,0.454049,25,0.1
5257627,1338996,7177634,1920007,0.254677,3259003,0.267499,0.45405,50,0.1


In [13]:
def compute_sensitivity(sensitivity_df: pl.DataFrame) -> float:
    sensitivity_df = sensitivity_df.sort("genome_k")

    x = sensitivity_df["genome_k"].to_numpy()
    y = sensitivity_df["get_labeled_prop"].to_numpy()
    slope = np.polyfit(x, y, 1)[0]

    return slope

The result will, of course, only make sense in the context of other sensitivity values....

In [14]:
compute_sensitivity(sensitivity_df)

4.5487924091762055e-05