In [None]:
import os
import sys


import torch
import torch.nn as nn
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import selfies as sf
from train_selfies import MS2Gen, MSDataModule
import pytorch_lightning as pl
from tqdm import tqdm as tqdm
from pyteomics import mass
from rdkit import Chem
#import rdkit molwt
from rdkit.Chem import Descriptors
from rdkit.Chem import AllChem
from msdatasets import MSDataset

config = {
    "dataset_path": "datasets/",
    "batch_size": 2,
}

path = "" #datasets path

# set cuda visible devices


os.environ["CUDA_VISIBLE_DEVICES"] = "2"

In [None]:
pl.seed_everything(42)

# Load dataset
dataset = MSDataModule(
    config["dataset_path"],
    config["batch_size"],
)

# Set up trainer and fit
trainer = pl.Trainer(
    accelerator="gpu",
    devices=[0],
    precision="16-mixed",
    sync_batchnorm=True,
    use_distributed_sampler=True,
    max_epochs=30,
    gradient_clip_val=1.0,
    accumulate_grad_batches=4,
)


model = MS2Gen.load_from_checkpoint(f'./trained_models/ms2_paper_generative_best.ckpt')
model.cuda()
model.temperature = 1.0
model.num_sequence = 100

pl.seed_everything(42)
dataset.setup(stage='test')
test_loader = dataset.test_dataloader()
val_loader = dataset.val_dataloader()

model.eval()

In [29]:
# Dictionary of common adducts and their corresponding mass adjustments
adducts_mass_adjustments = {
    "[M]+": 0,                                        # No mass adjustment
    "[M+H]+": -mass.calculate_mass(formula='H'),       # Subtract the mass of a proton
    "[M+Na]+": -mass.calculate_mass(formula='Na'),     # Subtract the mass of a sodium ion
    "[M+K]+": -mass.calculate_mass(formula='K'),       # Subtract the mass of a potassium ion
    "[M+NH4]+": -mass.calculate_mass(formula='NH4'),   # Subtract the mass of an ammonium ion
    "[M+2H]2+": -mass.calculate_mass(formula='H'),     # Subtract the mass of a proton, note this applies to the singly charged m/z
    "[M+H+Na]2+": -(mass.calculate_mass(formula='H') + mass.calculate_mass(formula='Na')) / 2, # Average mass shift for hybrid adduct
    "[M+2Na]2+": -mass.calculate_mass(formula='Na'),   # Subtract the mass of a sodium ion, note for singly charged m/z
    "[M-H]-": mass.calculate_mass(formula='H'),        # Add the mass of a proton (as it is removed)
    "[M+Cl]-": -mass.calculate_mass(formula='Cl'),     # Subtract the mass of a chloride ion
    "[M+FA]-": -mass.calculate_mass(formula='CHO2'),   # Subtract the mass of a formate ion
    "[M+Br]-": -mass.calculate_mass(formula='Br')      # Subtract the mass of a bromide ion
}

