# Introduction
This notebook documents the Kernel Metric Network (KMN), which defines the reaction specific fingerprint (RSFP). This notebook is intended to be a standalone implementation for the ease of reproduction.

Author: Haote Li, haote.li@yale.edu

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, random_split
from torch.optim.lr_scheduler import CosineAnnealingLR
from sklearn.preprocessing import LabelEncoder
import math
from tqdm.notebook import tqdm

class CustomDataset(Dataset):
    def __init__(self, features, labels):
        self.features = torch.tensor(features, dtype=torch.float32)
        self.label_encoder = LabelEncoder()
        self.labels = torch.tensor(self.label_encoder.fit_transform(labels), dtype=torch.long)

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        return self.features[idx], self.labels[idx]

     
    
class KernelMetricNetwork(nn.Module):
    def __init__(self, input_dim, num_classes):
        super(KernelMetricNetwork, self).__init__()
        print('Using', num_classes, 'classes predictions')
        self.fc1 = nn.Linear(input_dim, 256)
        self.fc2 = nn.Linear(256, 128)
        self.fc3 = nn.Linear(128, num_classes)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(0.3)
        self.batch_norm1 = nn.BatchNorm1d(256)
        self.batch_norm2 = nn.BatchNorm1d(128)

    def forward(self, x):
        x = self.dropout(self.batch_norm1(self.relu(self.fc1(x))))
        x = self.dropout(self.batch_norm2(self.relu(self.fc2(x))))
        x = self.fc3(x)
        return x


def create_data_loaders(features, labels, batch_size, train_ratio=0.9):
    dataset = CustomDataset(features, labels)
    train_size = int(train_ratio * len(dataset))
    eval_size = len(dataset) - train_size
    torch.manual_seed(2)
    train_dataset, eval_dataset = random_split(dataset, [train_size, eval_size])

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    eval_loader = DataLoader(eval_dataset, batch_size=batch_size, shuffle=False)

    return train_loader, eval_loader

class WarmupCosineSchedule(optim.lr_scheduler._LRScheduler):
    def __init__(self, optimizer, warmup_epochs, total_epochs, last_epoch=-1):
        self.warmup_epochs = warmup_epochs
        self.total_epochs = total_epochs
        super(WarmupCosineSchedule, self).__init__(optimizer, last_epoch)

    def get_lr(self):
        if self.last_epoch < self.warmup_epochs:
            return [base_lr * (self.last_epoch / self.warmup_epochs) for base_lr in self.base_lrs]
        else:
            return [base_lr * 0.5 * (1 + math.cos(math.pi * (self.last_epoch - self.warmup_epochs) / (self.total_epochs - self.warmup_epochs))) 
                    for base_lr in self.base_lrs]

def train_model(model, train_loader, eval_loader, num_epochs, lr, device, warmup_epochs):
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.AdamW(model.parameters(), lr=lr)
    scheduler = WarmupCosineSchedule(optimizer, warmup_epochs, num_epochs)

    best_eval_acc = 0.0
    train_loss = 0.
    best_model_state = None

    for epoch in range(num_epochs):
        model.train()
        for batch_features, batch_labels in tqdm(train_loader):
            batch_features, batch_labels = batch_features.to(device), batch_labels.to(device)
            optimizer.zero_grad()
            outputs = model(batch_features)
            loss = criterion(outputs, batch_labels)
            train_loss += loss.item()
            loss.backward()
            optimizer.step()
        
        scheduler.step()
        
        
        train_loss = train_loss/len(train_loader)
        
        model.eval()
        correct = 0
        total = 0
        eval_loss = 0.
        with torch.no_grad():
            model.eval()
            for batch_features, batch_labels in eval_loader:
                batch_features, batch_labels = batch_features.to(device), batch_labels.to(device)
                outputs = model(batch_features)
                eval_loss += criterion(outputs, batch_labels).item()
                
                _, predicted = torch.max(outputs.data, 1)
                total += batch_labels.size(0)
                correct += (predicted == batch_labels).sum().item()
        eval_loss = eval_loss/len(eval_loader)
        eval_acc = correct / total
        print(f'Epoch {epoch+1}/{num_epochs}, Validation Accuracy: {eval_acc:.4f}, Train_loss {train_loss:.4f} ,Eval_Loss: {eval_loss:.4f}')
        print()
        if eval_acc > best_eval_acc:
            best_eval_acc = eval_acc
            best_model_state = model.state_dict()

    return best_model_state, best_eval_acc

