In [2]:
import os
import sys
# Make sure all code is in the PATH.
try:
    sys.path.append("../src")
except:
    pass
try:
    sys.path.append(
        os.path.normpath(
            os.path.join(
                os.environ["HOME"], "Projects", "cosine_neutral_loss", "src"
            )
        )
    )
except:
    pass

In [3]:
import functools
import lzma
import pathlib
import re

import matplotlib.pyplot as plt
import matplotlib.ticker as mticker
import numba as nb
import numpy as np
import pandas as pd
import pyteomics.mgf
import seaborn as sns
import spectrum_utils.spectrum as sus
from matplotlib.colors import LogNorm
from matplotlib.gridspec import GridSpec
from rdkit import Chem, DataStructs, RDLogger
from tqdm.autonotebook import tqdm

import similarity


RDLogger.DisableLog('rdApp.*')
tqdm.pandas()

  from tqdm.autonotebook import tqdm


In [4]:
# Plot styling.
plt.style.use(["seaborn-white", "seaborn-paper"])
plt.rc("font", family="sans-serif")
sns.set_palette(["#9e0059", "#6da7de", "#dee000"])
sns.set_context("paper", font_scale=1.)

## Analysis settings

In [5]:
# Spectra and spectrum pairs to include with the following settings.
charges = 0, 1
min_n_peaks = 6
fragment_mz_tolerance = 0.1
min_mass_diff = 1    # Da
max_mass_diff = 200    # Da

In [8]:
# Profile spectra contain 0 intensity values.
@nb.njit
def is_centroid(intensity_array):
    return np.all(intensity_array > 0)


# Assumes that the spectra are sorted by ascending precusor m/z.
@nb.njit
def generate_pairs(
    spectrum_indexes, masses, min_mass_diff, max_mass_diff
):
    for i in range(len(spectrum_indexes)):
        j = i + 1
        while (
            j < len(spectrum_indexes) and
            masses[j] - masses[i] < min_mass_diff
        ):
            j += 1
        while (
            j < len(spectrum_indexes) and
            masses[j] - masses[i] < max_mass_diff
        ):
            yield spectrum_indexes[i]
            yield spectrum_indexes[j]
            j += 1
            
            
@functools.lru_cache
def _smiles_to_mol(smiles):
    try:
        return Chem.MolFromSmiles(smiles)
    except ArgumentError:
        return None
    

@functools.lru_cache
def tanimoto(smiles1, smiles2):
    mol1, mol2 = _smiles_to_mol(smiles1), _smiles_to_mol(smiles2)
    if mol1 is None or mol2 is None:
        return np.nan
    fp1, fp2 = Chem.RDKFingerprint(mol1), Chem.RDKFingerprint(mol2)
    return DataStructs.TanimotoSimilarity(fp1, fp2)

## Data IO

In [None]:
filename_text = ("../data/external/ALL_GNPS_NO_PROPOGATED.mgf")
filename = ("../data/external/ALL_GNPS_NO_PROPOGATED.mgf.xz")
if not pathlib.Path(filename).exists():
    with open(filename_text, "rt") as file:
        with lzma.open(filename, "wt") as out:
            for line in tqdm(file):
                out.write(line)
    print("exported compressed file")
else: print("compressed file already available")

0it [00:00, ?it/s]

In [None]:
# Read all spectra from the MGF.
# ALL_GNPS_NO_PROPOGATED (retrieved on 2022-05-12) downloaded from
# https://gnps-external.ucsd.edu/gnpslibrary

