# Demo of Classification

This notebook demonstrates the code necessary to train different classifiers and use these to predict new MSMS spectra of either being of "interest" or "other". 

## Setup

In [None]:
from AnnoMe.Classification import (
    generate_ms2deepscore_embeddings,
    add_mzmine_metainfos,
    add_sirius_fingerprints,
    add_sirius_canopus,
    add_sirius_predictions,
    add_mzmine_quant,
    remove_invalid_CEs,
    show_dataset_overview,
    generate_embedding_plots,
    train_and_classify,
    generate_prediction_overview,
    generate_ml_metrics_overview,
    set_random_seeds,
)

import pandas as pd

from IPython.display import display

from collections import OrderedDict
import os

set_random_seeds(42)

## Parameters

In [None]:
# fmt: off
# parameters

# Main folder for input and output files
base_folder = "../../../"

# Path to the MS2DeepScore model file
model_file_name = f"{base_folder}/models/ms2deepscore_model.pt"

# Output directory
output_dir = f"{base_folder}/output_PrenylatedCompounds_BOKUDBs/"

# Main datasets to process for classification
datasets = [
    # MS/MS of reference prenylated flavones
    {
        "name": "prenylated flavonoids - gt relevant",
        "type": "train - relevant",
        "file": f"{base_folder}/../results/CID_pos__sirius.mgf",
        "fragmentation_method": "fragmentation_method",
        "colour": "#D41F11",
        "canopus_file": f"{base_folder}/../results/CID_pos__sirius/canopus_formula_summary.tsv",
    },
    {
        "name": "prenylated flavonoids - gt relevant",
        "type": "train - relevant",
        "file": f"{base_folder}/../results/CID_neg__sirius.mgf",
        "fragmentation_method": "fragmentation_method",
        "colour": "#D41F11",
        "canopus_file": f"{base_folder}/../results/CID_neg__sirius/canopus_formula_summary.tsv",
    },
    {
        "name": "prenylated flavonoids - gt relevant",
        "type": "train - relevant",
        "file": f"{base_folder}/../results/HCD_pos__sirius.mgf",
        "fragmentation_method": "fragmentation_method",
        "colour": "#D41F11",
        "canopus_file": f"{base_folder}/../results/HCD_pos__sirius/canopus_formula_summary.tsv",
    },
    {
        "name": "prenylated flavonoids - gt relevant",
        "type": "train - relevant",
        "file": f"{base_folder}/../results/HCD_neg__sirius.mgf",
        "fragmentation_method": "fragmentation_method",
        "colour": "#D41F11",
        "canopus_file": f"{base_folder}/../results/HCD_neg__sirius/canopus_formula_summary.tsv",
    },

    # MS/MS of wheat samples
    {
        "name": "wheat metabolites - other",
        "type": "train - other",
        "file": f"{base_folder}/../results/Wheat_pos__sirius.mgf",
        "fragmentation_method": "fragmentation_method",
        "colour": "#017192",
    },
    {
        "name": "wheat metabolites - other",
        "type": "train - other",
        "file": f"{base_folder}/../results/Wheat_neg__sirius.mgf",
        "fragmentation_method": "fragmentation_method",
        "colour": "#017192",
    },

    ## MS/MS of BOKU MassBank
    {
        "name": "MassBank BOKU - gt relevant",
        "type": "train - relevant",
        "file": f"{base_folder}/data/derived/BOKU_iBAM_MB___StructureOfInterest__MatchingSmiles.mgf",
        "fragmentation_method": "fragmentation_method",
        "colour": "#80BF02",
        "fingerprintFile": "::SIRIUS",
    },
    {
        "name": "MassBank BOKU - gt other",
        "type": "train - other",
        "file": f"{base_folder}/data/derived/BOKU_iBAM_MB___StructureOfInterest__NonMatchingSmiles.mgf",
        "fragmentation_method": "fragmentation_method",
        "colour": "#80BF02",
        "fingerprintFile": "::SIRIUS",
    },

    # MS/MS for inference
    {
        "name": "PaulowinaTomentosa",
        "type": "inference",
        "file": f"{base_folder}/../results/Samp1_neg__sirius.mgf",
        "fragmentation_method": "fragmentation_method",
        "colour": "#8B0773",
        "fingerprintFile": f"{base_folder}/../results/Samp1_neg__sirius_fingerprints.json",
        "canopus_file": f"{base_folder}/../results/Samp1_neg__sirius/canopus_formula_summary.tsv",
        "quant_file": f"{base_folder}/../results/Samp1_neg__full_feature_table.csv",
        "sirius_file": f"{base_folder}/../results/Samp1_neg__sirius/structure_identifications_top-15.tsv",
        "mzmine_meta_table": f"{base_folder}/../results/Samp1_neg__full_feature_table.csv",
    },
    {
        "name": "PT22CH",
        "type": "inference",
        "file": f"{base_folder}/../results/PT22CH_pos__sirius.mgf",
        "fragmentation_method": "fragmentation_method",
        "colour": "#8B0773",
    },
    {
        "name": "PT22CH",
        "type": "inference",
        "file": f"{base_folder}/../results/PT22CH_neg__sirius.mgf",
        "fragmentation_method": "fragmentation_method",
        "colour": "#8B0773",
    },
]

