In [27]:
import pandas as pd
from pathlib import Path
import warnings

from lightning import pytorch as pl
from numpy.typing import ArrayLike
import torch
from torch import Tensor
import torchmetrics
from torchmetrics.functional.classification.hinge import (
    _multiclass_hinge_loss_tensor_validation,
    _multiclass_confusion_matrix_format,
)
from torchmetrics.utilities.data import to_onehot

from chemprop import data, featurizers, models, nn

In Chemprop, loss functions and metrics are both instances of `chemprop.nn.metrics.ChempropMetric`, which inherits from `torchmetrics.Metric`. This notebook shows how to adapt loss functions and metrics from `torchmetrics` to work in Chemprop. Custom loss functions and metrics that are not available in `torchmetrics` can be created following the instructions provided on the `torchmetrics` website and then adapted to Chemprop by following the example below.

Set up dataset

In [29]:
chemprop_dir = Path.cwd().parent
input_path = chemprop_dir / "tests" / "data" / "classification" / "mol_multiclass.csv"
df_input = pd.read_csv(input_path)
smis = df_input.loc[:, "smiles"].values
ys = df_input.loc[:, ["activity"]].values
all_data = [data.MoleculeDatapoint.from_smi(smi, y) for smi, y in zip(smis, ys)]
train_indices, val_indices, test_indices = data.make_split_indices(all_data, "random", (0.8, 0.1, 0.1))
train_data, val_data, test_data = data.split_data_by_indices(
    all_data, train_indices, val_indices, test_indices
)
train_dset = data.MoleculeDataset(train_data)
val_dset = data.MoleculeDataset(val_data)
test_dset = data.MoleculeDataset(test_data)
train_loader = data.build_dataloader(train_dset)
val_loader = data.build_dataloader(val_dset, shuffle=False)
test_loader = data.build_dataloader(test_dset, shuffle=False)

`ChempropMetric`s optionally expand the capabilities of `torchmetrics.Metric`s by allowing for weighting tasks and data points in the loss function. Additionally, targets can be masked out to not be used in the calculation or marked as one sided (meaning predictions greater than or, alternatively, less than target won't be penalized). 

The `__init__` method of custom loss functions and metrics should accept `task_weights` as an argument. They may not be used in the actual calculation. The `update` should accept `preds, target, mask, weights, lt_mask, gt_mask`, even if mask, weights, lt_mask, and gt_mask are not used in the calculation. Greater than and less than masks only apply to regression tasks for example. 

`ChempropMetrics`s should also have an alias property. This is used when logging the metric values.

In [34]:
class ChempropMulticlassHingeLoss(torchmetrics.classification.MulticlassHingeLoss):
    def __init__(self, task_weights: ArrayLike = 1.0, **kwargs):
        super().__init__(**kwargs)
        self.task_weights = torch.as_tensor(task_weights, dtype=torch.float).view(1, -1)
        if (self.task_weights != 1.0).any():
            warnings.warn("task_weights were provided but are ignored by metric "
                          f"{self.__class__.__name__}. Got {task_weights}")

    def update(self, preds: Tensor, targets: Tensor, mask: Tensor | None = None, *args, **kwargs):
        if mask is None:
            mask = torch.ones_like(targets, dtype=torch.bool)

        super().update(preds[mask], targets[mask].long())

    @property
    def alias(self) -> str:
        return "hinge"

In [35]:
class ChempropMulticlassAUROC(torchmetrics.classification.MulticlassAUROC):
    def __init__(self, task_weights: ArrayLike = 1.0, **kwargs):
        super().__init__(**kwargs)
        self.task_weights = torch.as_tensor(task_weights, dtype=torch.float).view(1, -1)
        if (self.task_weights != 1.0).any():
            warnings.warn("task_weights were provided but are ignored by metric "
                          f"{self.__class__.__name__}. Got {task_weights}")

    def update(self, preds: Tensor, targets: Tensor, mask: Tensor | None = None, *args, **kwargs):
        if mask is None:
            mask = torch.ones_like(targets, dtype=torch.bool)

        super().update(preds[mask], targets[mask].long())

    @property
    def alias(self) -> str:
        return "multiclass_auroc"

Supply the custom loss function and metric to the model:

In [36]:
n_classes = max(ys).item() + 1

loss_function = ChempropMulticlassHingeLoss(num_classes = n_classes)
ffn = nn.MulticlassClassificationFFN(n_classes=n_classes, criterion=loss_function)

metrics = [ChempropMulticlassAUROC(num_classes=n_classes)]

model = models.MPNN(nn.BondMessagePassing(), nn.NormAggregation(), ffn, metrics=metrics)

Run training

In [37]:
trainer = pl.Trainer(max_epochs=2)
trainer.fit(model, train_loader, val_loader)
trainer.test(model, test_loader)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
Loading `train_dataloader` to estimate number of stepping batches.
/home/knathan/anaconda3/envs/chemprop/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:424: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.
/home/knathan/anaconda3/envs/chemprop/lib/python3.11/site-packages/lightning/pytorch/loops/fit_loop.py:298: The number of training batches (7) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.

  | Name            | Type                        | Params | Mode 
------------------------------------------------------------------------
0 | message_passing | BondMessagePassing          | 2

Sanity Checking: |          | 0/? [00:00<?, ?it/s]

/home/knathan/anaconda3/envs/chemprop/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:424: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.


                                                                           



Epoch 1: 100%|██████████| 7/7 [00:01<00:00,  4.31it/s, v_num=23, train_loss_step=0.336, val_loss=0.930, train_loss_epoch=0.516]

`Trainer.fit` stopped: `max_epochs=2` reached.


Epoch 1: 100%|██████████| 7/7 [00:01<00:00,  4.21it/s, v_num=23, train_loss_step=0.336, val_loss=0.930, train_loss_epoch=0.516]


/home/knathan/anaconda3/envs/chemprop/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:424: The 'test_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.


Testing DataLoader 0: 100%|██████████| 1/1 [00:00<00:00, 19.18it/s]




[{'test/multiclass_auroc': 0.0}]