In [None]:
import os
import sys
os.environ['GLEAMS_HOME'] = os.path.join(os.environ['HOME'],
                                         'Projects/gleams')
# Make sure all code is in the PATH.
sys.path.append(
    os.path.normpath(os.path.join(os.environ['GLEAMS_HOME'], 'src')))

In [None]:
import warnings
from sklearn.exceptions import EfficiencyWarning
warnings.simplefilter(action='ignore', category=EfficiencyWarning)
warnings.simplefilter(action='ignore', category=FutureWarning)

In [None]:
import collections
import functools

import joblib
import matplotlib.pyplot as plt
import matplotlib.ticker as mticker
import numpy as np
import pandas as pd
import seaborn as sns
import skopt
from sklearn.metrics import homogeneity_score, completeness_score

In [None]:
# Initialize logging.
from gleams import logger as glogger
glogger.init()
# Initialize all random seeds before importing any packages.
from gleams import rndm
rndm.set_seeds()

from gleams import config
from gleams.cluster import cluster

In [None]:
import logging
logger = logging.getLogger('gleams')
logger.setLevel(logging.DEBUG)

In [None]:
# Plot styling.
plt.style.use(['seaborn-white', 'seaborn-paper'])
plt.rc('font', family='serif')
sns.set_palette('Set1')
sns.set_context('paper', font_scale=1.3)    # Single-column figure.

## Initialization

In [None]:
min_cluster_size = 2
max_prop_clust_incorrect = 0.01

In [None]:
metadata_filename_ident = os.path.join(
    os.environ['GLEAMS_HOME'], 'data', 'metadata',
    f'massivekb_ids_{config.massivekb_task_id}.parquet')

metadata_filename_val = os.path.join(
    os.environ['GLEAMS_HOME'], 'data', 'cluster',
    f'embed_{config.massivekb_task_id}_val.parquet')
embed_filename_val = os.path.join(
    os.environ['GLEAMS_HOME'], 'data', 'cluster',
    f'embed_{config.massivekb_task_id}_val.npy')
cluster_filename_val = os.path.join(
    os.environ['GLEAMS_HOME'], 'data', 'cluster',
    f'clusters_{config.massivekb_task_id}_val.npy')
dist_filename_val = os.path.join(
    os.environ['GLEAMS_HOME'], 'data', 'cluster',
    f'dist_{config.massivekb_task_id}_val.npz')

metadata_filename_test = os.path.join(
    os.environ['GLEAMS_HOME'], 'data', 'cluster',
    f'embed_{config.massivekb_task_id}_test.parquet')
embed_filename_test = os.path.join(
    os.environ['GLEAMS_HOME'], 'data', 'cluster',
    f'embed_{config.massivekb_task_id}_test.npy')
cluster_filename_test = os.path.join(
    os.environ['GLEAMS_HOME'], 'data', 'cluster',
    f'clusters_{config.massivekb_task_id}_test.npy')
dist_filename_test = os.path.join(
    os.environ['GLEAMS_HOME'], 'data', 'cluster',
    f'dist_{config.massivekb_task_id}_test.npz')

In [None]:
# Make sure the pairwise distances are precomputed.
cluster.compute_pairwise_distances(embed_filename_val, metadata_filename_val)
cluster.compute_pairwise_distances(embed_filename_test, metadata_filename_test)

In [None]:
def get_metadata(metadata_filename, metadata_filename_ident):
    metadata = pd.merge(pd.read_parquet(metadata_filename),
                        pd.read_parquet(metadata_filename_ident)
                        [['dataset', 'filename', 'scan', 'sequence']],
                        'left', ['dataset', 'filename', 'scan'], copy=False)
    # Don't disambiguate between I/L.
    metadata['sequence'] = (metadata['sequence'].str.replace('I', 'L')
                            + '/' + metadata['charge'].astype(str))
    return metadata

In [None]:
# Get the metadata including all spectrum identifications.
metadata_val = get_metadata(metadata_filename_val, metadata_filename_ident)
metadata_test = get_metadata(metadata_filename_test, metadata_filename_ident)

In [None]:
def _count_majority_label_mismatch(labels):
    labels_assigned = labels.dropna()
    if len(labels_assigned) <= 1:
        return 0
    else:
        return len(labels_assigned) - labels_assigned.value_counts().iat[0]


