In [1]:
import pandas as pd
import numpy as np
import os
import random
from sklearn.model_selection import train_test_split
from sklearn.metrics import f1_score, roc_auc_score, average_precision_score
import gc
from tqdm import tqdm
import time
from collections import defaultdict
import copy
import matplotlib.pyplot as plt

import torch
from torch.utils.data import Dataset
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
from torch.utils.data import DataLoader, Dataset
from torch.cuda.amp import autocast, GradScaler

import pdb

import wandb

# Defining global parameters

In [2]:
def get_available_device():
    if torch.cuda.is_available():
        return "cuda"
    elif torch.backends.mps.is_available():
        return "mps"
    else:
        return "cpu"
    
CONFIG = {
    "seed": 42,
    "epochs": 12, # 42, ~MAX 20 hours of training
    "train_batch_size": 16,
    "valid_batch_size": 64,
    "learning_rate": 5e-5,
    "scheduler": 'CosineAnnealingLR',
    "min_lr": 5e-7,
    "T_max": 12,
    "weight_decay": 1e-6,
    "fold" : 0,
    "n_fold": 5,
    "n_accumulate": 1,
    "device": get_available_device(),
    'labels': ['unable_to_assess', 'not_close_match','close_match', 'near_exact_match', 'exact_match'], # 'unable_to_assess', 'not_close_match','close_match', 'near_exact_match', 'exact_match'
    'FP': 'molformer', # 'fp', 'molformer', 'ECFP', 'grover'
}

def seed_everything(seed=42):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

seed_everything(seed=CONFIG['seed'])

# Read the data that has duplicates REMOVED
## Validation set of 20% of the size

In [3]:
train_df = pd.read_pickle('../data/train_split_remove_duplicates_all_embeddings.pkl')
valid_df = pd.read_pickle('../data/valid_split_remove_duplicates_all_embeddings.pkl')

In [4]:
if CONFIG['FP'] == 'ECFP':
    CONFIG['input_size'] = 2048
elif CONFIG['FP'] == 'molformer':
    CONFIG['input_size'] = 768
elif CONFIG['FP'] == 'fp':
    CONFIG['input_size'] = 2215
elif CONFIG['FP'] == 'grover':
    CONFIG['input_size'] = 5000


## Wrapping in Pytorch Dataset

In [8]:
class EnvedaDataset(Dataset):
    def __init__(self, CONFIG, dataframe, labels = []):
        """
        Args:
            dataframe (pd.DataFrame): A DataFrame containing 'ground_truth_embeddings', 
                                       'predicted_embeddings', and output columns.
        """
        self.dataframe = dataframe
        # pdb.set_trace()
        if CONFIG['FP'] == 'molformer':
            # Convert Molformer embeddings to tensors
            self.ground_truth_embeddings = torch.tensor(dataframe['ground_truth_embeddings'].tolist(), dtype=torch.float32)
            self.predicted_embeddings = torch.tensor(dataframe['predicted_embeddings'].tolist(), dtype=torch.float32)
    
        # fingerprints
        elif CONFIG['FP'] == 'ECFP':
            self.ground_truth_embeddings = torch.tensor(dataframe['ground_truth_ECFP'].tolist(), dtype=torch.float32)
            self.predicted_embeddings = torch.tensor(dataframe['predicted_ECFP'].tolist(), dtype=torch.float32)
        elif CONFIG['FP'] == 'fp':
            self.ground_truth_embeddings = dataframe['ground_truth_fp'].tolist()
            self.predicted_embeddings = dataframe['predicted_fp'].tolist()

        elif CONFIG['FP'] == 'grover':
            self.ground_truth_embeddings = torch.tensor(dataframe['ground_truth_grover_fp'].tolist(), dtype=torch.float32)
            self.predicted_embeddings = torch.tensor(dataframe['predicted_grover_fp'].tolist(), dtype=torch.float32)
    
        
        self.labels_text = labels
        # Convert labels to tensor
        self.labels = torch.tensor(dataframe[labels].values, dtype=torch.float32)

    def __len__(self):
        """Returns the total number of samples."""
        return len(self.dataframe)

    def __getitem__(self, idx):
        """Generates one sample of data."""
        if CONFIG['FP'] == 'molformer':
            return self.ground_truth_embeddings[idx].squeeze(0), self.predicted_embeddings[idx].squeeze(0), self.labels[idx]
        elif CONFIG['FP'] == 'fp':
            return self.ground_truth_embeddings.iloc[idx].squeeze(0), self.predicted_embeddings.iloc[idx].squeeze(0), self.labels[idx]
        elif CONFIG['FP'] == 'ECFP':
            return self.ground_truth_embeddings[idx], self.predicted_embeddings[idx], self.labels[idx]
        elif CONFIG['FP'] == 'grover':
            return self.ground_truth_embeddings[idx].squeeze(0), self.predicted_embeddings[idx].squeeze(0), self.labels[idx]


