In [2]:
# Set the current directory to be the repo's parent directory
# This needs to be be updated for your local machine
import os
os.chdir("/Users/erjo3868/repos/chem-embed/chem-embed")

# Imports

In [3]:
from src.constants import DATA_DIR, EMBEDDINGS_DIR, INDICES_DIR
from src.create_and_query_index import add_and_query_index
from src.utils import pickle_and_compress, decompress_and_unpickle
import faiss
import glob
import numpy as np
from sklearn.model_selection import train_test_split
import time
from pathlib import Path

# Define indices to test with

In [None]:
# Constants
rand_seed = 42069
num_query_vecs = 1000
num_neighbors = 20

# Variable that will keep track of all results

# Load embedding
embeddings = glob.glob(str(EMBEDDINGS_DIR) + "/*.npy")

for embedding in embeddings:
    print(f"Embedding = {embedding}")
    embedding = Path(embedding)
    data = np.load(embedding)
    embed_dim = data.shape[1]

    # Define indices
    indices = {
        "flat": faiss.IndexFlatL2(embed_dim),
        "flat-binary": faiss.IndexBinaryFlat(embed_dim)
    }
    m_vals = [32]
    for m in m_vals:
        indices[f"hnsw-{m}"] = faiss.IndexHNSWFlat(embed_dim, m)

    # Split data into vectors that will go into the index and those that we'll use to
    # query the index
    index_vecs, query_vecs = train_test_split(
        data, test_size=num_query_vecs, random_state=rand_seed
    )
    # Add vectors to index and query it
    results = {}
    for name, index in indices.items():
        index_file_path = INDICES_DIR / f"{embedding.stem}.{name}.index"
        result_file_path = INDICES_DIR / f"{embedding.stem}.{name}.result.pkl"
        if (not index_file_path.exists()) or (not result_file_path.exists()):
            try:
                print(f"Running index '{name}'...")
                t0 = time.time()
                result, index = add_and_query_index(
                    index_vecs=index_vecs,
                    query_vecs=query_vecs,
                    index=index,
                    num_neighbors=num_neighbors,
                )
                print(f"\t Took {round(time.time()-t0, 2)} seconds")
                faiss.write_index(index, str(index_file_path))
                pickle_and_compress(
                    obj=result,
                    file_path=result_file_path
                )
            except:
                print(f"Index '{name}' failed!")
                continue


Embedding = /Users/erjo3868/repos/chem-embed/chem-embed/data/embeddings/chembl_v23_smiles_morgan_radius=1_dim=2048.npy
Running index 'flat-binary'...
Index 'flat-binary' failed!
Embedding = /Users/erjo3868/repos/chem-embed/chem-embed/data/embeddings/chembl_v23_smiles_morgan_radius=1_dim=1024.npy
Running index 'flat-binary'...
Index 'flat-binary' failed!
Running index 'hnsw-32'...
	 Took 222.68 seconds
Embedding = /Users/erjo3868/repos/chem-embed/chem-embed/data/embeddings/chembl_v23_smiles_morgan_radius=0_dim=1024.npy
Running index 'flat'...
	 Took 13.45 seconds
Running index 'flat-binary'...
Index 'flat-binary' failed!
Running index 'hnsw-32'...
	 Took 593.9 seconds
Embedding = /Users/erjo3868/repos/chem-embed/chem-embed/data/embeddings/chembl_v23_smiles_morgan_radius=0_dim=2048.npy
Running index 'flat'...
	 Took 28.79 seconds
Running index 'flat-binary'...
Index 'flat-binary' failed!
Running index 'hnsw-32'...
