In [None]:
import os
import sys
os.environ['GLEAMS_HOME'] = os.path.join(
    os.environ['HOME'], 'Projects', 'gleams')
# Make sure all code is in the PATH.
src_dir = os.path.normpath(os.path.join(os.environ['GLEAMS_HOME'], 'src'))
if src_dir not in sys.path:
    sys.path.append(src_dir)

In [None]:
import collections
import itertools
import math
from typing import Iterator, List, Tuple

import joblib
import matplotlib.pyplot as plt
import numba as nb
import numpy as np
import pandas as pd
import scipy.spatial.distance as ssd
import seaborn as sns
import tqdm.notebook as tqdm
from spectrum_utils import utils as suu
from tensorflow.keras.utils import Sequence

In [None]:
import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)

In [None]:
# Initialize logging.
from gleams import logger as glogger
glogger.init()
# Initialize all random seeds before importing any packages.
from gleams import rndm
rndm.set_seeds()

from gleams import config
from gleams.feature import spectrum
from gleams.metadata import metadata
from gleams.ms_io import ms_io

In [None]:
import logging
logger = logging.getLogger('gleams')
logger.setLevel(logging.DEBUG)

In [None]:
# Plot styling.
plt.style.use(['seaborn-white', 'seaborn-paper'])
plt.rc('font', family='serif')
sns.set_palette('Set1')
sns.set_context('paper', font_scale=1.3)    # Single-column figure.

In [None]:
def generate_pairs_unknown(metadata_filename: str,
                           mz_tolerance: float) -> None:
    """
    Generate index pairs for unknown pairs for the given metadata file.

    The unknown pairs consist of all pairs with a precursor m/z difference
    smaller than the given m/z tolerance, and for which at least one spectrum
    is not identified.
    Pairs of row numbers in the metadata file for each unknown pair are stored
    in Parquet file `{metadata_filename}_pairs_unknown.parquet`.
    If this file already exists it will _not_ be recreated.

    Parameters
    ----------
    metadata_filename_ident : str
        The metadata file name containing information for the identified
        spectra. Should be a Parquet file.
    metadata_filename : str
        The metadata file name containing information for all spectra. Should
        be a Parquet file.
    mz_tolerance : float
        Maximum precursor m/z tolerance in ppm for two spectra to be
        considered an unknown pair.
    """
    pairs_filename = metadata_filename.replace('.parquet', '_pairs_unk.npy')
    if not os.path.isfile(pairs_filename):
        logger.info('Generate unknown pair indexes for metadata file %s',
                    metadata_filename)
        metadata = (pd.read_parquet(metadata_filename,
                                    columns=['sequence', 'charge', 'mz'])
                    .reset_index())
        metadata = (metadata.sort_values(['charge', 'mz'])
                    .reset_index(drop=True))
        row_nums = metadata['index'].values
        mzs = metadata['mz'].values
        # List because Numba can't handle object (string) arrays.
        sequences = nb.typed.List(metadata['sequence'].fillna('unknown'))
        logger.debug('Save unknown pair indexes to %s', pairs_filename)
        np.save(pairs_filename,
                np.fromiter(
                    _generate_pairs_unknown(row_nums, mzs, sequences,
                                            mz_tolerance),
                    np.uint32).reshape((-1, 2)))


@nb.njit
def _generate_pairs_unknown(
    row_nums: np.ndarray, mzs: np.ndarray, sequences: nb.typed.List,
    precursor_mz_tol: float) -> Iterator[int]:
    """
    Numba utility function to efficiently generate row numbers for unknown
    pairs.

    Parameters
    ----------
    row_nums : np.ndarray
        A NumPy array of row numbers for each spectrum.
    mzs : np.ndarray
        A NumPy array of precursor m/z values for each spectrum.
    sequences : nb.typed.List
        A list of peptide sequences for each spectrum.
    precursor_mz_tol : float
        Maximum precursor m/z tolerance in ppm for two PSMs to be considered
        a negative pair.

    Returns
    -------
    Iterator[int]
        A generator of row numbers for the spectrum pairs, with row numbers
        `i` and `i + 1` forming pairs.
    """
    for row_num1 in range(len(row_nums)):
        row_num2 = row_num1 + 1
        while (row_num2 < len(mzs) and
               (abs(suu.mass_diff(mzs[row_num1], mzs[row_num2], False))
                <= precursor_mz_tol)):
            if (sequences[row_num1] == 'unknown'
                    or sequences[row_num2] == 'unknown'):
                yield row_nums[row_num1]
                yield row_nums[row_num2]
            row_num2 += 1

