# Searching MGnify Protein Database Embeddings Using Faiss

## Background

So far, the encapsulin hits in this repo have come from two different search methods - either structure search against the ESM Atlas, or functional annotation search against Pfam labels in the MGnify Protein Database.

However, we know that protein language models generate meaningful vector representations of protein sequences - so why can't we use vector-based methods to search encapsulin query vectors against a target database of vectors generated for all MGnify proteins?

Meta AI have made the embeddings for all MGnify proteins available publicly for download, so we don't have to generate 6B+ vectors ourselves. We can just use ESM-2 to generate embedding vectors for a handful of encapsulin sequences and then search the database using [faiss](https://github.com/facebookresearch/faiss), a package from Facebook AI Research designed to make vector search faster using GPU acceleration.

This notebook will experiment with loading, analyzing, and searching the database, and generating the query vectors.

## Database Vectors

The MGnify Protein Database embeddings can be downloaded using URLs available from the [ESM repo](https://github.com/facebookresearch/esm). I've already downloaded them here:

In [1]:
!ls ../../sequence_DBs/atlas/embeddings | head -10

tm_.40_.50_plddt_.40_.50_00.npz
tm_.40_.50_plddt_.40_.50_01.npz
tm_.40_.50_plddt_.50_.60_00.npz
tm_.40_.50_plddt_.50_.60_01.npz
tm_.40_.50_plddt_.60_.70_00.npz
tm_.40_.50_plddt_.60_.70_01.npz
tm_.40_.50_plddt_.70_.80_00.npz
tm_.40_.50_plddt_0_.40_00.npz
tm_.40_.50_plddt_0_.40_01.npz
tm_.50_.60_plddt_.40_.50_00.npz


Let's load one of these and see what's in it:

In [2]:
import numpy as np

db = np.load("../../sequence_DBs/atlas/embeddings/tm_.40_.50_plddt_.40_.50_00.npz")
db

<numpy.lib.npyio.NpzFile at 0x7f278c4bae30>

This object is a Numpy NPZ file, a compressed format for storing matrices and vectors. Let's check out what's inside:

In [3]:
for item in db.items():
    print(item)
    break

('/scratch/slurm_tmpdir/379122/tm_.40_.50_plddt_.40_.50_00/468/MGYP000160952468', array([ 0.05182 , -0.05518 , -0.009895, ..., -0.01677 , -0.0749  ,
       -0.012024], dtype=float16))


We can iterate through the NPZ file using the `.items()` method just like a standard Python dictionary - looks like the keys are a string containing the filename, and MGYP protein ID from MGnify, with the value being a Numpy array. Let's check out that array:

In [4]:
for key, value in db.items():
    print(f"ID: {key.split('/')[-1]}")
    print(f"Shape: {value.shape}")
    break

ID: MGYP000160952468
Shape: (2560,)


I'm curious about something - how many vectors do we have in total? Let's iterate through all the files and find out:

In [5]:
import os

length = 0

for i, filepath in enumerate(os.listdir("../../sequence_DBs/atlas/embeddings/")):
    print(f"Checking file {i+1}")
    db = np.load(f"../../sequence_DBs/atlas/embeddings/{filepath}")
    length += len(db.items())

print(length)

Checking file 1
Checking file 2
Checking file 3
Checking file 4
Checking file 5
Checking file 6
Checking file 7
Checking file 8
Checking file 9
Checking file 10
Checking file 11
Checking file 12
Checking file 13
Checking file 14
Checking file 15
Checking file 16
Checking file 17
Checking file 18
Checking file 19
Checking file 20
Checking file 21
Checking file 22
Checking file 23
Checking file 24
Checking file 25
Checking file 26
Checking file 27
Checking file 28
Checking file 29
Checking file 30
Checking file 31
Checking file 32
Checking file 33
Checking file 34
Checking file 35
Checking file 36
Checking file 37
Checking file 38
Checking file 39
Checking file 40
Checking file 41
Checking file 42
Checking file 43
Checking file 44
Checking file 45
Checking file 46
Checking file 47
Checking file 48
Checking file 49
Checking file 50
Checking file 51
Checking file 52
Checking file 53
Checking file 54
Checking file 55
Checking file 56
Checking file 57
Checking file 58
Checking file 59
Checki

In [7]:
print(f"Memory usage in gigabytes: {(length * 2560 * 2) / 1000000000}")

Memory usage in gigabytes: 251.33513216


49,088,893 vectors - at 2560 dimensions, and float16 dtype, that's 250 GB of vectors to load into memory! Clearly we can't just brute force everything. 

## Testing `faiss` Indexes

The way `faiss` works is that we create an `Index` and add all of our vectors to it. We can then do *k*-nearest neighbours search on this index with a set of query vectors.

There are lots of different index types and ways of doing this, the documentation is rather cryptic about this and gets into lots of technical jargon around clustering, transforms, search, and other details. However, there is a somewhat helpful [set of guidelines](https://github.com/facebookresearch/faiss/wiki/Guidelines-to-choose-an-index) around choosing the Index to use.

The way I see it is this: the most basic, brute-force search method `IndexFlatL2` is the most accurate, since it does an exhaustive search against all vectors, and stores all vectors uncompressed. As such this is the "best" performing method but will take too long and use too much memory for our purposes.

There are lots of different ways of saving time and memory, however I've decided to test the approach [outlined here](https://github.com/facebookresearch/faiss/wiki/Guidelines-to-choose-an-index#if-quite-important-then-opqm_dpqmx4fsr) in the docs. We'd eventually like to run these searches on GPU so our memory limit is around 24 GB (but I'd like some headroom on this figure).

Let's test this index with a subset of our data and see what happens. First, we need to load our embedding vectors into memory:

In [14]:
embeddings = np.array([db[file] for file in db.files])