# Model evaluation

This notebook is a simple example of how to evaluate our model on our tet split of the NIST20 dataset.

We will show you how to perform **Simple evaluation**, **Evaluation in comparison with DB search** and **Visualization of predictions**.

## Simple evaluation
### Generate predictions

To generate predictions we need to specify a configuration file and a path to the trained model. The configuration file for NIST dataset is provided in the `config` directory as [predict_nist.yaml](../configs/predict_nist.yaml). 

The script [predict.py](../predict.py) is used to generate predictions. It takes the prepared `jsonl` file and outputs a `jsonl` file where each line is a `json` with keys being the generated SMILES strings and their values are candidates' probabilities generated by our model. We use these probabilities to sort candidates according to "the model's view".

Statistics and all metadata about the run are stored in `log_file.yaml` in the same directory as the predictions.

Running `predict.py` could look like this:

```bash
CUDA_VISIBLE_DEVICES=0 python ../predict.py --checkpoint ../checkpoints/finetune/fearless-wildflower-490_rassp1_neims1_224kPretrain_148k/checkpoint-147476 \
                                            --output-folder predictions \
                                            --config-file configs/predict_nist.yaml
```

### Evaluate predictions
Evaluation of predictions is done with the [evaluate_predictions.py](../evaluate_predictions.py) script. It takes a path to the predictions file, a path to the ground truth file and a path to a [config file](../configs/evaluate_nist.yaml) as input and appends all the evaluation metrics to the corresponding `log_file.yaml`. The script also generates several plots and saves them in the same directory as the predictions.

Running `evaluate_predictions.py` could look like this:

```bash
python evaluate_predictions.py --predictions-path {path-to-predictions.jsonl} \
                               --labels-path data/nist/train.jsonl \
                               --config-file configs/evaluate_nist.yaml
```


## Evaluation in comparison with database search
To find out how is the model doing against baseline you can also run the evluation in comparison with a standard database search in the de novo scenario. As a reference library we use the NIST train set with 232 025 experimentally measured spectra, and query library is the NIST test set. 

For this scenario to work we need to precompute the highest cosine similarity of the query spectra within the reference library and their respective SMILES (fingerprint) similarity. This fingerprint similarity is then compared to the similarity of our model's candidates.

### Precompute similarity index for NIST test

```bash
SPLIT_NAME=test
FP_TYPE=morgan
SIMIL_FUN=tanimoto

python precompute_db_index.py \
           --reference data/nist/train.jsonl \
           --query data/nist/test.jsonl \
           --outfile data/nist/test_with_db_index.jsonl \
           --num_processes 32 \
           --fingerprint_type ${FP_TYPE} \
           --fp_simil_function ${SIMIL_FUN}
```

### Evaluate including database search comparison
For this you have to set `do_db_search` in [evaluate_nist.yaml](../configs/evaluate_nist.yaml) to `True` and change the labels to the ones enriched with db index.

Example:
```bash
python evaluate_predictions.py --predictions-path {path-to-predicitons.jsonl} \
                               --labels-path data/nist/test_with_db_index.jsonl \
                               --config-file configs/evaluate_nist.yaml
```

## Comment on the satistics' interpretation
The statistics are stored in the `log_file.yaml` in the same directory as the predictions. Let's see what they mean:

`similsort` - predictions are sorted by the fingerprint similarity of the candidates to the query spectrum. It shows the upper bound on the model's performance when generating multiple candidates. This scenario can also mimic the situation with a domain expert that can pick the correct candiddate every time. 

`probsort` - predictions are sorted by the model's probabilities. It shows the model's performance when generating multiple candidates without any other information on their quality.


