# Demo of Classification

This notebook demonstrates the code necessary to train different classifiers and use these to predict new MSMS spectra of either being of "interest" or "other". 

## Setup

In [None]:
from AnnoMe.Classification import (
    generate_embeddings,
    add_all_metadata,
    show_dataset_overview,
    generate_embedding_plots,
    train_and_classify,
    generate_prediction_overview,
    generate_ml_metrics_overview,
    set_random_seeds,
)

import pandas as pd

from IPython.display import display

from collections import OrderedDict
import os

set_random_seeds(42)

## Parameters

In [None]:
# fmt: off
# parameters

# Output directory
output_dir = f"./output/PrenylatedCompounds_BOKUDB/"

# Main datasets to process for classification
base_folder = "../resources/libraries_other"
datasets = [
    # MS/MS of reference prenylated flavones
    {
        "name": "prenylated flavonoids - gt relevant",
        "type": "train - relevant",
        "file": os.path.join(base_folder, "HCD_pos__sirius.mgf"),
        "fragmentation_method": "fragmentation_method",
        "colour": "#D41F11",
    },
    {
        "name": "prenylated flavonoids - gt relevant",
        "type": "train - relevant",
        "file": os.path.join(base_folder, "HCD_neg__sirius.mgf"),
        "fragmentation_method": "fragmentation_method",
        "colour": "#D41F11",
    },

    # MS/MS of wheat samples
    {
        "name": "wheat metabolites - gt other",
        "type": "train - other",
        "file": os.path.join(base_folder, "Wheat_pos__sirius.mgf"),
        "fragmentation_method": "fragmentation_method",
        "colour": "#017192",
    },
    {
        "name": "wheat metabolites - gt other",
        "type": "train - other",
        "file": os.path.join(base_folder, "Wheat_neg__sirius.mgf"),
        "fragmentation_method": "fragmentation_method",
        "colour": "#017192",
    },
    {
        "name": "wheat metabolites - gt other",
        "type": "train - other",
        "file": os.path.join(base_folder, "n_wheat_pos__sirius.mgf"),
        "fragmentation_method": "fragmentation_method",
        "colour": "#017192",
    },
    {
        "name": "wheat metabolites - gt other",
        "type": "train - other",
        "file": os.path.join(base_folder, "n_wheat_neg__sirius.mgf"),
        "fragmentation_method": "fragmentation_method",
        "colour": "#017192",
    },

    ## MS/MS of BOKU MassBank
    {
        "name": "MassBank BOKU - gt relevant",
        "type": "train - relevant",
        "file": os.path.join("..", "resources", "libraries_filtered", "BOKU_iBAM___prenyl_flavonoid_or_chalcone__MatchingSmiles.mgf"),
        "fragmentation_method": "fragmentation_method",
        "colour": "#80BF02",
    },
    {
        "name": "MassBank BOKU - gt other",
        "type": "train - other",
        "file": os.path.join("..", "resources", "libraries_filtered", "BOKU_iBAM___prenyl_flavonoid_or_chalcone__NonMatchingSmiles.mgf"),
        "fragmentation_method": "fragmentation_method",
        "colour": "#80BF02",
    },

    # MS/MS for inference
    {
        "name": "Paulowina tomentosa",
        "type": "inference",
        "file": os.path.join(base_folder, "n_PT22_neg__sirius.mgf"),
        "fragmentation_method": "fragmentation_method",
        "colour": "#8B0773",
    },
    {
        "name": "Paulowina tomentosa",
        "type": "inference",
        "file": os.path.join(base_folder, "n_PT22_pos__sirius.mgf"),
        "fragmentation_method": "fragmentation_method",
        "colour": "#8B0773",
    },
    {
        "name": "Paulowina tomentosa",
        "type": "inference",
        "file": os.path.join(base_folder, "n_PT24_neg__sirius.mgf"),
        "fragmentation_method": "fragmentation_method",
        "colour": "#8B0773",
    },
    {
        "name": "Paulowina tomentosa",
        "type": "inference",
        "file": os.path.join(base_folder, "n_PT24_pos__sirius.mgf"),
        "fragmentation_method": "fragmentation_method",
        "colour": "#8B0773",
    },
    {
        "name": "Paulowina tomentosa",
        "type": "inference",
        "file": os.path.join(base_folder, "PT22CH_neg__sirius.mgf"),
        "fragmentation_method": "fragmentation_method",
        "colour": "#8B0773",
    },
    {
        "name": "Paulowina tomentosa",
        "type": "inference",
        "file": os.path.join(base_folder, "PT22CH_pos__sirius.mgf"),
        "fragmentation_method": "fragmentation_method",
        "colour": "#8B0773",
    },

    {
        "name": "Glycyrrhizza uralensis",
        "type": "inference",
        "file": os.path.join(base_folder, "n_GU_neg__sirius.mgf"),
        "fragmentation_method": "fragmentation_method",
        "colour": "#8B0773",
    },
    {
        "name": "Glycyrrhizza uralensis",
        "type": "inference",
        "file": os.path.join(base_folder, "n_GU_pos__sirius.mgf"),
        "fragmentation_method": "fragmentation_method",
        "colour": "#8B0773",
    },

    {
        "name": "Samp1",
        "type": "inference",
        "file": os.path.join(base_folder, "Samp1_neg__sirius.mgf"),
        "fragmentation_method": "fragmentation_method",
        "colour": "#8B0773",
    }
]

