# BLINK Speed Benchmarking

In [None]:
from IPython.core.display import display, HTML
display(HTML("<style>.container { width:100% !important; }</style>"))

import sys
sys.path.insert(0, '../')

import blink

import time
import pickle
import math
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import confusion_matrix

import matchms as mms
from matchms.similarity import CosineGreedy, ModifiedCosine

from ms_entropy import FlashEntropySearch

# Load Test Data

In [None]:
def verify_mz_order(spec):
    """
    verify that mz array is sorted
    """
    is_sorted = np.all(spec[0][:-1] <= spec[0][1:])
    
    return is_sorted

def create_mms_spectra(row):
    """
    create MatchMS formated spectra
    """
    cols = ['name', 'precursor_mz','inchi', 'smiles', 'spectrumid']
    metadata = row[cols].to_dict()
    spectrum = mms.Spectrum(mz=np.array(row['spectrum'][0], dtype="float"), intensities=np.array(row['spectrum'][1], dtype="float"), metadata=metadata)
    
    return spectrum

def generate_sample_spectra(query_size, ref_size, msms_library):
    """
    sample spectra from test library
    """
    query_sample = msms_library.sample(query_size)
    ref_sample = msms_library.sample(ref_size)
    
    return query_sample, ref_sample

def create_mms_spectra(row):
    """
    create MatchMS formated spectra
    """
    cols = ['name', 'precursor_mz','inchi', 'smiles', 'spectrumid']
    metadata = row[cols].to_dict()
    spectrum = mms.Spectrum(mz=np.array(row['spectrum'][0], dtype="float"), intensities=np.array(row['spectrum'][1], dtype="float"), metadata=metadata)
    
    return spectrum

def remove_noise_ions(s):
    """
    remove ions <1% of base peak intensity 
    """
    s_ratio = s[1] / s[1].max()
    idx = s_ratio > 0.01
    s_clean = np.array([s[0][idx], s[1][idx]])
    
    return s_clean

def filter_spectra(row, decimal=4):
    """
    filter noise ions and round m/z values to a consistent decimal place
    """
    idx = np.argwhere(abs(row['spectrum'][0]-row['precursor_mz'])>14).flatten()
    s = row['spectrum'][:,idx]
    s_filtered = s.round(decimal)
    
    return s_filtered

def round_precursor_mz(row, decimal=4):
    """
    round precursor m/z 
    """
    r_pmz = round(row['precursor_mz'], decimal)
    
    return r_pmz

def create_entropy_spectra(spec):
    """
    convert spectra to peak lists used in ms_entropy
    """
    entropy_spec = []
    for peak_mz, peak_i in zip(spec[0], spec[1]):
        entropy_spec.append([peak_mz, peak_i])
        
    return entropy_spec

In [None]:
gnps_all = blink.open_msms_file('/global/cfs/cdirs/metatlas/projects/spectral_libraries/ALL_GNPS_20221017.mgf')

#remove incorrectly sorted spectra and spectra with small precursor mzs
gnps_all['is_sorted'] = gnps_all.spectrum.apply(verify_mz_order)
gnps_all = gnps_all[gnps_all.is_sorted == True]
gnps_all = gnps_all[gnps_all.precursor_mz > 60]

#remove all zero intensity ions
gnps_all.spectrum = blink.spectral_normalization._filter_spectra(gnps_all.spectrum)

#remove fragment ions greater than precursor mz and round values consistently
gnps_all.spectrum = gnps_all.apply(lambda x: filter_spectra(x), axis=1)

#remove empty spectra
gnps_all['spec_size'] = gnps_all.spectrum.apply(lambda x: len(x[1]))
gnps_all = gnps_all[gnps_all.spec_size > 0]

#remove probable profile mode spectra
gnps_all['median_mz_diff'] = gnps_all.spectrum.apply(lambda x: np.median(np.diff(x[0])))
gnps_all['num_ions'] = gnps_all.spectrum.apply(lambda x: len(x[0]))
gnps_all = gnps_all[(gnps_all['median_mz_diff'] > 0.8) & (gnps_all['num_ions'] < 800)].reset_index()

#remove ions less than 1% of base peak intensity for higher quality scores
gnps_all.spectrum = gnps_all.spectrum.apply(remove_noise_ions)

