In [None]:
import os
import sys
os.environ['GLEAMS_HOME'] = os.path.join(os.environ['HOME'],
                                         'Projects/gleams')
# Make sure all code is in the PATH.
sys.path.append(
    os.path.normpath(os.path.join(os.environ['GLEAMS_HOME'], 'src')))

In [None]:
import warnings
from sklearn.exceptions import EfficiencyWarning
warnings.simplefilter(action='ignore', category=EfficiencyWarning)
warnings.simplefilter(action='ignore', category=FutureWarning)

In [None]:
import joblib
import matplotlib.pyplot as plt
import matplotlib.ticker as mticker
import numpy as np
import pandas as pd
import seaborn as sns
import skopt

In [None]:
import logging
logger = logging.getLogger('gleams')
logger.setLevel(logging.DEBUG)

In [None]:
from gleams.dag import dag

from gleams import config
from gleams.cluster import cluster

In [None]:
# Plot styling.
plt.style.use(['seaborn-white', 'seaborn-paper'])
plt.rc('font', family='serif')
sns.set_palette('Set1')
sns.set_context('paper', font_scale=1.3)    # Single-column figure.

In [None]:
split = 'train'

In [None]:
cluster_filename = os.path.join(
    os.environ['GLEAMS_HOME'], 'data', 'ann',
    f'clusters_{config.massivekb_task_id}_{split}.npy')
metadata_ident_filename = os.path.join(
    os.environ['GLEAMS_HOME'], 'data', 'metadata',
    f'metadata_{config.massivekb_task_id}_{split}.parquet')
metadata_all_filename = os.path.join(
    os.environ['GLEAMS_HOME'], 'data', 'embed',
    f'embed_{config.massivekb_task_id}_{split}.parquet')

In [None]:
cluster.build_ann_index(
    os.path.join(os.environ['GLEAMS_HOME'], 'data', 'embed',
                 f'embed_{config.massivekb_task_id}_{split}.npy'))
cluster.compute_pairwise_distances(
    os.path.join(os.environ['GLEAMS_HOME'], 'data', 'embed',
                 f'embed_{config.massivekb_task_id}_{split}.npy'),
    os.path.join(os.environ['GLEAMS_HOME'], 'data', 'ann',
                 f'ann_{config.massivekb_task_id}_{split}.faiss'))

In [None]:
metadata = pd.merge(pd.read_parquet(metadata_all_filename),
                    pd.read_parquet(metadata_ident_filename),
                    'left', ['dataset', 'filename', 'scan'])
# Don't disambiguate between I/L.
metadata['sequence'] = metadata['sequence'].str.replace('I', 'L')

In [None]:
def evaluate_clusters(clusters_filename, min_peptide_size=None):
    # Only consider identified spectra as clustering ground truth.
    clusters = (pd.DataFrame({'sequence': metadata['sequence'],
                              'cluster': np.load(clusters_filename)})
                .dropna())
    # Possibly only consider clusters of a minimal size.
    # (Is the clustering better for small/large clusters?)
    if min_peptide_size is not None:
        peptide_counts = clusters['sequence'].value_counts()
        clusters = clusters[clusters['sequence'].isin(
            peptide_counts[peptide_counts >= min_peptide_size].index)]
    clusters_non_noise = clusters[clusters['cluster'] != -1]
    prop_clustered = len(clusters_non_noise) / len(clusters)
    prop_clustered_incorrect = (
            clusters_non_noise.groupby('cluster')['sequence']
            .apply(lambda labels: len(labels) - labels.value_counts().iat[0])
            .sum()
            / len(clusters))

    return prop_clustered, prop_clustered_incorrect

In [None]:
min_peptide_size = 5
max_prop_clustered_incorrect = 0.01


def optimize_cluster_hyperparameters(args):
    config.eps, config.min_samples = args
    if os.path.isfile(cluster_filename):
        os.remove(cluster_filename)
    cluster.cluster(os.path.join(
        os.environ['GLEAMS_HOME'], 'data', 'ann',
        f'dist_{config.massivekb_task_id}_{split}.npz'))
    prop_clustered, prop_clustered_incorrect = evaluate_clusters(
        cluster_filename, min_peptide_size)
    props_clustered.append(prop_clustered)
    props_clustered_incorrect.append(prop_clustered_incorrect)
    if prop_clustered_incorrect > max_prop_clustered_incorrect:
        return 1
    else:
        return 1 - prop_clustered

In [None]:
props_clustered, props_clustered_incorrect = [], []
optim = skopt.gp_minimize(optimize_cluster_hyperparameters,
                          [skopt.space.Real(0.0001, 0.1, name='eps'),
                           skopt.space.Integer(2, 10, name='min_samples')])

In [None]:
cluster_hyperparameter = pd.DataFrame(
    {'eps': np.asarray(optim.x_iters)[:, 0],
     'min_samples': np.asarray(optim.x_iters)[:, 1],
     'prop_clustered': props_clustered,
     'prop_clustered_incorrect': props_clustered_incorrect})

In [None]:
# Remove final (suboptimal) clustering.
os.remove(cluster_filename)

In [None]:
joblib.dump(cluster_hyperparameter, 'cluster_hyperparameter.joblib')

In [None]:
cluster_hyperparameter = joblib.load('cluster_hyperparameter.joblib')

In [None]:
def get_pareto_frontier(arr):
    # Sort by the first column.
    arr_sorted = arr[arr[:, 0].argsort()]
    # Iteratively add points to the Pareto frontier.
    pareto_idx = [0]
    for i in range(1, arr_sorted.shape[0]):
        if (arr_sorted[i, 0] > arr_sorted[pareto_idx[-1], 0] and
                arr_sorted[i, 1] > arr_sorted[pareto_idx[-1], 1]):
            pareto_idx.append(i)
    return arr_sorted[pareto_idx]

In [None]:
acceptable_hyperparam = cluster_hyperparameter[
    cluster_hyperparameter['prop_clustered_incorrect'] < 0.01]
best_hyperparam = acceptable_hyperparam.loc[
    acceptable_hyperparam['prop_clustered'].idxmax()]
print(f'Optimal clustering hyperparameters:\n'
      f'  - eps = {best_hyperparam["eps"]:.4f}\n'
      f'  - min_samples = {best_hyperparam["min_samples"]:.0f}\n'
      f'-> {best_hyperparam["prop_clustered"]:.2%} clustered, '
      f'{best_hyperparam["prop_clustered_incorrect"]:.2%} clustered incorrectly')

In [None]:
width = 7
height = width / 1.618    # golden ratio
fig, ax = plt.subplots(figsize=(width, height))

clustering_pareto = get_pareto_frontier(np.column_stack(
    [cluster_hyperparameter['prop_clustered_incorrect'],
     cluster_hyperparameter['prop_clustered']]))
ax.plot(clustering_pareto[:, 0], clustering_pareto[:, 1], marker='o')
scatter = ax.scatter(cluster_hyperparameter['prop_clustered_incorrect'],
                     cluster_hyperparameter['prop_clustered'], marker='.')
ax.axvline(max_prop_clustered_incorrect, c='darkgray', ls='--')

ax.set_xlim(-0.005, 0.1)
ax.set_ylim(-0.05, 1)

ax.set_xlabel('Incorrectly clustered spectra')
ax.set_ylabel('Clustered spectra')

ax.xaxis.set_major_formatter(mticker.PercentFormatter(1))
ax.yaxis.set_major_formatter(mticker.PercentFormatter(1))

sns.despine()

plt.savefig('cluster_hyperparameter.png', dpi=300, bbox_inches='tight')
plt.show()
plt.close()

In [None]:
logging.shutdown()