In [1]:
from pathlib import Path
from typing import NamedTuple

import igraph as ig
import numpy as np
import scipy.sparse as sp
import tables as tb
from faiss import (
    METRIC_INNER_PRODUCT,
    IndexFlatIP,
    IndexFlatL2,
    IndexIVFFlat,
    omp_set_num_threads,
)
from numpy.typing import NDArray
from tqdm import tqdm

In [2]:
%load_ext watermark
%watermark -vp igraph,numpy,scipy,tables,faiss,tqdm

Python implementation: CPython
Python version       : 3.11.3
IPython version      : 8.13.2

igraph: 0.11.3
numpy : 1.23.5
scipy : 1.11.4
tables: 3.8.0
faiss : 1.8.0
tqdm  : 4.65.0



In [3]:
FloatArray = NDArray[np.float32]
INTTYPE = np.uint32
LongArray = NDArray[INTTYPE]
THREADS = 128

omp_set_num_threads(THREADS)

class KnnGraphBase(NamedTuple):
    sim: FloatArray
    idx: LongArray

In [4]:
def build_index(
    data: FloatArray, normalized: bool = True
) -> IndexIVFFlat | IndexFlatIP | IndexFlatL2:
    dim = data.shape[-1]

    n_cells = max(data.shape[0] // 39, 1)

    if normalized:
        if n_cells == 1:
            index = IndexFlatIP(dim)
        else:
            quantizer = IndexFlatIP(dim)
            index = IndexIVFFlat(quantizer, dim, n_cells, METRIC_INNER_PRODUCT)
    else:
        if n_cells == 1:
            index = IndexFlatL2(dim)
        else:
            quantizer = IndexFlatL2(dim)
            index = IndexIVFFlat(quantizer, dim, n_cells)

    index.train(data)  # type: ignore
    index.add(data)  # type: ignore
    return index

In [5]:
def to_csr(edge_index: LongArray, edge_attr: FloatArray) -> sp.csr_array:
    n_nodes = edge_index.max() + 1
    spadj = sp.coo_matrix(
        (edge_attr, (edge_index[0], edge_index[1])), shape=(n_nodes, n_nodes)
    ).tocsr()
    return spadj


def to_sparse(idx: LongArray, sim: FloatArray) -> sp.csr_array:
    k = idx.shape[-1]
    edge_index = np.vstack(
        (
            np.repeat(np.arange(idx.shape[0], dtype=INTTYPE), k),
            idx.flatten(),
        )
    )

    edge_attr = sim.flatten()
    mask = edge_attr > 0.0

    edge_index = edge_index[:, mask]
    edge_attr = edge_attr[mask]

    rev_edge_index = np.vstack((edge_index[1], edge_index[0]), dtype=INTTYPE)
    edge_index = np.hstack((edge_index, rev_edge_index), dtype=INTTYPE)
    edge_attr = np.concatenate((edge_attr, edge_attr))

    edge_index, sort_idx = np.unique(edge_index, return_index=True, axis=1)
    edge_attr = edge_attr[sort_idx]

    return to_csr(edge_index, edge_attr)


def to_graph(idx: LongArray, sim: FloatArray) -> ig.Graph:
    adj = to_sparse(idx, sim)

    graph = ig.Graph.Weighted_Adjacency(adj, mode="undirected", attr="weight")

    return graph


def knn(
    data: FloatArray,
    index: IndexIVFFlat | IndexFlatIP | IndexFlatL2,
    k: int,
) -> KnnGraphBase:
    keff = k + 1
    dist: FloatArray
    idx: LongArray
    dist, idx = index.search(data, k=keff)  # type: ignore

    idx = idx.astype(INTTYPE)

    if index.metric_type == METRIC_INNER_PRODUCT:
        return KnnGraphBase(sim=dist, idx=idx)

    scale = 1 / np.sqrt(data.shape[-1])
    rbf = np.exp(-(np.square(dist.astype(np.float64))) * scale).astype(np.float32)
    return KnnGraphBase(sim=rbf, idx=idx)


def leiden_clustering(knn_graph: ig.Graph, resolution: float = 1.0) -> LongArray:
    # result df has two columns: vertex and partition, ie genome id and cluster id
    result = knn_graph.community_leiden(
        objective_function="cpm", weights="weight", resolution=resolution
    )

    return np.array(result.membership, dtype=INTTYPE)


def read_data(file: str | Path, loc: str, normalize: bool = True) -> FloatArray:
    with tb.open_file(file) as fp:
        data = fp.root[loc][:]

    if normalize:
        norms = np.linalg.norm(data, axis=1, keepdims=True)
        data /= norms
    return data

Proteins are only clustered if their source genomes cluster together.

In [6]:
def cluster_proteins(
    ptn_embeddings: FloatArray,
    protein_k: int,
    protein_resolution: float,
    genome_ptr: LongArray,
    genome_clusters: LongArray,
    is_normalized: bool = True,
) -> tuple[LongArray, LongArray]:
    uniq_genome_clusters, genome_cluster_sizes = np.unique(genome_clusters, return_counts=True)

    n_ptns = ptn_embeddings.shape[0]
    ptn_cluster_labels = np.arange(n_ptns, dtype=INTTYPE)

    genome_cluster_labels = np.arange(n_ptns, dtype=INTTYPE)

    pbar = tqdm(total=uniq_genome_clusters.size)
    for genome_cluster, genome_cluster_size in zip(uniq_genome_clusters, genome_cluster_sizes):
        # only cluster proteins within the same genome cluster
        genome_idx = np.where(genome_clusters == genome_cluster)[0]

        global_starts = genome_ptr[genome_idx]
        global_ends = genome_ptr[genome_idx + 1]
        genome_sizes = global_ends - global_starts

        local_ptr = np.concatenate((
            [0],
            np.cumsum(genome_sizes)
        ))

        local_starts = local_ptr[:-1]
        local_ends = local_ptr[1:]

        for start, end in zip(global_starts, global_ends):
            genome_cluster_labels[start:end] = genome_cluster

        if genome_cluster_size < 2:
            continue

        local_ptn_embeddings = np.concatenate([
            ptn_embeddings[start:end] for start, end in zip(global_starts, global_ends)
        ])

        index = build_index(local_ptn_embeddings, normalized=is_normalized)

        knn_base = knn(local_ptn_embeddings, index, protein_k)
        graph = to_graph(knn_base.idx, knn_base.sim)
        labels = leiden_clustering(graph, protein_resolution)

        for start, end, local_start, local_end in zip(global_starts, global_ends, local_starts, local_ends):
            ptn_cluster_labels[start:end] = labels[local_start:local_end]

        pbar.update(1)

    return ptn_cluster_labels, genome_cluster_labels

In [7]:
with tb.open_file("../genome_embeddings/evaluations/genome_clusters.h5") as fp:
    genome_clusters = fp.root.test["pst-large"].data[14, :]

genome_clusters

array([   0,    1,    2, ..., 2643, 3327, 1490])

In [8]:
ptn_embeddings = read_data(
    "IMGVRv4_test_set_esm-small.h5", "data", normalize=True
)

ptn_embeddings.shape

(7182220, 320)

In [9]:
with tb.open_file("evaluations/IMGVRv4_test_set_esm-large_embeddings.h5") as fp:
    genome_ptr = fp.root.ptr[:]

# shape: (num genomes + 1,)
genome_ptr.shape

(151256,)

In [10]:
pclu, gclu = cluster_proteins(
    ptn_embeddings,
    15,
    0.5,
    genome_ptr,
    genome_clusters,
    is_normalized=True,
)

pclu, gclu

 98%|█████████▊| 3812/3891 [18:56<00:23,  3.35it/s]  


(array([   0,    1,    2, ..., 1239, 2821, 2799], dtype=uint32),
 array([   0,    0,    0, ..., 1490, 1490, 1490], dtype=uint32))

In [11]:
pclu.shape, gclu.shape

((7182220,), (7182220,))