# Forward model training

In this notebook we'll train the forward models NEIMS and RASSP that we use for fenerating syntetic datasets. In order to get through this notebook you first need to have the **python environments** set up (find more in README or the previous notebook). You also need the **NIST library** in the form of a .msp file. 

In [1]:
import sys 
sys.path.append('..')

## NIST splitting
**TODO**
this part is done in the 3rd notebook.. Organize it better i guess.

In [2]:
nist_train_path = "../data/nist/train.msp"
nist_test_path = "../data/nist/test.msp"
nist_valid_path = "../data/nist/valid.msp"

## NEIMS


### Conversion from MSP to SDF
First we need to convert the NIST library to SDF format that is supported by the NEIMS codebase. 

In [14]:
# takes approximately 15 mins to run
from utils.spectra_process_utils import msp2sdf

# msp2sdf(nist_test_path)
# msp2sdf(nist_valid_path)
msp2sdf(nist_train_path)

In [15]:
# change the format of spectral information to the one expected by NEIMS
# this function might differ in your case depending on the metadata format of your copy of NIST library
# for us it took about 4 mins to run

# load sdf files
from rdkit import Chem
from rdkit.Chem import Descriptors
from tqdm import tqdm
from pathlib import Path
import ast

def neims_mol_filter(mol,
                     max_atoms=100,
                     max_mass_peak_loc=1000,
                     filter_max_mass_charge_peak_weight_cutoff=3.0):
    """
    This function was altered from the NEIMS codebase to filter molecules.
    We didn't use the original function for a lack of environment compatibility.

    The filtering thresholds are set based on the NEIMS default values.
    """
    if mol is None:
        return False
    elif max_atoms is not None and mol.GetNumAtoms() > max_atoms:
        return False
    elif not mol.GetProp("MASS SPECTRAL PEAKS"):
        return False
    elif not mol.GetProp("smiles"):
        return False
    max_peak = float(mol.GetProp("MASS SPECTRAL PEAKS").split("\n")[-2].split(" ")[0])
    if max_peak > max_mass_peak_loc:
        return False
    if max_peak / mol.GetDoubleProp("EXACT MASS") > filter_max_mass_charge_peak_weight_cutoff:
        return False
    
    return True


def transform_to_neims_sdf_format(input_path, output_path):

    nist_train_sdf = Chem.SDMolSupplier(input_path)
    writer = Chem.SDWriter(output_path)
    num_filtered = 0

    for mol in tqdm(nist_train_sdf):
        old_format_peaks = mol.GetProp("peaks_json")
        old_format_peaks = ast.literal_eval(old_format_peaks)
        new_format_peaks = "".join([f"{round(mz)} {round(i)}\n" for (mz, i) in old_format_peaks])
        exact_mass = Descriptors.ExactMolWt(mol)
        inchikey = mol.GetProp("inchikey")
        if not inchikey:
            print("WARNING: No inchikey found for", mol.GetProp("iupac_name"))
            inchikey = Chem.inchi.MolToInchiKey(mol)

        mol.SetProp("MASS SPECTRAL PEAKS", new_format_peaks)
        mol.SetDoubleProp("EXACT MASS", exact_mass)
        mol.SetProp("INCHIKEY", inchikey)
        mol.SetProp("NAME", mol.GetProp("iupac_name"))

        mol.ClearProp("peaks_json")
        mol.ClearProp("inchikey")
        mol.ClearProp("iupac_name")

        if neims_mol_filter(mol): 
            writer.write(mol)
        else:
            num_filtered += 1

output_train = "../data/nist/neims_training_data/train_neims.sdf"
output_test = "../data/nist/neims_training_data/test_neims.sdf"
output_valid = "../data/nist/neims_training_data/valid_neims.sdf"

Path(output_train).parent.mkdir(parents=True, exist_ok=True)

# run the transformation
transform_to_neims_sdf_format(nist_train_path.replace(".msp", ".sdf"), 
                              output_train)

transform_to_neims_sdf_format(nist_test_path.replace(".msp", ".sdf"),
                              output_test)

transform_to_neims_sdf_format(nist_valid_path.replace(".msp", ".sdf"),
                              output_valid)

100%|██████████| 232025/232025 [03:14<00:00, 1193.42it/s]
100%|██████████| 29218/29218 [00:24<00:00, 1190.00it/s]
100%|██████████| 29053/29053 [00:24<00:00, 1195.80it/s]


### Extract replicates
We'll extract the replicates from the NIST library, creating two sets: one for training NEIMS model (mainlib) and one for the database retrieval testing scenario (replicates). This took about 3 mins.

In [19]:
import pandas as pd
from tqdm import tqdm 
from rdkit.Chem import PandasTools
from utils.spectra_process_utils import smiles_to_inchikey

tqdm.pandas()

