If you have the Kernel Metric Network, you can use it to creat Reaction Specific Fingerprint.

From these fingerprints, you can divide them into Voronoi reagions using FAISS.

Lastly, record the Voronoi/Expert indices to the Pistachio Dataframe to train individual experts.


In [None]:
import torch
import torch.nn as nn
from scipy.special import softmax

    
class KernelMetricNetwork(nn.Module):
    def __init__(self, input_dim, num_classes):
        super(KernelMetricNetwork, self).__init__()
        print('Using', num_classes, 'classes predictions')
        self.fc1 = nn.Linear(input_dim, 256)
        self.fc2 = nn.Linear(256, 128)
        self.fc3 = nn.Linear(128, num_classes)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(0.3)
        self.batch_norm1 = nn.BatchNorm1d(256)
        self.batch_norm2 = nn.BatchNorm1d(128)

    def forward(self, x):
        x = self.dropout(self.batch_norm1(self.relu(self.fc1(x))))
        x = self.dropout(self.batch_norm2(self.relu(self.fc2(x))))
        x = self.fc3(x)
        return x

    
    def get_embeddings(self,x):
        
        x = self.batch_norm1(self.relu(self.fc1(x)))
        x = self.batch_norm2(self.relu(self.fc2(x)))
        
        return(x)

def load_model(model, filename):
    model.load_state_dict(torch.load(filename))
    return model    

model = load_model(KernelMetricNetwork(2048*3, 2285), "best_model_50ep_4096batchsize_AdamW.pth")

model.eval()  # Set to evaluation mode


device = 'cuda:0'

model.to(device);

In [None]:
import pandas as pd
import numpy as np
from rdkit import Chem
from tqdm import tqdm
from rdkit.Chem import AllChem,DataStructs
import matplotlib.pyplot as plt
import pickle
import faiss
from typing import List,Tuple
from rdkit import RDLogger     
from faiss import write_index, read_index
RDLogger.DisableLog('rdApp.*')   
import pickle
import pandas as pd



In [None]:
# Both of these two fingerprint can be efficiently generated using multi-processing from the Get_Molecukar_FP_MultiProcessing_LargeRAM.ipynb notebook.

# This variable below contains the array of features used to train the FAISS clustering of voronoi reagions
mix_fp_features = pickle.load(open('/global/cfs/cdirs/m410/haoteli/LLaMA/Mixture_Expert_Prediction_Preparation/MixFP_Reactant_Features_p4_r2_update_1024_dim.pkl', 'rb'))

# This variable below contains the array of features that is used to train the Llama3.1 model.
mix_fp_features_to_add = pickle.load(open('/global/cfs/cdirs/m410/haoteli/LLaMA/Mixture_Expert_Prediction_Preparation/MixFP_Train_Xprt_Reactant_Features_p4_r2_update_1024_dim.pkl', 'rb'))

In [None]:
from tqdm.notebook import tqdm

# processing all (concated/mix)-fingerprints into RSFP
# The following is executed on a single A100 on a shared node with 128GB RAM.

def process_features_in_batches(model, mix_fp_features, batch_size=512, device=device):
    # Convert the input list/array to a numpy array if it isn't already
    if not isinstance(mix_fp_features, np.ndarray):
        mix_fp_features = np.array(mix_fp_features)
    
    # Calculate the number of batches
    n_samples = len(mix_fp_features)
    n_batches = (n_samples + batch_size - 1) // batch_size
    
    # Initialize list to store features
    features = []
    
    # Process in batches
    model.eval()
    with torch.no_grad():
        
        for i in tqdm(range(n_batches)):
            # Get batch indices
            start_idx = i * batch_size
            end_idx = min((i + 1) * batch_size, n_samples)
            
            # Prepare batch
            batch = mix_fp_features[start_idx:end_idx]
            batch_tensor = torch.from_numpy(batch).float().to(device)
            
            # Get embeddings for batch
            batch_features = model.get_embeddings(batch_tensor)
            #batch_features = model(batch_tensor)
            
            # Store results
            features.extend(batch_features.cpu().numpy())
    
    return np.array(features)

# Usage example:
batch_size = 1024  # Adjust based on your GPU memory and model size
features = process_features_in_batches(model, mix_fp_features, batch_size=batch_size)
# The training features for FAISS

xprt_features = process_features_in_batches(model, mix_fp_features_to_add, batch_size=batch_size)
# The training features for Llama3.1

In [None]:
d = features.shape[-1]

nlist = 2500  # how many cells
print('nlist =',nlist)
quantizer = faiss.IndexFlatL2(d)
index = faiss.IndexIVFFlat(quantizer, d, nlist)

res = faiss.StandardGpuResources()
print('Converting to GPU index')
index = faiss.index_cpu_to_gpu(res, 0, index) # comment out this line if you do not, or dont have a GPU.


print('Training FAISS')
index.train(features)                                                                                                                                                                                                                                                                                                                                                                                                                                                      
print('Finished Training')
assert index.is_trained  # This has to be True, otherwise something is wrong
print('Adding features')
#index.add(features)
index.add(xprt_features)
print('Finished adding features')

In [None]:
cpu_index = faiss.index_gpu_to_cpu(index) # converting back to CPU index to save it


In [None]:

# Loading the pistachio database and sub-indexing the relevant files. This will then be added the voronoi indices for later training

df = pickle.load(open('/global/cfs/cdirs/m410/haoteli/EnLatent_Challenge/DataReaction_Match/FPCompatible_Cleaned_Pistachio.pkl','rb'))

df = df[~df['paragraphText'].isna()].reset_index(drop=True) # If there are no descriptions, drop them
df = df[~(df['agent'] == '[]')].reset_index(drop=True) # If there are no yield, drop them
df = df[~(df['agent_name'] == '[]')].reset_index(drop=True) # If there are no yield, drop them
df = df[~(df['solvent'] == '[]')].reset_index(drop=True)
df = df[~(df['solvent_name'] == '[]')].reset_index(drop=True) # If there are no yield, drop them
df = df[~(df['yield'] == '[]')].reset_index(drop=True)
df = df[['Example' not in p for p in df['paragraphText']] ].reset_index(drop=True)




In [None]:
# Calculating the Voronoi index assignments from FAISS
print('Calculating Dataset Quadrants')
distance, cell_ids = index.quantizer.search(xprt_features.reshape((-1,d)), k=1)
cell_ids = cell_ids.flatten()
print('Adding Expert IDs to Pistachio')
df['quadrant'] = cell_ids
print('Finished Calculations')

In [None]:
pickle.dump(df, open('RSFP_Train_Expert_df.pkl','wb')) # Saving the dataset with expert ID for LLM trainigns

In [None]:
write_index(cpu_index, "RSFP_Index.index") # saving index to be loaded next time

In [None]:
# Use the code below to test
index = read_index("RSFP_Index.index") # testing reading the file

res = faiss.StandardGpuResources()
index = faiss.index_cpu_to_gpu(res, 0, index)

df = pickle.load(open('RSFP_Train_Expert_df.pkl','rb'))
