# Train Decision Trees & Convert to MassQL Queries

## Initial Setup

In [None]:
import os
import pickle
import json
import pandas as pd
from rdkit.Chem import AllChem
from tqdm.auto import tqdm

tqdm.pandas()

from chemecho.utils import load_processed_gnps_data, merge_in_nist
from chemecho.featurization import subformula_featurization, build_feature_matrix, feature_reduction, save_featurized_spectra, load_featurized_spectra
from chemecho.train_predict import filter_failed_idxs, train_substructure_tree, convert_tree_to_massql

In [None]:
polarity = 'negative' # either 'negative' or 'positive'
workdir = '/pscratch/sd/t/tharwood/massql_constructor' # storage location for feature matrices, MS2 libraries, etc.

vector_assignment_method = 'blur' # either 'blur' or 'top'. blur assigns all subformula within tolerance, top takes only the best one
max_ppm_error = 10 # max ppm error of the subformula assignment
min_feature_occurence = 6 # minimum number of feature occurences to keep

## Build Training Feature Matrix

In [None]:
gnps_cleaned = load_processed_gnps_data(gnps_cleaned_path=f'{workdir}/gnps_cleaned.tsv', convert_spectra=True, polarity=polarity)
merged_lib = merge_in_nist(gnps_cleaned, nist_cleaned_path=f'{workdir}/nist_cleaned.tsv', convert_spectra=True, polarity=polarity)

In [None]:
peak_subformula_vectors, nl_subformula_vectors = subformula_featurization(merged_lib, 
                                                                          vector_assignment=vector_assignment_method, 
                                                                          max_ppm_error=max_ppm_error)

In [None]:
# get indices of spectra with no subformula assigned
failed_spectra_idxs = [i for i in range(len(peak_subformula_vectors)) if peak_subformula_vectors[i] is None]

In [None]:
featurized_spectral_data, feature_vector_index_map = build_feature_matrix(peak_subformula_vectors, nl_subformula_vectors)

In [None]:
featurized_spectral_data, feature_vector_index_map = feature_reduction(featurized_spectral_data, feature_vector_index_map, 
                                                                       min_occurence=min_feature_occurence)

In [None]:
# save embeddings
save_featurized_spectra(featurized_spectral_data, feature_vector_index_map, failed_spectra_idxs, workdir, 
                        overwrite=False,
                        polarity=polarity)

## Load Featurized MS2 Data

In [None]:
# load embeddings
featurized_spectral_data, feature_vector_index_map, failed_spectra_idxs = load_featurized_spectra(workdir, polarity=polarity)

## Load Substructure Definitions

In [None]:
# substructures were generated using group selfies
with open(f'{workdir}/{polarity}_grammar_fragment.pkl', 'rb') as handle:
    grammar_fragment = pickle.load(handle)

In [None]:
def encode_substructure(smiles):
    mol = AllChem.MolFromSmiles(smiles)
    
    groups = grammar_fragment.extract_groups(mol)
    encoded_selfie = grammar_fragment.encoder(mol, groups)

    return encoded_selfie

In [None]:
merged_lib['frag_encoded_selfie'] = merged_lib.smiles.progress_apply(encode_substructure)

In [None]:
merged_lib[['smiles', 'frag_encoded_selfie']].to_csv(f'{workdir}/{polarity}_frag_encoded_selfies.csv')

In [None]:
# load pre encoded selfies
encoded_selfies = pd.read_csv(f'{workdir}/{polarity}_frag_encoded_selfies.csv')
merged_lib = pd.concat([merged_lib, encoded_selfies.frag_encoded_selfie], axis=1)

## Train Decision Trees to Predict Substructures

In [None]:
min_positive_unique = 10  # minimum number of unique structures for training
min_frag_count = 1 # minimum number of substrucutres present in molecule to be labeled a positive sample

frags = list(grammar_fragment.vocab.keys())

In [None]:
# remove spectra which failed vectorization
filtered_spectra_data, filtered_merged_lib = filter_failed_idxs(featurized_spectral_data, merged_lib, failed_spectra_idxs)

In [None]:
# train model for each group selfie frag
# for single core:
for frag in frags:
    train_substructure_tree(frag, merged_lib, featurized_spectral_data, workdir, polarity,
                            frag_type='group_selfies',
                            max_depth=3,
                            min_frag_count=1,
                            min_positive_unique=10,
                            save_model=True)