# meta-data to add to the output from the MS/MS spectra
data_to_add = OrderedDict(
    [
        ("name", ["feature_id", "name", "title", "compound_name"]),
        ("formula", ["formula"]),
        ("smiles", ["smiles"]),
        ("adduct", ["adduct", "precursor_type"]),
        ("ionMode", ["ionmode"]),
        ("RTINSECONDS", ["rtinseconds", "retention_time"]),
        ("precursor_mz", ["pepmass", "precursor_mz"]),
        ("fragmentation_method", ["fragmentation_method", "fragmentation_mode"]),
        ("CE", ["collision_energy"]),
    ]
)

training_subsets = {
    
    "hcd_neg" : lambda x: (x["fragmentation_method"] == "hcd") & (x["ionMode"] == "negative"),
    "hcd_pos" : lambda x: (x["fragmentation_method"] == "hcd") & (x["ionMode"] == "positive"),

    "hcd_neg_20.0" : lambda x: (x["fragmentation_method"] == "hcd") & (x["ionMode"] == "negative") & (x["CE"] == "20.0"),
    "hcd_neg_30.0" : lambda x: (x["fragmentation_method"] == "hcd") & (x["ionMode"] == "negative") & (x["CE"] == "30.0"),
    "hcd_neg_40.0" : lambda x: (x["fragmentation_method"] == "hcd") & (x["ionMode"] == "negative") & (x["CE"] == "40.0"),
    "hcd_neg_step[20,45,70]" : lambda x: (x["fragmentation_method"] == "hcd") & (x["ionMode"] == "negative") & (x["CE"] in ["45.0", "stepped20,45,70ev(absolute)"]),
    
    "hcd_pos_20.0" : lambda x: (x["fragmentation_method"] == "hcd") & (x["ionMode"] == "positive") & (x["CE"] == "20.0"),
    "hcd_pos_30.0" : lambda x: (x["fragmentation_method"] == "hcd") & (x["ionMode"] == "positive") & (x["CE"] == "30.0"),
    "hcd_pos_40.0" : lambda x: (x["fragmentation_method"] == "hcd") & (x["ionMode"] == "positive") & (x["CE"] == "40.0"),
    "hcd_pos_step[20,45,70]" : lambda x: (x["fragmentation_method"] == "hcd") & (x["ionMode"] == "positive") & (x["CE"] in ["45.0", "stepped20,45,70ev(absolute)"]),
}

# derived, do not change
colors = {ds["name"]: ds["colour"] for ds in datasets}
# fmt: on

## Execute pipeline

In [None]:
# Create output directory if it doesn't exist
os.makedirs(output_dir, exist_ok=True)

pickle_file = f"{output_dir}/df_embeddings.pkl"
if os.path.exists(pickle_file):
    df = pd.read_pickle(pickle_file)
    show_dataset_overview(df, print_method=display)

