In [1]:
import pandas as pd
import re
from tqdm import tqdm
import numpy as np
import itertools
from matplotlib import pyplot as plt

from ms_entropy.file_io.msp_file import read_one_spectrum
import numba as nb
import pyteomics.mgf
import sys
import importlib
import re

sys.path.append("../../src/ms_similarity_metrics/")
from create_spectrum import smile2inchi
importlib.reload(sys.modules['create_spectrum'])
from create_spectrum import smile2inchi


In [29]:
QUERY_DIR = '../../data/spec_entropy_queries/'
QUERY_PATH = 's3://enveda-data-user/chloe.engler/cosine_similarity/spec_entropy_queries/'
DATA_DIR = '../../data/'

FILE = 'weighted_filtered_10_ppm'

# Get queries

In [30]:
# List of metrics that were used in the queries
metric_names = ['bhattacharya_2', 'manhattan', 'entropy']

# Read in the queries
# queries = pd.read_csv(f'{QUERY_PATH}{UNWEIGHTED_FILE}.csv', index_col=0)
# weighted_queries = pd.read_csv(f'{QUERY_PATH}{WEIGHTED_FILE}.csv', index_col=0)
queries = pd.read_csv(f'{QUERY_DIR}{FILE}.csv', index_col=0)

In [31]:
# Reformat the library_spectra_matches column for the unweighted queries
new_columns = {name:{} for name in metric_names}
non_decimal = re.compile(r'[^\d.]+')
for query in tqdm(queries.index.values):
    metric_scores = queries.loc[query,'library_spectra_matches'].split('{')[1].split('}')[0].split(']')[:-1]
    for i in range(len(metric_scores)):
        matches = []
        current_query = metric_scores[i]
        col_name = current_query.split(':')[0].replace("'",'').replace(',','').strip()
        current_query = current_query.split('[')[1]
        current_query = current_query.replace("'", "").replace(')', '').split('(')[1:]
        for pair in current_query:
            matches.append((pair.split(',')[0], float(non_decimal.sub('', pair.split(',')[1]))))
        new_columns[col_name][query] = matches

# Create new column for each metric
for name in new_columns.keys():
    queries[name] = queries.index.map(new_columns[name])


100%|██████████| 25437/25437 [00:03<00:00, 8148.34it/s] 


# Get NIST23 data

In [5]:
# Get NIST23 library
spectra_list = []
for i,spectrum in tqdm(enumerate(read_one_spectrum(f'{DATA_DIR}NIST23-HR-MSMS.MSP'))):
    spectra_list.append(spectrum)

1934658it [03:07, 10302.96it/s]


In [6]:
# Create NIST dataframe
nist_df = pd.DataFrame(spectra_list)
nist_df = nist_df[nist_df['precursor_type'] == '[M+H]+' ]

# Get Wout data

In [7]:
# Profile spectra contain 0 intensity values.
@nb.njit
def is_centroid(intensity_array):
    return np.all(intensity_array > 0)

In [10]:
# Read all spectra from the MGF.
spectra = []

# Download from https://zenodo.org/record/6829249/files/ALL_GNPS_NO_PROPOGATED.mgf?download=1
filename = (f"{DATA_DIR}ALL_GNPS_NO_PROPOGATED.mgf")

# Get wout spectra
with pyteomics.mgf.MGF(filename) as f_in:
    for spectrum_dict in tqdm(f_in):
        spectra.append(spectrum_dict)

495600it [01:59, 4140.34it/s] 


In [11]:
# Create wout dataframe
wout_df = pd.DataFrame(spectra)
wout_df = pd.concat([wout_df.drop(['params'], axis=1), wout_df['params'].apply(pd.Series)], axis=1)
wout_df.head(2)