def evaluate_data(loader, adducts_mass_adjustments, top_k=[1, 10, 100], do_reorder=False):
    pl.seed_everything(42)
    # for casmi dataset, get the predictions
    preds = []
    targets = []
    precursor_list = []
    i = 0
    
    #get the predictions
    for batch in tqdm(loader):
        mz, inty, precursormz, selfies = (
            batch["mz"],
            batch["inty"],
            batch["precursormz"],
            batch["selfies"],
        )
        mz, inty, precursormz = mz.cuda(), inty.cuda(), precursormz.cuda()
        bert_inputs = model.collator(model.tokenizer(selfies))
        bert_inputs = {
            k: v.cuda() for k, v in bert_inputs.items()
        }  # move to device since default dict is on cpu
        logits, hidden_states, z, reconstructed = model(
            mz, inty, precursormz, bert_inputs, selfies, mode="val"
        )
        
        
        # decode all of the selfies, predicted and targets
        reconstructed = [sf.decoder(x) for x in reconstructed]
        selfies = [sf.decoder(x) for x in selfies]
        
        # convert smiles to canonical
        reconstructed_list = []
        for x in reconstructed:
            try:
                mol = Chem.MolFromSmiles(x)
                reconstructed_list.append(Chem.MolToSmiles(mol)) # convert to canonical smiles
            except:
                reconstructed_list.append(x)
        selfies_list = []
        for x in selfies:
            try:
                mol = Chem.MolFromSmiles(x)
                selfies_list.append(Chem.MolToSmiles(mol)) # convert to canonical smiles
            except:
                selfies_list.append(x)
        
        preds.append(reconstructed_list)
        targets.extend(selfies_list)
        precursor_list.append(precursormz.detach().cpu().numpy())
    
    # fix to weird bug that happends with append if odd values
    if len(precursor_list[0] > 1):
        precursor_list_replacement = []
        for i in range(len(precursor_list)):
            for j in range(len(precursor_list[i])):
                precursor_list_replacement.append(precursor_list[i][j])
       
       
        precursor_list = precursor_list_replacement
    
    # move the predictions all to the same file 
    batch_preds = np.empty((len(precursor_list), 100), dtype=object)

    total_chunks = sum((len(batch) + 99) // 100 for batch in preds)  # This calculates the ceiling of len(batch)/100 for each batch
    batch_preds = np.empty((total_chunks, 100), dtype=object)
    j = 0
    for batch in tqdm(preds):
        # break up batch into groups of 100 since 100 predictions were made at a time
        batch = np.array(batch)
        for i in range(0, len(batch), 100):
            batch_preds[j] = batch[i:i+100]
            j += 1

    
    
    precursor_list = np.asarray(precursor_list).flatten()
    targets = np.array(targets).flatten()
    
    results = {}
    
    top_1_scores = []
    top_10_scores = []
    top_100_scores = []
    exact_1 = []
    exact_10 = []
    exact_100 = []
    
    similarities_list = []
    reorder_list = []
    diffs_list = []
    
    #  get results one spectrum at a time
    for i in tqdm(range(len(precursor_list))):
        # create a list of possible precursors + mass adjustments
        possible_precursors = np.array([precursor_list[i] + adducts_mass_adjustments[adduct] for adduct in adducts_mass_adjustments])

        # get the list of masses for the predictions 
        pred_masses = np.empty(shape=(len(batch_preds[i])))
        for j in range(len(batch_preds[i])):
            try:
                mol = Chem.MolFromSmiles(batch_preds[i][j])
                pred_masses[j] = (Descriptors.ExactMolWt(mol))
            except:
                pred_masses[j] = 0
        
        # for each predicted mass, calculate the smallest difference between the predicted mass and the possible precursors
        diffs = np.empty(shape=(len(pred_masses)))
        for z in range(len(pred_masses)):
            diffs[z] = np.min(np.abs(possible_precursors - pred_masses[z]))
        # # reorder the predictions based on the smallest difference
        
        if do_reorder:
            reorder = np.argsort(diffs) 
            
            reorder_list.append(reorder)
            diffs = diffs[reorder]
            diffs_list.append(diffs)
            
            batch_preds[i] = batch_preds[i][reorder]
        else:
            diffs_list.append(diffs)
            reorder_list.append(np.arange(len(batch_preds[i])))
        
        for k in top_k:  
            top_fingerprints = []
            exact_matches = []
            for x in batch_preds[i][:k]:
                try:
                    mol = Chem.MolFromSmiles(x)
                    top_fingerprints.append(AllChem.GetMorganFingerprintAsBitVect(mol, 2, nBits=1024))
                except:
                    pass
            target_fingerprint = AllChem.GetMorganFingerprintAsBitVect(Chem.MolFromSmiles(targets[i]), 2, nBits=1024)
            
            for x in batch_preds[i][:k]:
                if x == targets[i]:
                    exact_matches.append(1)
                else:
                    exact_matches.append(0)
            
            # calculate the tanimoto similarity between the top 10 predictions and the target
            similarities = np.array([Chem.DataStructs.TanimotoSimilarity(target_fingerprint, x) for x in top_fingerprints])
            if similarities.size == 0:
                similarities = np.array([0])          
            similarities_list.append(similarities)
                
            if k == 1:
                top_1_scores.append(np.max(similarities))
                # if any exact matches == 1, append 1 to exact_1
                if 1 in exact_matches:
                    exact_1.append(1)
                else:
                    exact_1.append(0)
            elif k == 10:  
                top_10_scores.append(np.max(similarities))
                if 1 in exact_matches:
                    exact_10.append(1)
                else:
                    exact_10.append(0)
            elif k == 100:
                top_100_scores.append(np.max(similarities))
                if 1 in exact_matches:
                    exact_100.append(1)
                else:
                    exact_100.append(0)
    
    results['top_1_mean'] = np.mean(top_1_scores)
    results['top_10_mean'] = np.mean(top_10_scores)
    results['top_100_mean'] = np.mean(top_100_scores)
    results['top_1_std'] = np.std(top_1_scores)
    results['top_10_std'] = np.std(top_10_scores)
    results['top_100_std'] = np.std(top_100_scores)
    results['top_100_median'] = np.median(top_100_scores)
    results['top_100_over40'] = len([x for x in top_100_scores if x > 0.4]) / len(top_100_scores)
    results['top_100_over65'] = len([x for x in top_100_scores if x > 0.65]) / len(top_100_scores)
    results['top_100_over95'] = len([x for x in top_100_scores if x > 0.95]) / len(top_100_scores)
    results['top_100_100'] = len([x for x in top_100_scores if x >= 1.]) / len(top_100_scores)
    results['top_10_median'] = np.median(top_10_scores)
    results['top_10_over40'] = len([x for x in top_10_scores if x > 0.4]) / len(top_10_scores)
    results['top_10_over65'] = len([x for x in top_10_scores if x > 0.65]) / len(top_10_scores)
    results['top_10_over95'] = len([x for x in top_10_scores if x > 0.95]) / len(top_10_scores)
    results['top_10_100'] = len([x for x in top_10_scores if x >= 1.]) / len(top_10_scores)
    results['top_1_median'] = np.median(top_1_scores)
    results['top_1_over40'] = len([x for x in top_1_scores if x > 0.4]) / len(top_1_scores)
    results['top_1_over65'] = len([x for x in top_1_scores if x > 0.65]) / len(top_1_scores)
    results['top_1_over95'] = len([x for x in top_1_scores if x > 0.95]) / len(top_1_scores)
    results['top_1_100'] = len([x for x in top_1_scores if x >= 1.]) / len(top_1_scores)
    results['exact_1'] = np.mean(exact_1)
    results['exact_10'] = np.mean(exact_10)
    results['exact_100'] = np.mean(exact_100)
    
 
    return results, batch_preds, targets, precursor_list, similarities_list, reorder_list, diffs_list

In [4]:
# write a function to save all of these files
def save_gen_results(path, model, dataset, results, reorder = False):
    results_dict, batch_preds, targets, precursors, similarities, reorder, diffs = results
    results_df = pd.DataFrame(results_dict, index=[0])
    batch_preds_df = pd.DataFrame(batch_preds)
    targets_df = pd.DataFrame(targets)
    precursors_df = pd.DataFrame(precursors)
    similarities_df = pd.DataFrame(similarities)
    reorder_df = pd.DataFrame(reorder)
    diffs_df = pd.DataFrame(diffs)
    
    if reorder:
        reorder_string = 'rerank'
    else:
        reorder_string = 'no_rerank'
    
    results_df.to_csv(f'{path}/{model}/{dataset}_{reorder_string}_results.csv', index=False)
    batch_preds_df.to_csv(f'{path}/{model}/{dataset}_{reorder_string}_batch_preds.csv', index=False)
    targets_df.to_csv(f'{path}/{model}/{dataset}_{reorder_string}_targets.csv', index=False)
    precursors_df.to_csv(f'{path}/{model}/{dataset}_{reorder_string}_precursors.csv', index=False)
    similarities_df.to_csv(f'{path}/{model}/{dataset}_{reorder_string}_similarities.csv', index=False)
    reorder_df.to_csv(f'{path}/{model}/{dataset}_{reorder_string}_reorder.csv', index=False)
    diffs_df.to_csv(f'{path}/{model}/{dataset}_{reorder_string}.csv', index=False)

In [None]:
reorder_list = [True]
model_name = ''

for i in range(len(reorder_list)):
    # for each loader in the test set, get the predictions
    test = test_loader[0]
    test = iter(test)
    casmi = test_loader[1]
    casmi = iter(casmi)
    disjoint = test_loader[2]
    disjoint = iter(disjoint)
    casmi_2017 = test_loader[3]
    casmi_2017 = iter(casmi_2017)

    casmi2017_results = evaluate_data(casmi_2017, adducts_mass_adjustments, top_k=[1, 10, 100], do_reorder=True)
    save_gen_results('../results/gen/', f'{model_name}', 'casmi2017', casmi2017_results, reorder=reorder_list[i])

    casmi_results = evaluate_data(casmi, adducts_mass_adjustments, top_k=[1, 10, 100], do_reorder=True)
    save_gen_results('../results/gen/', f'{model_name}', 'casmi', casmi_results, reorder=reorder_list[i])

    test_results = evaluate_data(test, adducts_mass_adjustments, top_k=[1, 10, 100], do_reorder=True)
    save_gen_results('../results/gen/', f'{model_name}', 'unknown', test_results, reorder=reorder_list[i])

### find top temperature 

In [None]:
temperatures = [0.5, 1.0, 1.25, 1.5, 1.75, 2.0, 2.25, 2.5, 2.75, 3.0]
coeff_keys = ['top_100_median', 'top_10_median', 'top_1_median']
results = {}

for t in temperatures:
    model = MS2Gen.load_from_checkpoint(f'./trained_models/{model_name}.ckpt')
    model.cuda()
    model.temperature = t

    # load the the unknown validation dataset
    valid_set = MSDataset(
                f"{path}/val/no_casmi_val.zarr",
                mode="gen",
                smiles_path=f"{path}/smiles/no_casmi_val_smiles.csv",
            )

    # seed numpy
    np.random.seed(40)

    # get a subset of 250 samples from valid set
    valid_set = torch.utils.data.Subset(valid_set, np.random.choice(len(valid_set), 250))
    
    # create a dataloader for the subset
    valid_loader = torch.utils.data.DataLoader(valid_set, batch_size=1, num_workers=4)

    valid_results = evaluate_data(valid_loader, adducts_mass_adjustments, top_k=[1, 10, 100])

    # add median results of top 100, 10, and 1
    results_coeff = 0
    for key in coeff_keys:
        results_coeff += valid_results[0][key]
    
    results[t] = results_coeff
    

In [6]:
# make results into df then save
temperature_sweep_results = pd.DataFrame.from_dict(results, orient='index')
temperature_sweep_results.to_csv(f'../results/gen/{model_name}/temperature_sweep_results.csv')