In [9]:
CONFIG_ECFP = {
    "seed": 42,
    'labels': ['unable_to_assess', 'not_close_match','close_match', 'near_exact_match', 'exact_match'], # 'unable_to_assess', 'not_close_match','close_match', 'near_exact_match', 'exact_match'
    "train_batch_size": 16,
    "valid_batch_size": 64,
    'FP': 'ECFP', # 'fp', 'molformer', 'ECFP', 'grover'
    'input_size': 2048
}

CONFIG_molformer = {
    "seed": 42,
    'labels': ['unable_to_assess', 'not_close_match','close_match', 'near_exact_match', 'exact_match'], # 'unable_to_assess', 'not_close_match','close_match', 'near_exact_match', 'exact_match'
    "train_batch_size": 16,
    "valid_batch_size": 64,
    'FP': 'molformer', # 'fp', 'molformer', 'ECFP', 'grover'
    'input_size': 768
}

CONFIG_fp = {
    "seed": 42,
    'labels': ['unable_to_assess', 'not_close_match','close_match', 'near_exact_match', 'exact_match'], # 'unable_to_assess', 'not_close_match','close_match', 'near_exact_match', 'exact_match'
    "train_batch_size": 16,
    "valid_batch_size": 64,
    'FP': 'fp',  # 'fp', 'molformer', 'ECFP', 'grover'
    'input_size': 2215
}

CONFIG_grover = {
    "seed": 42,
    'labels': ['unable_to_assess', 'not_close_match','close_match', 'near_exact_match', 'exact_match'], # 'unable_to_assess', 'not_close_match','close_match', 'near_exact_match', 'exact_match'
    "train_batch_size": 16,
    "valid_batch_size": 64,
    'FP': 'grover', # 'fp', 'molformer', 'ECFP', 'grover'
    'input_size': 5000
}

In [10]:
# Initialize trainsets and validsets for each configuration
trainset_ECFP = EnvedaDataset(CONFIG_ECFP, dataframe=train_df, labels=CONFIG_ECFP['labels'])
validset_ECFP = EnvedaDataset(CONFIG_ECFP, dataframe=valid_df, labels=CONFIG_ECFP['labels'])

trainset_molformer = EnvedaDataset(CONFIG_molformer, dataframe=train_df, labels=CONFIG_molformer['labels'])
validset_molformer = EnvedaDataset(CONFIG_molformer, dataframe=valid_df, labels=CONFIG_molformer['labels'])

trainset_fp = EnvedaDataset(CONFIG_fp, dataframe=train_df, labels=CONFIG_fp['labels'])
validset_fp = EnvedaDataset(CONFIG_fp, dataframe=valid_df, labels=CONFIG_fp['labels'])

trainset_grover = EnvedaDataset(CONFIG_grover, dataframe=train_df, labels=CONFIG_grover['labels'])
validset_grover = EnvedaDataset(CONFIG_grover, dataframe=valid_df, labels=CONFIG_grover['labels'])

## Wrapping in Pytorch DataLoader

In [11]:
# trainloader, validloader = DataLoader(trainset, batch_size=CONFIG['train_batch_size']), \
#                            DataLoader(validset, batch_size=CONFIG['valid_batch_size'])


# Initialize DataLoaders for each configuration
trainloader_ECFP = DataLoader(trainset_ECFP, batch_size=CONFIG_ECFP['train_batch_size'])
validloader_ECFP = DataLoader(validset_ECFP, batch_size=CONFIG_ECFP['valid_batch_size'])

