# Create overlap-only predictions
We have two sets of predictions for MONA_GCMS dataset - with overlaps with NIST/NEIMS_gen and without. This notebook creates predictions for the overlap-only (with NIST train) cases by filtering the predictions with overlaps.

The whole process is a bit more complicated bcs of datapipe filtering and missing GT info in predictions, so it might not make sense to you... 

However, it will save A LOT of compute. If you want, just burn this.

In [2]:
from matchms.importing import load_from_msp
from tqdm import tqdm
import pandas as pd
import asyncio
import nest_asyncio
from rdkit import Chem
from pathlib import Path

import sys
sys.path.append('../')
from utils.data_utils import filter_datapoints, build_single_datapipe

In [3]:
DATA_ROOT_FOLDER = "../clean_paper/data/"

In [17]:
import json

with_overlap_path = DATA_ROOT_FOLDER + "extra_libraries/MONA_GCMS/MONA_GCMS_overlaps_included.jsonl"
nist_train_path = DATA_ROOT_FOLDER + "nist/train.jsonl"

filtering_args = {
        "max_num_peaks": 300,
        "max_mz": 500,
        "max_mol_repr_len": 100,
        "mol_repr": "smiles",
        "log_base": 1.28,
        "log_shift": 29,
        "inference_mode": True,
        "keep_all_columns": True,
    }
def filter_dataset(dataset_path, filtering_args=filtering_args):
    filtered_dataset = []
    with open(dataset_path, "r") as old_f:
        for line in tqdm(old_f):
            datapoint = json.loads(line)
            if filter_datapoints(datapoint, filtering_args):
                filtered_dataset.append(datapoint)
    return filtered_dataset

In [19]:
nist_smiles_set = set(pd.read_json(nist_train_path, lines=True)["smiles"])
with_overlaps_filtered_dicts = filter_dataset(with_overlap_path)

18464it [00:03, 5119.91it/s]


In [20]:
def create_overlap_only_mask(nist_smiles_set, with_overlaps_filtered_dicts):
    """Creates a mask that can be applied on predictions.jsonl of MONA_GCMS_overlaps_included.jsonl
    to get only the datapoints that are in MONA_GCMS_overlaps_included.jsonl and also in NIST train.jsonl.
    Thus, the mask will only keep the datapoints that are in the overlap of MONA and NIST train.
    """
    with_overlaps_filtered_smiles_list = [datapoint["smiles"] for datapoint in with_overlaps_filtered_dicts]

    overlap_only_mask = pd.Series(with_overlaps_filtered_smiles_list).isin(nist_smiles_set)

    return overlap_only_mask

overlap_only_mask = create_overlap_only_mask(nist_smiles_set, with_overlaps_filtered_dicts)

In [21]:
sum(overlap_only_mask)

12758

In [22]:
from pathlib import Path
import yaml

def apply_mask_to_predicitons(overlap_only_mask, predictions_path, output_path):
    """Applies the mask on predictions.jsonl of MONA_GCMS_overlaps_included.jsonl
    to get only the datapoints that are in MONA_GCMS_overlaps_included.jsonl and also in NIST train.jsonl.

    Writes the filtered predictions to output_path.
    """
    Path(output_path).parent.mkdir(parents=True, exist_ok=True)
    predictions_df = pd.read_csv(predictions_path, sep="\t", header=None, names=["prediction"])
    overlap_only_predictions = predictions_df[overlap_only_mask]

    with open(output_path, "w") as f:
        for prediction in overlap_only_predictions["prediction"]:
            f.write(prediction + "\n")