def save_model(model_state, filename):
    torch.save(model_state, filename)

def load_model(model, filename):
    model.load_state_dict(torch.load(filename))
    return model

def main(features, labels, num_epochs=50, batch_size=1024, lr=0.001, warmup_epochs=5):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    train_loader, eval_loader = create_data_loaders(features, labels, batch_size)
    
    input_dim = features.shape[1]
    num_classes = len(set(labels))
    model = KernelMetricNetwork(input_dim, num_classes).to(device)
    
    total_params = sum(p.numel() for p in model.parameters())
    print(f"Number of parameters: {total_params}")
    
    best_model_state, best_acc = train_model(model, train_loader, eval_loader, num_epochs, lr, device, warmup_epochs)
    
    print(f"Best validation accuracy: {best_acc:.4f}")
    
    save_model(best_model_state, "best_model_50ep_4096batchsize_AdamW.pth")

In [2]:
import pickle
import numpy as np
import pandas as pd

In [3]:

features = pickle.load(open('MixFP_Reactant_Features_p4_r2_update_1024_dim.pkl', 'rb'))
# Load the premade features. This object is an int8 numpy array with dimensions [N data entries, 1024*3]



In [4]:
df = pickle.load(open('FPCompatible_Cleaned_Pistachio.pkl','rb'))
# Loading Pistachio to obtain all named reactions to create individual labels
df = df[~df['namerxndef'].isna()].reset_index(drop=True)

labels = df['namerxndef'].values 
del(df)

In [5]:
# Sanity check
valid_indices = []
for l in labels:
    if type(l) == type(''):
        valid_indices.append(True)
    else:
        valid_indices.append(False)

In [5]:
main(features, labels, num_epochs = 50, batch_size = 4096,lr=0.001)

Using 2285 classes predictions
Number of parameters: 1901549


  0%|          | 0/817 [00:00<?, ?it/s]

Epoch 1/50, Validation Accuracy: 0.0004, Train_loss 7.9725 ,Eval_Loss: 7.8573



  0%|          | 0/817 [00:00<?, ?it/s]

Epoch 2/50, Validation Accuracy: 0.8239, Train_loss 2.1127 ,Eval_Loss: 0.8009



  0%|          | 0/817 [00:00<?, ?it/s]

Epoch 3/50, Validation Accuracy: 0.8783, Train_loss 0.7126 ,Eval_Loss: 0.4413



  0%|          | 0/817 [00:00<?, ?it/s]

Epoch 4/50, Validation Accuracy: 0.8915, Train_loss 0.5237 ,Eval_Loss: 0.3657



  0%|          | 0/817 [00:00<?, ?it/s]

Epoch 5/50, Validation Accuracy: 0.8970, Train_loss 0.4601 ,Eval_Loss: 0.3362



  0%|          | 0/817 [00:00<?, ?it/s]

Epoch 6/50, Validation Accuracy: 0.9019, Train_loss 0.4293 ,Eval_Loss: 0.3182



  0%|          | 0/817 [00:00<?, ?it/s]

Epoch 7/50, Validation Accuracy: 0.9039, Train_loss 0.4013 ,Eval_Loss: 0.3070



  0%|          | 0/817 [00:00<?, ?it/s]

Epoch 8/50, Validation Accuracy: 0.9066, Train_loss 0.3838 ,Eval_Loss: 0.2956



  0%|          | 0/817 [00:00<?, ?it/s]

