In [None]:
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from sklearn.metrics import auc, roc_curve

In [None]:
# Plot styling.
plt.style.use(['seaborn-white', 'seaborn-paper'])
plt.rc('font', family='serif')
sns.set_palette('Set1')
sns.set_context('paper', font_scale=1.3)    # Single-column figure.

In [None]:
dot_embed = pd.read_parquet('dot_embed.parquet')
dot_embed['pair_type'] = dot_embed['pair_type'].map(
    {'Positive': 1, 'Negative': 0})
dot_embed = dot_embed.dropna()
dot_embed['pair_type'] = dot_embed['pair_type'].astype(np.uint8)

### AU(C)ROC plots

In [None]:
def concentrate_fpr(fpr, alpha):
    return (1 - np.exp(-alpha * fpr)) / (1 - np.exp(-alpha))

In [None]:
alpha = 14

In [None]:
width = 7
fig, ax = plt.subplots(figsize=(width, width))

fpr_high_res, tpr_high_res, _ = roc_curve(dot_embed['pair_type'],
                                          dot_embed['dot_high_res'])
croc_fpr_high_res = concentrate_fpr(fpr_high_res, alpha)
ax.plot(croc_fpr_high_res, tpr_high_res,
        label=f'Dot product high res '
              f'(AUCROC = {auc(croc_fpr_high_res, tpr_high_res):.2%})')

fpr_low_res, tpr_low_res, _ = roc_curve(dot_embed['pair_type'],
                                        dot_embed['dot_low_res'])
croc_fpr_low_res = concentrate_fpr(fpr_low_res, alpha)
ax.plot(croc_fpr_low_res, tpr_low_res,
        label=f'Dot product low res '
              f'(AUCROC = {auc(croc_fpr_low_res, tpr_low_res):.2%})')

fpr_embed, tpr_embed, _ = roc_curve(
    dot_embed['pair_type'],
    1 - dot_embed['gleams_dist'] / dot_embed['gleams_dist'].max())
croc_fpr_embed = concentrate_fpr(fpr_embed, alpha)
ax.plot(croc_fpr_embed, tpr_embed,
        label=f'Embedding '
              f'(AUCROC = {auc(croc_fpr_embed, tpr_embed):.2%})')

ax.plot(concentrate_fpr(np.arange(0, 1.01, 0.01), alpha),
        np.arange(0, 1.01, 0.01), color='black', linestyle='--')

ax.set_xlim([-0.05, 1.05])
ax.set_ylim([-0.05, 1.05])

ax.set_xlabel('Concentrated false positive rate')
ax.set_ylabel('True positive rate')

ax.legend(loc='lower center', bbox_to_anchor=(0.5, -0.3))

sns.despine()

plt.savefig('croc.png', dpi=300, bbox_inches='tight')
plt.show()
plt.close()

In [None]:
width = 7
fig, ax = plt.subplots(figsize=(width, width))

ax.plot(fpr_high_res, tpr_high_res,
        label=f'Dot product high res '
              f'(AUROC = {auc(fpr_high_res, tpr_high_res):.2%})')

ax.plot(fpr_low_res, tpr_low_res,
        label=f'Dot product low res '
              f'(AUROC = {auc(fpr_low_res, tpr_low_res):.2%})')

ax.plot(fpr_embed, tpr_embed,
        label=f'Embedding '
              f'(AUROC = {auc(fpr_embed, tpr_embed):.2%})')

ax.plot(np.arange(0, 1.01, 0.01), np.arange(0, 1.01, 0.01),
        color='black', linestyle='--')

ax.set_xlim([-0.05, 1.05])
ax.set_ylim([-0.05, 1.05])

ax.set_xlabel('False positive rate')
ax.set_ylabel('True positive rate')

ax.legend(loc='lower center', bbox_to_anchor=(0.5, -0.3))

sns.despine()

# plt.savefig('roc.png', dpi=300, bbox_inches='tight')
plt.show()
plt.close()

