In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
## Load previous results of tanimoto similarities and simba results
## based on the tanimoto similarities, retrieve the spectra that corresponds to the best candidate for each query
## find if the best candidate is in the 10 first matches

In [3]:
import os
os.chdir('/Users/sebas/projects/metabolomics')
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"

In [4]:
import dill
from src.preprocessor import Preprocessor
from src.load_data import LoadData
from src.config import Config
from rdkit import Chem
import numpy as np

## Parameters

In [5]:
data_folder= '/Users/sebas/projects/data/'
dataset_path= data_folder +'edit_distance_neurips_nist_exhaustive.pkl'
output_janssen_file= data_folder + 'all_spectrums_janssen.pkl'
casmi_file= data_folder + 'processed_massformer/spec_df.pkl' 
USE_CASMI=True
USE_MCES_RAW=False

In [6]:
similarities_files= [
#'simba_EDIT_DISTANCE_MCES20_NEURIPS_20241115',
#'simba_EDIT_DISTANCE_MCES20_NEURIPS_precursor_randomized',
    'modified_cosine_casmi_NEURIPS',
    'ms2deepscore_v2_casmi_NEURIPS',
    #'simba_generated_data_20250126',
    'simba_generated_data_20250130',
'spec2vec_casmi_NEURIPS',
                    ]

#model_results_path='simba_EDIT_DISTANCE_MCES20_NEURIPS'
#model_results_path='spec2vec_casmi_NEURIPS'
#model_results_path='ms2deepscore_casmi_NEURIPS'

In [7]:
tanimoto_results_path = 'tanimoto_similarity_casmi_NEURIPS'

In [8]:
format_file_unknown = "./notebooks/discovery_search/results/{}_results_analog_discovery_unknwon_compounds.pkl"

In [9]:
config=Config()

## Load reference dataset

In [10]:
pp=Preprocessor()


In [11]:
with open(dataset_path, 'rb') as file:
            dataset = dill.load(file)

In [12]:
all_spectrums_reference_original= dataset['molecule_pairs_train'].spectrums_original +\
                    dataset['molecule_pairs_val'].spectrums_original + \
                        dataset['molecule_pairs_test'].spectrums_original

In [13]:
import copy
all_spectrums_reference_processed= [copy.deepcopy(s) for s in all_spectrums_reference_original]

In [14]:
all_spectrums_reference_processed = [pp.preprocess_spectrum(
            s,
            fragment_tol_mass=10,
            fragment_tol_mode="ppm",
            min_intensity=0.01,
            max_num_peaks=100,
            scale_intensity=None,
        ) for s in all_spectrums_reference_processed]
    
all_spectrums_reference = [s_original for s_original, s_processed in zip(all_spectrums_reference_original,all_spectrums_reference_processed) if len(s_processed.mz)>=6]

## Load query spectra

In [15]:
if USE_CASMI:
    all_spectrums_janssen_original=LoadData.get_all_spectrums_casmi(
        casmi_file,
        config=config,
    )
else:
    loader_saver = LoaderSaver(
            block_size=100,
            pickle_nist_path='',
            pickle_gnps_path='',
            pickle_janssen_path=output_janssen_file,
        )
    all_spectrums_janssen_original = loader_saver.get_all_spectrums(
                janssen_path,
                100000000,
                use_tqdm=True,
                use_nist=False,
                config=config,
                use_janssen=True,
            )

In [16]:
all_spectrums_janssen_processed= [copy.deepcopy(s) for s in all_spectrums_janssen_original]

In [17]:
pp=Preprocessor()
### remove extra peaks in janssen
all_spectrums_janssen_processed = [pp.preprocess_spectrum(
            s,
            fragment_tol_mass=10,
            fragment_tol_mode="ppm",
            min_intensity=0.01,
            max_num_peaks=100,
            scale_intensity=None,
        ) for s in all_spectrums_janssen_processed]

all_spectrums_janssen = [s_original for s_original, s_processed in zip(all_spectrums_janssen_original,all_spectrums_janssen_processed) if len(s_processed.mz)>=6]

