In [None]:
import os
import sys
os.environ['GLEAMS_HOME'] = os.path.join(os.environ['HOME'],
                                         'Projects/gleams')
# Make sure all code is in the PATH.
sys.path.append(
    os.path.normpath(os.path.join(os.environ['GLEAMS_HOME'], 'src')))

In [None]:
import math

import joblib
import matplotlib.pyplot as plt
import matplotlib.ticker as mticker
import numpy as np
import pandas as pd
import seaborn as sns
import tqdm.notebook as tqdm

In [None]:
import logging
logger = logging.getLogger('gleams')
logger.setLevel(logging.DEBUG)

In [None]:
import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)

In [None]:
from gleams.dag import dag

from gleams import config
from gleams.cluster import cluster
from gleams.nn import nn

In [None]:
# Plot styling.
plt.style.use(['seaborn-white', 'seaborn-paper'])
plt.rc('font', family='serif')
sns.set_palette('Set1')
sns.set_context('paper', font_scale=1.3)    # Single-column figure.

In [None]:
embeddings = np.load(
    os.path.join(os.environ['GLEAMS_HOME'], 'data', 'embed',
                 f'embed_{config.massivekb_task_id}.npy'),
    mmap_mode='r')
metadata = pd.merge(
    pd.read_parquet(os.path.join(
        os.environ['GLEAMS_HOME'], 'data', 'embed',
        f'embed_{config.massivekb_task_id}.parquet')),
    pd.read_parquet(
        os.path.join(os.environ['GLEAMS_HOME'], 'data', 'metadata',
                     f'massivekb_ids_{config.massivekb_task_id}.parquet'))
    .drop_duplicates(['filename', 'scan']),
    'left', ['filename', 'scan'], copy=False).dropna(subset=['sequence'])

In [None]:
# Build ANN index for efficient nearest neighbor querying.
mz = metadata['mz'].sort_values()
min_mz, max_mz = math.floor(mz.iat[0]), math.ceil(mz.iat[-1])
config.mz_interval = max_mz - min_mz
config.precursor_tol_mass, config.precursor_tol_mode = max_mz - min_mz, 'Da'
cluster._build_ann_index('nn_dist.faiss', embeddings, mz, [min_mz])

In [None]:
index = cluster._load_ann_index('nn_dist.faiss', 0)

In [None]:
num_samples = min(1000000, len(metadata))
n_neighbors = 100

In [None]:
idx_sample = np.random.choice(metadata.index, num_samples, False)

In [None]:
logger.info('Find the labeled nearest neighbors for %d labeled embeddings',
            num_samples)
distances, neighbors = [], []
batch_size = min(num_samples, config.batch_size_dist)
for batch_i in tqdm.tqdm(range(0, num_samples, batch_size),
                         desc='Batches processed', leave=False,
                         unit='batch'):
    batch_start, batch_stop = batch_i, min(batch_i + batch_size, num_samples)
    idx_sample_batch = idx_sample[batch_start:batch_stop]
    batch_distances, batch_neighbors = index.search(
        embeddings[idx_sample_batch], n_neighbors)
    distances.extend(batch_distances)
    neighbors.extend(batch_neighbors)

In [None]:
logger.info('Verify whether neighbors have the same peptide label')
embed_dist, same_label = [], []
for nn_dist, nn_id in tqdm.tqdm(zip(distances, neighbors),
                                desc='Embeddings processed', leave=False,
                                total=num_samples, unit='emb'):
    embed_dist.extend(nn_dist)
    labels = metadata.loc[nn_id, 'sequence']
    same_label.extend(labels == labels.iat[0])
order = np.argsort(embed_dist)
embed_dist = np.asarray(embed_dist)[order]
same_label = np.asarray(same_label)[order]
prop_same_label = np.cumsum(same_label) / np.arange(1, len(same_label) + 1)

In [None]:
joblib.dump([embed_dist, prop_same_label], 'nn_dist.joblib')

In [None]:
# embed_dist, prop_same_label = joblib.load('nn_dist.joblib')

In [None]:
width = 7
height = width / 1.618    # golden ratio
fig, ax = plt.subplots(figsize=(width, height))

ax.plot(embed_dist, prop_same_label)

ax.set_xlabel('Embedded distance')
ax.set_ylabel('Proportion same label')

ax.yaxis.set_major_formatter(mticker.PercentFormatter(1))

sns.despine()

plt.savefig('nn_dist.png', dpi=300, bbox_inches='tight')
plt.show()
plt.close()

In [None]:
logging.shutdown()