#remove duplicate noise ions
gnps_all.spectrum = blink.spectral_normalization._remove_duplicate_ions(gnps_all.spectrum, min_diff=0.01)

# Speed Comparison

### Define Speed Benchmarking Parameters

In [None]:
replicate = 3

#Loop parameters
iteration_num = 7
multiplier = 10
initial_query_size = 10
initial_ref_size = 10

#MatchMS parameters
cos = CosineGreedy(tolerance=0.0099, intensity_power=0.5)

#BLINK parameters
bin_width = 0.0001
tolerance = 0.01

#FlashEntropy parameters
entropy_search = FlashEntropySearch()
flash_tolerance = 0.0099

### Compute Comparison

In [None]:
index = 0
iteration = 1
multiplier_sqrt = math.sqrt(multiplier)

speed_test_results = {'query_spectra_num':{}, 'ref_spectra_num':{}, 'blink_time':{}, 'mms_time':{}, 'flash_time':{}, 
                      'blink_setup_time':{}, 'mms_setup_time':{}, 'flash_setup_time':{}, 'replicate':{}}

query_size = initial_query_size
ref_size = initial_ref_size

while iteration <= iteration_num: 
    
    print("iteration {iteration} of {iteration_num} start".format(iteration=iteration, iteration_num=iteration_num))

    query_sample, ref_sample = generate_sample_spectra(query_size, ref_size, gnps_all)
    
    query_spectra = query_sample.spectrum.tolist()
    ref_spectra = ref_sample.spectrum.tolist()
    
    query_precursor_mzs = query_sample.precursor_mz.tolist()
    ref_precursor_mzs = ref_sample.precursor_mz.tolist()

    MMS1 = query_sample.apply(lambda x: create_mms_spectra(x), axis=1)
    MMS2 = ref_sample.apply(lambda x: create_mms_spectra(x), axis=1)
    
    cols = ['index','precursor_mz','spectrum','scans']
    e_small = query_sample[cols].copy()
    e_medium = ref_sample[cols].copy()
    e_small['spectrum'] = e_small['spectrum'].apply(lambda x: [list(i) for i in x.T])
    e_medium['spectrum'] = e_medium['spectrum'].apply(lambda x: [list(i) for i in x.T])
    e_small.rename(columns={'spectrum':'peaks'},inplace=True)
    e_medium.rename(columns={'spectrum':'peaks','index':'ref'},inplace=True)
    e_small = e_small.to_dict('records')
    e_medium = e_medium.to_dict('records')
    
    t0 = time.time()
    S1 = blink.discretize_spectra(query_spectra,  ref_spectra, query_precursor_mzs, ref_precursor_mzs, intensity_power=0.5, bin_width=bin_width, tolerance=tolerance)
    t1 = time.time()
    
    blink_setup_time = t1 - t0

    t0 = time.time()
    S12 = blink.score_sparse_spectra(S1)
    t1 = time.time()

    blink_time = t1 - t0

    t0 = time.time()
    MMS12 = cos.matrix(references=MMS1, queries=MMS2)
    t1 = time.time()

    mms_time = t1 - t0
    mms_setup_time = 0
    
    t0 = time.time()
    e_ref = entropy_search.build_index(e_medium)
    t1 = time.time()
    
    flash_setup_time = t1 - t0
    
    t0 = time.time()
    for i,s in enumerate(e_small):
        entropy_similarity = entropy_search.search(precursor_mz=s['precursor_mz'], peaks=s['peaks'], ms2_tolerance_in_da=flash_tolerance,method='open')
    t1 = time.time()
    
    flash_time = t1 - t0
    
    speed_test_results['query_spectra_num'][index] = query_size
    speed_test_results['ref_spectra_num'][index] = ref_size
    speed_test_results['replicate'][index] = replicate
    speed_test_results['blink_time'][index] = blink_time
    speed_test_results['mms_time'][index] = mms_time
    speed_test_results['flash_time'][index] = flash_time
    speed_test_results['blink_setup_time'][index] = blink_setup_time
    speed_test_results['mms_setup_time'][index] = mms_setup_time
    speed_test_results['flash_setup_time'][index] = flash_setup_time

    query_size = round(query_size * multiplier_sqrt)
    ref_size = round(ref_size * multiplier_sqrt)

    print("iteration {iteration} of {iteration_num} end".format(iteration=iteration, iteration_num=iteration_num))
    
    if iteration == iteration_num:
        with open('blink_scores_replicate0{num}.pickle'.format(num=replicate), 'wb') as output_file:
            pickle.dump(S12, output_file, protocol=pickle.HIGHEST_PROTOCOL)
            
        with open('mms_scores_replicate0{num}.pickle'.format(num=replicate), 'wb') as output_file:
            pickle.dump(MMS12, output_file, protocol=pickle.HIGHEST_PROTOCOL)

    index += 1
    iteration += 1