## Get only the spectra that is not present

In [18]:
canon_smiles_reference = [Chem.CanonSmiles(s.smiles) for s in all_spectrums_reference]
canon_smiles_janssen =   [Chem.CanonSmiles(s.smiles) for s in all_spectrums_janssen]
janssen_indexes_in_ref= [i for i,s in enumerate(canon_smiles_janssen) if s in canon_smiles_reference]
janssen_indexes_not_in_ref = [i for i,s in enumerate(canon_smiles_janssen) if s not in canon_smiles_reference]

In [19]:
len(janssen_indexes_in_ref),len(janssen_indexes_not_in_ref)

(26, 132)

In [20]:
all_spectrums_janssen = [all_spectrums_janssen[index] for index in janssen_indexes_not_in_ref]

## Load similarity results

## Load results for simba and spec2vec

In [21]:
model_results = {}
for sim in similarities_files:
    with open(format_file_unknown.format(sim), 'rb') as f:
                model_results[sim]= dill.load(f)

In [22]:
with open(format_file_unknown.format(tanimoto_results_path), 'rb') as f:
                tanimoto_results= dill.load(f)

## Get similarities of mod cosine

In [23]:
similarities_modcos=model_results['modified_cosine_casmi_NEURIPS']['similarities']

## Get similarities of simba

In [44]:
similarities_simba= model_results['simba_generated_data_20250130']['similarities']

In [60]:
similarities_simba_integer = np.array((similarities_simba*40), dtype=np.int32)

In [24]:
# load the mces
similarities1=model_results['simba_generated_data_20250130']['similarities1']
similarities2=model_results['simba_generated_data_20250130']['similarities2']

In [25]:
similarities2_norm = np.argmax(similarities2, axis=2)/5

In [26]:
similarities2_integer = similarities2_norm*5

In [27]:
similarities1_integer = np.array((similarities1*40), dtype=np.int32)

## Binarize modcos

In [28]:
similarities_modcos_bin =similarities_modcos.copy()
condition= (similarities_modcos_bin>0.5)
similarities_modcos_bin[condition]=1
similarities_modcos_bin[~condition]=0

In [29]:
plt.hist(similarities_modcos[0])

NameError: name 'plt' is not defined

In [32]:
similarities1

array([[0.12229509, 0.13228802, 0.1703687 , ..., 0.00555182, 0.00893946,
        0.00569157],
       [0.36447302, 0.35750628, 0.39829499, ..., 0.09419428, 0.10468589,
        0.20606658],
       [0.22454047, 0.22922032, 0.27092642, ..., 0.01546701, 0.02174977,
        0.02201709],
       ...,
       [0.56217504, 0.58433819, 0.59804761, ..., 0.43300492, 0.43076649,
        0.46994081],
       [0.75193834, 0.69264817, 0.71157503, ..., 0.55995375, 0.54186875,
        0.52859634],
       [0.55429602, 0.52895939, 0.5171144 , ..., 0.40146884, 0.39670485,
        0.38093945]])

In [33]:
import numpy as np

def compute_ranking(similarities1, similarities1_integer, similarities2_integer, max_value_2_int=1):
    # Preallocate the ranking array with the same shape as similarities1.
    ranking_total = np.zeros(similarities1.shape, dtype=int)
    
    # Process each row (or each set of values) individually.
    for row_index, (row_sim, row_int, row_int2) in enumerate(zip(similarities1, similarities1_integer, similarities2_integer)):
        # Use lexsort with a composite key:
        #   - Primary: similarities1_integer (ascending)
        #   - Secondary: similarities2_integer (ascending)
        #   - Tertiary: similarities1 (descending, so use -row_sim)
        #
        # Note: np.lexsort uses the last key as the primary key.
        sorted_indices = np.lexsort( ( row_sim, row_int2, row_int ) )
        
        # Now assign ranking values based on sorted order.
        # Here the best (first in sorted_indices) gets rank 0,
        # the next gets rank 1, etc.
        ranking = np.empty_like(sorted_indices)
        ranking[sorted_indices] = np.arange(len(row_sim))
        
        # Store the ranking for this row.
        ranking_total[row_index] = ranking

    #normalizing
    ranking_total =ranking_total/ranking_total.shape[1]
    return ranking_total