Epoch 9/50, Validation Accuracy: 0.9089, Train_loss 0.3714 ,Eval_Loss: 0.2878



  0%|          | 0/817 [00:00<?, ?it/s]

Epoch 10/50, Validation Accuracy: 0.9097, Train_loss 0.3615 ,Eval_Loss: 0.2827



  0%|          | 0/817 [00:00<?, ?it/s]

Epoch 11/50, Validation Accuracy: 0.9115, Train_loss 0.3527 ,Eval_Loss: 0.2783



  0%|          | 0/817 [00:00<?, ?it/s]

Epoch 12/50, Validation Accuracy: 0.9118, Train_loss 0.3448 ,Eval_Loss: 0.2754



  0%|          | 0/817 [00:00<?, ?it/s]

Epoch 13/50, Validation Accuracy: 0.9128, Train_loss 0.3383 ,Eval_Loss: 0.2724



  0%|          | 0/817 [00:00<?, ?it/s]

Epoch 14/50, Validation Accuracy: 0.9138, Train_loss 0.3320 ,Eval_Loss: 0.2688



  0%|          | 0/817 [00:00<?, ?it/s]

Epoch 15/50, Validation Accuracy: 0.9141, Train_loss 0.3261 ,Eval_Loss: 0.2672



  0%|          | 0/817 [00:00<?, ?it/s]

Epoch 16/50, Validation Accuracy: 0.9152, Train_loss 0.3209 ,Eval_Loss: 0.2642



  0%|          | 0/817 [00:00<?, ?it/s]

Epoch 17/50, Validation Accuracy: 0.9162, Train_loss 0.3158 ,Eval_Loss: 0.2620



  0%|          | 0/817 [00:00<?, ?it/s]

Epoch 18/50, Validation Accuracy: 0.9170, Train_loss 0.3115 ,Eval_Loss: 0.2584



  0%|          | 0/817 [00:00<?, ?it/s]

Epoch 19/50, Validation Accuracy: 0.9174, Train_loss 0.3071 ,Eval_Loss: 0.2570



  0%|          | 0/817 [00:00<?, ?it/s]

Epoch 20/50, Validation Accuracy: 0.9180, Train_loss 0.3031 ,Eval_Loss: 0.2547



  0%|          | 0/817 [00:00<?, ?it/s]

Epoch 21/50, Validation Accuracy: 0.9184, Train_loss 0.2994 ,Eval_Loss: 0.2543



  0%|          | 0/817 [00:00<?, ?it/s]

Epoch 22/50, Validation Accuracy: 0.9192, Train_loss 0.2951 ,Eval_Loss: 0.2522



  0%|          | 0/817 [00:00<?, ?it/s]

Epoch 23/50, Validation Accuracy: 0.9194, Train_loss 0.2921 ,Eval_Loss: 0.2515



  0%|          | 0/817 [00:00<?, ?it/s]

Epoch 24/50, Validation Accuracy: 0.9198, Train_loss 0.2890 ,Eval_Loss: 0.2490



  0%|          | 0/817 [00:00<?, ?it/s]

Epoch 25/50, Validation Accuracy: 0.9202, Train_loss 0.2857 ,Eval_Loss: 0.2490



  0%|          | 0/817 [00:00<?, ?it/s]

Epoch 26/50, Validation Accuracy: 0.9206, Train_loss 0.2819 ,Eval_Loss: 0.2465



  0%|          | 0/817 [00:00<?, ?it/s]

Epoch 27/50, Validation Accuracy: 0.9208, Train_loss 0.2789 ,Eval_Loss: 0.2474



  0%|          | 0/817 [00:00<?, ?it/s]

Epoch 28/50, Validation Accuracy: 0.9214, Train_loss 0.2764 ,Eval_Loss: 0.2458



  0%|          | 0/817 [00:00<?, ?it/s]

Epoch 29/50, Validation Accuracy: 0.9219, Train_loss 0.2736 ,Eval_Loss: 0.2438



  0%|          | 0/817 [00:00<?, ?it/s]

