# Demo of Classification

This notebook demonstrates the code necessary to train different classifiers and use these to predict new MSMS spectra of either being of "interest" or "other". 

## Setup

In [None]:
from AnnoMe.Prediction import (
    generate_ms2deepscore_embeddings,
    add_mzmine_metainfos,
    add_sirius_fingerprints,
    add_sirius_canopus,
    add_sirius_predictions,
    add_mzmine_quant,
    remove_invalid_CEs,
    show_dataset_overview,
    generate_embedding_plots,
    train_and_classify,
    generate_prediction_overview,
    set_random_seeds
)

from collections import OrderedDict
import os

set_random_seeds(42)

## Parameters

In [None]:
# fmt: off
# parameters

# Main folder for input and output files
base_folder = "../../../"

# Path to the MS2DeepScore model file
model_file_name = f"{base_folder}/models/ms2deepscore_model.pt"

# Output directory
output_dir = f"{base_folder}/output_PrenylatedCompounds_BOKUDBs/"

# Main datasets to process for classification
datasets = OrderedDict(
    [
        (
            "query CID pos",
            {
                "name": "prenylated flavonoids",
                "type": "interesting",
                "file": f"{base_folder}/../results/CID_pos__sirius.mgf",
                "fragmentation_method": "fragmentation_method",
                "colour": "#D41F11",
                "canopus_file": f"{base_folder}/../results/CID_pos__sirius/canopus_formula_summary.tsv",
            },
        ),
        (
            "query CID neg",
            {
                "name": "prenylated flavonoids",
                "type": "interesting",
                "file": f"{base_folder}/../results/CID_neg__sirius.mgf",
                "fragmentation_method": "fragmentation_method",
                "colour": "#F37A00",
                "canopus_file": f"{base_folder}/../results/CID_neg__sirius/canopus_formula_summary.tsv",
            },
        ),
        (
            "query HCD pos",
            {
                "name": "prenylated flavonoids",
                "type": "interesting",
                "file": f"{base_folder}/../results/HCD_pos__sirius.mgf",
                "fragmentation_method": "fragmentation_method",
                "colour": "#10ADC2",
                "canopus_file": f"{base_folder}/../results/HCD_pos__sirius/canopus_formula_summary.tsv",
            },
        ),
        (
            "query HCD neg",
            {
                "name": "prenylated flavonoids",
                "type": "interesting",
                "file": f"{base_folder}/../results/HCD_neg__sirius.mgf",
                "fragmentation_method": "fragmentation_method",
                "colour": "#017192",
                "canopus_file": f"{base_folder}/../results/HCD_neg__sirius/canopus_formula_summary.tsv",
            },
        ),
        (
            "MB BOKU",
            {
                "name": "MassBank BOKU",
                "type": "other",
                "file": f"{base_folder}/data/BOKU_iBAM.mgf",
                "fragmentation_method": "fragmentation_method",
                "colour": "#8B8989",
                "fingerprintFile": "::SIRIUS",
            },
        ),
        (
            "Samp1_neg",
            {
                "name": "Samp1_neg",
                "type": "inference",
                "file": f"{base_folder}/../results/Samp1_neg__sirius.mgf",
                "fragmentation_method": "fragmentation_method",
                "colour": "#80BF02",
                "fingerprintFile": f"{base_folder}/../results/Samp1_neg__sirius_fingerprints.json",
                "canopus_file": f"{base_folder}/../results/Samp1_neg__sirius/canopus_formula_summary.tsv",
                "quant_file": f"{base_folder}/../results/Samp1_neg__full_feature_table.csv",
                "sirius_file": f"{base_folder}/../results/Samp1_neg__sirius/structure_identifications_top-15.tsv",
                "mzmine_meta_table": f"{base_folder}/../results/Samp1_neg__full_feature_table.csv",
            },
        ),
    ]
)

# meta-data to add to the output from the MS/MS spectra
data_to_add = OrderedDict(
    [
        ("name", ["feature_id", "name", "title", "compound_name"]),
        ("formula", ["formula"]),
        ("smiles", ["smiles"]),
        ("adduct", ["adduct", "precursor_type"]),
        ("ionMode", ["ionmode"]),
        ("RTINSECONDS", ["rtinseconds", "retention_time"]),
        ("precursor_mz", ["pepmass", "precursor_mz"]),
        ("fragmentation_method", ["fragmentation_method", "fragmentation_mode"]),
        ("CE", ["collision_energy"]),
    ]
)

