In [None]:
import os
import sys
os.environ['GLEAMS_HOME'] = os.path.join(os.environ['HOME'],
                                         'Projects/gleams')
# Make sure all code is in the PATH.
sys.path.append(
    os.path.normpath(os.path.join(os.environ['GLEAMS_HOME'], 'src')))

In [None]:
import joblib
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import scipy.sparse as ss
import seaborn as sns

In [None]:
import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)

In [None]:
# Initialize logging.
from gleams import logger as glogger
glogger.init()
# Initialize all random seeds before importing any packages.
from gleams import rndm
rndm.set_seeds()

from gleams import config
from gleams.cluster import cluster

In [None]:
import logging
logger = logging.getLogger('gleams')
logger.setLevel(logging.DEBUG)

In [None]:
# Plot styling.
plt.style.use(['seaborn-white', 'seaborn-paper'])
plt.rc('font', family='serif')
sns.set_palette('Set1')
sns.set_context('paper', font_scale=1.3)    # Single-column figure.

In [None]:
split = 'test'

In [None]:
cluster.compute_pairwise_distances(
    os.path.join(os.environ['GLEAMS_HOME'], 'data', 'embed',
                 f'embed_{config.massivekb_task_id}_{split}.npy'),
    os.path.join(os.environ['GLEAMS_HOME'], 'data', 'embed',
                 f'embed_{config.massivekb_task_id}_{split}.parquet'))

In [None]:
metadata = pd.merge(
    pd.read_parquet(os.path.join(
        os.environ['GLEAMS_HOME'], 'data', 'embed',
        f'embed_{config.massivekb_task_id}_{split}.parquet')),
    pd.read_parquet(
        os.path.join(os.environ['GLEAMS_HOME'], 'data', 'metadata',
                     f'massivekb_ids_{config.massivekb_task_id}.parquet'))
    .drop_duplicates(['filename', 'scan']),
    'left', ['filename', 'scan'], copy=False).dropna(subset=['sequence'])

In [None]:
num_samples = min(10_000_000, len(metadata))
idx_sample = np.random.choice(metadata.index, num_samples, False)
metadata = metadata.loc[idx_sample]

In [None]:
pairwise_distances = ss.load_npz(os.path.join(
    os.environ['GLEAMS_HOME'], 'data', 'cluster',
    f'dist_{config.massivekb_task_id}_{split}.npz'))
pairwise_distances = pairwise_distances[metadata.index][:, metadata.index]
logger.info('Using %d non-zero pairwise distances between %d randomly '
            'selected embeddings', pairwise_distances.count_nonzero(),
            len(metadata))

In [None]:
logger.info('Verify whether neighbors have the same peptide label')
rows, columns, dist = ss.find(pairwise_distances)
sequences = ((metadata['sequence'] + '/' + metadata['charge'].astype(str))
             .reset_index(drop=True))
same_label = (sequences.loc[rows].reset_index(drop=True) ==
              sequences.loc[columns].reset_index(drop=True))
order = np.argsort(dist)
dist = np.asarray(dist)[order]
same_label = np.asarray(same_label)[order]
prop_same_label = np.cumsum(same_label) / np.arange(1, len(same_label) + 1)

In [None]:
joblib.dump([dist, prop_same_label], 'nn_dist.joblib')

In [None]:
# dist, prop_same_label = joblib.load('nn_dist.joblib')

In [None]:
width = 7
height = width / 1.618    # golden ratio
fig, ax = plt.subplots(figsize=(width, height))

ax.plot(dist, prop_same_label)

ax.set_xlim(0, 1.2)

ax.set_xlabel('Embedded distance')
ax.set_ylabel('Proportion same peptide')

sns.despine()

plt.savefig('nn_dist.png', dpi=300, bbox_inches='tight')
plt.show()
plt.close()

In [None]:
logging.shutdown()