In [1]:
import torch
import sys
sys.path.append('MOSAIC/PredictionUtils') # This is where the library is located

from Transformation_Model import KernelMetricNetwork
from ChemUtils import create_rxn_Mix_FP

def load_model(model, filename):
    model.load_state_dict(torch.load(filename))
    return model      

KMN_cpth = 'best_model_50ep_4096batchsize_AdamW.pth'

model = load_model(KernelMetricNetwork(2048*3, 2285), KMN_cpth)
model.eval();  # Set to evaluation mode
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
model.to(device);

features = pickle.load('Features.pkl') # This is the set of all features already processed by KMN.

Using 2285 classes predictions


  model.load_state_dict(torch.load(filename))


In [3]:
# Overwriting variable to make it smaller for this test.
features = features[:1000] # Top-1000 extracted features from KMN. They were previously computed in the process of building FAISS index for MOSAIC.
print(features.shape)

(1000, 128)


In [4]:
import numpy as np
import faiss
import pickle
from faiss import read_index

# ============= LOAD ORIGINAL INDEX AND DATA =============
print('Loading original index...')
index = read_index("RSFP_Index.index") # testing reading the file
res = faiss.StandardGpuResources()
index_level1 = faiss.index_cpu_to_gpu(res, 0, index_level1)

# Extract feature dimension from index
d = index_level1.d
print(f'Feature dimension: {d}')

# ============= RANDOMLY SAMPLE NEW FEATURES (Level 2, the new data that needs to be incrementally added) =============
np.random.seed(2)
sample_indices = np.random.choice(len(features), size=500, replace=False)

new_features = np.array([features[i] for i in sample_indices])

print(f'Sampled {len(new_features)} new features')

# ============= BUILD LEVEL 2 INDEX =============
nlist_new = 5  # Creating 5 clusters for demonstration
print(f'Building Level 2 index with {nlist_new} clusters...')

quantizer_level2 = faiss.IndexFlatL2(d)
index_level2 = faiss.IndexIVFFlat(quantizer_level2, d, nlist_new)
res2 = faiss.StandardGpuResources()
index_level2 = faiss.index_cpu_to_gpu(res2, 0, index_level2)

print('Training Level 2 index...')
index_level2.train(new_features)
assert index_level2.is_trained

print('Adding features to Level 2...')
index_level2.add(new_features)

# ============= HIERARCHICAL SEARCH FUNCTION =============
def hierarchical_search(query_feat, k=3): # Default searching with nearest 3 Voronoi reagions (k=3)
    # Search Level 1
    D1, I1 = index_level1.quantizer.search(query_feat.reshape(1, -1), k)
    
    # Search Level 2
    D2, I2 = index_level2.quantizer.search(query_feat.reshape(1, -1), k)
    
    # Concatenate and rank
    all_distances = np.concatenate([D1[0], D2[0]])
    all_indices = np.concatenate([
        I1[0], 
        I2[0] + index_level1.ntotal  # Offset by Level 1 size. If this one indeed gets returned, the indices will not overlap.
    ])
    
    sorted_idx = np.argsort(all_distances) # Combining results from hierarchical search
    return all_distances[sorted_idx][:k], all_indices[sorted_idx][:k]

Loading original index...
Feature dimension: 128
Sampled 500 new features
Building Level 2 index with 5 clusters...
Training Level 2 index...
Adding features to Level 2...


In [5]:
fp_size = 1024
clean_reactions = [
    'BrC1=CC=CC=C1.C2COCCN2>>C3(N4CCOCC4)=CC=CC=C3', # Classic Buchwald-Hartwig
    'BrC1=CC=CC=C1.CC2(C)C(C)(C)OB(B3OC(C)(C)C(C)(C)O3)O2>>CC(O4)(C)C(C)(C)OB4C5=CC=CC=C5', # Suzuki
]

test_features = []
with torch.no_grad():
    for i in range(len(clean_reactions)):
        rxn = clean_reactions[i]
        rxn_fp = create_rxn_Mix_FP(rxn, rxnfpsize=fp_size, pfpsize=fp_size, useChirality=True) 
        rxn_fp = np.concatenate((rxn_fp[1],rxn_fp[2],rxn_fp[0]), axis = -1) # reactant, diff, product
        feat = model.get_embeddings(torch.from_numpy(rxn_fp).view(1,-1).float().to(device))
        test_features.append(feat.cpu().numpy())
        
test_feats = np.array(test_features).squeeze()

# Query
query_rxn = clean_reactions[0]
query_feat = test_feats[0]

# Search both levels independently
k = 3
D1, I1 = index_level1.quantizer.search(query_feat.reshape(1, -1), k)
D2, I2 = index_level2.quantizer.search(query_feat.reshape(1, -1), k)

print(f'\nQuery: {query_rxn}')
print('\n' + '='*60)
print('LEVEL 1 RESULTS (Original Index)')
print('='*60)
for i, (dist, idx) in enumerate(zip(D1[0], I1[0])):
    print(f'  {i+1}. Expert {idx:4d} | Distance: {dist:.2f}')

print('\n' + '='*60)
print('LEVEL 2 RESULTS (New Index)')
print('='*60)
for i, (dist, idx) in enumerate(zip(D2[0], I2[0])):
    print(f'  {i+1}. Expert {idx:4d} | Distance: {dist:.2f}')

# Hierarchical search (concatenated, ranked, and used as one framework).
distances, indices = hierarchical_search(query_feat, k=k)

print('\n' + '='*60)
print('HIERARCHICAL RESULTS (Combined & Ranked)')
print('='*60)
for i, (dist, idx) in enumerate(zip(distances, indices)):
    level = "Level 1" if idx < index_level1.ntotal else "Level 2"
    actual_idx = idx if idx < index_level1.ntotal else idx - index_level1.ntotal
    print(f'  {i+1}. Distance: {dist:.2f} | {level} Expert {actual_idx}')


Query: BrC1=CC=CC=C1.C2COCCN2>>C3(N4CCOCC4)=CC=CC=C3

LEVEL 1 RESULTS (Original Index)
  1. Expert   59 | Distance: 29.32
  2. Expert 1021 | Distance: 43.02
  3. Expert 1774 | Distance: 48.19

LEVEL 2 RESULTS (New Index)
  1. Expert    0 | Distance: 308.35
  2. Expert    2 | Distance: 397.04
  3. Expert    3 | Distance: 459.27

HIERARCHICAL RESULTS (Combined & Ranked)
  1. Distance: 29.32 | Level 1 Expert 59
  2. Distance: 43.02 | Level 1 Expert 1021
  3. Distance: 48.19 | Level 1 Expert 1774