### Example
```yaml
evaluation_0:
    average_num_of_predictions: '9.52'            # mean number of UNIQUE VALID candidates generated per 
                                                  # query
    db_search:
        mean_db_score: '0.22'                     # mean fingerprint similarity of the best db candidate
        mean_fpsd_score_probsort: '0.05'          # mean difference between the best candidates of db and
                                                  # the model (probsort)
        mean_fpsd_score_similsort: '0.14'         # mean difference between the best candidates of db and
                                                  # the model (similsort)
        percentage_of_BART_wins_probsort: '0.63'  # percentage of queries where the model's best 
                                                  # candidate is better than the best db candidate 
                                                  # (probsort)
        percentage_of_BART_wins_similsort: '0.84' # percentage of queries where the model's best 
                                                  # candidate is better than the best db candidate 
                                                  # (similsort)
        percentage_of_ties_probsort: '0.14'       # percentage of queries where the model's best 
                                                  # candidate is equal to the best db candidate (probsort)
        percentage_of_ties_similsort: '0.06'      # percentage of queries where the model's best 
                                                  # candidate is equal to the best db candidate (similsort)
        ties:                                     # detailed statistics on ties
            mean_tie_simils_probsort: '0.38'
            mean_tie_simils_similsort: '0.60'
            num_of_ties_probsort: '35'
            num_of_ties_simils_equal_to_1_probsort: '4'  # number of ties with similarity equal to 1 (probsort)
            num_of_ties_simils_equal_to_1_similsort: '6' # number of ties with similarity equal to 1 (similsort)
            num_of_ties_similsort: '16'
            percentage_of_ties_simils_equal_to_1_probsort: '0.11'
            percentage_of_ties_simils_equal_to_1_similsort: '0.37'
    eval_config:                                 # log of configuration used for evaluation
        do_db_search: true
        filtering_args:
            max_mol_repr_len: 100
            max_mz: 500
            max_num_peaks: 300
            mol_repr: smiles
        fingerprint_type: morgan
        on_the_fly: true
        save_best_predictions: true
        fp_simil_function: tanimoto
        threshold: 0.85
    eval_time: 00:00:02
    formula_stats:                               # statistics on molecular formulas      
        num_all_correct_formulas: 154 / 2324        # from all generated candidates
        num_at_least_one_correct_formula: '38'      # at least one correct formula was generated for the query
        num_correct_formulas_at_best_prob: '24'     # correct formula was returned as the best candidate (probsort)
        num_correct_formulas_at_best_simil: '29'    # correct formula was returned as the best candidate (similsort)
        percentage_of_all_correct_formulas: '0.06'        
        percentage_of_at_least_one_correct_formula: '0.15'
        percentage_of_correct_formulas_at_best_prob: '0.09'
        percentage_of_correct_formulas_at_best_simil: '0.11'
    hit_at_k_prob: '[(1, 0.02), (2, 0.03), (3, 0.04),
        (4, 0.06), (5, 0.08)]'                      # percentage of correct formulas returned in top k 
                                                    # candidates (probsort) 
                                                    # note: this metric for similsort would not make sense
    labels_path: data/mace/MACE_r05_with_db_index.jsonl
    num_datapoints_tested: '244'
    num_empty_preds: '0'                            # number of queries with no valid candidates
    num_predictions_at_k_counter: '[244, 244, 243, 243, 242, 242, 241, 235, 222, 168]'
    precise_preds_stats:      # statistics on precise returned best candidates (exactly the same canonical SMILES)
        num_precise_preds_probsort: '4'
        num_precise_preds_similsort: '20'
        percentage_of_precise_preds_probsort: '0.01'
        percentage_of_precise_preds_similsort: '0.08'
    simil_1_hits:                              # statistics on returned best candidates with similarity 1
        counter_multiple_hits: dict_items([(3, 2), (1, 13), (2, 5)]) # monitors situations where more than one 
                                                                     # candidate has prob equal to 1. This can 
                                                                     # happen bcs of imperfection of fp similarity. 
                                                                     # Format: (num of hits, num of occurences)
        num_1_hits_as_first_probsort: '5'
        num_1_hits_as_first_similsort: '20'
        num_fp_simil_fail_prob: '1'                                  # number of queries where the best candidate 
                                                                     # has similarity 1 but is not the correct one
        num_fp_simil_fail_simil: '0'
        percentage_of_1_hits_as_first_probsort: '0.02'
        percentage_of_1_hits_as_first_similsort: '0.08'
    start_time_utc: 01/09/2024 11:19:01
    threshold_stats:            # statistics using a threshold (not backed by any theoretical
                                # or empirical reasoning, just to have a clue about "relatively good" candidates)
        num_better_than_threshold_probsort: '7'
        num_better_than_threshold_similsort: '22'
        percentage_of_better_than_threshold_probsort: '0.02'
        percentage_of_better_than_threshold_similsort: '0.09'
        threshold: '0.85'
    topk_probsort: '[0.28, 0.27, 0.27, 0.26, 0.26, 0.25, 0.24,
        0.23, 0.23, 0.22]'          # mean similarities on k-th position (probsort) 
    topk_similsort: '[0.37, 0.32, 0.29, 0.27, 0.25, 0.23, 0.22,
        0.19, 0.18, 0.16]'          # mean similarities on k-th position (similsort) 
```

## Visualize predictions
To see what the model's predictions look like we prepared a little visualization script.

In [155]:
import json
from pathlib import Path
from rdkit import Chem, DataStructs

from utils.eval_utils import load_labels_to_datapipe

# load labels and predictions
predictions_path = "path/to/predictions.jsonl" # CHANGE for the actual path
labels_path = "data/nist/test_with_db_index.jsonl" 


labels, _ = load_labels_to_datapipe(Path(labels_path))
labels = list(labels)
str_predictions = open(predictions_path).readlines()
dict_predictions = [json.loads(p) for p in str_predictions]
sorted_predictions = [sorted([k for k, _ in sorted(pred.items(), key=lambda x: x[1])]) for pred in dict_predictions]


In [None]:
pairs_to_viz = list(zip(labels, sorted_predictions))[142]  # TODO: change this to visualize different pairs
best_n_to_viz = 5                                          # TODO: change this to visualize more/less predictions

if isinstance(pairs_to_viz, tuple):
    pairs_to_viz = [pairs_to_viz]

for gt_smiles, preds in pairs_to_viz:
    print("\n##################")
    print("GT smiles:", gt_smiles)
    gt_mol = Chem.MolFromSmiles(gt_smiles)
    display(gt_mol)
    for i, pred_smiles in enumerate(preds[:best_n_to_viz]):
        pred_mol = Chem.MolFromSmiles(pred_smiles)
        print(f"Prediction {i}: {pred_smiles}, similarity: {DataStructs.FingerprintSimilarity(Chem.RDKFingerprint(gt_mol), Chem.RDKFingerprint(pred_mol))}")
        display(pred_mol)