### Save & Plot Results

In [None]:
df = pd.DataFrame.from_dict(speed_test_results, orient='columns')
df['comparisons'] = df['query_spectra_num'] * df['ref_spectra_num']

In [None]:
df['blink_total_time'] = df['blink_time'] + df['blink_setup_time']
df['flash_total_time'] = df['flash_time'] + df['flash_setup_time']
df['mms_total_time'] = df['mms_time']

In [None]:
plot_df = df.groupby('comparisons')[['blink_time', 'mms_time', 'flash_time']].mean()
plot_df.reset_index(inplace=True)

plt.plot(plot_df['comparisons'], plot_df['blink_time'], '-o', label='BLINK')
plt.plot(plot_df['comparisons'], plot_df['mms_time'], '-o', label='MatchMS')
plt.plot(plot_df['comparisons'], plot_df['flash_time'], '-o', label='Flash Entropy')

plt.ylabel('Compute Time (seconds)')
plt.xlabel('# Comparisons')
plt.yscale('log')
plt.xscale('log')
plt.legend()
plt.show()

In [None]:
plot_df = df.groupby('comparisons')[['blink_total_time', 'mms_total_time', 'flash_total_time']].median()
plot_df.reset_index(inplace=True)

plt.plot(plot_df['comparisons'], plot_df['blink_total_time'], '-o', label='BLINK')
plt.plot(plot_df['comparisons'], plot_df['mms_total_time'], '-o', label='MatchMS')
plt.plot(plot_df['comparisons'], plot_df['flash_total_time'], '-o', label='Flash Entropy')

plt.ylabel('Compute Time (seconds)')
plt.xlabel('# Comparisons')
plt.yscale('log')
plt.xscale('log')
plt.legend()
plt.show()

In [None]:
df.to_csv('cos_speed_benchmarking_replicate0{num}.csv'.format(num=replicate))

# Merge Replicates & Generate Final Speed Plot

In [None]:
rep1 = pd.read_csv('cos_speed_benchmarking_replicate01.csv', index_col=0)
rep2 = pd.read_csv('cos_speed_benchmarking_replicate02.csv', index_col=0)
rep3 = pd.read_csv('cos_speed_benchmarking_replicate03.csv', index_col=0)

final_df = pd.concat([rep1, rep2, rep3])

In [None]:
final_df = final_df[final_df.comparisons >= 1500]

In [None]:
plot_df = final_df.groupby('comparisons')[['blink_time', 'mms_time', 'flash_time']].median()
plot_df.reset_index(inplace=True)

fig, ax = plt.subplots(figsize=(14,10))
ax.plot(plot_df['comparisons'].values, plot_df['blink_time'].values, '-o', label='BLINK', linewidth=3, markersize=15)
ax.plot(plot_df['comparisons'].values, plot_df['mms_time'].values, '-o', label='MatchMS', linewidth=3, markersize=15)
ax.plot(plot_df['comparisons'].values, plot_df['flash_time'].values, '-o', label='Flash Entropy', linewidth=3, markersize=15)

plt.ylabel('Compute Time (seconds)', fontsize=40)
plt.xlabel('# Comparisons', fontsize=40)
ax.tick_params(axis='both', labelsize=36)
ax.tick_params(axis='both', labelsize=36)
plt.yscale('log')
plt.xscale('log')
plt.legend()
plt.legend(loc=2, prop={'size': 36})
plt.grid()
plt.show()

# fig.savefig('cos_speed_benchmark.pdf', bbox_inches="tight")

In [None]:
plot_df = final_df.groupby('comparisons')[['blink_setup_time', 'flash_setup_time']].median()
plot_df.reset_index(inplace=True)