# Spectrum quality filtering:
#   - Don't include propagated spectra (LIBRARYQUALITY==4).
#   - Don't include multiply charged molecules.
#   - Don't include spectra with invalid precursor m/z (0).
#   - Don't include spectra with too few peaks (minimum 6).
#   - Only include positive ion mode spectra.
#   - Only include spectra with [M+H]+ adducts.
#   - Only include centroid data (does not contain zero intensity values).
spectra = []
filename = ("../data/external/ALL_GNPS_NO_PROPOGATED.mgf.xz")
with lzma.open(filename, "rt") as xz_in:
    with pyteomics.mgf.MGF(xz_in) as f_in:
        for spectrum_dict in tqdm(f_in):
            if (
                int(spectrum_dict["params"]["libraryquality"]) <= 3 and
                int(spectrum_dict["params"]["charge"][0]) in charges and
                float(spectrum_dict["params"]["pepmass"][0]) > 0 and
                len(spectrum_dict["m/z array"]) >= min_n_peaks and
                spectrum_dict["params"]["ionmode"] == "Positive" and
                spectrum_dict["params"]["name"].rstrip().endswith(" M+H") and
                is_centroid(spectrum_dict["intensity array"])
            ):
                spec = sus.MsmsSpectrum(
                    spectrum_dict["params"]["spectrumid"],
                    float(spectrum_dict["params"]["pepmass"][0]),
                    # Re-assign charge 0 to 1.
                    max(int(spectrum_dict["params"]["charge"][0]), 1),
                    spectrum_dict["m/z array"],
                    spectrum_dict["intensity array"],
                )
                spec.library = spectrum_dict["params"]["organism"]
                spec.smiles = spectrum_dict["params"]["smiles"]
                spec.remove_precursor_peak(0.1, "Da")
                spectra.append(spec)

In [None]:
# Extract the metadata (library identifier and precursor charge and m/z).
identifiers, libraries, smiles, charges, mzs = [], [], [], [], []
metadata = []
for spectrum in spectra:
    identifiers.append(spectrum.identifier)
    libraries.append(spectrum.library)
    smiles.append(spectrum.smiles)
    charges.append(spectrum.precursor_charge)
    mzs.append(spectrum.precursor_mz)
metadata = pd.DataFrame(
    {
        "id": identifiers,
        "library": libraries,
        "smiles": smiles,
        "charge": charges,
        "mz": mzs,
    }
)

## Compute spectrum-spectrum similarities

In [None]:
# Extract indexes for the relevant pairs of spectra.
pairs = []
for charge in np.arange(
        metadata["charge"].min(),
        metadata["charge"].max() + 1,
    ):
    # Make sure the spectra are sorted by ascending precursor m/z.
    metadata_charge = (metadata[metadata["charge"] == charge]
                       .copy()
                       .sort_values("mz")
                       .reset_index())
    # Generate pairs with the specified precursor m/z difference.
    pairs.append(
        np.fromiter(
            generate_pairs(
                metadata_charge["index"].values,
                metadata_charge["mz"].values,
                min_mass_diff,
                max_mass_diff,
            ),
            np.int32)
        .reshape((-1, 2))
    )
pairs = np.vstack(pairs)

In [None]:
# Randomly subsample the pairs so it remains computationally tractable.
np.random.seed(1)
pairs = pairs[np.random.choice(pairs.shape[0], 10_000_000, replace=False)]

In [None]:
# Compute similarities between spectrum pairs.
scores = []
for i, j in tqdm(pairs):
    cos = similarity.cosine(spectra[i], spectra[j], fragment_mz_tolerance)
    mod_cos = similarity.modified_cosine(
        spectra[i], spectra[j], fragment_mz_tolerance
    )
    nl = similarity.neutral_loss(
        spectra[i], spectra[j], fragment_mz_tolerance
    )
    tan = tanimoto(metadata.at[i, "smiles"], metadata.at[j, "smiles"])
    scores.append(
        (cos[0], cos[1], mod_cos[0], mod_cos[1], nl[0], nl[1], tan)
    )
similarities = pd.DataFrame(
    {
        "pair1": pairs[:, 0],
        "pair2": pairs[:, 1],
        "id1": metadata.loc[pairs[:, 0], "id"].values,
        "id2": metadata.loc[pairs[:, 1], "id"].values,
        "smiles1": metadata.loc[pairs[:, 0], "smiles"].values,
        "smiles2": metadata.loc[pairs[:, 1], "smiles"].values,
        "charge1": metadata.loc[pairs[:, 0], "charge"].values,
        "charge2": metadata.loc[pairs[:, 1], "charge"].values,
        "mz1": metadata.loc[pairs[:, 0], "mz"].values,
        "mz2": metadata.loc[pairs[:, 1], "mz"].values,
    }
)
similarities[
    [
        "cosine",
        "cosine_explained",
        "modified_cosine",
        "modified_cosine_explained",
        "neutral_loss",
        "neutral_loss_explained",
        "tanimoto",
    ]
] = scores
similarities.to_parquet("gnps_libraries.parquet")

## Results plotting