In [34]:
import numpy as np

def compute_ranking_3metrics(similarities1, similarities1_integer, similarities2_integer, modcosine, max_value_2_int=1):
    # Preallocate the ranking array with the same shape as similarities1.
    ranking_total = np.zeros(similarities1.shape, dtype=int)
    
    # Process each row (or each set of values) individually.
    for row_index, (row_sim, row_int, row_int2, mod) in enumerate(zip(similarities1, similarities1_integer, similarities2_integer,modcosine)):
        # Use lexsort with a composite key:
        #   - Primary: similarities1_integer (ascending)
        #   - Secondary: similarities2_integer (ascending)
        #   - Tertiary: similarities1 (descending, so use -row_sim)
        #
        # Note: np.lexsort uses the last key as the primary key.
        sorted_indices = np.lexsort( ( row_sim, mod, row_int2, row_int ) )
        
        # Now assign ranking values based on sorted order.
        # Here the best (first in sorted_indices) gets rank 0,
        # the next gets rank 1, etc.
        ranking = np.empty_like(sorted_indices)
        ranking[sorted_indices] = np.arange(len(row_sim))
        
        # Store the ranking for this row.
        ranking_total[row_index] = ranking

    #normalizing
    ranking_total =ranking_total/ranking_total.shape[1]
    return ranking_total

In [36]:
similarities1_integer=np.around(similarities1*40).astype(int)

In [61]:
#ranking_total = compute_ranking_3metrics(similarities_simba, similarities1_integer, similarities2_integer, similarities_modcos_bin)
ranking_total = compute_ranking(similarities_simba, similarities_simba_integer, similarities_modcos_bin, )


In [62]:
ranking_total

array([[0.45200347, 0.4721797 , 0.52551628, ..., 0.02141498, 0.03519959,
        0.13493936],
       [0.78189953, 0.76243247, 0.86788818, ..., 0.04501225, 0.07053614,
        0.27617598],
       [0.48597039, 0.49756053, 0.5672688 , ..., 0.02798314, 0.03985634,
        0.04034941],
       ...,
       [0.81779002, 0.85784115, 0.87632512, ..., 0.26188614, 0.25373834,
        0.46877692],
       [0.96445648, 0.90631097, 0.93250445, ..., 0.55088494, 0.47194229,
        0.42426078],
       [0.85792942, 0.81674301, 0.78373484, ..., 0.21545251, 0.18666281,
        0.14226537]])

In [63]:
results ={ 'smiles_janssen' : model_results['simba_generated_data_20250130']['smiles_janssen'], 
        'smiles_reference' : model_results['simba_generated_data_20250130']['smiles_reference'], 
        'mces_retrieved':None,
          'similarities':ranking_total,
         }
with open('./notebooks/discovery_search/results/'+ 'simba_modcos' + '_results_analog_discovery_unknwon_compounds.pkl', 'wb') as f:
    dill.dump(results, f)

In [64]:
np.max(ranking_total[0,:])

0.9999969563695577

In [65]:
np.argmax(ranking_total[0,:])

84365

In [66]:
np.argmin(ranking_total[0,:])

920

In [67]:
similarities_simba[0,920]

0.0

In [68]:
np.argmax(similarities_simba[0,:])

84365

In [69]:
similarities_simba[0,269444]

0.9991660452587847

In [70]:
similarities_simba[0,84365]

0.9999969563695577

In [71]:
similarities_modcos_bin[0,269444]

0.0

In [72]:
similarities_modcos_bin[0,84365]

1.0

In [73]:
similarities1_integer[0,84365]

40

In [74]:
similarities1_integer[0,269444]

40

In [75]:
dsds

NameError: name 'dsds' is not defined