fig, ax = plt.subplots(figsize=(14,10))
ax.plot(plot_df['comparisons'].values, plot_df['blink_setup_time'].values, '-o', label='BLINK', linewidth=3, markersize=15)
ax.plot(plot_df['comparisons'].values, plot_df['flash_setup_time'].values, '-o', label='Flash Entropy', linewidth=3, markersize=15, color='g')

plt.ylabel('Compute Time (seconds)', fontsize=40)
plt.xlabel('# Comparisons', fontsize=40)
ax.tick_params(axis='both', labelsize=36)
ax.tick_params(axis='both', labelsize=36)
plt.yscale('log')
plt.xscale('log')
plt.legend()
plt.legend(loc=2, prop={'size': 36})
plt.grid()
plt.show()

fig.savefig('cos_setup-speed_benchmark.pdf', bbox_inches="tight")

# Load Scores and Generate Final Agreement Matrix

In [None]:
with open('blink_scores_replicate04.pickle', 'rb') as input_file:
    S12 = pickle.load(input_file)
    
with open('mms_scores_replicate04.pickle', 'rb') as input_file:
    MMS12 = pickle.load(input_file)

In [None]:
#filter scores using GNPS default
good_score = 0.7
good_matches = 6

In [None]:
idx1 = S12['mzi'].toarray().flatten()>=good_score
idx2 = S12['mzc'].toarray().flatten()>=good_matches

idx3 = MMS12['score'].flatten()>=good_score
idx4 = MMS12['matches'].flatten()>=good_matches

blink_ids = idx1 * idx2  
matchms_ids = idx3 * idx4

# cm_norm = confusion_matrix(matchms_ids, blink_ids, normalize='pred')
cm = confusion_matrix(matchms_ids, blink_ids)
cm_norm = cm / cm.astype(float).sum(axis=0)
df = pd.DataFrame(cm)

# perc = df.copy()
# cols=perc.columns.values
# perc[cols]=perc[cols].div(perc[cols].sum(axis=1), axis=0).multiply(100)
perc = pd.DataFrame(data=cm_norm * 100)
annot = perc.round(2).astype(str) + "%" + "\n" + df.round(2).astype(str) 

fig, ax = plt.subplots(figsize=(11.5, 10))
ax = sns.heatmap(cm_norm, annot=annot, fmt='', vmin=0, vmax=1, cmap="Blues", 
            annot_kws={"fontsize":36}, linewidth=1,linecolor='black' ,xticklabels=['Dissimilar', 'Similar'], yticklabels=['Dissimilar', 'Similar'])

ax.tick_params(labelsize=36)
cbar = ax.collections[0].colorbar
cbar.ax.tick_params(labelsize=36)
cbar.set_ticks([0, .2, .75, 1])
cbar.set_ticklabels(['low', '20%', '75%', '100%'])

plt.xlabel('BLINK', fontsize=40)
plt.ylabel('MatchMS', fontsize=40)
plt.show()

# fig.savefig('cos_confusion_matrix.pdf')

# Max Score and Count Differences of True Positives

In [None]:
c = np.logical_and(blink_ids, matchms_ids)

blink_scores = S12['mzi'].toarray().flatten()
blink_counts = S12['mzc'].toarray().flatten()
mms_scores = MMS12['score'].flatten()
mms_counts = MMS12['matches'].flatten()

blink_nonzero = blink_scores > 0
mms_nonzero = mms_scores > 0

# mean_score_diff = np.mean(blink_scores[blink_nonzero] - mms_scores[mms_nonzero])
# mean_count_diff = np.mean(blink_counts[blink_nonzero] - mms_counts[mms_nonzero])

mean_score_diff = np.max(blink_scores[c] - mms_scores[c])
mean_count_diff = np.max(blink_counts[c] - mms_counts[c])
print("mean BLINK score difference: +{mean_score_diff}".format(mean_score_diff=mean_score_diff))
print("mean BLINK count difference: +{mean_count_diff}".format(mean_count_diff=mean_count_diff))

# Identity Search (1 vs 450 thousand spectra) Speed

In [None]:
#precursor mz tolerance in dalton
pmz_tol = 0.05

search_entries = gnps_all.sample(50)

