## Loss functions

In [1]:
from lightning import pytorch as pl
import numpy as np
from chemprop import data, models, nn
from chemprop.nn.loss import LossFunctionRegistry

### Available functions

Chemprop provides several loss functions. The derivatives of these differentiable functions are used to update the model weights. Users only need to select the loss function to use. The rest of the details are handled by Chemprop and the lightning trainer, which reports the training loss during model fitting (training). Note that the loss function is not used during evaluation on the validation or test sets. A [metric](./metrics.ipynb) is used in those cases. 

In [2]:
for lossfunction in LossFunctionRegistry:
    print(lossfunction)

mse
bounded-mse
mve
evidential
bce
ce
binary-mcc
multiclass-mcc
binary-dirichlet
multiclass-dirichlet
sid
earthmovers
wasserstein


### Task weights

A model can make predictions of multiple targets/tasks at the same time. For example, a model may predict both solubility and melting point. Task weights can be specified when some of the tasks are more important to get accurate than others. The weight for each task defaults to 1.

In [3]:
from chemprop.nn.loss import MSELoss

predictor = nn.RegressionFFN(criterion=MSELoss(task_weights=[0.1, 0.5, 1.0]))
model = models.MPNN(nn.BondMessagePassing(), nn.MeanAggregation(), predictor)
predictor.criterion

MSELoss(task_weights=[[0.10000000149011612, 0.5, 1.0]])

### Mean squared error and bounded mean square error

`MSELoss` is the default loss function for regression tasks.

In [4]:
predictor = nn.RegressionFFN()
model = models.MPNN(nn.BondMessagePassing(), nn.MeanAggregation(), predictor)
predictor.criterion

MSELoss(task_weights=[[1.0]])

`BoundedMSELoss` is useful when the target values have "less than" or "greater than" behavior, e.g. the prediction is correct as long as it is below/above a target value. Datapoints have a less than/greater than property that keeps track of bounded targets. Note that, like target values, the less than and greater than masks used to make datapoints are 1-D numpy arrays of bools instead of a single bool. This is because a single datapoint can have multiple target values and the less than/greater than masks are defined for each target value separately.

In [5]:
from chemprop.nn.loss import BoundedMSELoss

smis = ["C" * i for i in range(1, 6)]
ys = np.random.rand(len(smis), 1)
lt_mask = np.array([[True], [False], [False], [False], [True]])
gt_mask = np.array([[False], [True], [False], [True], [False]])
datapoints = [
    data.MoleculeDatapoint.from_smi(smi, y, lt_mask=lt, gt_mask=gt)
    for smi, y, lt, gt in zip(smis, ys, lt_mask, gt_mask)
]
bounded_dataset = data.MoleculeDataset(datapoints)
bounded_dataset.lt_mask

array([[ True],
       [False],
       [False],
       [False],
       [ True]])

In [6]:
predictor = nn.RegressionFFN(criterion=BoundedMSELoss())
model = models.MPNN(nn.BondMessagePassing(), nn.MeanAggregation(), predictor)

### Binary cross entropy and cross entropy

`BCELoss` is the default loss function for binary classification and `CrossEntropyLoss` is the default for multiclass classification.

In [7]:
predictor = nn.BinaryClassificationFFN()
predictor.criterion

BCELoss(task_weights=[[1.0]])

In [8]:
predictor = nn.MulticlassClassificationFFN(n_classes=3)
predictor.criterion

CrossEntropyLoss(task_weights=[[1.0, 1.0, 1.0]])

### Matthews correlation coefficient

MCC loss is useful for imbalanced classification data. More details coming soon.

In [9]:
from chemprop.nn.loss import BinaryMCCLoss, MulticlassMCCLoss

### Uncertainty and spectral loss functions

Beta versions of metrics for uncertainty and spectra prediction will be finalized for v2.1.

In [10]:
from chemprop.nn.loss import MVELoss, EvidentialLoss, BinaryDirichletLoss, MulticlassDirichletLoss
from chemprop.nn.loss import SIDLoss, WassersteinLoss

### Custom loss functions

A custom loss function can be made by inheriting from `LossFunction` and overwriting the methods as needed.

In [11]:
import torch
from torch import Tensor
from chemprop.nn.loss import LossFunction


class CustomLoss(LossFunction):
    def __init__(self, task_weights=None, norm: float = 1.0):
        super().__init__(task_weights)
        norm = torch.as_tensor(norm)
        self.register_buffer("norm", norm)

    def _calc_unreduced_loss(self, preds, targets, mask, weights, lt_mask, gt_mask) -> Tensor:
        return torch.sum((preds - targets) ** 2) / self.norm