# Demo of Classification

This notebook demonstrates the code necessary to train different classifiers and use these to predict new MSMS spectra of either being of "interest" or "other". 

## Setup

In [None]:
from AnnoMe.Classification import (
    generate_ms2deepscore_embeddings,
    add_mzmine_metainfos,
    add_sirius_fingerprints,
    add_sirius_canopus,
    add_sirius_predictions,
    add_mzmine_quant,
    remove_invalid_CEs,
    show_dataset_overview,
    generate_embedding_plots,
    train_and_classify,
    generate_prediction_overview,
    set_random_seeds,
)

import pandas as pd

from collections import OrderedDict
import os

set_random_seeds(42)

## Parameters

In [None]:
# fmt: off
# parameters
# 

# Path to the MS2DeepScore model file
model_file_name = f"./publicDBs/models/ms2deepscore_model.pt"

# Output directory
output_dir = f"./publicDBs/output/PrenylatedCompounds_publicDBs/"

# Main datasets to process for classification
max_instances = 10000
datasets = OrderedDict(
    [
        (
            "MB Riken - other",
            {
                "name": "MB Riken - other",
                "type": "train - other",
                "file": f"./publicDBs/libraries/derived_prenylated_compounds/MassBank_RIKEN___StructureOfInterest__NonMatchingSmiles.mgf",
                "fragmentation_method": "fragmentation_method",
                "colour": "#D41F11",
                "randomly_sample": max_instances
            },
        ),
        (
            "MB Riken - relevant",
            {
                "name": "MB Riken - relevant",
                "type": "train - relevant",
                "file": f"./publicDBs/libraries/derived_prenylated_compounds/MassBank_RIKEN___StructureOfInterest__MatchingSmiles.mgf",
                "fragmentation_method": "fragmentation_method",
                "colour": "#F37A00",
                "randomly_sample": max_instances
            },
        ),
        (
            "MassSpecGym - other",
            {
                "name": "MassSpecGym - other",
                "type": "train - other",
                "file": f"./publicDBs/libraries/derived_prenylated_compounds/MassSpecGym___StructureOfInterest__NonMatchingSmiles.mgf",
                "fragmentation_method": "fragmentation_method",
                "colour": "#017192",
                "randomly_sample": max_instances
            },
        ),
        (
            "MassSpecGym - relevant",
            {
                "name": "MassSpecGym - relevant",
                "type": "train - relevant",
                "file": f"./publicDBs/libraries/derived_prenylated_compounds/MassSpecGym___StructureOfInterest__MatchingSmiles.mgf",
                "fragmentation_method": "fragmentation_method",
                "colour": "#10ADC2",
                "randomly_sample": max_instances
            },
        ),
        (
            "MONA - other",
            {
                "name": "MONA - other",
                "type": "train - other",
                "file": f"./publicDBs/libraries/derived_prenylated_compounds/MONA___StructureOfInterest__NonMatchingSmiles.mgf",
                "fragmentation_method": "fragmentation_method",
                "colour": "#8B8989",
                "randomly_sample": max_instances
            },
        ),
        (
            "Wine-DB - other",
            {
                "name": "Wine-DB - other",
                "type": "inference",
                "file": f"./publicDBs/libraries/derived_prenylated_compounds/WINE-DB-ORBITRAP___StructureOfInterest__NonMatchingSmiles.mgf",
                "fragmentation_method": "fragmentation_method",
                "colour": "#80BF02",
                "randomly_sample": max_instances
            },
        ),
        (
            "PrenFlav stds - relevant",
            {
                "name": "Prenylated-Flavonoids standards - relevant",
                "type": "validation - relevant",
                "file": f"./../../../../results/Samp1_neg__sirius.mgf",
                "fragmentation_method": "fragmentation_method",
                "colour": "#FF00FF",
                "randomly_sample": max_instances
            },
        )
    ]
)

# meta-data to add to the output from the MS/MS spectra
data_to_add = OrderedDict(
    [
        ("name", ["feature_id", "name", "title", "compound_name"]),
        ("formula", ["formula"]),
        ("smiles", ["smiles"]),
        ("adduct", ["adduct", "precursor_type"]),
        ("ionMode", ["ionmode"]),
        ("RTINSECONDS", ["rtinseconds", "retention_time"]),
        ("precursor_mz", ["pepmass", "precursor_mz"]),
        ("fragmentation_method", ["fragmentation_method", "fragmentation_mode"]),
        ("CE", ["collision_energy"]),
    ]
)

training_subsets = {
    #"HCD_negative_CE20": lambda x: (x["fragmentation_method"] == "HCD") & (x["ionMode"] == "negative") & (x["CE"] == "20.0"),
    #"HCD_negative_CE30": lambda x: (x["fragmentation_method"] == "HCD") & (x["ionMode"] == "negative") & (x["CE"] == "30.0"),
    #"HCD_negative_CE40": lambda x: (x["fragmentation_method"] == "HCD") & (x["ionMode"] == "negative") & (x["CE"] == "40.0"),
    
    "negative": lambda x: ((x["ionMode"] == "negative")),

    "positive": lambda x: ((x["ionMode"] == "positive")),
}

# derived, do not change
colors = {ds: datasets[ds]["colour"] for ds in datasets}
# fmt: on

## Execute pipeline

In [None]:
# Create output directory if it doesn't exist
os.makedirs(output_dir, exist_ok=True)

pickle_file = f"{output_dir}/df_embeddings.pkl"
if not os.path.exists(pickle_file):
    df = pd.read_pickle(pickle_file)

else:
    # Import the spectra and process MS2DeepScore embeddings
    df = generate_ms2deepscore_embeddings(model_file_name, datasets, data_to_add)

    # add associated metadata
    df = add_mzmine_metainfos(datasets, df)
    # df = add_sirius_fingerprints(datasets, df)
    df = add_sirius_canopus(datasets, df)
    df = add_sirius_predictions(datasets, df)
    df = add_mzmine_quant(datasets, df)

    # remove invalid collision energies
    df = remove_invalid_CEs(df)

    # show overview and plot
    show_dataset_overview(df)
    generate_embedding_plots(df, output_dir, colors)

    # export dataframe for re-use
    df.to_pickle(f"{output_dir}/df_embeddings.pkl")

# train and predict new datasets
df_train, df_validation, df_inference, df_metrics = train_and_classify(df, subsets=training_subsets)
generate_prediction_overview(df, df_train, output_dir, "training")
generate_prediction_overview(df, df_inference, output_dir, "inference")

# Generate an overview of the machine learning metrics
generate_ml_metrics_overview(df_metrics, output_dir)