In [None]:
class PairIndexSequence(Sequence):

    def __init__(self, filename_pairs_pos: str, filename_pairs_neg: str,
                 filename_pairs_unk: str, batch_size: int,
                 max_num_pairs: int = None, shuffle: bool = True):
        pairs_pos = np.load(filename_pairs_pos, mmap_mode='r')
        pairs_neg = np.load(filename_pairs_neg, mmap_mode='r')
        pairs_unk = np.load(filename_pairs_unk, mmap_mode='r')
        num_pairs = min(len(pairs_pos), len(pairs_neg), len(pairs_unk))
        if max_num_pairs is not None:
            num_pairs = min(num_pairs, max_num_pairs // 2)
        logger.info('Using %d positive, negative, and unknown feature pairs '
                    'each', num_pairs)
        idx_pos = np.random.choice(pairs_pos.shape[0], num_pairs, False)
        self.pairs_pos = pairs_pos[idx_pos]
        idx_neg = np.random.choice(pairs_neg.shape[0], num_pairs, False)
        self.pairs_neg = pairs_neg[idx_neg]
        idx_unkown = np.random.choice(pairs_unk.shape[0], num_pairs, False)
        self.pairs_unk = pairs_unk[idx_unkown]

        self.batch_size = batch_size
        self.shuffle = shuffle
        self.epoch_count = 0

    def __len__(self) -> int:
        """
        Gives the total number of batches.

        Returns
        -------
        int
            The number of batches.
        """
        return int(math.ceil(3 * len(self.pairs_pos) / self.batch_size))

    def __getitem__(self, idx: int) -> Tuple[List[np.ndarray], np.ndarray]:
        """
        Get the pair indexes and labels for the batch with the given index.

        Parameters
        ----------
        idx : int
            Index of the requested batch.

        Returns
        -------
        Tuple[Tuple[np.ndarray], np.ndarray]
            A tuple of pair indexes and class labels.
            The class labels are 1 for positive pairs, -1 for negative pairs,
            and 0 for unknown pairs.
        """
        batch_pairs_pos = self.pairs_pos[idx * self.batch_size // 2:
                                         (idx + 1) * self.batch_size // 2]
        batch_pairs_neg = self.pairs_neg[idx * self.batch_size // 2:
                                         (idx + 1) * self.batch_size // 2]
        batch_pairs_unk = self.pairs_unk[idx * self.batch_size // 2:
                                         (idx + 1) * self.batch_size // 2]
        batch_pairs = np.vstack((batch_pairs_pos, batch_pairs_neg,
                                 batch_pairs_unk))

        batch_x1 = batch_pairs[:, 0]
        batch_x2 = batch_pairs[:, 1]
        batch_y = np.hstack((np.ones(len(batch_pairs_pos), np.uint8),
                             -1 * np.ones(len(batch_pairs_neg), np.uint8),
                             np.zeros(len(batch_pairs_unk), np.uint8)))

        return (batch_x1, batch_x2), batch_y

    def on_epoch_end(self):
        self.epoch_count += 1
        if self.shuffle and self.epoch_count % len(self) == 0:
            logger.debug('Shuffle the features because all pairs have been '
                         'processed after epoch %d', self.epoch_count)
            np.random.shuffle(self.pairs_pos)
            np.random.shuffle(self.pairs_neg)
            np.random.shuffle(self.pairs_unk)

In [None]:
def _get_spectra_from_file(dataset, filename, scans):
    spectra = {}
    filepath = os.path.join(os.environ['GLEAMS_HOME'], 'data', 'peak',
                            dataset, filename)
    if not os.path.isfile(filepath):
        logger.warning('Missing peak file %s, no spectra read', filename)
    else:
        for spec in ms_io.get_spectra(filepath, scans):
            spectra[f'{dataset}/{filename}/{spec.identifier}'] = \
                spectrum.preprocess(spec, config.fragment_mz_min,
                                    config.fragment_mz_max)
    return spectra

In [None]:
@nb.njit(parallel=True)
def dot(spectra_arr1, spectra_arr2, out, fragment_mz_tol):
    for i in nb.prange(spectra_arr1.shape[0]):
        out[i] = spectrum.dot(
            spectra_arr1[i, 0], spectra_arr1[i, 1],
            spectra_arr2[i, 0], spectra_arr2[i, 1],
            fragment_mz_tol)
    return out

### Preprocessing

In [None]:
split = 'test'
num_pairs = 5_000_000

In [None]:
# Metadata table for all (identified and unidentified) spectra.   
spectrum_info = pd.merge(
    pd.read_parquet(os.path.join(
        os.environ['GLEAMS_HOME'], 'data', 'embed',
        f'embed_{config.massivekb_task_id}_{split}.parquet')),
    pd.read_parquet(os.path.join(
        os.environ['GLEAMS_HOME'], 'data', 'feature',
        f'feature_{config.massivekb_task_id}_{split}.parquet'))
    [['dataset', 'filename', 'scan', 'sequence']],
    'left', ['dataset', 'filename', 'scan'], copy=False)
spectrum_info.to_parquet('dot_embed_metadata.parquet', index=False)

In [None]:
# Generate all types of pairs (positive, negative, unknown).
metadata.generate_pairs_positive('dot_embed_metadata.parquet')
metadata.generate_pairs_negative(
    'dot_embed_metadata.parquet', config.pair_mz_tolerance,
    config.negative_pair_fragment_tolerance,
    config.negative_pair_matching_fragments_threshold)
generate_pairs_unknown(
    'dot_embed_metadata.parquet', config.pair_mz_tolerance)

In [None]:
pair_generator = PairIndexSequence(
    'dot_embed_metadata_pairs_pos.npy', 'dot_embed_metadata_pairs_neg.npy',
    'dot_embed_metadata_pairs_unk.npy', config.batch_size, num_pairs, False)

In [None]:
# Subset the spectrum info to only include the selected spectra from the pair
# generator.
spectrum_indexes = np.hstack((pair_generator.pairs_pos.reshape((-1)),
                              pair_generator.pairs_neg.reshape((-1)),
                              pair_generator.pairs_unk.reshape((-1))))
spectrum_info = spectrum_info.loc[np.unique(spectrum_indexes)]

### Dot product

In [None]:
# Read the selected spectra from the peak files.
dataset_total = spectrum_info['dataset'].nunique()
spectra = []
for dataset_i, (dataset, md_dataset) in enumerate(
        spectrum_info.groupby('dataset', sort=False), 1):
    logging.info('Process dataset %s (%d files) [%3d/%3d]', dataset,
                 md_dataset['filename'].nunique(), dataset_i, dataset_total)
    spectra.extend(joblib.Parallel(n_jobs=-1, backend='multiprocessing')(
        joblib.delayed(_get_spectra_from_file)(dataset, filename,
                                               md_file['scan'])
        for filename, md_file in md_dataset.groupby(
            'filename', sort=False)))
spectra = collections.ChainMap(*spectra)

In [None]:
# Compute the dot products (high/low resolution) between all spectra pairs.
spectra_arr = [], []
for batch_i in tqdm.tqdm(range(len(pair_generator)),
                         desc='Spectra converted', unit='batch'):
    (batch_i1, batch_i2), _ = pair_generator[batch_i]
    for pair_i, arr_i in zip(itertools.chain(batch_i1, batch_i2),
                             np.hstack((np.zeros_like(batch_i1),
                                        np.ones_like(batch_i2)))):
        spec = spectra[f"{spectrum_info.at[pair_i, 'dataset']}/"
                       f"{spectrum_info.at[pair_i, 'filename']}/"
                       f"{spectrum_info.at[pair_i, 'scan']}"]
        spectra_arr[arr_i].append(
            np.pad([spec.mz, spec.intensity],
                   ((0, 0), (config.max_peaks_used - len(spec.mz), 0)),
                   'constant'))

fragment_mz_tol_high_res, fragment_mz_tol_low_res = 0.05, 0.8
spectra_arr1 = np.asarray(spectra_arr[0])
spectra_arr2 = np.asarray(spectra_arr[1])
logger.info('Compute dot product (high resolution; '
            'fragment m/z tolerance = %.2f)', fragment_mz_tol_high_res)
dot_high_res = dot(
    spectra_arr1, spectra_arr2, np.zeros(spectra_arr1.shape[0], np.float32),
    fragment_mz_tol_high_res)
logger.info('Compute dot product (low resolution; '
            'fragment m/z tolerance = %.2f)', fragment_mz_tol_low_res)
dot_low_res = dot(
    spectra_arr1, spectra_arr2, np.zeros(spectra_arr1.shape[0], np.float32),
    fragment_mz_tol_low_res)

### GLEAMS Euclidean distance

In [None]:
# Compute the Euclidean distances between all embeddings pairs.
embeddings = np.load(os.path.join(
    os.environ['GLEAMS_HOME'], 'data', 'embed',
    f'embed_{config.massivekb_task_id}_{split}.npy'))

scores_embed = []
for batch_i in tqdm.tqdm(range(len(pair_generator)),
                         desc='GLEAMS distances calculated', unit='batch'):
    (batch_i1, batch_i2), _ = pair_generator[batch_i]
    for pair1, pair2 in zip(batch_i1, batch_i2):
        scores_embed.append(ssd.euclidean(embeddings[pair1],
                                          embeddings[pair2]))

### Dot product versus GLEAMS Euclidean distance

In [None]:
labels = np.hstack([pair_generator[batch_i][1]
                    for batch_i in range(len(pair_generator))])
dot_embed = pd.DataFrame({'dot_low_res': dot_low_res,
                          'dot_high_res': dot_high_res,
                          'gleams_dist': scores_embed,
                          'pair_type': labels}).sort_values('pair_type')
# Convert pair type to nice labels.
dot_embed['pair_type'] = dot_embed['pair_type'].map(
    {1: 'Positive', -1: 'Negative', 0: 'Unknown'})
# Add precursor information.
spec_idx = np.hstack([pair_generator[batch_i][0][0]
                      for batch_i in range(len(pair_generator))])
dot_embed[['charge', 'mz']] = (spectrum_info.loc[spec_idx, ['charge', 'mz']]
                               .reset_index(drop=True))

In [None]:
dot_embed.to_parquet('dot_embed.parquet', index=False)

In [None]:
# dot_embed = pd.read_parquet('aucroc_dot.parquet')

In [None]:
width = 7

jg = sns.jointplot(data=dot_embed, x='dot_high_res', y='gleams_dist',
                   hue='pair_type', palette='Set1', height=width,
                   s=1, marker='.', rasterized=True,
                   joint_kws={'alpha': 0.1})

jg.ax_joint.legend(jg.ax_joint.get_legend_handles_labels()[0],
                   ['Negative', 'Unknown', 'Positive'], title='Pair type')
jg.set_axis_labels('Spectrum dot product', 'GLEAMS euclidean distance')

plt.savefig('dot_embed_type.png', dpi=300, bbox_inches='tight')
plt.show()
plt.close()

In [None]:
height = 7

dot_embed_charge = dot_embed[dot_embed['charge'] <= 4]

jg = sns.JointGrid(height=height)
sns.scatterplot(data=dot_embed_charge, x='dot_high_res', y='gleams_dist',
                hue='charge', palette='Set1', alpha=0.1, s=1, marker='.',
                rasterized=True, ax=jg.ax_joint)
sns.kdeplot(data=dot_embed_charge, x='dot_high_res', hue='charge',
            palette='Set1', legend=False, common_norm=False, fill=True,
            ax=jg.ax_marg_x)
sns.kdeplot(data=dot_embed_charge, y='gleams_dist', hue='charge',
            palette='Set1', legend=False, common_norm=False, fill=True,
            ax=jg.ax_marg_y)

jg.ax_joint.legend(loc='upper right', title='Precursor charge')
jg.set_axis_labels('Spectrum dot product', 'Embedded euclidean distance')

plt.savefig('dot_embed_charge.png', dpi=300, bbox_inches='tight')
plt.show()
plt.close()

In [None]:
height = 7

jg = sns.JointGrid(height=height)
sns.scatterplot(data=dot_embed, x='dot_high_res', y='gleams_dist',
                alpha=0.1, s=1, c=dot_embed['mz'], marker='.',
                cmap=plt.cm.get_cmap('YlGnBu'), rasterized=True,
                ax=jg.ax_joint)
sns.kdeplot(data=dot_embed, x='dot_high_res', color='black', legend=False,
            common_norm=False, fill=True, ax=jg.ax_marg_x)
sns.kdeplot(data=dot_embed, y='gleams_dist', color='black', legend=False,
            common_norm=False, fill=True, ax=jg.ax_marg_y)

ax_joint_pos = jg.ax_joint.get_position()
cbar_ax = jg.fig.add_axes([1.025, ax_joint_pos.x0 + 0.05,
                           0.025, ax_joint_pos.height - 0.1])
colorbar = jg.fig.colorbar(jg.ax_joint.get_children()[0], cax=cbar_ax)
colorbar.solids.set(alpha=1)
colorbar.set_label('Precursor m/z', size='large', labelpad=15)

jg.set_axis_labels('Spectrum dot product', 'Embedded euclidean distance')

plt.savefig('dot_embed_mz.png', dpi=300, bbox_inches='tight')
plt.show()
plt.close()

In [None]:
logging.shutdown()