In [1]:
import tables as tb
import torch
from torch_geometric.utils import segment

In [2]:
%load_ext watermark
%watermark -v -p torch,torch_geometric,torch_scatter,tables

Python implementation: CPython
Python version       : 3.10.14
IPython version      : 8.24.0

torch          : 2.2.2
torch_geometric: 2.5.2
torch_scatter  : 2.1.2
tables         : 3.9.2



In [3]:
data_file = "../../../data/raw/esm2_t6_8M/hq_viruses_cleaned.filtered.GRAPHFMT.h5"

# to use torch_geometric.uitls.segment, these need to be torch tensors and NOT numpy arrays
with tb.open_file(data_file) as fp:
    ptn_embed = torch.from_numpy(fp.root.data[:])
    genome_ptr = torch.from_numpy(fp.root.ptr[:])


# the ptn embeddings are stored in a stacked batch
# this is a 2D tensor of shape (num proteins, embed_dim)
ptn_embed.shape

torch.Size([6391562, 320])

In [4]:
# to keep track of where the proteins are for a given genome, we use the genome_ptr
# the genome_ptr is a 1D tensor of shape (num genomes + 1)
# the first element is always 0, and the last element is the total number of proteins
# this storage is basically to the CSR format used in sparse matrices
# see PyTorch-Geometric
genome_ptr.shape, genome_ptr

(torch.Size([103590]),
 tensor([      0,       2,       4,  ..., 6387660, 6389580, 6391562]))

You can find more information about this data handling and batching procedure as described by PyTorch-Geometric [here](https://pytorch-geometric.readthedocs.io/en/latest/get_started/introduction.html#mini-batches)

In [5]:
# you can get the number of proteins encoded by each genome like this:
genome_ptr[1:] - genome_ptr[:-1]

tensor([   2,    2,    2,  ..., 1912, 1920, 1982])

In [6]:
# the 8 protein embeddings for the 12346th genome can be retrieved like this:
genome_idx = 12345
start = genome_ptr[genome_idx]
end = genome_ptr[genome_idx + 1]
data = ptn_embed[start:end]
data, data.shape

(tensor([[ 0.0674, -0.1024,  0.2181,  ...,  0.1273,  0.1046, -0.0729],
         [-0.0429, -0.0993,  0.1606,  ...,  0.1879,  0.0067, -0.0889],
         [ 0.0924,  0.1377,  0.0870,  ...,  0.1892, -0.3167, -0.1418],
         ...,
         [-0.0493, -0.1621,  0.1583,  ...,  0.0697,  0.0931, -0.0173],
         [-0.0351, -0.1719,  0.2076,  ...,  0.0286,  0.0916,  0.0069],
         [ 0.1466, -0.0786,  0.0592,  ...,  0.1793, -0.1845,  0.0206]]),
 torch.Size([8, 320]))

In [7]:
# we use pytorch-geometric and pytorch-scatter to perform reductions on this stacked batch format using the genome_ptr
# to average the protein embeddings per genome:
genome_embed = segment(src=ptn_embed, ptr=genome_ptr, reduce="mean")
genome_embed, genome_embed.shape

(tensor([[ 0.0060, -0.2644,  0.1481,  ...,  0.0706,  0.1421,  0.0293],
         [-0.1472, -0.1224,  0.0080,  ...,  0.1619,  0.0151, -0.0125],
         [ 0.0055, -0.2752,  0.1366,  ...,  0.2125,  0.0188, -0.0243],
         ...,
         [-0.0525, -0.1741,  0.0631,  ...,  0.1884,  0.1247, -0.0510],
         [-0.0386, -0.2544,  0.1082,  ...,  0.1716,  0.1657, -0.0124],
         [-0.0468, -0.1754,  0.0623,  ...,  0.1929,  0.1073, -0.0620]]),
 torch.Size([103589, 320]))