else:
    # Import the spectra and process MS2DeepScore embeddings
    df = generate_embeddings(datasets, data_to_add)

    # add associated metadata
    df = add_all_metadata(datasets, df)

    # export dataframe for re-use
    df.to_pickle(f"{output_dir}/df_embeddings.pkl")

# iterate over the training subsets, produces better output
for subset_name in training_subsets:
    print(f"Processing subset: {subset_name}")
    print(f"##############################################################################")

    # Get the subset function
    subset_fn = training_subsets[subset_name]

    # show overview and plot
    generate_embedding_plots(df, output_dir, colors)

    # Create output directory for the subset
    c_output_dir = f"{output_dir}/subset_{subset_name}/"
    os.makedirs(c_output_dir, exist_ok=True)

    # subset the dataframe
    df_subset = df[df.apply(subset_fn, axis=1)].reset_index(drop=True)

    # train and predict new datasets
    df_train, df_validation, df_inference, df_metrics = train_and_classify(df_subset, subsets=subset_fn, output_dir=c_output_dir)
    generate_prediction_overview(df_subset, df_train, c_output_dir, "training", min_prediction_threshold=13)
    generate_prediction_overview(df_subset, df_inference, c_output_dir, "inference", min_prediction_threshold=13)

    # Generate an overview of the machine learning metrics
    generate_ml_metrics_overview(df_metrics, c_output_dir)

In [None]:
# Generate an overview of the training datasets obtained from the BOKU repositories
print(f"\nGenerating an overview of the training datasets obtained from the BOKU repositories")

# Initialize an empty list to store DataFrames
all_training_input = []

# Iterate through all folders in the output directory
for folder_name in os.listdir(output_dir):
    folder_path = os.path.join(output_dir, folder_name)
    if os.path.isdir(folder_path):  # Check if it's a directory
        file_path = os.path.join(folder_path, "dataset_overview.xlsx")
        if os.path.exists(file_path):  # Check if the file exists
            # Read the contents of the sheet 'overall'
            df = pd.read_excel(file_path, sheet_name="input")
            # Add a new column with the folder name
            df["subset"] = folder_name
            # Append the DataFrame to the list
            all_training_input.append(df)

if len(all_training_input) == 0:
    print("No training datasets found in the output directory.")
    all_training_input = None
else:
    # Concatenate all DataFrames into a single DataFrame
    all_training_input = pd.concat(all_training_input, ignore_index=True)
    # Split the 'subset' column into three new columns using the regex pattern
    all_training_input[["polarity", "fragmentation_method", "collision_energy"]] = all_training_input["combo"].str.extract(r"(.*)_(.*)_(.*)")
    # Select only the required columns
    columns_to_select = ["fragmentation_method", "polarity", "collision_energy"] + [col for col in all_training_input.columns if "- " in col]
    all_training_input = all_training_input[columns_to_select]
    # Convert the DataFrame to long format
    all_training_input = all_training_input.melt(id_vars=["fragmentation_method", "polarity", "collision_energy"], var_name="key", value_name="value")
    all_training_input[["reference_library", "gt_type"]] = all_training_input["key"].str.extract(r"(.*) - (.*)")
    all_training_input["gt_type"] = all_training_input["gt_type"].str.replace("gt ", "", regex=False)
    # Remove all rows with missing values
    all_training_input.dropna(inplace=True)
    all_training_input = all_training_input[all_training_input["value"] != 0]
    # Rename the 'key' column by removing any substring '_neg_' or '_pos_'
    all_training_input["reference_library"] = all_training_input["reference_library"].str.replace(r"_neg_|_pos_", "", regex=True)

    # Pivot the table
    def _agg(x):
        return list(set(x)) if len(set(x)) > 1 else list(set(x))[0]

    all_training_input = all_training_input.pivot_table(
        index=["fragmentation_method", "polarity", "collision_energy"], columns=["reference_library", "gt_type"], values="value", aggfunc=_agg
    ).reset_index()
    # Transpose the table
    all_training_input = all_training_input.set_index(["fragmentation_method", "polarity", "collision_energy"]).transpose().reset_index()
    all_training_input.rename(columns={"key": "reference_library"}, inplace=True)
    # Order the DataFrame by 'reference_library' and 'gt_type'
    all_training_input.sort_values(by=["reference_library", "gt_type"], inplace=True)
    # Group by 'gt_type' and sum all numeric columns
    totals = all_training_input.groupby("gt_type").sum(numeric_only=True).reset_index()
    # Add a new column 'reference_library' with the value 'total_spectra'
    totals["reference_library"] = "total_spectra"
    # Append the totals to the end of all_training_input
    all_training_input = pd.concat([all_training_input, totals], ignore_index=True)