def extract_replicates(path_sdf):
    """This function extracts replicates from the main library and saves them in a separate file.
    The replicates (repeating molecules) stay in one version in the main library and all other
    versions are moved to the replicates file.

    Args:
        path_sdf (str): path to the main library sdf file
    
    Returns:
        None
    """

    df = PandasTools.LoadSDF(path_sdf)

    # check all mols have inchikeys
    df['inchikey'] = df.progress_apply(lambda row: smiles_to_inchikey(row['smiles']) if pd.isna(row['INCHIKEY']) else row['INCHIKEY'], axis=1)
    unique_df = df.drop_duplicates(subset=['INCHIKEY'], keep='first')
    replicates_df = df[~df.index.isin(unique_df.index)]
    PandasTools.AddMoleculeColumnToFrame(replicates_df, smilesCol='smiles', molCol='ROMol')
    PandasTools.AddMoleculeColumnToFrame(unique_df, smilesCol='smiles', molCol='ROMol')
    
    path_mainlib_sdf = path_sdf.replace(".sdf", "_main.sdf")
    path_replicates_sdf = path_sdf.replace(".sdf", "_replicates.sdf")

    PandasTools.WriteSDF(unique_df, path_mainlib_sdf, properties=list(
        unique_df.columns))
    PandasTools.WriteSDF(replicates_df, path_replicates_sdf, properties=list(
        replicates_df.columns))

    
extract_replicates(output_train)


100%|██████████| 231983/231983 [00:01<00:00, 132650.67it/s]
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  frame[molCol] = frame[smilesCol].map(Chem.MolFromSmiles)


In [20]:
# now we can delete the original intermediate sdf files and keep only the NEIMS formatted ones
Path("../data/nist/train.sdf").unlink()
Path("../data/nist/test.sdf").unlink()
Path("../data/nist/valid.sdf").unlink()

### Data split and preprocessing
Now we can use a NEIMS script to split the data and preprocess it. Run the following bash command to do so:


```bash
cd deep-molecular-massspec
TARGET_PATH_NAME=tmp/massspec_predictions
conda activate NEIMSpy3

python make_train_test_split.py --main_sdf_name=../data/nist/neims_training_data/train_neims_main.sdf \
                                --replicates_sdf_name=../data/nist/neims_training_data/train_neims_replicates.sdf \
                                --output_master_dir=$TARGET_PATH_NAME/spectra_tf_records
```

### Model training

The preprocessed data is now ready for training. We'll use the following command to train the model:

```bash
python molecule_estimator.py --dataset_config_file=$TARGET_PATH_NAME/spectra_tf_records/query_replicates_val_predicted_replicates_val.json \
                             --train_steps=100000 \
                             --model_dir=$TARGET_PATH_NAME/models/output \
                             --hparams=make_spectra_plots=True,batch_size=100 \
                             --alsologtostderr

```


## RASSP

**TODO: potrebujeme .jsonl ze splitu, ale tady je jeste nemame.**

For the purpose of this work, we keep a [dedicated branch](https://github.com/ljocha/rassp-public/tree/ljocha) of our [fork](https://github.com/ljocha/rassp-public/) of the original RASSP git repository. Currently the code itself is not modified, 

RASSP requires quite specific environment (ancient Python and PyTorch versions) to run smoothly, therefore we provide a Docker container (cerit.io/ljocha/rassp:nvidia-2024-1) with everything prepared.
Dockerfile we used to generate it is available in the repository (see above).

### Filtering NIST for RASSP

This step follows splitting and cleanup of the NIST, we expect the train and valid sets in the .jsonl format here (test set is not used).

RASSP, in its published implementation, is restricted to molecules of at most 48 atoms (including hydrogens), containing H, C, N, O, F, P, S, Cl elements only, and expanding to not more than 4096 combinatorial subformulae (see https://doi.org/10.1021/acs.analchem.2c02093 for details). Filtering the NIST records and transformation to a parquet format required by RASSP training is done with:

```bash
python prepare_rassp_train.py train.jsonl train-rassp-small.pq
python prepare_rassp_train.py valid.jsonl valid-rassp-small.pq
```

The produced .pq files should be approx. 44 and 5.5 MB.

### Training RASSP

Training RASSP requires a GPU with at least 8 GB memory. Calling RASSP is wrapped in the `run_train_rassp.py` script; you may want to adjust the path to the repository clone, number of OMP threads and number of dataloader workers eventually. Than, the training in the docker container and with suitable settings for GPU is run with

```bash
make run-train-nvidia
```

With common GPUs (eg. Nvidia GeForce RTX 2080) expect approx. 2 minutes per epoch. Progress of training can be observed in the `validate-small.ipynb` notebook (executed with `make run-nvidia` in the same directory). Dot product of the predicted spectra wrt. ground truth on the validation set reached almost plateau at approx. 0.85 ±0.12 at epoch #400 in our experiments while it kept improving up to 3000 epochs, which turns into several days of training.

Output of the training is stored in `checkpoints/` directory. Based on the validation choose the best one and grab the corresponding `.meta` and `.model` files for the following prediction step.