In [None]:
import os
import sys
os.environ['GLEAMS_HOME'] = os.path.join(
    os.environ['HOME'], 'Projects', 'gleams')
# Make sure all code is in the PATH.
src_dir = os.path.normpath(os.path.join(os.environ['GLEAMS_HOME'], 'src'))
if src_dir not in sys.path:
    sys.path.append(src_dir)

In [None]:
import collections
import itertools

import joblib
import matplotlib.pyplot as plt
import numba as nb
import numpy as np
import pandas as pd
import seaborn as sns
from sklearn.metrics import auc, roc_curve

In [None]:
import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)

In [None]:
# Initialize logging.
from gleams import logger as glogger
glogger.init()
# Initialize all random seeds before importing any packages.
from gleams import rndm
rndm.set_seeds()

from gleams import config
from gleams.feature import spectrum
from gleams.ms_io import ms_io
from gleams.nn import embedder, data_generator, nn

In [None]:
import logging
logger = logging.getLogger('gleams')
logger.setLevel(logging.DEBUG)

In [None]:
# Plot styling.
plt.style.use(['seaborn-white', 'seaborn-paper'])
plt.rc('font', family='serif')
sns.set_palette('Set1')
sns.set_context('paper', font_scale=1.3)    # Single-column figure.

In [None]:
split = 'test'
num_pairs = 10_000_000

In [None]:
pair_generator = data_generator.PairSequence(
    os.path.join(
        os.environ['GLEAMS_HOME'], 'data', 'feature',
        f'feature_{config.massivekb_task_id}_{split}.npz'),
    os.path.join(
        os.environ['GLEAMS_HOME'], 'data', 'feature',
        f'feature_{config.massivekb_task_id}_{split}_pairs_pos.npy'),
    os.path.join(
        os.environ['GLEAMS_HOME'], 'data', 'feature',
        f'feature_{config.massivekb_task_id}_{split}_pairs_neg.npy'),
    config.batch_size, nn._get_feature_split(), num_pairs,
    False)

In [None]:
pair_metadata = pd.read_parquet(
    os.path.join(os.environ['GLEAMS_HOME'], 'data', 'feature',
                 f'feature_{config.massivekb_task_id}_{split}.parquet'),
    columns=['dataset', 'filename', 'scan'])

In [None]:
def _get_spectra_from_file(dataset, filename, scans):
    spectra = {}
    filepath = os.path.join(os.environ['GLEAMS_HOME'], 'data', 'peak',
                            dataset, filename)
    if not os.path.isfile(filepath):
        logger.warning('Missing peak file %s, no spectra read', filename)
    else:
        for spec in ms_io.get_spectra(filepath, scans):
            spectra[f'{dataset}/{filename}/{spec.identifier}'] = \
                spectrum.preprocess(spec, config.fragment_mz_min,
                                    config.fragment_mz_max)
    return spectra

In [None]:
dataset_total = pair_metadata['dataset'].nunique()
spectra = []
for dataset_i, (dataset, md_dataset) in enumerate(
        pair_metadata.groupby('dataset', sort=False), 1):
    logging.info('Process dataset %s (%d files) [%3d/%3d]', dataset,
                 md_dataset['filename'].nunique(), dataset_i, dataset_total)
    spectra.extend(joblib.Parallel(n_jobs=-1, backend='multiprocessing')(
        joblib.delayed(_get_spectra_from_file)(dataset, filename,
                                               md_file['scan'])
        for filename, md_file in md_dataset.groupby(
            'filename', sort=False)))
spectra = collections.ChainMap(*spectra)

In [None]:
@nb.njit(parallel=True)
def dot(spectra_arr1, spectra_arr2, out, fragment_mz_tol):
    for i in nb.prange(spectra_arr1.shape[0]):
        out[i] = spectrum.dot(
            spectra_arr1[i, 0], spectra_arr1[i, 1],
            spectra_arr2[i, 0], spectra_arr2[i, 1],
            fragment_mz_tol)
    return out