In [None]:
# for multicore:
from concurrent.futures import ProcessPoolExecutor

# simple wrapper function
def multicore_substructure_tree(frag):
    train_substructure_tree(
        frag, 
        merged_lib, 
        featurized_spectral_data, 
        workdir, 
        polarity,
        frag_type='group_selfies',
        max_depth=3,
        min_frag_count=1,
        min_positive_unique=10,
        save_model=True
    )

with ProcessPoolExecutor(max_workers=16) as executor:
    executor.map(multicore_substructure_tree, frags)

In [None]:
# evaluate the success of the model training
model_reports = []
for frag in frags:
    if os.path.isfile(f"{workdir}/models/{polarity}_{frag}_report.json"):
        with open(f"{workdir}/models/{polarity}_{frag}_report.json", "r") as f:
            json_report = json.load(f)
            model_reports.append({'frag': frag, 'false_precision': json_report['False']['precision'], 'false_recall': json_report['False']['recall'], 'false_f1': json_report['False']['f1-score'],
                                'true_precision': json_report['True']['precision'], 'true_recall': json_report['True']['recall'], 'true_f1': json_report['True']['f1-score'],
                                'macro_avg_f1': json_report['macro avg']['f1-score'], 'pos_smiles': json_report['pos_smiles']})

In [None]:
model_reports = pd.DataFrame(model_reports)
model_reports.sort_values('true_precision', ascending=False).head(5)

## Convert Decision Trees into MassQL Queries

### get best performing models based on use case
- high precision, low recall means few hits with high confidence
- low precision, high recall means lots of hits with low confidence
- depending on use case, either could be preferred. Note, this is a simplified explanation, but should hopefully demonstrate the general idea.

In [None]:
ms2_tol_ppm = 5

top_frags = model_reports.sort_values('true_precision', ascending=False).head(5).frag.tolist()

In [None]:
import joblib

feature_labels = list(feature_vector_index_map.keys())

massql_queries = dict()
for frag in frags:
    if frag not in top_frags:
        continue

    frag_model_path = f"{workdir}/models/{polarity}_{frag}_model.joblib"
    if os.path.isfile(frag_model_path):
        frag_model = joblib.load(frag_model_path)

    massql_queries[frag] = convert_tree_to_massql(frag_model, feature_labels, tolerance=ms2_tol_ppm)

In [None]:
# example query for frag24
display(grammar_fragment.vocab['frag24'].mol)

massql_queries['frag24'].split(' ||| ')[0]

## Evaluate and Refine MassQL Queries

In [None]:
from massql import msql_engine
from massql import msql_fileloading

In [None]:
gnps_json_lib_path = f'{workdir}/{polarity}_merged_lib.json'

In [None]:
def numpy_spec_to_json(spec):
    peak_pairs = list(zip(spec[0], spec[1]))
    peak_pairs = [(float(mz), int(i)) for mz, i in peak_pairs]

    json_peak_pairs = json.dumps(peak_pairs)
    return json_peak_pairs

if not os.path.isfile(gnps_json_lib_path):
    merged_lib['peaks_json'] = merged_lib.spectrum.apply(numpy_spec_to_json)
    merged_lib.rename(columns={'precursor_mz': 'Precursor_MZ'})[['spectrum_id', 'Precursor_MZ', 'peaks_json']].to_json(gnps_json_lib_path, orient='records')

In [None]:
# Loading Data
ms1_df, ms2_df = msql_fileloading.load_data(gnps_json_lib_path)

if polarity == 'negative':
    ms1_df['polarity'] = 2
    ms2_df['polarity'] = 2

In [None]:
# testing queries
test_query = massql_queries['frag24'].split(' ||| ')[0]

results_df = msql_engine.process_query(test_query, gnps_json_lib_path, ms1_df=ms1_df, ms2_df=ms2_df)

In [None]:
unique_matched_smiles = set(merged_lib[merged_lib.spectrum_id.isin(results_df.scan)].smiles.tolist())

max_i = 40
for i, smiles in enumerate(unique_matched_smiles):
    print('----------------------------------------')
    mol = AllChem.MolFromSmiles(smiles)

    display(mol)
    if i >= max_i:
        break