def copy_logfile(logfile_folder, output_folder):
    """Copies the logfile from the predictions folder to the output folder.
    It nly changes dataset name and query data entries so it holds for the new predictions.
    """
    logfile_path = Path(logfile_folder) / "log_file.yaml"
    output_logfile_path = Path(output_folder) / "log_file.yaml"

    with open(logfile_path, "r", encoding="utf-8") as f:
        old_logs = yaml.safe_load(f)

    old_logs["dataset"]["dataset_name"] = "MONA_GCMS_overlaps_only"
    old_logs["dataset"]["query_data"] = "data/extra_libraries/MONA_GCMS/MONA_GCMS_overlaps_only.jsonl"

    i = 0
    while old_logs.get(f"evaluation_{i}"):
        old_logs.pop(f"evaluation_{i}")
        i += 1

    with open(output_logfile_path, "w", encoding="utf-8") as f:
        yaml.dump(old_logs, f)


def create_new_predictions(overlap_only_mask, predictions_folder, output_folder):
    apply_mask_to_predicitons(overlap_only_mask,
                              Path(predictions_folder) / "predictions.jsonl",
                              Path(output_folder) / "predictions.jsonl")
    copy_logfile(predictions_folder, output_folder)

In [27]:
# ./clean_paper/predictions/balmy-violet-577_exp8_224_148/MONA_GCMS_overlaps_included/1730907164_all_full_beam10/predictions.jsonl
# ./clean_paper/predictions/balmy-violet-577_exp8_224_148/MONA_GCMS_overlaps_included/1730907162_all_full_beam50/predictions.jsonl
# ./clean_paper/predictions/balmy-violet-577_exp8_224_148/MONA_GCMS_overlaps_included/1730907164_all_full_greedy/predictions.jsonl
# ./clean_paper/predictions/db_search_morgan_tanimoto/MONA_GCMS_overlaps_included/1730984285_a~ll_full_1cand/predictions.jsonl
# ./clean_paper/predictions/db_search_morgan_tanimoto/MONA_GCMS_overlaps_included/1730916155_all_full_50cand/predictions.jsonl
# ./clean_paper/predictions/db_search_morgan_tanimoto/MONA_GCMS_overlaps_included/1730983692_all_full_10cand/predictions.jsonl
# ./clean_paper/predictions/db_search_sss/MONA_GCMS_overlaps_included/1730984285_all_full_1cand/predictions.jsonl
# ./clean_paper/predictions/db_search_sss/MONA_GCMS_overlaps_included/1730916155_all_full_50cand/predictions.jsonl
# ./clean_paper/predictions/db_search_sss/MONA_GCMS_overlaps_included/1730981471_all_full_10cand/predictions.jsonl
# ./clean_paper/predictions/db_search_hss/MONA_GCMS_overlaps_included/1730984285_all_full_1cand/predictions.jsonl
# ./clean_paper/predictions/db_search_hss/MONA_GCMS_overlaps_included/1730916155_all_full_50cand/predictions.jsonl
# ./clean_paper/predictions/db_search_hss/MONA_GCMS_overlaps_included/1730981471_all_full_10cand/predictions.jsonl

predictions_folders = ["../clean_paper/predictions/balmy-violet-577_exp8_224_148/MONA_GCMS_overlaps_included/1730907164_all_full_beam10/",
                       "../clean_paper/predictions/balmy-violet-577_exp8_224_148/MONA_GCMS_overlaps_included/1730907162_all_full_beam50/",
                       "../clean_paper/predictions/balmy-violet-577_exp8_224_148/MONA_GCMS_overlaps_included/1730907164_all_full_greedy/",
                       "../clean_paper/predictions/db_search_morgan_tanimoto/MONA_GCMS_overlaps_included/1730984285_all_full_1cand/",
                       "../clean_paper/predictions/db_search_morgan_tanimoto/MONA_GCMS_overlaps_included/1730916155_all_full_50cand/",
                       "../clean_paper/predictions/db_search_morgan_tanimoto/MONA_GCMS_overlaps_included/1730983692_all_full_10cand/",
                       "../clean_paper/predictions/db_search_sss/MONA_GCMS_overlaps_included/1730984285_all_full_1cand/",
                       "../clean_paper/predictions/db_search_sss/MONA_GCMS_overlaps_included/1730916155_all_full_50cand/",
                       "../clean_paper/predictions/db_search_sss/MONA_GCMS_overlaps_included/1730981471_all_full_10cand/",
                       "../clean_paper/predictions/db_search_hss/MONA_GCMS_overlaps_included/1730984285_all_full_1cand/",
                       "../clean_paper/predictions/db_search_hss/MONA_GCMS_overlaps_included/1730916155_all_full_50cand/",
                       "../clean_paper/predictions/db_search_hss/MONA_GCMS_overlaps_included/1730981471_all_full_10cand/"]

