# Visualize spectra using embedding umap
This notebooks creates a 2D visualization of the case study spectrum set. By using Umap the multidiminsional embeddings are reduced to 2 dimensions. Annotations are added to make a interactive visualization possible

# Download model, spectra and annotations

In [3]:
import requests
import os
from tqdm import tqdm

def download_file(link, file_name):
    response = requests.get(link, stream=True)
    if os.path.exists(file_name):
        print(f"The file {file_name} already exists, the file won't be downloaded")
        return
    total_size = int(response.headers.get('content-length', 0))
    
    with open(file_name, "wb") as f, tqdm(desc="Downloading file", total=total_size, unit='B', unit_scale=True, unit_divisor=1024,) as bar:
        for chunk in response.iter_content(chunk_size=1024):
            if chunk:
                f.write(chunk)
                bar.update(len(chunk))  # Update progress bar by the chunk size
    
model_file_name="../../../../data/pytorch/new_corinna_included/trained_models/positive_mode_precursor_mz_ionmode_10000_layers_500_embedding_2024_11_26_11_30_47/ms2deepscore_model.pt"

case_study_spectra_file_name = "case_study_spectra.mgf"
ms2query_annotations = "ms2query_annotations.csv"

download_file("https://zenodo.org/records/14290920/files/settings.json?download=1", "ms2deepscore_settings.json")
download_file("https://zenodo.org/records/14535374/files/cleaned_spectra_pos_neg_with_numbering.mgf?download=1", case_study_spectra_file_name)
download_file("https://zenodo.org/records/14535374/files/ms2query_annotations.csv?download=1", ms2query_annotations)



The file ms2deepscore_settings.json already exists, the file won't be downloaded
The file case_study_spectra.mgf already exists, the file won't be downloaded
The file ms2query_annotations.csv already exists, the file won't be downloaded


### Load ms2deepscore model

In [4]:
from ms2deepscore.models import load_model
model = load_model(model_file_name)

In [5]:
from ms2deepscore import MS2DeepScore
ms2ds_model = MS2DeepScore(model)

### load spectra
The spectra are precprocessed in the notebook pre_processing_spectra.ipynb

In [6]:
from matchms.importing import load_from_mgf
from tqdm import tqdm
spectra = list(tqdm(load_from_mgf(case_study_spectra_file_name)))

2909it [00:00, 3226.33it/s]


### Create embeddings
MS2Deepscore embeddings are a 500 dimensional representation of a spectrum. Similar embeddings should correspond to similar molecules.

In [7]:
ms2ds_embeddings = ms2ds_model.get_embedding_array(spectra)

2909it [00:15, 186.71it/s]


In [8]:
ms2ds_embeddings.shape

(2909, 500)

### Fit Umap
Umap is a method for reducing multidimensional data to 2 dimensions. This is done in a way that aims to keep distances between embeddings close to the thruth. However, since reducing from 500 dementions some distortions always take place. Reducing to 2 dimensions allows us to visualize the embeddings.

The code below tries to learn how to best transform 500 dimensions to 2 dimensions.

In [9]:
import umap

reducer = umap.UMAP(random_state=42,  # this or whatever your favorite number is
                    n_neighbors=50 ,  # key parameters How global or local the distribution 30, 50
                    min_dist=0.2 , # can the dots overlap if you use 5 they move out a bit. 0.1, 0.2
                    )