def evaluate_clusters(clusters, min_cluster_size=None, max_cluster_size=None):
    clusters = clusters.copy()
    # Only consider clusters with specific minimum (inclusive) and/or
    # maximum (exclusive) size.
    cluster_counts = clusters['cluster'].value_counts(dropna=False)
    if min_cluster_size is not None:
        clusters.loc[clusters['cluster'].isin(cluster_counts[
            cluster_counts < min_cluster_size].index), 'cluster'] = -1
    if max_cluster_size is not None:
        clusters.loc[clusters['cluster'].isin(cluster_counts[
            cluster_counts >= max_cluster_size].index), 'cluster'] = -1

    # Use consecutive cluster labels, skipping the noise points.    
    cluster_map = (clusters['cluster'].value_counts(dropna=False)
                   .drop(index=-1).to_frame().reset_index().reset_index()
                   .rename(columns={'index': 'old', 'level_0': 'new'})
                   .set_index('old')['new'])
    cluster_map = cluster_map.to_dict(collections.defaultdict(lambda: -1))
    clusters['cluster'] = clusters['cluster'].map(cluster_map)
    num_clusters = clusters['cluster'].max() + 1

    # Reassign noise points to singleton clusters.
    noise_mask = clusters['cluster'] == -1
    num_noise = noise_mask.sum()
    clusters.loc[noise_mask, 'cluster'] = np.arange(
        num_clusters, num_clusters + num_noise)

    # Compute cluster evaluation measures.
    prop_clustered = (len(clusters) - num_noise) / len(clusters)

    clusters_ident = clusters.dropna(subset=['sequence'])
    clusters_ident_non_noise = (clusters[~noise_mask]
                                .dropna(subset=['sequence']))

    # The number of incorrectly clustered spectra is the number of PSMs that
    # differ from the majority PSM. Unidentified spectra are not considered.
    prop_clustered_incorrect = sum(joblib.Parallel(n_jobs=-1)(
        joblib.delayed(_count_majority_label_mismatch)(clust['sequence'])
        for _, clust in clusters[~noise_mask].groupby('cluster')))
    prop_clustered_incorrect /= len(clusters_ident_non_noise)

    # Homogeneity measures whether clusters contain only identical PSMs.
    # This is only evaluated on non-noise points, because the noise cluster
    # is highly non-homogeneous by definition.
    homogeneity = homogeneity_score(clusters_ident_non_noise['sequence'],
                                    clusters_ident_non_noise['cluster'])
    # Completeness measures whether identical PSMs are assigned to the same
    # cluster.
    # This is evaluated on all PSMs, including those clustered as noise.
    completeness = completeness_score(clusters_ident['sequence'],
                                      clusters_ident['cluster'])

    return (len(clusters) - num_noise, num_noise,
            prop_clustered, prop_clustered_incorrect,
            homogeneity, completeness)

In [None]:
def run_cluster(eps, min_samples, cluster_filename, dist_filename,
                metadata_filename, sequences, min_cluster_size):
    if os.path.isfile(cluster_filename):
        os.remove(cluster_filename)
    config.eps, config.min_samples = eps, min_samples
    cluster.cluster(dist_filename, metadata_filename)
    return evaluate_clusters(
        pd.DataFrame({'sequence': sequences,
                      'cluster': np.load(cluster_filename)}),
        min_cluster_size)

In [None]:
run_cluster_val = functools.partial(
    run_cluster,
    cluster_filename=cluster_filename_val,
    metadata_filename=metadata_filename_val,
    sequences=metadata_val['sequence'],
    dist_filename=dist_filename_val,
    min_cluster_size=min_cluster_size)


def cluster_optim(args):
    _, _, prop_clust, prop_clust_incorrect, _, _ = run_cluster_val(*args)
    props_clust.append(prop_clust)
    props_clust_incorrect.append(prop_clust_incorrect)
    if prop_clust_incorrect > max_prop_clust_incorrect:
        return 1
    else:
        return 1 - prop_clust

## Clustering hyperparameter optimization

In [None]:
props_clust, props_clust_incorrect = [], []
optim = skopt.gp_minimize(cluster_optim,
                          [skopt.space.Real(0.01, config.margin / 3, name='eps'),
                           skopt.space.Integer(2, 5, name='min_samples')],
                          n_jobs=-1)