In [None]:
fig, axes = plt.subplots(2, 2, sharex=True, sharey=True, figsize=(14, 14))

for ax, charge in zip(axes.ravel(), [2, 3, 4, 5]):
    dot_embed_bin = dot_embed[dot_embed['charge'] == charge
                              if charge < 5 else
                              dot_embed['charge'] >= charge]
    
    fpr_dot, tpr_dot, _ = roc_curve(dot_embed_bin['pair_type'],
                                    dot_embed_bin['dot_high_res'])
    croc_fpr_dot = concentrate_fpr(fpr_dot, alpha)
    ax.plot(croc_fpr_dot, tpr_dot,
            label=f'Dot product (AUC = {auc(croc_fpr_dot, tpr_dot):.2%})')

    fpr_gleams, tpr_gleams, _ = roc_curve(
        dot_embed_bin['pair_type'],
        1 - dot_embed_bin['gleams_dist'] / dot_embed_bin['gleams_dist'].max())
    croc_fpr_gleams = concentrate_fpr(fpr_gleams, alpha)
    ax.plot(croc_fpr_gleams, tpr_gleams,
            label=f'GLEAMS (AUC = {auc(croc_fpr_gleams, tpr_gleams):.2%})')

    ax.plot(concentrate_fpr(np.arange(0, 1.01, 0.01), alpha),
            np.arange(0, 1.01, 0.01), color='black', linestyle='--')
    
    ax.set_title(f'Precursor charge {"=" if charge < 5 else "≥"} {charge} '
                 f'({len(dot_embed_bin):,d} pairs)')
    ax.legend(loc='lower right')

for ax in axes[:, 0]:
    ax.set_ylabel('True positive rate')
for ax in axes[1]:
    ax.set_xlabel(f'Concentrated false positive rate (alpha = {alpha})')

plt.tight_layout()
sns.despine()

plt.savefig('croc_charge.png', dpi=300, bbox_inches='tight')
plt.show()
plt.close()

In [None]:
num_bins = 4
dot_embed['bin'] = pd.qcut(dot_embed['mz'], num_bins)

width = 7 * 2
height = 7 * num_bins // 2
fig, axes = plt.subplots(num_bins // 2, 2, sharex=True, sharey=True,
                         figsize=(width, height))

for ax, (label, dot_embed_bin) in zip(axes.ravel(), dot_embed.groupby('bin')):
    fpr_dot, tpr_dot, _ = roc_curve(dot_embed_bin['pair_type'],
                                    dot_embed_bin['dot_high_res'])
    croc_fpr_dot = concentrate_fpr(fpr_dot, alpha)
    ax.plot(croc_fpr_dot, tpr_dot,
            label=f'Dot product (AUC = {auc(croc_fpr_dot, tpr_dot):.2%})')

    fpr_gleams, tpr_gleams, _ = roc_curve(
        dot_embed_bin['pair_type'],
        1 - dot_embed_bin['gleams_dist'] / dot_embed_bin['gleams_dist'].max())
    croc_fpr_gleams = concentrate_fpr(fpr_gleams, alpha)
    ax.plot(croc_fpr_gleams, tpr_gleams,
            label=f'GLEAMS (AUC = {auc(croc_fpr_gleams, tpr_gleams):.2%})')

    ax.plot(concentrate_fpr(np.arange(0, 1.01, 0.01), alpha),
            np.arange(0, 1.01, 0.01), color='black', linestyle='--')
    
    ax.set_title(f'Precursor $m$/$z$ interval = {label}')
    ax.legend(loc='lower right')

for ax in axes[:, 0]:
    ax.set_ylabel('True positive rate')
for ax in axes[1]:
    ax.set_xlabel(f'Concentrated false positive rate (alpha = {alpha})')

plt.tight_layout()
sns.despine()

plt.savefig('croc_mz.png', dpi=300, bbox_inches='tight')
plt.show()
plt.close()