In [None]:
similarities["tanimoto_interval"] = pd.cut(
    similarities["tanimoto"],
    5,
    labels=["0.0–0.2", "0.2–0.4", "0.4–0.6", "0.6–0.8", "0.8–1.0"],
)
similarities_tanimoto = pd.melt(
    similarities,
    id_vars="tanimoto_interval",
    value_vars=["cosine", "neutral_loss", "modified_cosine"],
)

In [None]:
print(f"Number of spectrum pairs: {len(similarities):,}")
print(
    f"Spectrum pairs where neutral loss outperforms cosine: "
    f"{(similarities['neutral_loss'].round(5) > similarities['cosine'].round(5)).sum() / len(similarities):.1%}"
)
print(
    f"Spectrum pairs where neutral loss outperforms modified cosine: "
    f"{(similarities['neutral_loss'].round(5) > similarities['modified_cosine'].round(5)).sum() / len(similarities):.1%}"
)

In [None]:
mosaic = """
11111.
222223
222223
222223
222223
222223
"""

bins = 100
tick_locators = mticker.FixedLocator(np.arange(0, bins + 1, bins / 4))
tick_labels = np.asarray([f"{a:.2f}" for a in np.arange(0, 1.01, 0.25)])

with sns.plotting_context("paper", font_scale=1.6):
    fig = plt.figure(constrained_layout=True, figsize=(7.2 * 2, 7.2 / 1.618 * 3))
    gs = GridSpec(3, 3, figure=fig)
    
    # Top panel: Compare different similarities.
    axes_left = fig.add_subfigure(gs[0, 0]).subplot_mosaic(mosaic)
    axes_middle = fig.add_subfigure(gs[0, 1]).subplot_mosaic(mosaic)
    axes_right = fig.add_subfigure(gs[0, 2]).subplot_mosaic(mosaic)
    cbar_ax = fig.add_axes([-0.04, 0.75, 0.02, 0.15])
    
    labels = np.asarray([
        ["cosine", "modified_cosine"],
        ["neutral_loss", "cosine"],
        ["neutral_loss", "modified_cosine"]
    ])

    for i, (axes, (xlabel, ylabel)) in enumerate(
        zip([axes_left, axes_middle, axes_right], labels)
    ):
        # Plot heatmaps.
        hist, _, _ = np.histogram2d(
            similarities[xlabel],
            similarities[ylabel],
            bins=bins,
            range=[[0, 1], [0, 1]],
        )
        hist /= len(similarities)
        heatmap = sns.heatmap(
            np.rot90(hist),
            vmin=0.0,
            vmax=0.001,
            cmap="viridis",
            cbar=i == 2,
            cbar_kws={"format": mticker.StrMethodFormatter("{x:.3%}")},
            cbar_ax=cbar_ax if i == 2 else None,
            square=True,
            xticklabels=False,
            yticklabels=False,
            ax=axes["2"],
            norm=LogNorm(vmax=0.001),
        )
        axes["2"].yaxis.set_major_locator(tick_locators)
        axes["2"].set_yticklabels(tick_labels[::-1])
        axes["2"].xaxis.set_major_locator(tick_locators)
        axes["2"].set_xticklabels(tick_labels)
        for _, spine in heatmap.spines.items():
            spine.set_visible(True)
        axes["2"].set_xlabel(xlabel.replace("_", " ").capitalize())
        axes["2"].set_ylabel(ylabel.replace("_", " ").capitalize())

        axes["2"].plot(
            [0, bins], [bins, 0], color="black", linestyle="dashed"
        )

        sns.despine(ax=axes["2"])

        # Plot density plots.
        sns.kdeplot(
            data=similarities,
            x=xlabel,
            clip=(0, 1),
            legend=True,
            color="black",
            fill=True,
            ax=axes["1"],
        )
        axes["1"].set_xlim(0, 1)
        axes["1"].xaxis.set_ticklabels([])
        axes["1"].yaxis.set_major_locator(tick_locators)
        axes["1"].set_yticks([])
        sns.despine(ax=axes["1"], left=True)
        sns.kdeplot(
            data=similarities,
            y=ylabel,
            clip=(0, 1),
            legend=True,
            color="black",
            fill=True,
            ax=axes["3"],
        )
        axes["3"].set_ylim(0, 1)
        axes["3"].yaxis.set_ticklabels([])
        axes["3"].xaxis.set_major_locator(tick_locators)
        axes["3"].set_xticks([])
        sns.despine(ax=axes["3"], bottom=True)
        for ax in [axes[c] for c in "13"]:
            ax.set_xlabel("")
            ax.set_ylabel("")
            
    cbar_ax.set_ylabel("Proportion of pairs")
    cbar_ax.yaxis.set_label_position("left")
    cbar_ax.spines["outline"].set(visible=True, lw=.8, edgecolor="black")
    
    # Middle panel: Compare similarities vs explained intensity.
    axes_left = fig.add_subfigure(gs[1, 0]).subplot_mosaic(mosaic)
    axes_middle = fig.add_subfigure(gs[1, 1]).subplot_mosaic(mosaic)
    axes_right = fig.add_subfigure(gs[1, 2]).subplot_mosaic(mosaic)
    cbar_ax = fig.add_axes([-0.04, 0.45, 0.02, 0.15])
    
    labels = np.asarray([
        ["cosine_explained", "cosine"],
        ["neutral_loss_explained", "neutral_loss"],
        ["modified_cosine_explained", "modified_cosine"],
    ])

    for i, (axes, (xlabel, ylabel)) in enumerate(
        zip([axes_left, axes_middle, axes_right], labels)
    ):
        # Plot heatmaps.
        hist, _, _ = np.histogram2d(
            similarities[xlabel],
            similarities[ylabel],
            bins=bins,
            range=[[0, 1], [0, 1]],
        )
        hist /= len(similarities)
        heatmap = sns.heatmap(
            np.rot90(hist),
            vmin=0.0,
            vmax=0.001,
            cmap="viridis",
            cbar=i == 2,
            cbar_kws={"format": mticker.StrMethodFormatter("{x:.3%}")},
            cbar_ax=cbar_ax if i == 2 else None,
            square=True,
            xticklabels=False,
            yticklabels=False,
            ax=axes["2"],
            norm=LogNorm(vmax=0.001),
        )
        axes["2"].yaxis.set_major_locator(tick_locators)
        axes["2"].set_yticklabels(tick_labels[::-1])
        axes["2"].xaxis.set_major_locator(tick_locators)
        axes["2"].set_xticklabels(tick_labels)
        axes["2"].xaxis.set_major_formatter(mticker.PercentFormatter())
        for _, spine in heatmap.spines.items():
            spine.set_visible(True)
        axes["2"].set_xlabel("Explained intensity")
        axes["2"].set_ylabel(ylabel.replace("_", " ").capitalize())

        sns.despine(ax=axes["2"])

        # Plot density plots.
        sns.kdeplot(
            data=similarities,
            x=xlabel,
            clip=(0, 1),
            legend=True,
            color="black",
            fill=True,
            ax=axes["1"],
        )
        axes["1"].set_xlim(0, 1)
        axes["1"].xaxis.set_ticklabels([])
        axes["1"].yaxis.set_major_locator(tick_locators)
        axes["1"].set_yticks([])
        sns.despine(ax=axes["1"], left=True)
        sns.kdeplot(
            data=similarities,
            y=ylabel,
            clip=(0, 1),
            legend=True,
            color="black",
            fill=True,
            ax=axes["3"],
        )
        axes["3"].set_ylim(0, 1)
        axes["3"].yaxis.set_ticklabels([])
        axes["3"].xaxis.set_major_locator(tick_locators)
        axes["3"].set_xticks([])
        sns.despine(ax=axes["3"], bottom=True)
        for ax in [axes[c] for c in "13"]:
            ax.set_xlabel("")
            ax.set_ylabel("")
            
    cbar_ax.set_ylabel("Proportion of pairs")
    cbar_ax.yaxis.set_label_position("left")
    cbar_ax.spines["outline"].set(visible=True, lw=.8, edgecolor="black")
    
    # Bottom panel: Evaluate similarities in terms of the Tanimoto index.
    ax = fig.add_subplot(gs[2, :])
    
    sns.violinplot(
        data=similarities_tanimoto,
        x="tanimoto_interval",
        y="value",
        hue="variable",
        hue_order=["cosine", "neutral_loss", "modified_cosine"],
        cut=0,
        scale="width",
        scale_hue=False,
        ax=ax,
    )
    ax.set_xlabel("Tanimoto index")
    ax.set_ylabel("Spectrum similarity")
    for label in ax.legend().get_texts():
        label.set_text(label.get_text().replace("_", " ").capitalize())
    sns.move_legend(
        ax,
        "lower center",
        bbox_to_anchor=(.5, 1),
        ncol=3,
        title=None,
        frameon=False,
    )

    sns.despine(ax=ax)
    
    # Subplot labels.
    for y, label in zip([1, 2/3, 0.35], "abc"):
        fig.text(
            -0.05, y, label, fontdict=dict(fontsize="xx-large", weight="bold")
        )

    # Save figure.
    plt.savefig("gnps_libraries.png", dpi=300, bbox_inches="tight")
    plt.show()
    plt.close()

