# Datamodule code

This loads in the data

In [None]:
from ms_uq.data import RetrievalDataset_PrecompFPandInchi
from massspecgym.data.data_module import MassSpecDataModule
from massspecgym.data.transforms import MolFingerprinter, SpecBinner
import os

helper_files_dir = "/data/home/gaetandw/msms/data/"


dataset = RetrievalDataset_PrecompFPandInchi(
    spec_transform=SpecBinner(max_mz = 1005, bin_width=0.1, to_rel_intensities=True),
    mol_transform=MolFingerprinter(fp_size=4096),
    pth="/data/home/gaetandw/msms/data/MassSpecGym.tsv",
    fp_pth=os.path.join(helper_files_dir, "fp_4096.npy"),
    inchi_pth=os.path.join(helper_files_dir, "inchis.npy"),
    candidates_pth=os.path.join(helper_files_dir, "MassSpecGym_retrieval_candidates_formula.json"),
    candidates_fp_pth=os.path.join(helper_files_dir, "MassSpecGym_retrieval_candidates_formula_fps.npz"),
    candidates_inchi_pth=os.path.join(helper_files_dir, "MassSpecGym_retrieval_candidates_formula_inchi.npz"),
)

data_module = MassSpecDataModule(
    dataset=dataset,
    batch_size=32,
    num_workers=8,
)

data_module.setup()

This only loads the training and validation data.

If you want to use the test set: also call:

`data_module.setup("test")`

Example of a batch:

In [None]:
batch = next(iter(data_module.train_dataloader()))
list(batch)

Input spectra:

In [None]:
batch["spec"], batch["spec"].shape

Output fingerprint bitvectors:

In [None]:
batch["mol"], batch["mol"].shape

All candidates in the batch:

In [None]:
batch["candidates"], batch["candidates"].shape

These are by default flattened to a single list of all candidates for all spectra in the batch.

To get separate candidate and label lists per spectrum (input):

In [None]:
from torch_geometric.utils import unbatch
import massspecgym.utils as utils

indexes = utils.batch_ptr_to_batch_idx(batch["batch_ptr"])
labels_per_spectrum = unbatch(batch["labels"], indexes)
cands_per_spectrum = unbatch(batch["candidates"], indexes)

cands_per_spectrum[0], cands_per_spectrum[0].shape

In [None]:
labels_per_spectrum[0], labels_per_spectrum[0].shape

# Model

this is an example for the biencoder model

In [None]:
from ms_uq.models import FingerprintPredicter, batch_samplewise_tanimoto
import torch
import torch.nn as nn
model = FingerprintPredicter(
    n_in = int(1005/0.1),  # number of bins
    layer_dims = [512, 512, 512],  # hidden layer sizes
    n_bits = 4096,  # fingerprint size
    layer_or_batchnorm = "layer",
    dropout=0.25,
    lr=1e-5,
    weight_decay=0,
    df_test_path=None,
    bitwise_loss = None, # "bce", "fl"
    fpwise_loss = None, # "cossim", "iou"
    rankwise_loss = "bienc", # "bienc", "cross"
    rankwise_lambd = 1, # loss weighting
    rankwise_kwargs = {
        "temp" : 1.0,
        "n_bits" : 4096,
        "dropout" : 0.0,
        "sim_func" : "cossim",
        "projector": False,
    },
)

If you have a model checkpoint, do this:

In [None]:
model = FingerprintPredicter.load_from_checkpoint("/your/path/here.ckpt")

Training a model from scratch can also be done with `train_retriever.py` e.g.:


```bash
python .../ms_mole/train_retriever.py \
    .../data/MassSpecGym.tsv \
    .../data/ \
    /path/to/logs/ \
    --bonus_challenge False \
    --batch_size 128 \
    --devices [0] \
    --precision 32-true \
    --lr 0.0001 \
    --bitwise_loss None --fpwise_loss None --rankwise_loss bienc \
    --bitwise_lambd 0.0 --fpwise_lambd 0.0 --rankwise_lambd 1.0 \
    --rankwise_temp 0.1 --rankwise_dropout 0.25 --rankwise_sim_func cossim
```

Testing a model on a single forward pass:

In [None]:
x = batch["spec"]
fp_true = batch["mol"]
cands = batch["candidates"].int()
batch_ptr = batch["batch_ptr"]

model.eval()

with torch.no_grad():

    # Predict fingerprint
    embedding = model(x)
    fp_pred = torch.sigmoid(model.loss.fp_pred_head(embedding))

# average tanimoto to true fp
tanimotos = batch_samplewise_tanimoto(fp_pred, fp_true, reduce=False)

In [None]:
tanimotos, tanimotos.mean()

Get scores (for cross-enc, this uses the MLP model)

In [None]:

fp_pred_repeated = fp_pred.repeat_interleave(batch_ptr, dim=0)
with torch.no_grad():
    scores = model.loss.ranker(fp_pred_repeated, cands)

Eval retrieval scores for a batch

In [None]:
from torch_geometric.utils import unbatch
import massspecgym.utils as utils
from torchmetrics.functional.retrieval import retrieval_hit_rate

def retrieve(scores, batch_ptr, labels, top_k=20):
    # this makes it so that if all scores are equal, random retrieval acc is obtained
    some_noise = torch.randn_like(scores) * torch.finfo(scores.dtype).eps
    scores_w_noise = scores + some_noise

    indexes = utils.batch_ptr_to_batch_idx(batch_ptr)
    scores = unbatch(scores_w_noise, indexes)
    labels = unbatch(labels, indexes)

    hit_rates = []
    for scores_sample, labels_sample in zip(scores, labels):
        hit_rates.append(retrieval_hit_rate(scores_sample, labels_sample, top_k=top_k))
    hit_rates = torch.tensor(hit_rates, device=batch_ptr.device)
    return hit_rates

In [None]:
retrieve(scores, batch_ptr, batch["labels"]), retrieve(scores, batch_ptr, batch["labels"]).mean()