Unnamed: 0,m/z array,intensity array,charge array,pepmass,charge,mslevel,source_instrument,filename,seq,ionmode,...,pi,datacollector,smiles,inchi,inchiaux,pubmed,submituser,libraryquality,spectrumid,scans
0,"[289.286377, 295.545288, 298.489624, 317.32495...","[8068.0, 22507.0, 3925.0, 18742.0, 8604.0, 804...","[--, --, --, --, --, --, --, --, --, --, --, -...","(981.54, None)",[0+],2,LC-ESI-qTof,130618_Ger_Jenia_WT-3-Des-MCLR_MH981.4-qb.1.1....,*..*,Positive,...,Gerwick,Jenia,CC(C)CC1NC(=O)C(C)NC(=O)C(=C)N(C)C(=O)CCC(NC(=...,,,,mwang87,1,CCMSLIB00000001547,1
1,"[278.049927, 278.957642, 281.258667, 291.99609...","[35793.0, 47593.0, 95495.0, 115278.0, 91752.0,...","[--, --, --, --, --, --, --, --, --, --, --, -...","(940.25, None)",[0+],2,LC-ESI-qTof,20111105_Anada_Ger_HoiamideB_MH940_qb.1.1..mgf,*..*,Positive,...,Gerwick,Amanda,CCC[C@@H](C)[C@@H]([C@H](C)[C@@H]1[C@H]([C@H](...,InChI=1S/C45H73N5O10S3/c1-14-17-24(6)34(52)26(...,,,mwang87,1,CCMSLIB00000001548,1


In [12]:
# Get wout metadata
metadata = pd.read_csv(
    'https://zenodo.org/record/6829249/files/gnps_libraries_metadata.csv?download=1'
)
metadata.set_index('id', inplace=True)

In [13]:
wout_df = wout_df.set_index('spectrumid')

# Get library for each query spectra

In [32]:
# Remove wout spectra from NIST
queries['wout_library'] = list(metadata.loc[queries['wout_identifier'],'library'])
queries = queries[queries['wout_library'] != 'GNPS-NIST14-MATCHES']

# Get smiles

In [33]:
# Get nist smiles dict
nist_smiles_dict = {}
for index in tqdm(nist_df.index.values):
    nist_smiles_dict[index] = nist_df.loc[index,'smiles']

# Get wout smiles dict
wout_smiles_dict = {}
for index in tqdm(wout_df.index.values):
    wout_smiles_dict[index] = wout_df.loc[index,'smiles']

100%|██████████| 567631/567631 [00:04<00:00, 133887.31it/s]
100%|██████████| 495600/495600 [00:03<00:00, 133250.52it/s]


In [34]:
# Get wout smiles
for i in tqdm(queries.index.values):
    wout_id = queries.loc[i, 'wout_identifier']
    queries.loc[i, 'wout_smiles'] = wout_smiles_dict[wout_id]

100%|██████████| 25437/25437 [00:01<00:00, 13682.57it/s]


# Get inchikeys

In [35]:
# Get NIST23 partial inchikeys for unweighted queries
inchi_dict = {name:{} for name in metric_names}
for i in tqdm(queries.index.values):
    for name in metric_names:
        inchi_list = []
        for pair in queries.loc[i, name]:
            index = int(pair[0].split('_')[0])
            inchi_list.append(nist_df.loc[int(pair[0].split('_')[0]),'inchikey'][:14])
        inchi_dict[name][i] = inchi_list
for name in metric_names:
    queries[f'{name}_inchis'] = queries.index.map(inchi_dict[name])

  0%|          | 0/25437 [00:00<?, ?it/s]

100%|██████████| 25437/25437 [00:10<00:00, 2377.98it/s]


In [36]:
# Get wout partial inchikeys for queries
for i in tqdm(queries.index.values):
    inchi = smile2inchi(queries.loc[i, 'wout_smiles'])
    queries.loc[i, 'wout_inchi'] = inchi[:14]

100%|██████████| 25437/25437 [00:18<00:00, 1342.32it/s]


# Look at exact matches

In [37]:
# Get indexes of exact matches for queries
all_matches = {name:{} for name in metric_names}
for i in tqdm(queries.index.values):
    for name in metric_names:
        exact_matches = np.where(np.array(list(queries.loc[i, f'{name}_inchis'])) == queries.loc[i, 'wout_inchi'])[0]
        all_matches[name][i] = exact_matches
for name in metric_names:
    queries[f'{name}_exact_matches'] = queries.index.map(all_matches[name])


  exact_matches = np.where(np.array(list(queries.loc[i, f'{name}_inchis'])) == queries.loc[i, 'wout_inchi'])[0]
100%|██████████| 25437/25437 [00:01<00:00, 15679.33it/s]


In [38]:
# Get percent of exact matches for queries
nist_df['partial_inchikey'] = nist_df['inchikey'].str[:14]
for i in tqdm(queries.index.values):
    queries.loc[i, 'num_inchi_matches'] = len(nist_df[nist_df['partial_inchikey'] == queries.loc[i, 'wout_inchi']])
    for name in metric_names:
        queries.loc[i, f'{name}_percent_exact'] = len(queries.loc[i, f'{name}_exact_matches'])/queries.loc[i, 'num_inchi_matches']

100%|██████████| 25437/25437 [12:55<00:00, 32.79it/s]


# Look at top n matches

In [39]:
from rdkit import Chem, DataStructs
import functools

@functools.lru_cache
def _smiles_to_mol(smiles):
    try:
        return Chem.MolFromSmiles(smiles)
    except:
        return None
@functools.lru_cache
def tanimoto(smiles1, smiles2):
    mol1, mol2 = _smiles_to_mol(smiles1), _smiles_to_mol(smiles2)
    if mol1 is None or mol2 is None:
        return np.nan
    fp1, fp2 = Chem.RDKFingerprint(mol1), Chem.RDKFingerprint(mol2)
    return DataStructs.TanimotoSimilarity(fp1, fp2)

In [40]:
# Get NIST23 smiles for queries
all_smiles_dict = {name: {} for name in metric_names}
for i in tqdm(queries.index.values):
    for name in metric_names:
        current_matches = queries.loc[i, name]
        if len(current_matches) == 0:
            all_smiles_dict[name][i] = []
        else:
            nist_indexes = [int(x.split('_')[0]) for x in np.array(queries.loc[i, name])[:,0]]
            smiles_list = nist_df.loc[nist_indexes, 'smiles'].values
            all_smiles_dict[name][i] = smiles_list

for name in metric_names:
    queries[f'{name}_smiles'] = queries.index.map(all_smiles_dict[name])

  0%|          | 0/25437 [00:00<?, ?it/s]

100%|██████████| 25437/25437 [00:20<00:00, 1221.85it/s]


In [41]:
# Get dictionary of tanimoto scores
all_tanimotos = {name:{} for name in metric_names}

# Get tanimoto scores for queries
for i in tqdm(queries.index.values):
    for name in metric_names:
        tanimotos = []
        query_smiles = queries.loc[i, 'wout_smiles']
        if len(queries.loc[i, f'{name}_smiles']) != 0:
            for library_smiles in queries.loc[i, f'{name}_smiles']:
                tanimotos.append(tanimoto(query_smiles, library_smiles))
            all_tanimotos[name][i] = [x for x in tanimotos if not pd.isna(x)]
        else:
            all_tanimotos[name][i] = []
for name in metric_names:
    queries[f'{name}_tanimoto'] = queries.index.map(all_tanimotos[name])

100%|██████████| 25437/25437 [02:19<00:00, 182.93it/s] 


# Look at AUC scores

In [42]:
from sklearn.metrics import roc_curve, auc

no_matches = 0

# Get AUC scores for queries
for index in tqdm(queries.index.values):
    for name in metric_names:
        if len(np.array(list(queries.loc[index,name]))) != 0:
            prob = np.array(list(queries.loc[index,name]))[:,1].astype('float')
            y_true = np.zeros(len(prob))
            y_true[queries.loc[index,f'{name}_exact_matches']] = 1

            # check if there arent any 1.0s in y_true
            if np.sum(y_true) == 0:
                queries.loc[index, f'{name}_auc'] = 0
            # check if all values are 1.0
            elif np.sum(y_true) == len(y_true):
                queries.loc[index, f'{name}_auc'] = 1
            else:
                fpr, tpr, thresholds = roc_curve(y_true, prob)
                queries.loc[index, f'{name}_auc'] = auc(fpr, tpr)
        else:
            no_matches += 1
            queries.loc[index, f'{name}_auc'] = np.nan

100%|██████████| 25437/25437 [00:37<00:00, 674.18it/s] 


In [43]:
queries.to_csv(f'{QUERY_PATH}{FILE}_with_stats.csv')
queries.to_csv(f'{QUERY_DIR}{FILE}_with_stats.csv')