training_subsets = {
    ## "all"          : lambda x: True,
    ## "CE30"         : lambda x: x["CE"] == "30.0",
    ## "CE50"         : lambda x: x["CE"] == "50.0",
    ## "CE70"         : lambda x: x["CE"] == "70.0",
    ## "positive"     : lambda x: x["ionMode"] == "positive",
    ## "negative"     : lambda x: x["ionMode"] == "negative",
    ## "CID"          : lambda x: x["fragmentation_method"] == "CID",
    ## "HCD"          : lambda x: x["fragmentation_method"] == "HCD",
    ## "positive_CE30": lambda x: (x["ionMode"] == "positive") & (x["CE"] == "30.0"),
    ## "positive_CE50": lambda x: (x["ionMode"] == "positive") & (x["CE"] == "50.0"),
    ## "positive_CE70": lambda x: (x["ionMode"] == "positive") & (x["CE"] == "70.0"),
    ## "negative_CE30": lambda x: (x["ionMode"] == "negative") & (x["CE"] == "30.0"),
    ## "negative_CE50": lambda x: (x["ionMode"] == "negative") & (x["CE"] == "50.0"),
    ## "negative_CE70": lambda x: (x["ionMode"] == "negative") & (x["CE"] == "70.0"),
    ## "CID_CE30"     : lambda x: (x["fragmentation_method"] == "CID") & (x["CE"] == "30.0"),
    ## "CID_CE50"     : lambda x: (x["fragmentation_method"] == "CID") & (x["CE"] == "50.0"),
    ## "CID_CE70"     : lambda x: (x["fragmentation_method"] == "CID") & (x["CE"] == "70.0"),
    ## "HCD_CE30"     : lambda x: (x["fragmentation_method"] == "HCD") & (x["CE"] == "30.0"),
    ## "HCD_CE50"     : lambda x: (x["fragmentation_method"] == "HCD") & (x["CE"] == "50.0"),
    ## "HCD_CE70"     : lambda x: (x["fragmentation_method"] == "HCD") & (x["CE"] == "70.0"),
    ## "CID_positive" : lambda x: (x["fragmentation_method"] == "CID") & (x["ionMode"] == "positive"),
    ## "CID_negative" : lambda x: (x["fragmentation_method"] == "CID") & (x["ionMode"] == "negative"),
    ## "HCD_positive" : lambda x: (x["fragmentation_method"] == "HCD") & (x["ionMode"] == "positive"),
    ## "HCD_negative" : lambda x: (x["fragmentation_method"] == "HCD") & (x["ionMode"] == "negative"),
    ## "CID_positive_CE30" : lambda x: (x["fragmentation_method"] == "CID") & (x["ionMode"] == "positive") & (x["CE"] == "30.0"),
    ## "CID_positive_CE50" : lambda x: (x["fragmentation_method"] == "CID") & (x["ionMode"] == "positive") & (x["CE"] == "50.0"),
    ## "CID_positive_CE70" : lambda x: (x["fragmentation_method"] == "CID") & (x["ionMode"] == "positive") & (x["CE"] == "70.0"),
    ## "CID_negative_CE30" : lambda x: (x["fragmentation_method"] == "CID") & (x["ionMode"] == "negative") & (x["CE"] == "30.0"),
    ## "CID_negative_CE50" : lambda x: (x["fragmentation_method"] == "CID") & (x["ionMode"] == "negative") & (x["CE"] == "50.0"),
    ## "CID_negative_CE70" : lambda x: (x["fragmentation_method"] == "CID") & (x["ionMode"] == "negative") & (x["CE"] == "70.0"),
    ## "HCD_positive_CE30" : lambda x: (x["fragmentation_method"] == "HCD") & (x["ionMode"] == "positive") & (x["CE"] == "30.0"),
    ## "HCD_positive_CE50" : lambda x: (x["fragmentation_method"] == "HCD") & (x["ionMode"] == "positive") & (x["CE"] == "50.0"),
    ## "HCD_positive_CE70" : lambda x: (x["fragmentation_method"] == "HCD") & (x["ionMode"] == "positive") & (x["CE"] == "70.0"),
    "HCD_negative_CE20": lambda x: (x["fragmentation_method"] == "HCD") & (x["ionMode"] == "negative") & (x["CE"] == "20.0"),
    "HCD_negative_CE30": lambda x: (x["fragmentation_method"] == "HCD") & (x["ionMode"] == "negative") & (x["CE"] == "30.0"),
    "HCD_negative_CE40": lambda x: (x["fragmentation_method"] == "HCD") & (x["ionMode"] == "negative") & (x["CE"] == "40.0"),
    ## "HCD_negative_CE50" : lambda x: (x["fragmentation_method"] == "HCD") & (x["ionMode"] == "negative") & (x["CE"] == "50.0"),
    ## "HCD_negative_CE70" : lambda x: (x["fragmentation_method"] == "HCD") & (x["ionMode"] == "negative") & (x["CE"] == "70.0"),
}


# derived, do not change
colors = {ds: datasets[ds]["colour"] for ds in datasets}
# fmt: on

## Execute pipeline

In [None]:
# Create output directory if it doesn't exist
os.makedirs(output_dir, exist_ok=True)

# Import the spectra and process MS2DeepScore embeddings
df = generate_ms2deepscore_embeddings(model_file_name, datasets, data_to_add)

# add associated metadata
df = add_mzmine_metainfos(datasets, df)
# df = add_sirius_fingerprints(datasets, df)
df = add_sirius_canopus(datasets, df)
df = add_sirius_predictions(datasets, df)
df = add_mzmine_quant(datasets, df)

# remove invalid collision energies
df = remove_invalid_CEs(df)

# show overview and plot
show_dataset_overview(df)
generate_embedding_plots(df, output_dir, colors)

# train and predict new datasets
df_inference = train_and_classify(df, subsets=training_subsets)

# generate overview of the predictions
generate_prediction_overview(df, df_inference, output_dir)