# Generate an overview of the validation datasets measured in-house
print(f"\nGenerating an overview of the validation datasets measured in-house")
# Initialize an empty list to store DataFrames
all_validation_results = []

# Iterate through all folders in the output directory
for folder_name in os.listdir(output_dir):
    folder_path = os.path.join(output_dir, folder_name)
    if os.path.isdir(folder_path):  # Check if it's a directory
        file_path = os.path.join(folder_path, "validation_data.xlsx")
        if os.path.exists(file_path):  # Check if the file exists
            # Read the contents of the sheet 'overall'
            df = pd.read_excel(file_path, sheet_name="overall")
            # Add a new column with the folder name
            df["subset"] = folder_name
            # Append the DataFrame to the list
            all_validation_results.append(df)

if len(all_validation_results) == 0:
    print("No validation datasets found in the output directory.")
    all_validation_results = None
else:
    # Concatenate all DataFrames into a single DataFrame
    all_validation_results = pd.concat(all_validation_results, ignore_index=True)
    all_validation_results["annotated_as"] = all_validation_results["annotated_as_times:relevant"].map(lambda x: "relevant" if x != 0 else "other")
    all_validation_results.rename(columns={"row_count": "n_features"}, inplace=True)
    all_validation_results["percent_features"] = (100.0 * all_validation_results["n_features"] / all_validation_results.groupby(["source", "subset"])["n_features"].transform("sum")).round(1)
    # Split the 'subset' column into three new columns using the regex pattern
    all_validation_results[["fragmentation_method", "polarity", "collision_energy"]] = all_validation_results["subset"].str.extract(r".*_(.*)_(.*)_(.*)")
    all_validation_results["source"] = all_validation_results["source"].str.replace(" - gt ", " - ", regex=False)
    all_validation_results[["source", "gt_type"]] = all_validation_results["source"].str.extract(r"(.*) - (other|relevant)")
    # Order the DataFrame by 'source', 'subset', and 'annotated_as'
    all_validation_results.sort_values(by=["source", "polarity", "fragmentation_method", "collision_energy", "gt_type", "annotated_as"], inplace=True)
    # Reorder the columns
    all_validation_results = all_validation_results[["source", "polarity", "fragmentation_method", "collision_energy", "gt_type", "annotated_as", "n_features", "percent_features"]]


# Generate an overview of the inference dataset measured in-house
print(f"\nGenerating an overview of the inference datasets measured in-house")
# Initialize an empty list to store DataFrames
inference_results = []

# Iterate through all folders in the output directory
for folder_name in os.listdir(output_dir):
    folder_path = os.path.join(output_dir, folder_name)
    if os.path.isdir(folder_path):  # Check if it's a directory
        file_path = os.path.join(folder_path, "inference_data.xlsx")
        if os.path.exists(file_path):  # Check if the file exists
            try:
                # Read the contents of the sheet 'overall'
                df = pd.read_excel(file_path, sheet_name="overall")
                # Add a new column with the folder name
                df["subset"] = folder_name
                # Append the DataFrame to the list
                inference_results.append(df)
            except Exception as e:
                print(f"Error reading {file_path}: {e}")

if len(inference_results) == 0:
    print("No inference datasets found in the output directory.")
    all_inference_results = None