trainloader_molformer = DataLoader(trainset_molformer, batch_size=CONFIG_molformer['train_batch_size'])
validloader_molformer = DataLoader(validset_molformer, batch_size=CONFIG_molformer['valid_batch_size'])

trainloader_fp = DataLoader(trainset_fp, batch_size=CONFIG_fp['train_batch_size'])
validloader_fp = DataLoader(validset_fp, batch_size=CONFIG_fp['valid_batch_size'])

trainloader_grover = DataLoader(trainset_grover, batch_size=CONFIG_grover['train_batch_size'])
validloader_grover = DataLoader(validset_grover, batch_size=CONFIG_grover['valid_batch_size'])

In [14]:
gt, pred, labels = next(iter(validloader_fp))
print(gt.shape, pred.shape, labels.shape)


torch.Size([64, 2215]) torch.Size([64, 2215]) torch.Size([64, 5])


# Defining neural network

In [15]:
# Define the Siamese Network
class SiameseNetwork(nn.Module):
    def __init__(self, input_dim=768, output_dim=4):
        super(SiameseNetwork, self).__init__()

        self.fc = nn.Sequential(
        nn.Linear(input_dim, 1024),
        nn.ReLU(inplace=True),
        nn.Dropout(0.2),
        nn.Linear(1024, 512),
        nn.ReLU(inplace=True),
        nn.Dropout(0.2),
        nn.Linear(512, 256),
        nn.ReLU(inplace=True),
        nn.Dropout(0.2),
        nn.Linear(256, 128),
        nn.ReLU(inplace=True),
        nn.Dropout(0.2))

        self.fc5 = nn.Linear(128, output_dim)

        #self.fc5 = nn.Linear(16 * 2, output_dim)  # Output layer for similarity judgements

    def forward_one(self, x):
        return self.fc(x)

    def forward(self, input1, input2):
        out1 = self.forward_one(input1)
        out2 = self.forward_one(input2)
        # pdb.set_trace()
        # Combine both outputs by concatenation
        # combined = torch.concat((out1, out2), dim=1) # concatenate embeddings
        combined = torch.sub(out1, out2)  # maybe torch.abs(out1 - out2)
        output = self.fc5(combined)                  # Outputs raw logits
        return output
    
# model = SiameseNetwork()
# model(gt, pred)

# Validation regime

In [16]:

@torch.no_grad()
def valid_one_epoch(model, dataloader, criterion, epoch):
    model.eval()
    
    dataset_size = 0
    running_loss = 0.0
    running_auroc = 0.0
    running_f1 = 0.0 
    
    bar = tqdm(enumerate(dataloader), total=len(dataloader))
    for step, data in bar:
        
        gt, preds, targets = data
        gt, preds, targets = gt.to(CONFIG['device']), preds.to(CONFIG['device']), targets.to(CONFIG['device'])
        batch_size = gt.shape[0]
        
        with autocast():
            outputs = model(gt, preds)
            loss = criterion(outputs, targets)
            loss = loss / CONFIG['n_accumulate']
        
        probabilities = torch.softmax(outputs, dim=1).detach().cpu().numpy() if outputs.shape[1] > 1 else torch.sigmoid(outputs).detach().cpu().numpy()
        preds = np.eye(outputs.shape[1])[np.argmax(probabilities, axis=1)] if outputs.shape[1] > 1 else (probabilities > 0.5).astype(float)
        
        auroc = average_precision_score(targets.cpu().numpy(), probabilities, average='weighted')
        f1 = f1_score(targets.cpu().numpy(), preds, average='weighted')
        
        running_loss += (loss.item() * batch_size)
        running_auroc  += (auroc * batch_size)
        running_f1 += (f1 * batch_size)
        dataset_size += batch_size
        
        epoch_loss = running_loss / dataset_size
        epoch_auroc = running_auroc / dataset_size
        epoch_f1 = running_f1 / dataset_size
        
        bar.set_postfix(Epoch=epoch, Valid_Loss=epoch_loss, Valid_Auroc=epoch_auroc, 
                        Valid_F1=epoch_f1,
                        )   
    gc.collect()
    
    return epoch_loss, epoch_auroc, epoch_f1