reducer.fit(ms2ds_embeddings)

  warn(


### get 2d coordinates from embeddings
Here we use the fitted umap model to actually transform embeddings to a x and y coordinate

In [10]:
embedding_umap = reducer.transform(ms2ds_embeddings)

In [27]:
embedding_umap

array([[ 6.399148  ,  5.249808  ],
       [ 5.9670463 ,  0.29087648],
       [ 4.300889  ,  3.7659795 ],
       ...,
       [18.725214  , -1.9337422 ],
       [10.805029  , -9.491422  ],
       [18.697523  , -1.9080055 ]], dtype=float32)

### Convert to pandas df
Use the query_spectrum_id from the spectra as index for the pandas dataframe

In [11]:
import pandas as pd
embedding_umap_df = pd.DataFrame(embedding_umap, 
                                 index = [spectrum.get("query_spectrum_nr") for spectrum in spectra],
                                columns = ["x", "y"])
indexes = embedding_umap_df.index
embedding_umap_df["query_spectrum_nr"] = indexes

### Visualize embeddings
We use plotly to visualize the embeddings of the spectra

In [12]:
import plotly.io as pio
pio.renderers.default = 'iframe'

In [13]:
import plotly.express as px
import numpy as np

# Create scatter plot using Plotly Express
fig = px.scatter(embedding_umap_df,
    x="x",
    y="y",
    size_max=50,
    opacity=0.5,
    title='UMAP projection MS2DeepScore embeddings',
    width=800,
    height=800,
)

fig.show()

### Add annotations to visualization
Visualizing just the embeddings is not very informative. Below we add the ms2query annotations, but any information can be added in principle. 


In [14]:
ms2query_annotations = pd.read_csv(ms2query_annotations)

### Merge annotations with umap table

In [15]:
merged_data = pd.merge(embedding_umap_df, ms2query_annotations, on="query_spectrum_nr", how= "outer")

In [16]:
import plotly.express as px
import numpy as np

# Create scatter plot using Plotly Express
fig = px.scatter(merged_data.iloc[::-1],
    x="x",
    y="y",
    color="ionmode",
    color_discrete_map={"positive": "#C6DBEF", "negative": "#FCBBA1"},

    # color_continuous_scale="viridis",
    size_max=50,
    opacity=1.0,
    title='UMAP projection MS2DeepScore embeddings',
    hover_data={"x": False,
                "y": False,
                # "fraction": True,
                "precursor_mz_difference": False,
               "ionmode": True,
               "ms2query_model_prediction": False,
               "smiles": True,
               "analog_compound_name": True,
                "precursor_mz_query_spectrum": True, 
               "rtinminutes": True,
               "query_spectrum_nr": True},
    # symbol="cf_superclass",
    width=800,
    height=800,
)
# Update background colors
fig.update_layout(
    plot_bgcolor='white',  # Set the plot area background color
    paper_bgcolor='white',    # Set the figure background color
)

# Update grid line colors
fig.update_xaxes(
    gridcolor='gray',          # Grid lines on x-axis
    zerolinecolor='gray'      # Line at x=0
)
fig.update_yaxes(
    gridcolor='gray',          # Grid lines on y-axis
    zerolinecolor='gray'      # Line at y=0
)

fig.update_traces(marker=dict(size=5))  # Increase size to 12 (default is typically smaller)

fig.show()

In [None]:
fig.write_html("embedding_visualization_red_blue.html")

### Create zoomed in versions

In [None]:
fig.update_layout(
    xaxis=dict(range=[0, 18],
              dtick=2),  # Set x-axis zoom range (adjust as needed)
    yaxis=dict(range=[0, 18],
              dtick=2)   # Set y-axis zoom range (adjust as needed)
)

In [None]:
# percentage of umap dots excluded with specified zoom
nr_of_spectra_outside_range = merged_data[(merged_data['x'] > 18) | (merged_data['x'] < 0) | (merged_data['y'] > 18) | (merged_data['y'] < 0)].shape[0]
print(f"nr of spectra outside range of 0-18: {nr_of_spectra_outside_range/merged_data.shape[0]*100:.1f} %")


In [None]:
fig.update_layout(
    xaxis=dict(range=[7.99, 10.01]),  # Set x-axis zoom range (adjust as needed)
    yaxis=dict(range=[3.99, 6.01])   # Set y-axis zoom range (adjust as needed)
)
fig.update_traces(marker=dict(size=15))  # Increase size to 12 (default is typically smaller)

fig.show()

### Find interesting examples with reliable MS2Query annotations in both ionmodes

In [None]:
# get all ms2query scores above 0.8
high_ms2query_score = merged_data[(merged_data['ms2query_model_prediction'] > 0.8)]
high_ms2query_score['precursor_mz_difference'].astype(float)
# and mass diff < 0.1 Da
high_ms2query_exact_matches = high_ms2query_score[(high_ms2query_score['precursor_mz_difference'].astype(float) < 0.1)]
# Get the cases that have at least 1 pos and 1 neg example
has_pos_and_neg = high_ms2query_exact_matches.groupby('inchikey')['ionmode'].apply(lambda x: {'positive', 'negative'}.issubset(set(x)))
result = high_ms2query_exact_matches[high_ms2query_exact_matches['inchikey'].isin(has_pos_and_neg[has_pos_and_neg].index)]


In [None]:
pd.set_option('display.max_columns', None)
pd.set_option('display.max_colwidth', None)

result.sort_values(by='inchikey')


# plot example spectra
Based on these results we pick the first pos and neg example

In [None]:
from matchms.plotting.spectrum_plots import plot_spectra_mirror
import matplotlib.pyplot as plt
selected_spectra =[spectrum for spectrum in spectra if spectrum.get("query_spectrum_nr") in ["neg_1231", "pos_761", ]]
# Set compound name to empty string, since matchms by default prints the compound name
selected_spectra[0].set("compound_name", "")
selected_spectra[1].set("compound_name", "")
plt.figure(figsize=(5, 5))
fig = plot_spectra_mirror(selected_spectra[1], selected_spectra[0], grid=False, 
                          max_mz=420,
                          color_top = "#FCBBA1",
                          color_bottom="#C6DBEF"
                         )
 # "positive": "#C6DBEF", "negative": "#FCBBA1"
plt.savefig("./spectrum_comparison_cholin_acid.svg", bbox_inches='tight')

In [None]:
print(selected_spectra[0].metadata)
print(selected_spectra[1].metadata)

#### Get (modified) cosine score and MS2Deepscore for example spectrum pair

In [None]:
import matchms.similarity as mssim
selected_spectra =[spectrum for spectrum in spectra if spectrum.get("query_spectrum_nr") in ["neg_1231", "pos_1558", ]]

similarity_cosine = mssim.CosineGreedy(tolerance=0.1).pair(selected_spectra[0], selected_spectra[1])
similarity_modified_cosine = mssim.ModifiedCosine(tolerance=0.1).pair(selected_spectra[0], selected_spectra[1])
ms2deepscore_score = ms2ds_model.pair(selected_spectra[0], selected_spectra[1])
print(f"MS2Deepscore: {ms2deepscore_score:.3f}")
print(f"Cosine score: {similarity_cosine['score']:.3f}")
print(f"Modified cosine score: {similarity_modified_cosine['score']:.3f}")

# Plot with superclass overlay
MS2Query adds classifier annotations. These can be used to visualize the spectra. 

In [26]:
import plotly.express as px
import numpy as np

# Create scatter plot using Plotly Express
fig = px.scatter(merged_data,
    x="x",
    y="y",
    color="ionmode",
    color_discrete_map={"unknown": 'lightgrey'},

    # color_continuous_scale="viridis",
    size_max=50,
    opacity=0.5,
    title='UMAP projection MS2DeepScore embeddings',
    hover_data={"x": False,
                "y": False,
                # "fraction": True,
                "precursor_mz_difference": True,
               "ionmode": True,
               "ms2query_model_prediction": True,
               "smiles": True,
               "analog_compound_name": True,
               "rtinminutes": True},
    # symbol="ionmode",
    width=1000,
    height=800,
)

fig.show()


In [None]:
fig.write_html("embedding_visualization_compound_classes.html")

### Visualize molecular structures
Below you can create an interactive plot, that draws the molecular structure directly on top of the plot. 
The molecular structures are drawn based on MS2Query annotations.

In [18]:
merged_data = merged_data[merged_data["smiles"] != "unknown"]

In [19]:
# make number
merged_data.loc[:, "rtinminutes"] = merged_data["rtinminutes"].astype(float)

In [24]:
port = 8800

In [25]:
import pandas as pd
import plotly.express as px

import molplotly

fig = px.scatter(merged_data,
    x="x",
    y="y",
    color="ionmode",
    symbol="cf_superclass",
    width=1200,
    height=800,
)

# add molecules to the plotly graph - returns a Dash app
app = molplotly.add_molecules(fig=fig,
                              df=merged_data,
                              smiles_col='smiles',
                              color_col="ionmode",
                              caption_cols=['rtinminutes', "ionmode", "precursor_mz_query_spectrum", "analog_compound_name"],
                              show_coords=False,
                              symbol_col="cf_superclass"
                            )

port += 1 # If you want to recreate a plot the port has to not be used before (this makes it easy to not forget increasing the port)

# run Dash app inline in notebook (or in an external server)
app.run_server(mode='external', port=port, height=1000)

Dash app running on http://127.0.0.1:8801/



JupyterDash is deprecated, use Dash instead.
See https://dash.plotly.com/dash-in-jupyter for more details.