In [None]:
similarities_filtered = similarities[similarities["tanimoto"] > 0.9]
print(
    f"Number of spectrum pairs with Tanimoto > 0.9: "
    f"{len(similarities_filtered):,}"
)

with sns.plotting_context("paper", font_scale=1.6):
    fig = plt.figure(constrained_layout=True, figsize=(7.2 * 2, 7.2 / 1.618))
    gs = GridSpec(1, 3, figure=fig)
    
    axes_left = fig.add_subfigure(gs[0]).subplot_mosaic(mosaic)
    axes_middle = fig.add_subfigure(gs[1]).subplot_mosaic(mosaic)
    axes_right = fig.add_subfigure(gs[2]).subplot_mosaic(mosaic)
    cbar_ax = fig.add_axes([-0.08, 0.2, 0.02, 0.5])
    
    labels = np.asarray([
        ["cosine_explained", "cosine"],
        ["neutral_loss_explained", "neutral_loss"],
        ["modified_cosine_explained", "modified_cosine"],
    ])

    for i, (axes, (xlabel, ylabel)) in enumerate(
        zip([axes_left, axes_middle, axes_right], labels)
    ):
        # Plot heatmaps.
        hist, _, _ = np.histogram2d(
            similarities_filtered[xlabel],
            similarities_filtered[ylabel],
            bins=bins,
            range=[[0, 1], [0, 1]],
        )
        hist /= len(similarities_filtered)
        hist[hist == 0.0] = np.nan
        heatmap = sns.heatmap(
            np.rot90(hist),
            vmin=0.0,
            vmax=0.001,
            cmap="viridis",
            cbar=i == 2,
            cbar_kws={"format": mticker.StrMethodFormatter("{x:.2%}")},
            cbar_ax=cbar_ax if i == 2 else None,
            square=True,
            xticklabels=False,
            yticklabels=False,
            ax=axes["2"],
        )
        axes["2"].yaxis.set_major_locator(tick_locators)
        axes["2"].set_yticklabels(tick_labels[::-1])
        axes["2"].xaxis.set_major_locator(tick_locators)
        axes["2"].set_xticklabels(tick_labels)
        axes["2"].xaxis.set_major_formatter(mticker.PercentFormatter())
        for _, spine in heatmap.spines.items():
            spine.set_visible(True)
        axes["2"].set_xlabel("Explained intensity")
        axes["2"].set_ylabel(ylabel.replace("_", " ").capitalize())

        sns.despine(ax=axes["2"])

        # Plot density plots.
        sns.kdeplot(
            data=similarities_filtered,
            x=xlabel,
            clip=(0, 1),
            legend=True,
            color="black",
            fill=True,
            ax=axes["1"],
        )
        axes["1"].set_xlim(0, 1)
        axes["1"].xaxis.set_ticklabels([])
        axes["1"].yaxis.set_major_locator(tick_locators)
        axes["1"].set_yticks([])
        sns.despine(ax=axes["1"], left=True)
        sns.kdeplot(
            data=similarities_filtered,
            y=ylabel,
            clip=(0, 1),
            legend=True,
            color="black",
            fill=True,
            ax=axes["3"],
        )
        axes["3"].set_ylim(0, 1)
        axes["3"].yaxis.set_ticklabels([])
        axes["3"].xaxis.set_major_locator(tick_locators)
        axes["3"].set_xticks([])
        sns.despine(ax=axes["3"], bottom=True)
        for ax in [axes[c] for c in "13"]:
            ax.set_xlabel("")
            ax.set_ylabel("")
            
    cbar_ax.set_ylabel("Proportion of pairs")
    cbar_ax.yaxis.set_label_position("left")
    cbar_ax.spines["outline"].set(visible=True, lw=.8, edgecolor="black")

    # Save figure.
    plt.savefig("gnps_libraries_tanimoto.png", dpi=300, bbox_inches="tight")
    plt.show()
    plt.close()