# Initialize pretrained models

In [17]:
# Initialize models for each configuration
model_ECFP = SiameseNetwork(input_dim=CONFIG_ECFP['input_size'], output_dim=len(CONFIG_ECFP['labels'])).to(CONFIG['device'])
model_molformer = SiameseNetwork(input_dim=CONFIG_molformer['input_size'], output_dim=len(CONFIG_molformer['labels'])).to(CONFIG['device'])
model_fp = SiameseNetwork(input_dim=CONFIG_fp['input_size'], output_dim=len(CONFIG_fp['labels'])).to(CONFIG['device'])
model_grover = SiameseNetwork(input_dim=CONFIG_grover['input_size'], output_dim=len(CONFIG_grover['labels'])).to(CONFIG['device'])

In [18]:
# Load the state_dict for each model
model_ECFP.load_state_dict(torch.load('../results/best_F1_model_siamese_multiclass_sub_ECFP.bin', map_location=CONFIG['device']))
model_molformer.load_state_dict(torch.load('../results/best_F1_model_siamese_multiclass_sub_molformer.bin', map_location=CONFIG['device']))
model_fp.load_state_dict(torch.load('../results/best_F1_model_siamese_multiclass_sub_remove_duplicate.bin', map_location=CONFIG['device']))
model_grover.load_state_dict(torch.load('../results/best_F1_model_siamese_multiclass_sub_grover.bin', map_location=CONFIG['device']))

<All keys matched successfully>

## Define global criterion for evaluating all models

In [19]:
class_distribution = {'unable_to_assess': 10,
                      'not_close_match': 109, 
                      'close_match': 125, 
                      'near_exact_match': 66, 
                      'exact_match': 12,
                      'good_enough_for_prioritization': 224}

# Calculate class weights
total_samples = sum(class_distribution.values())
class_weights = {label: total_samples / (len(class_distribution) * count) for label, count in class_distribution.items()}

# Convert weights to a tensor
weights = torch.tensor([class_weights['unable_to_assess'],
                        class_weights['not_close_match'],
                        class_weights['close_match'],
                        class_weights['near_exact_match'],
                        class_weights['exact_match']], dtype=torch.float32).to(CONFIG['device'])

# find the indeces of keys in class_distribution in CONFIG['labels']
weights = torch.tensor([class_weights[c] for c in CONFIG['labels']], dtype=torch.float32).to(CONFIG['device'])

# Modify the loss function in your training loop
criterion = nn.CrossEntropyLoss(weight=weights)



# Performance of individual models

In [40]:
# Evaluate each model individually
val_epoch_loss_ECFP, val_epoch_auroc_ECFP, val_epoch_f1_ECFP = valid_one_epoch(
    model=model_ECFP,
    dataloader=validloader_ECFP,
    criterion=criterion,
    epoch=1
)



100%|██████████| 2/2 [00:00<00:00, 78.20it/s, Epoch=1, Valid_Auroc=0.358, Valid_F1=0.297, Valid_Loss=3.34]


In [41]:
val_epoch_loss_molformer, val_epoch_auroc_molformer, val_epoch_f1_molformer = valid_one_epoch(
    model=model_molformer,
    dataloader=validloader_molformer,
    criterion=criterion,
    epoch=1
)

100%|██████████| 2/2 [00:00<00:00, 93.68it/s, Epoch=1, Valid_Auroc=0.394, Valid_F1=0.403, Valid_Loss=2.96]


In [42]:
val_epoch_loss_fp, val_epoch_auroc_fp, val_epoch_f1_fp = valid_one_epoch(
    model=model_fp,
    dataloader=validloader_fp,
    criterion=criterion,
    epoch=1
)

100%|██████████| 2/2 [00:00<00:00, 97.74it/s, Epoch=1, Valid_Auroc=0.359, Valid_F1=0.352, Valid_Loss=2.16]


In [43]:

val_epoch_loss_grover, val_epoch_auroc_grover, val_epoch_f1_grover = valid_one_epoch(
    model=model_grover,
    dataloader=validloader_grover,
    criterion=criterion,
    epoch=1
)