In [None]:
spectra_arr = [], []
for pair1, pair2 in itertools.chain(pair_generator.pairs_pos,
                                    pair_generator.pairs_neg):
    for pair_i, arr_i in zip([pair1, pair2], [0, 1]):
        spec = spectra[f"{pair_metadata.at[pair_i, 'dataset']}/"
                       f"{pair_metadata.at[pair_i, 'filename']}/"
                       f"{pair_metadata.at[pair_i, 'scan']}"]
        spectra_arr[arr_i].append(
            np.pad(
                [spec.mz, spec.intensity],
                ((0, 0), (config.max_peaks_used - len(spec.mz), 0)),
                'constant'))

fragment_mz_tol_high_res, fragment_mz_tol_low_res = 0.05, 0.8
labels = np.hstack((np.ones(len(pair_generator.pairs_pos), np.uint8),
                    np.zeros(len(pair_generator.pairs_neg), np.uint8)))
spectra_arr1 = np.asarray(spectra_arr[0])
spectra_arr2 = np.asarray(spectra_arr[1])
dot_high_res = dot(
    spectra_arr1, spectra_arr2, np.zeros(spectra_arr1.shape[0], np.float32),
    fragment_mz_tol_high_res)
dot_low_res = dot(
    spectra_arr1, spectra_arr2, np.zeros(spectra_arr1.shape[0], np.float32),
    fragment_mz_tol_low_res)

In [None]:
emb = embedder.Embedder(
    config.num_precursor_features, config.num_fragment_features,
    config.num_ref_spectra, config.lr, config.model_filename)
emb.load()

In [None]:
labels_embed, scores_embed = [], []
for batch_i in range(len(pair_generator)):
    batch_x, batch_y = pair_generator[batch_i]
    labels_embed.extend(batch_y)
    scores_embed.extend(
        emb.siamese_model_parallel.predict(batch_x).reshape(-1))
labels_embed = np.asarray(labels_embed)
scores_embed = np.asarray(scores_embed)
scores_embed = 1 - scores_embed / scores_embed.max()

In [None]:
joblib.dump([labels, labels_embed, dot_high_res, dot_low_res, scores_embed],
            'aucroc_dot.joblib')

In [None]:
# labels, labels_embed, dot_high_res, dot_low_res, scores_embed =\
#     joblib.load('aucroc_dot.joblib')

In [None]:
def concentrate_fpr(fpr, alpha):
    return (1 - np.exp(-alpha * fpr)) / (1 - np.exp(-alpha))

In [None]:
alpha = 14

In [None]:
width = 7
# height = width / 1.618
fig, ax = plt.subplots(figsize=(width, width))

fpr_high_res, tpr_high_res, _ = roc_curve(labels, dot_high_res)
croc_fpr_high_res = concentrate_fpr(fpr_high_res, alpha)
ax.plot(croc_fpr_high_res, tpr_high_res,
        label=f'Dot product high res '
              f'(AUCROC = {auc(croc_fpr_high_res, tpr_high_res):.2%})')

fpr_low_res, tpr_low_res, _ = roc_curve(labels, dot_low_res)
croc_fpr_low_res = concentrate_fpr(fpr_low_res, alpha)
ax.plot(croc_fpr_low_res, tpr_low_res,
        label=f'Dot product low res '
              f'(AUCROC = {auc(croc_fpr_low_res, tpr_low_res):.2%})')

fpr_embed, tpr_embed, _ = roc_curve(labels_embed, scores_embed)
croc_fpr_embed = concentrate_fpr(fpr_embed, alpha)
ax.plot(croc_fpr_embed, tpr_embed,
        label=f'Embedding '
              f'(AUCROC = {auc(croc_fpr_embed, tpr_embed):.2%})')

ax.plot(concentrate_fpr(np.arange(0, 1.01, 0.01), alpha),
        np.arange(0, 1.01, 0.01), color='black', linestyle='--')

ax.set_xlim([-0.05, 1.05])
ax.set_ylim([-0.05, 1.05])

ax.set_xlabel('False positive rate')
ax.set_ylabel('True positive rate')

ax.legend(loc='lower center', bbox_to_anchor=(0.5, -0.3))

sns.despine()

plt.savefig('aucroc_dot.png', dpi=300, bbox_inches='tight')
plt.show()
plt.close()

In [None]:
logging.shutdown()