Epoch 30/50, Validation Accuracy: 0.9225, Train_loss 0.2712 ,Eval_Loss: 0.2427



  0%|          | 0/817 [00:00<?, ?it/s]

Epoch 31/50, Validation Accuracy: 0.9226, Train_loss 0.2683 ,Eval_Loss: 0.2420



  0%|          | 0/817 [00:00<?, ?it/s]

Epoch 32/50, Validation Accuracy: 0.9225, Train_loss 0.2655 ,Eval_Loss: 0.2419



  0%|          | 0/817 [00:00<?, ?it/s]

Epoch 33/50, Validation Accuracy: 0.9227, Train_loss 0.2628 ,Eval_Loss: 0.2414



  0%|          | 0/817 [00:00<?, ?it/s]

Epoch 34/50, Validation Accuracy: 0.9231, Train_loss 0.2609 ,Eval_Loss: 0.2397



  0%|          | 0/817 [00:00<?, ?it/s]

Epoch 35/50, Validation Accuracy: 0.9232, Train_loss 0.2581 ,Eval_Loss: 0.2396



  0%|          | 0/817 [00:00<?, ?it/s]

Epoch 36/50, Validation Accuracy: 0.9235, Train_loss 0.2557 ,Eval_Loss: 0.2387



  0%|          | 0/817 [00:00<?, ?it/s]

Epoch 37/50, Validation Accuracy: 0.9241, Train_loss 0.2537 ,Eval_Loss: 0.2378



  0%|          | 0/817 [00:00<?, ?it/s]

Epoch 38/50, Validation Accuracy: 0.9243, Train_loss 0.2524 ,Eval_Loss: 0.2371



  0%|          | 0/817 [00:00<?, ?it/s]

Epoch 39/50, Validation Accuracy: 0.9243, Train_loss 0.2499 ,Eval_Loss: 0.2373



  0%|          | 0/817 [00:00<?, ?it/s]

Epoch 40/50, Validation Accuracy: 0.9245, Train_loss 0.2485 ,Eval_Loss: 0.2364



  0%|          | 0/817 [00:00<?, ?it/s]

Epoch 41/50, Validation Accuracy: 0.9246, Train_loss 0.2472 ,Eval_Loss: 0.2362



  0%|          | 0/817 [00:00<?, ?it/s]

Epoch 42/50, Validation Accuracy: 0.9248, Train_loss 0.2458 ,Eval_Loss: 0.2358



  0%|          | 0/817 [00:00<?, ?it/s]

Epoch 43/50, Validation Accuracy: 0.9251, Train_loss 0.2443 ,Eval_Loss: 0.2356



  0%|          | 0/817 [00:00<?, ?it/s]

Epoch 44/50, Validation Accuracy: 0.9248, Train_loss 0.2429 ,Eval_Loss: 0.2355



  0%|          | 0/817 [00:00<?, ?it/s]

Epoch 45/50, Validation Accuracy: 0.9249, Train_loss 0.2425 ,Eval_Loss: 0.2353



  0%|          | 0/817 [00:00<?, ?it/s]

Epoch 46/50, Validation Accuracy: 0.9250, Train_loss 0.2411 ,Eval_Loss: 0.2352



  0%|          | 0/817 [00:00<?, ?it/s]

Epoch 47/50, Validation Accuracy: 0.9251, Train_loss 0.2403 ,Eval_Loss: 0.2352



  0%|          | 0/817 [00:00<?, ?it/s]

Epoch 48/50, Validation Accuracy: 0.9252, Train_loss 0.2403 ,Eval_Loss: 0.2349



  0%|          | 0/817 [00:00<?, ?it/s]

Epoch 49/50, Validation Accuracy: 0.9252, Train_loss 0.2398 ,Eval_Loss: 0.2351



  0%|          | 0/817 [00:00<?, ?it/s]

Epoch 50/50, Validation Accuracy: 0.9252, Train_loss 0.2400 ,Eval_Loss: 0.2351

Best validation accuracy: 0.9252