100%|██████████| 2/2 [00:00<00:00, 44.20it/s, Epoch=1, Valid_Auroc=0.425, Valid_F1=0.328, Valid_Loss=1.84]



# Ensemble of Siamese network with different feature encoders

In [45]:
@torch.no_grad()
def valid_ensemble(models, dataloaders, criterion, epoch):
    # Set models to evaluation mode
    for model in models:
        model.eval()
    
    dataset_size = 0
    running_loss = 0.0
    running_auroc = 0.0
    running_f1 = 0.0 

    # Initialize a list to accumulate model outputs
    all_outputs = []

    # Initialize a tensor for targets (assuming they are consistent across all dataloaders)
    targets_list = []
    model_list_names = ['ecfp', 'molformer', 'fp', 'grover']
    # Iterate over each dataloader and evaluate the corresponding model
    for i, (dataloader, model) in enumerate(zip(dataloaders, models)):
        bar = tqdm(enumerate(dataloader), total=len(dataloader), desc=f'Model {model_list_names[i]}')
        
        for step, data in bar:
            gt, preds, targets = data  # Get targets from each batch
            gt, preds, targets = gt.to(CONFIG['device']), preds.to(CONFIG['device']), targets.to(CONFIG['device'])
            
            with autocast():
                outputs = model(gt, preds)
                all_outputs.append(outputs)  # Store outputs on CPU for later averaging
            
            if i == 0:  # Only collect targets from the first dataloader
                targets_list.append(targets)

    # After processing all batches for all models, concatenate the outputs
    all_outputs_tensor = torch.cat(all_outputs, dim=0)  # Shape will be (total_samples, 5)

    # Average the outputs from all models
    ensemble_outputs = torch.mean(all_outputs_tensor.view(len(models), -1, 5), dim=0)  # Shape should be (88, 5)

    # Concatenate targets into a single tensor (assuming they are consistent)
    targets_tensor = torch.cat(targets_list, dim=0)  # Shape should be (88, 5)

    # Calculate loss and metrics using the ensemble outputs
    loss = criterion(ensemble_outputs, targets_tensor)

    probabilities = torch.softmax(ensemble_outputs, dim=1).detach().cpu().numpy() if ensemble_outputs.shape[1] > 1 else torch.sigmoid(ensemble_outputs).detach().cpu().numpy()
    preds = np.eye(ensemble_outputs.shape[1])[np.argmax(probabilities, axis=1)] if ensemble_outputs.shape[1] > 1 else (probabilities > 0.5).astype(float)

    auroc = average_precision_score(targets_tensor.cpu().numpy(), probabilities, average='weighted')
    f1 = f1_score(targets_tensor.cpu().numpy(), preds, average='weighted')

    running_loss += loss.item() * len(targets_tensor)
    running_auroc += auroc * len(targets_tensor)
    running_f1 += f1 * len(targets_tensor)
    dataset_size += len(targets_tensor)

    epoch_loss = running_loss / dataset_size
    epoch_auroc = running_auroc / dataset_size
    epoch_f1 = running_f1 / dataset_size

    print(f'Epoch: {epoch}, Valid Loss: {epoch_loss}, Valid AUROC: {epoch_auroc}, Valid F1: {epoch_f1}')
    
    gc.collect()
    
    return epoch_loss, epoch_auroc, epoch_f1

# Example usage:
models = [model_ECFP, model_molformer, model_fp, model_grover]
dataloaders = [validloader_ECFP, validloader_molformer, validloader_fp, validloader_grover]

# Call the ensemble evaluation function for a specific epoch
loss, auroc, f1 = valid_ensemble(models, dataloaders, criterion, 1)


Model ecfp: 100%|██████████| 2/2 [00:00<00:00, 179.53it/s]
Model molformer: 100%|██████████| 2/2 [00:00<00:00, 385.65it/s]
Model fp: 100%|██████████| 2/2 [00:00<00:00, 398.40it/s]
Model grover: 100%|██████████| 2/2 [00:00<00:00, 67.66it/s]

Epoch: 1, Valid Loss: 2.067840814590454, Valid AUROC: 0.366906338762862, Valid F1: 0.3713195435584391



