## Dataset object

If you're on deeptought, you can keep these paths

In [1]:
from ms_uq.data import RetrievalDataset_PrecompFPandInchi
from massspecgym.data.data_module import MassSpecDataModule
from massspecgym.data.transforms import MolFingerprinter, SpecBinner
import os

helper_files_dir = "/data/home/gaetandw/msms/data/"
dataset = RetrievalDataset_PrecompFPandInchi(
    spec_transform=SpecBinner(max_mz = 1005, bin_width=0.1, to_rel_intensities=True),
    mol_transform=MolFingerprinter(fp_size=4096),
    pth="/data/home/gaetandw/msms/data/MassSpecGym.tsv",
    fp_pth=os.path.join(helper_files_dir, "fp_4096.npy"),
    inchi_pth=os.path.join(helper_files_dir, "inchis.npy"),
    candidates_pth=os.path.join(helper_files_dir, "MassSpecGym_retrieval_candidates_formula.json"),
    candidates_fp_pth=os.path.join(helper_files_dir, "MassSpecGym_retrieval_candidates_formula_fps.npz"),
    candidates_inchi_pth=os.path.join(helper_files_dir, "MassSpecGym_retrieval_candidates_formula_inchi.npz"),
)

data_module = MassSpecDataModule(
    dataset=dataset,
    batch_size=32,
    num_workers=8,
)

data_module.setup()

Example of a batch:

In [4]:
batch = next(iter(data_module.train_dataloader()))
list(batch)

['spec',
 'mol',
 'precursor_mz',
 'adduct',
 'mol_freq',
 'identifier',
 'smiles',
 'candidates',
 'labels',
 'batch_ptr',
 'candidates_smiles']

In [8]:
batch["spec"], batch["spec"].shape

(tensor([[0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         ...,
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.]]),
 torch.Size([32, 10050]))

In [7]:
batch["mol"], batch["mol"].shape

(tensor([[0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         ...,
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.]]),
 torch.Size([32, 4096]))

In [9]:
batch["candidates"], batch["candidates"].shape

(tensor([[0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         ...,
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.]]),
 torch.Size([7347, 4096]))

In [10]:
batch["labels"], batch["labels"].shape

(tensor([ True, False, False,  ..., False, False, False]), torch.Size([7347]))

In [11]:
batch["batch_ptr"], batch["batch_ptr"].shape

(tensor([256, 153, 256, 256, 256, 256, 256, 256, 104, 256, 256, 149, 256, 166,
         256,  65, 256, 256, 256, 256, 217,  93, 256, 256, 256, 256, 256, 256,
         256, 256, 256, 256]),
 torch.Size([32]))

## Model

In [35]:
from ms_uq.models import FingerprintPredicter, batch_samplewise_tanimoto
import torch
import torch.nn as nn
model = FingerprintPredicter(
        n_in = int(1005/0.1),  # number of bins
        layer_dims = [512, 512, 512],  # hidden layer sizes
        n_bits = 4096,  # fingerprint size
        layer_or_batchnorm = "layer",
        dropout=0.25,
        lr=1e-5,
        weight_decay=0,
        df_test_path=None,
        bitwise_loss = None, # "bce", "fl"
        fpwise_loss = None, # "cossim", "iou"
        rankwise_loss = "bienc", # "bienc", "cross"
        bitwise_lambd = 0,
        fpwise_lambd = 0,
        rankwise_lambd = 1,
        rankwise_kwargs = {
            "temp" : 1.0,
            "n_bits" : 4096,
            "dropout" : 0.0,
            "sim_func" : "cossim",
            "projector": False,
        },
    )

Single forward pass:

In [36]:
x = batch["spec"]
fp_true = batch["mol"]
cands = batch["candidates"].int()
batch_ptr = batch["batch_ptr"]

# Predict fingerprint
embedding = model(x)
fp_pred = torch.sigmoid(model.loss.fp_pred_head(embedding))

# average tanimoto to true fp
tanimotos = batch_samplewise_tanimoto(fp_pred, fp_true, reduce=False)

In [37]:
tanimotos, tanimotos.mean()

(tensor([0.0039, 0.0189, 0.0034, 0.0087, 0.0089, 0.0113, 0.0087, 0.0057, 0.0043,
         0.0152, 0.0117, 0.0148, 0.0038, 0.0128, 0.0048, 0.0207, 0.0084, 0.0129,
         0.0131, 0.0092, 0.0096, 0.0120, 0.0063, 0.0132, 0.0154, 0.0176, 0.0139,
         0.0096, 0.0136, 0.0121, 0.0080, 0.0087]),
 tensor(0.0107))

In [38]:
# sim scores between true fp and candidates
fp_pred_repeated = fp_pred.repeat_interleave(batch_ptr, dim=0)
scores = nn.functional.cosine_similarity(fp_pred_repeated, cands)

In [39]:
from torch_geometric.utils import unbatch
import massspecgym.utils as utils
from torchmetrics.functional.retrieval import retrieval_hit_rate

def retrieve(scores, batch_ptr, labels, top_k=20):
    # this makes it so that if all scores are equal, random retrieval acc is obtained
    some_noise = torch.randn_like(scores) * torch.finfo(scores.dtype).eps
    scores_w_noise = scores + some_noise

    indexes = utils.batch_ptr_to_batch_idx(batch_ptr)
    scores = unbatch(scores_w_noise, indexes)
    labels = unbatch(labels, indexes)

    hit_rates = []
    for scores_sample, labels_sample in zip(scores, labels):
        hit_rates.append(retrieval_hit_rate(scores_sample, labels_sample, top_k=top_k))
    hit_rates = torch.tensor(hit_rates, device=batch_ptr.device)
    return hit_rates

In [40]:
retrieve(scores, batch_ptr, batch["labels"]), retrieve(scores, batch_ptr, batch["labels"]).mean()

(tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 0., 0.]),
 tensor(0.0938))