In [1]:
import tables as tb
import torch
from torch_geometric.utils import segment

In [2]:
%load_ext watermark
%watermark -v -p torch,torch_geometric,torch_scatter,tables

Python implementation: CPython
Python version       : 3.10.14
IPython version      : 8.24.0

torch          : 2.2.2
torch_geometric: 2.5.2
torch_scatter  : 2.1.2
tables         : 3.9.2



In [3]:
data_file = "datasets/protein_embeddings/test_set_esm-small_inputs.graphfmt.h5"

# to use torch_geometric.uitls.segment, these need to be torch tensors and NOT numpy arrays
with tb.open_file(data_file) as fp:
    ptn_embed = torch.from_numpy(fp.root.data[:])
    genome_ptr = torch.from_numpy(fp.root.ptr[:])


# the ptn embeddings are stored in a stacked batch
# this is a 2D tensor of shape (num proteins, embed_dim)
ptn_embed.shape

torch.Size([7182220, 320])

In [4]:
# to keep track of where the proteins are for a given genome, we use the genome_ptr
# the genome_ptr is a 1D tensor of shape (num genomes + 1)
# the first element is always 0, and the last element is the total number of proteins
# this storage is basically to the CSR format used in sparse matrices
# see PyTorch-Geometric
genome_ptr.shape, genome_ptr

(torch.Size([151256]),
 tensor([      0,      24,      36,  ..., 7181373, 7181835, 7182220]))

You can find more information about this data handling and batching procedure as described by PyTorch-Geometric [here](https://pytorch-geometric.readthedocs.io/en/latest/get_started/introduction.html#mini-batches)

In [5]:
# you can get the number of proteins encoded by each genome like this:
genome_ptr[1:] - genome_ptr[:-1]

tensor([ 24,  12,  11,  ..., 399, 462, 385])

In [6]:
# the 8 protein embeddings for the 12346th genome can be retrieved like this:
genome_idx = 12345
start = genome_ptr[genome_idx]
end = genome_ptr[genome_idx + 1]
data = ptn_embed[start:end]
data, data.shape

(tensor([[-0.0280, -0.2302,  0.2222,  ...,  0.0936,  0.0269, -0.1448],
         [-0.0183, -0.2730,  0.3387,  ...,  0.0386,  0.1384,  0.0912],
         [-0.0685, -0.3271,  0.0728,  ...,  0.2096,  0.1714, -0.0694],
         [-0.2642, -0.0469,  0.0405,  ...,  0.0922, -0.2237, -0.0227]]),
 torch.Size([4, 320]))

In [7]:
# we use pytorch-geometric and pytorch-scatter to perform reductions on this stacked batch format using the genome_ptr
# to average the protein embeddings per genome:
genome_embed = segment(src=ptn_embed, ptr=genome_ptr, reduce="mean")
genome_embed, genome_embed.shape

(tensor([[ 0.0303, -0.3255,  0.2116,  ...,  0.0552,  0.2166,  0.0266],
         [-0.0071, -0.2595,  0.2108,  ...,  0.0519,  0.2123,  0.0655],
         [ 0.0369, -0.2487,  0.2112,  ...,  0.0966,  0.1809,  0.0249],
         ...,
         [ 0.2230, -0.0103,  0.0860,  ...,  0.1798, -0.2928, -0.0601],
         [ 0.2105, -0.0035,  0.0710,  ...,  0.1934, -0.3043, -0.0727],
         [ 0.1954,  0.0189,  0.0122,  ...,  0.2291, -0.3480, -0.1056]]),
 torch.Size([151255, 320]))