output_folders = ["../clean_paper/predictions/balmy-violet-577_exp8_224_148/MONA_GCMS_overlaps_only/1730907164_all_full_beam10/",
                  "../clean_paper/predictions/balmy-violet-577_exp8_224_148/MONA_GCMS_overlaps_only/1730907162_all_full_beam50/",
                  "../clean_paper/predictions/balmy-violet-577_exp8_224_148/MONA_GCMS_overlaps_only/1730907164_all_full_greedy/",
                  "../clean_paper/predictions/db_search_morgan_tanimoto/MONA_GCMS_overlaps_only/1730984285_all_full_1cand/",
                  "../clean_paper/predictions/db_search_morgan_tanimoto/MONA_GCMS_overlaps_only/1730916155_all_full_50cand/",
                  "../clean_paper/predictions/db_search_morgan_tanimoto/MONA_GCMS_overlaps_only/1730983692_all_full_10cand/",
                  "../clean_paper/predictions/db_search_sss/MONA_GCMS_overlaps_only/1730984285_all_full_1cand/",
                  "../clean_paper/predictions/db_search_sss/MONA_GCMS_overlaps_only/1730916155_all_full_50cand/",
                  "../clean_paper/predictions/db_search_sss/MONA_GCMS_overlaps_only/1730981471_all_full_10cand/",
                  "../clean_paper/predictions/db_search_hss/MONA_GCMS_overlaps_only/1730984285_all_full_1cand/",
                  "../clean_paper/predictions/db_search_hss/MONA_GCMS_overlaps_only/1730916155_all_full_50cand/",
                  "../clean_paper/predictions/db_search_hss/MONA_GCMS_overlaps_only/1730981471_all_full_10cand/"]

for predictions_folder, output_folder in zip(predictions_folders, output_folders):
    create_new_predictions(overlap_only_mask, predictions_folder, output_folder)

## Create filtered dataset
In a similar manner create a dataset with only overlap cases. This will be used to evaluate the overlap-only predictions.

Leave there only th filtered labels, so we KNOW FOR SURE the number of labels is the same as the number of predictions.   

In [9]:
nist_nonfiltered_dicts = filter_dataset(nist_train_path, filtering_args=None)
with_overlaps_nonfiltered_dicts = filter_dataset(with_overlap_path, filtering_args=None)

In [10]:
old_data_path = DATA_ROOT_FOLDER + "extra_libraries/MONA_GCMS/MONA_GCMS_overlaps_included.jsonl"
new_data_path = DATA_ROOT_FOLDER + "extra_libraries/MONA_GCMS/MONA_GCMS_overlaps_only.jsonl"

In [15]:
nist_train_df = pd.read_json(nist_train_path, lines=True)
nist_train_smiles_set = set(nist_train_df["smiles"])

In [24]:
import json
## filter the data
counter = 0
with open(old_data_path, "r") as old_f, open(new_data_path, "w") as filtred_f:
    for line in tqdm(old_f):
        datapoint = json.loads(line)
        if filter_datapoints(datapoint, filtering_args) and datapoint["smiles"] in nist_train_smiles_set:
            filtred_f.write(line)
            counter += 1

18464it [00:03, 5008.02it/s]


In [26]:
len(nist_nonfiltered_dicts), len(with_overlaps_filtered_dicts)

(232025, 17812)