# meta-data to add to the output from the MS/MS spectra
data_to_add = OrderedDict(
    [
        ("name", ["feature_id", "name", "title", "compound_name"]),
        ("formula", ["formula"]),
        ("smiles", ["smiles"]),
        ("adduct", ["adduct", "precursor_type"]),
        ("ionMode", ["ionmode"]),
        ("RTINSECONDS", ["rtinseconds", "retention_time"]),
        ("precursor_mz", ["pepmass", "precursor_mz"]),
        ("fragmentation_method", ["fragmentation_method", "fragmentation_mode"]),
        ("CE", ["collision_energy"]),
    ]
)

training_subsets = {
    "hcd_neg_20.0" : lambda x: (x["fragmentation_method"] == "hcd") & (x["ionMode"] == "negative") & (x["CE"] == "20.0"),
    "hcd_neg_30.0" : lambda x: (x["fragmentation_method"] == "hcd") & (x["ionMode"] == "negative") & (x["CE"] == "30.0"),
    "hcd_neg_40.0" : lambda x: (x["fragmentation_method"] == "hcd") & (x["ionMode"] == "negative") & (x["CE"] == "40.0"),
    "hcd_neg_step[20,45,70]" : lambda x: (x["fragmentation_method"] == "hcd") & (x["ionMode"] == "negative") & (x["CE"] in ["45.0", "stepped20,45,70ev(absolute)"]),
    "hcd_pos_step[20,45,70]" : lambda x: (x["fragmentation_method"] == "hcd") & (x["ionMode"] == "positive") & (x["CE"] in ["45.0", "stepped20,45,70ev(absolute)"]),
}

# derived, do not change
colors = {ds["name"]: ds["colour"] for ds in datasets}
# fmt: on

## Execute pipeline

In [None]:
# Create output directory if it doesn't exist
os.makedirs(output_dir, exist_ok=True)

pickle_file = f"{output_dir}/df_embeddings.pkl"
if os.path.exists(pickle_file):
    df = pd.read_pickle(pickle_file)
    show_dataset_overview(df, print_method=display)

else:
    # Import the spectra and process MS2DeepScore embeddings
    df = generate_ms2deepscore_embeddings(model_file_name, datasets, data_to_add)

    # add associated metadata
    df = add_mzmine_metainfos(datasets, df)
    # df = add_sirius_fingerprints(datasets, df)
    df = add_sirius_canopus(datasets, df)
    df = add_sirius_predictions(datasets, df)
    df = add_mzmine_quant(datasets, df)

    # export dataframe for re-use
    df.to_pickle(f"{output_dir}/df_embeddings.pkl")

# iterate over the training subsets, produces better output
for subset_name in training_subsets:
    print(f"Processing subset: {subset_name}")
    print(f"##############################################################################")

    # Get the subset function
    subset_fn = training_subsets[subset_name]

    # show overview and plot
    generate_embedding_plots(df, output_dir, colors)

    # Create output directory for the subset
    c_output_dir = f"{output_dir}/subset_{subset_name}/"
    os.makedirs(c_output_dir, exist_ok=True)

    # subset the dataframe
    df_subset = df[df.apply(subset_fn, axis=1)].reset_index(drop=True)

    # train and predict new datasets
    df_train, df_validation, df_inference, df_metrics = train_and_classify(df_subset, subsets=subset_fn, output_dir=c_output_dir)
    generate_prediction_overview(df_subset, df_train, c_output_dir, "training", min_prediction_threshold=13)
    generate_prediction_overview(df_subset, df_inference, c_output_dir, "inference", min_prediction_threshold=13)

    # Generate an overview of the machine learning metrics
    generate_ml_metrics_overview(df_metrics, c_output_dir)