## Metrics

In [1]:
from lightning import pytorch as pl
import numpy as np
import torch
from chemprop import data, models, nn
from chemprop.nn.metrics import MetricRegistry

### Available metric functions

Chemprop provides several metrics. The functions calculate a single value that serves as a measure of model performance. Users only need to select the metric(s) to use. The rest of the details are handled by Chemprop and the lightning trainer, which reports the first metric on the validation set during training and all metric on the test set during testing. See the [source code](https://github.com/chemprop/chemprop/blob/main/chemprop/nn/metrics.py) for computation details of the metrics.

In [2]:
for metric in MetricRegistry:
    print(metric)

mae
mse
rmse
bounded-mae
bounded-mse
bounded-rmse
r2
roc
prc
accuracy
f1
bce
ce
binary-mcc
multiclass-mcc
sid
wasserstein


### Multiple metrics

A list of metrics is passed to the model at creation if multiple metrics are desired. The first metric in the list is reported during validation. Note that the list of metrics is used in place of the default metric and not in addition to the default metric.

In [3]:
from chemprop.nn.metrics import MSEMetric, MAEMetric, RMSEMetric

mets = [MSEMetric(), MAEMetric(), RMSEMetric()]
model = models.MPNN(nn.BondMessagePassing(), nn.MeanAggregation(), nn.RegressionFFN(), metrics=mets)

### Batch averaged metric

Note that the lightning trainer reports batch averaged metrics, i.e. the metric is evaluated for each batch and then averaged together. This can be different than the metric for the whole dataset for some metrics, like RMSE. Usually this is fine, but if the metric for the whole dataset is desired it can be calculated from the predictions manually.

In [4]:
smis = ["C" * i for i in range(1, 11)]
ys = np.random.rand(len(smis), 1)
dset = data.MoleculeDataset([data.MoleculeDatapoint.from_smi(smi, y) for smi, y in zip(smis, ys)])
dataloader = data.build_dataloader(dset, shuffle=False, batch_size=2)

trainer = pl.Trainer(logger=False, enable_checkpointing=False, max_epochs=1)
batch_avg_result = trainer.test(model, dataloader)
preds = trainer.predict(model, dataloader)
preds = torch.concat(preds)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
/home/knathan/anaconda3/envs/chemprop/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:441: The 'test_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.


Testing DataLoader 0: 100%|██████████| 5/5 [00:00<00:00,  9.83it/s]
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
 batch_averaged_test/mae    0.5207622647285461
 batch_averaged_test/mse    0.34527701139450073
batch_averaged_test/rmse    0.5413810610771179
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────


/home/knathan/anaconda3/envs/chemprop/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:441: The 'predict_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.


Predicting DataLoader 0: 100%|██████████| 5/5 [00:00<00:00,  9.57it/s]


In [5]:
whole_result = RMSEMetric()(preds, torch.from_numpy(dset.Y), None, None, None, None)
print("Batch averaged / Whole dataset")
print(f"{batch_avg_result[0]['batch_averaged_test/rmse']:.4f} / {whole_result.item():.4f}")

Batch averaged / Whole dataset
0.5414 / 0.5876


### Batch normalization

It is worth noting that if your model has a batch normalization layer, the computed metric will be different depending on if the model is in training or evaluation mode. When a batch normalization layer is training, it uses a biased estimator to calculate the standard deviation, but the value stored and used during evaluation is calculated with the unbiased estimator. Lightning takes care of this if the `Trainer()` is used. 

### Regression

There are several metric options for regression. `MSEMetric` is the default. There are also bounded versions (except for r2), similar to the bounded versions of the [loss functions](./loss_functions.ipynb). 

In [6]:
from chemprop.nn.metrics import MSEMetric, MAEMetric, RMSEMetric, R2Metric

In [7]:
from chemprop.nn.metrics import BoundedMAEMetric, BoundedMSEMetric, BoundedRMSEMetric

### Classification

There are metrics for both binary and multiclass classification.

In [8]:
from chemprop.nn.metrics import (
    BinaryAUROCMetric,
    BinaryAUPRCMetric,
    BinaryAccuracyMetric,
    BinaryF1Metric,
    BCEMetric,
    BinaryMCCMetric,
)

In [9]:
from chemprop.nn.metrics import CrossEntropyMetric, MulticlassMCCMetric

### Spectra

Beta versions of metrics for spectra prediction will be finalized for v2.1.

In [10]:
from chemprop.nn.metrics import SIDMetric, WassersteinMetric