else:
    # Concatenate all DataFrames into a single DataFrame
    all_inference_results = pd.concat(inference_results, ignore_index=True)
    all_inference_results["annotated_as"] = all_inference_results["annotated_as_times:relevant"].map(lambda x: "relevant" if x != 0 else "other")
    all_inference_results.drop(columns=["annotated_as_times:relevant"], inplace=True)
    all_inference_results = all_inference_results.groupby(["source", "annotated_as", "subset"], as_index=False).agg({"row_count": "sum"})
    all_inference_results["subset"] = all_inference_results["subset"].str.replace("(_neg|_pos)$", "\\1_all", regex=True)
    all_inference_results[["fragmentation_method", "polarity", "collision_energy"]] = all_inference_results["subset"].str.extract(r"subset_(.*)_(.*)_(.*)")
    all_inference_results.rename(columns={"row_count": "n_features"}, inplace=True)
    all_inference_results["percent_features"] = (100.0 * all_inference_results["n_features"] / all_inference_results.groupby(["source", "subset"])["n_features"].transform("sum")).round(1)
    all_inference_results = all_inference_results[["source", "polarity", "fragmentation_method", "collision_energy", "annotated_as", "n_features", "percent_features"]]
    all_inference_results.sort_values(by=["source", "polarity", "fragmentation_method", "collision_energy", "annotated_as"], inplace=True)
    all_inference_results.reset_index(drop=True, inplace=True)


# Generate an overview of the training dataset (test only) measured in-house
print(f"\nGenerating an overview of the training dataset (test only) measured in-house")
# Initialize an empty list to store DataFrames
all_test_results = []

# Iterate through all folders in the output directory
for folder_name in os.listdir(output_dir):
    folder_path = os.path.join(output_dir, folder_name)
    if os.path.isdir(folder_path):  # Check if it's a directory
        file_path = os.path.join(folder_path, "training_data.xlsx")
        if os.path.exists(file_path):  # Check if the file exists
            # Read the contents of the sheet 'overall'
            df = pd.read_excel(file_path, sheet_name="overall")
            # Add a new column with the folder name
            df["subset"] = folder_name
            # Append the DataFrame to the list
            all_test_results.append(df)

if len(all_test_results) == 0:
    print("No training datasets found in the output directory.")
    all_test_results = None
else:
    # Concatenate all DataFrames into a single DataFrame
    all_test_results = pd.concat(all_test_results, ignore_index=True)
    all_test_results["annotated_as"] = all_test_results["annotated_as_times:relevant"].map(lambda x: "relevant" if x != 0 else "other")
    all_test_results.drop(columns=["annotated_as_times:relevant"], inplace=True)
    all_test_results = all_test_results.groupby(["source", "annotated_as", "subset"], as_index=False).agg({"row_count": "sum"})
    all_test_results["subset"] = all_test_results["subset"].str.replace("(_neg|_pos)$", "\\1_all", regex=True)
    all_test_results[["fragmentation_method", "polarity", "collision_energy"]] = all_test_results["subset"].str.extract(r"subset_(.*)_(.*)_(.*)")
    all_test_results.rename(columns={"row_count": "n_features"}, inplace=True)
    all_test_results["percent_features"] = (100.0 * all_test_results["n_features"] / all_test_results.groupby(["source", "subset"])["n_features"].transform("sum")).round(1)
    all_test_results[["source", "gt_type"]] = all_test_results["source"].str.extract(r"(.*) - gt (.*)")
    all_test_results = all_test_results[["source", "polarity", "fragmentation_method", "collision_energy", "gt_type", "annotated_as", "n_features", "percent_features"]]
    all_test_results.sort_values(by=["source", "polarity", "fragmentation_method", "collision_energy", "gt_type", "annotated_as"], inplace=True)
    all_test_results.reset_index(drop=True, inplace=True)


# Export the two tables to an Excel file
output_excel_file = os.path.join(output_dir, "summary_tables.xlsx")
with pd.ExcelWriter(output_excel_file, engine="openpyxl") as writer:
    if all_training_input is not None:
        all_training_input.to_excel(writer, sheet_name="all_training_input")
    if all_validation_results is not None:
        all_validation_results.to_excel(writer, sheet_name="all_validation_results")
    if all_inference_results is not None:
        all_inference_results.to_excel(writer, sheet_name="all_inference_results")
    if all_test_results is not None:
        all_test_results.to_excel(writer, sheet_name="all_test_results")

print(f"Exported tables to {output_excel_file}")