In [None]:
# Remove final (potentially suboptimal) clustering.
os.remove(cluster_filename_val)

In [None]:
cluster_hyperparameter = pd.DataFrame(
    {'eps': np.asarray(optim.x_iters)[:, 0],
     'min_samples': np.asarray(optim.x_iters)[:, 1],
     'prop_clustered': props_clust,
     'prop_clustered_incorrect': props_clust_incorrect})

In [None]:
acceptable_hyperparam = cluster_hyperparameter[
    cluster_hyperparameter['prop_clustered_incorrect']
    < max_prop_clust_incorrect]
best_hyperparam = acceptable_hyperparam.loc[
    acceptable_hyperparam['prop_clustered'].idxmax()]
print(f'Optimal clustering hyperparameters:\n'
      f'  - eps = {best_hyperparam["eps"]:.4f}\n'
      f'  - min_samples = {best_hyperparam["min_samples"]:.0f}\n'
      f'-> {best_hyperparam["prop_clustered"]:.2%} clustered, '
      f'{best_hyperparam["prop_clustered_incorrect"]:.2%} clustered incorrectly')

## Validate clustering hyperparameters

In [None]:
# Cluster on the test set with the optimal hyperparameters.
prop_clust_test, prop_clust_incorrect_test, _, _ =  run_cluster(
    best_hyperparam['eps'], best_hyperparam['min_samples'],
    cluster_filename_test, dist_filename_test, metadata_filename_test,
    metadata_test['sequence'], min_cluster_size)

In [None]:
prop_clust_test, prop_clust_incorrect_test
print(f'Cluster the test dataset with optimal hyperparameters\n'
      f'-> {prop_clust_test:.2%} clustered, '
      f'{prop_clust_incorrect_test:.2%} clustered incorrectly')

In [None]:
joblib.dump(
    (cluster_hyperparameter, (prop_clust_test, prop_clust_incorrect_test)),
    'cluster_hyperparameter.joblib')

## Plot clustering performance

In [None]:
# cluster_hyperparameter, (prop_clust_test, prop_clust_incorrect_test) = \
#     joblib.load('cluster_hyperparameter.joblib')

In [None]:
def get_pareto_frontier(arr):
    # Sort by the first column.
    arr_sorted = arr[arr[:, 0].argsort()]
    # Iteratively add points to the Pareto frontier.
    pareto_idx = [0]
    for i in range(1, arr_sorted.shape[0]):
        if (arr_sorted[i, 0] > arr_sorted[pareto_idx[-1], 0] and
                arr_sorted[i, 1] > arr_sorted[pareto_idx[-1], 1]):
            pareto_idx.append(i)
    return arr_sorted[pareto_idx]

In [None]:
width = 7
height = width / 1.618    # golden ratio
fig, ax = plt.subplots(figsize=(width, height))

# Hyperparameter optimization.
clustering_pareto = get_pareto_frontier(np.column_stack(
    [cluster_hyperparameter['prop_clustered_incorrect'],
     cluster_hyperparameter['prop_clustered']]))
ax.plot(clustering_pareto[:, 0], clustering_pareto[:, 1], marker='o',
        markersize=7, label='validation')
ax.scatter(cluster_hyperparameter['prop_clustered_incorrect'],
           cluster_hyperparameter['prop_clustered'], marker='.', s=49)
ax.axvline(max_prop_clust_incorrect, c='darkgray', ls='--')

# Performance of optimal hyperparameters on validation set.
ax.scatter(prop_clust_incorrect_test, prop_clust_test, marker='s', s=49,
           label='test', zorder=10)

ax.legend(loc='lower right')

ax.set_xlim(0, 0.05)
ax.set_ylim(0, 1)

ax.set_xlabel('Incorrectly clustered spectra')
ax.set_ylabel('Clustered spectra')

ax.xaxis.set_major_formatter(mticker.PercentFormatter(1))
ax.yaxis.set_major_formatter(mticker.PercentFormatter(1))

sns.despine()

plt.savefig('cluster_hyperparameter.png', dpi=300, bbox_inches='tight')
plt.show()
plt.close()

In [None]:
logging.shutdown()