search_pmzs = search_entries.precursor_mz.tolist()
search_spectra = search_entries.spectrum.tolist()

ref_pmzs = gnps_all.precursor_mz.tolist()

In [None]:
id_speed_test_results = {'blink_time':[], 'mms_time':[], 'flash_time':[], 
                      'blink_setup_time':[], 'mms_setup_time':[], 'flash_setup_time':[], 'replicate':[]}

for i, pmz in enumerate(search_pmzs):
    print('Start Search: ' + str(i))
    pmz_filter = np.isclose(pmz, ref_pmzs, atol=pmz_tol).nonzero()[0]
    
    search_row = search_entries.iloc[i]
    filtered_ref = gnps_all.iloc[pmz_filter]
    
    MMS1 = create_mms_spectra(search_row)
    MMS2 = filtered_ref.apply(lambda x: create_mms_spectra(x), axis=1)
    
    cols = ['index','precursor_mz','spectrum','scans']
    e_small = search_row[cols].copy()
    e_medium = filtered_ref[cols].copy()
    e_small['spectrum'] = [list(i) for i in e_small['spectrum'].T]
    e_medium['spectrum'] = e_medium['spectrum'].apply(lambda x: [list(i) for i in x.T])
    e_small.rename({'spectrum':'peaks'},inplace=True)
    e_medium.rename(columns={'spectrum':'peaks','index':'ref'},inplace=True)
    e_small = e_small.to_dict()
    e_medium = e_medium.to_dict('records')
    
    t0 = time.time()
    d = blink.discretize_spectra([search_row.spectrum], filtered_ref.spectrum.tolist(), search_row.precursor_mz.tolist(), filtered_ref.precursor_mz.tolist(),
                                 tolerance=0.01)
    t1 = time.time()
    
    blink_setup_time = t1 - t0
    
    t0 = time.time()
    blink_scores = blink.score_sparse_spectra(d)
    t1 = time.time()
    
    blink_time = t1 - t0
    
    t0 = time.time()
    MMS12 = cos.matrix(references=[MMS1], queries=MMS2)
    t1 = time.time()

    mms_time = t1 - t0
    mms_setup_time = 0
    
    t0 = time.time()
    e_ref = entropy_search.build_index(e_medium)
    t1 = time.time()
    
    flash_setup_time = t1 - t0
    
    t0 = time.time()
    entropy_similarity = entropy_search.search(precursor_mz=e_small['precursor_mz'], peaks=e_small['peaks'], ms2_tolerance_in_da=0.01, ms1_tolerance_in_da=pmz_tol, method='identity')
    t1 = time.time()
    
    flash_time = t1 - t0
    
    id_speed_test_results['replicate'].append(replicate) 
    id_speed_test_results['blink_time'].append(blink_time) 
    id_speed_test_results['mms_time'].append(mms_time) 
    id_speed_test_results['flash_time'].append(flash_time) 
    id_speed_test_results['blink_setup_time'].append(blink_setup_time) 
    id_speed_test_results['mms_setup_time'].append(mms_setup_time) 
    id_speed_test_results['flash_setup_time'].append(flash_setup_time) 

In [None]:
keys = ['blink_time', 'mms_time', 'flash_time']
plot_dict= dict((k, id_speed_test_results[k]) for k in keys)

fig, ax = plt.subplots()
ax.boxplot(plot_dict.values())
ax.set_xticklabels(plot_dict.keys())

plt.ylabel('Compute Time (seconds)')
plt.yscale('log')
# plt.legend()
plt.show()

In [None]:
plot_dict = {'blink_total_time':[], 'mms_total_time':[], 'flash_total_time':[]}

plot_dict['blink_total_time'] = np.array(id_speed_test_results['blink_time']) + np.array(id_speed_test_results['blink_setup_time'])
plot_dict['mms_total_time'] = np.array(id_speed_test_results['mms_time'])
plot_dict['flash_total_time'] = np.array(id_speed_test_results['flash_time']) + np.array(id_speed_test_results['flash_setup_time'])

In [None]:
fig, ax = plt.subplots()
ax.boxplot(plot_dict.values())
ax.set_xticklabels(plot_dict.keys())

plt.ylabel('Compute Time (seconds)')
plt.yscale('log')
# plt.legend()
plt.show()