Skip to content

Commit

Permalink
Conformal Calibration (#304)
Browse files Browse the repository at this point in the history
Co-authored-by: Chas <charlesjmcgill@gmail.com>
Co-authored-by: Kevin Greenman <kpg@mit.edu>
Co-authored-by: Charles McGill <44245643+cjmcgill@users.noreply.github.com>
Co-authored-by: david graff <60193893+davidegraff@users.noreply.github.com>
Co-authored-by: Daniel Xu <danielxu@rosetta10.csail.mit.edu>
Co-authored-by: Shih-Cheng Li <scli@mit.edu>
Co-authored-by: Kevin Greenman <35846516+kevingreenman@users.noreply.github.com>
  • Loading branch information
8 people committed Feb 29, 2024
1 parent 7214a7b commit b6b4d93
Show file tree
Hide file tree
Showing 19 changed files with 8,955 additions and 114 deletions.
18 changes: 14 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -231,18 +231,20 @@ By default, both random and scaffold split the data into 80% train, 10% validati
### Loss functions

The loss functions available for training are dependent on the selected dataset type. Loss functions other than the defaults can be selected from the supported options with the argument `--loss_function <function>`.
* **Regression.** mse (default), bounded_mse, mve (mean-variance estimation, a.k.a. heteroscedastic loss), evidential.
* **Regression.** mse (default), bounded_mse, mve (mean-variance estimation, a.k.a. heteroscedastic loss), evidential, quantile_interval (Pinball loss, specify margins with `--quantile_loss_alpha <float>`).
* **Classification.** binary_cross_entropy (default), mcc (a soft version of Matthews Correlation Coefficient), dirichlet (a.k.a. evidential classification)
* **Multiclass.** cross_entropy (default), mcc (a soft version of Matthews Correlation Coefficient)
* **Spectra.** sid (default, spectral information divergence), wasserstein (First-order Wasserstein distance a.k.a. earthmover's distance.)


Dropout regularization can be applied regardless of loss function using the argument `--dropout <float>` and providing a dropout fraction between 0 and 1.

The regression loss functions `mve` and `evidential` function by minimizing the negative log likelihood of a predicted uncertainty distribution. If used during training, the uncertainty predictions from these loss functions can be used for uncertainty prediction during prediction tasks. A regularization specific to evidential learning can be applied using the argument `--evidential_regularization <float>`.
The regression loss functions `mve` and `evidential` function by minimizing the negative log likelihood of a predicted uncertainty distribution. If used during training, the uncertainty predictions from these loss functions can be used for uncertainty prediction during prediction tasks. A regularization specific to evidential learning can be applied using the argument `--evidential_regularization <float>`. The regression loss function `quantile_interval` trains the model with two different output heads which correspond to the `quantile_loss_alpha/2` and `1 - quantile_loss_alpha/2` quantile predictions. Since it is a symmetrical interval, return the center of the interval as the predicted value. The evaluation metric for `quantile_interval` is automatically set to the `quantile` metric.

### Metrics

Metrics are used to evaluate the success of the model against the test set as the final model score and to determine the optimal epoch to save the model at based on the validation set. The primary metric used for both purposes is selected with the argument `--metric <metric>` and additional metrics for test set score only can be added with `--extra_metrics <metric1> <metric2> ...`. Supported metrics are dependent on the dataset type. Unlike loss functions, metrics do not have to be differentiable.
* **Regression.** rmse (default), mae, mse, r2, bounded_rmse, bounded_mae, bounded_mse (default if bounded_mse is loss function).
* **Regression.** rmse (default), mae, mse, r2, bounded_rmse, bounded_mae, bounded_mse (default if bounded_mse is loss function), quantile (average of pinball loss for both output heads).
* **Classification.** auc (default), prc-auc, accuracy, binary_cross_entropy, f1, mcc, recall, precision and balanced accuracy.
* **Multiclass.** cross_entropy (default), accuracy, f1, mcc.
* **Spectra.** sid (default), wasserstein.
Expand Down Expand Up @@ -428,15 +430,22 @@ Uncertainty predictions may be calibrated to improve their performance on new pr

**Regression**

Calibrated regression outputs can be in the form of a standard deviation or an interval, as specified with the argument `--regression_calibrator_metric <"stdev" or "interval">`. The interval can be set using `--calibration_interval_percentile <float>` in the range (1,100).
Calibrated regression outputs can be in the form of a standard deviation or an interval, as specified with the argument `--regression_calibrator_metric <"stdev" or "interval">`. The interval can be set using `--calibration_interval_percentile <float>` in the range (1,100). The options mentioned above do not apply to the calibration methods `conformal_regression` and `conformal_quantile_regression`.
* `zscaling` Assumes that errors are normally distributed according to the estimated variance for each prediction. Applies a constant multiple to all stdev or interval outputs in order to minimize the negative log likelihood for the normal distributions. (https://arxiv.org/abs/1905.11659)
* `tscaling` Similar to zscaling. Assumes that the errors are normally distributed, but accounts for the ensemble size and uncertainty in the sample variance by using a sample-size reduced t-distribution in the negative log likelihood. Works best when errors are mostly due to variability between model instances and not dataset noise or model bias.
* `zelikman_interval` Assumes that the error distribution is the same for each prediction but scaled by the uncalibrated standard deviation for each. Multiplies the uncalibrated standard deviation by a factor necessary to cover the specified interval of the calibration set. Does not assume a Gaussian distribution. Intended for use with intervals but can return a stdev as well. (https://arxiv.org/abs/2005.12496)
* `mve_weighting` For use with ensembles of models trained with mve or evidential loss function. Uses a weighted average of the predicted variances to achieve a minimum negative log likelihood of predictions. (https://doi.org/10.1186/s13321-021-00551-x)
* `conformal_regression` Generates a symmetric interval of fixed size for each prediction such that the actual value has probability $1-\alpha$ of falling in the interval. The desired error rate is controlled using the parameter `--conformal_alpha <float>` which is set by default to 0.1. (https://arxiv.org/abs/2107.07511)
* `conformal_quantile_regression` Similar to `conformal_regression` but generates an interval of variable size for each prediction based on quantile predictions of the data. The model should be trained with parameters `--loss_function quantile_interval` and `--quantile_loss_alpha <float>` where $\alpha$ is the desired error rate of the quantile interval. The trained model will output the center of the $\alpha/2$ and $1-\alpha/2$ quantiles according to pinball loss as the predicted value and return the half range of the interval as the interval value in the uncertainty. The parameter `--conformal_alpha <float>` should be included to specify the desired error rate of the conformal method during inference. (https://arxiv.org/abs/2107.07511)

**Classification**
* `platt` Uses a linear scaling before the sigmoid function in prediction to minimize the negative log likelihood of the predictions. If the model checkpoint was generated after Chemprop v1.5.0, then a Bayesian correction is applied to account for the class balance in the training set during prediction. Implemented for classification but not multiclass datasets. (https://arxiv.org/abs/1706.04599)
* `isotonic` Fits an isotonic regression model to the predictions. Prediction outputs are transformed using a stepped histogram-style to match the empirical probability observed in the calibration data. Number and size of the histogram bins are procedurally decided. Histogram bins are wider in the regions of the model output that are less reliable in ordering confidence. Implemented for both classification and multiclass datasets. (https://arxiv.org/abs/1706.04599)
* `conformal` Generates a pair of sets of labels $C_{in} \subset C_{out}$ such that the true set of labels $S$ satisfies the property $C_{in} \subset S \subset C_{out}$ with probability at least $1-\alpha$. The desired error rate $\alpha$ can be controlled with the parameter `--conformal_alpha <float>` which is set by default to 0.1. (https://arxiv.org/abs/2004.10181)

**Multiclass**
* `conformal` Generates a set of possible classes for each prediction such that the true class has probability $1-\alpha$ of falling in the set. The desired error rate $\alpha$ can be controlled with the parameter `--conformal_alpha <float>` which is set by default to 0.1. Set generated using the basic conformal method. (https://arxiv.org/abs/2107.07511)
* `conformal_adaptive` Generates a set of possible classes for each prediction such that the true class has probability 1-alpha of falling in the set. The desired error rate $\alpha$ can be controlled with the parameter `--conformal_alpha <float>` which is set by default to 0.1. Set generated using the adaptive conformal method. (https://arxiv.org/abs/2107.07511)

### Uncertainty Evaluation Metrics

Expand All @@ -447,6 +456,7 @@ The performance of uncertainty predictions (calibrated or uncalibrated) as evalu
* `spearman` A regression evaluation metric. Returns the Spearman rank correlation between the predicted uncertainty and the actual error in predictions. Only considers ordering, does not assume a particular probability distribution.
* `ence` Expected normalized calibration error. A regression evaluation metric. Bins model prediction according to uncertainty prediction and compares the RMSE in each bin versus the expected error based on the predicted uncertainty variance then scaled by variance. (discussed in https://doi.org/10.1021/acs.jcim.9b00975)
* `miscalibration_area` A regression evaluation metric. Calculates the model's performance of expected probability versus realized probability at different points along the probability distribution. Values range (0, 0.5) with perfect calibration at 0. (discussed in https://doi.org/10.1021/acs.jcim.9b00975)
* `conformal_coverage` Measures the empirical coverage of the conformal methods, that is the proportion of datapoints that fall within the output set or interval. Must be used with a conformal calibration method which outputs a set or interval. The metric can be used with multiclass, multilabel, or regression conformal methods.

Different evaluation metrics consider different aspects of uncertainty. It is often appropriate to consider multiple metrics. For intance, miscalibration error is important for evaluating uncertainty magnitude but does not indicate that the uncertainty function discriminates well between different outputs. Similarly, spearman tests ordering but not prediction magnitude.

Expand Down
82 changes: 66 additions & 16 deletions chemprop/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,7 @@ class TrainArgs(CommonArgs):
"""Name of the columns to ignore when :code:`target_columns` is not provided."""
dataset_type: Literal['regression', 'classification', 'multiclass', 'spectra']
"""Type of dataset. This determines the default loss function used during training."""
loss_function: Literal['mse', 'bounded_mse', 'binary_cross_entropy', 'cross_entropy', 'mcc', 'sid', 'wasserstein', 'mve', 'evidential', 'dirichlet'] = None
loss_function: Literal['mse', 'bounded_mse', 'binary_cross_entropy', 'cross_entropy', 'mcc', 'sid', 'wasserstein', 'mve', 'evidential', 'dirichlet', 'quantile_interval'] = None
"""Choice of loss function. Loss functions are limited to compatible dataset types."""
multiclass_num_classes: int = 3
"""Number of classes when running multiclass classification."""
Expand Down Expand Up @@ -474,6 +474,8 @@ class TrainArgs(CommonArgs):
evidential_regularization: float = 0
"""Value used in regularization for evidential loss function. The default value recommended by Soleimany et al.(2021) is 0.2.
Optimal value is dataset-dependent; it is recommended that users test different values to find the best value for their model."""
quantile_loss_alpha: float = 0.1
"""Target error bounds for quantile interval loss"""
overwrite_default_atom_features: bool = False
"""
Overwrites the default atom descriptors with the new ones instead of concatenating them.
Expand Down Expand Up @@ -506,6 +508,7 @@ def __init__(self, *args, **kwargs) -> None:
self._task_names = None
self._crossval_index_sets = None
self._task_names = None
self._quantiles = None
self._num_tasks = None
self._features_size = None
self._train_data_size = None
Expand Down Expand Up @@ -549,6 +552,15 @@ def num_tasks(self) -> int:
"""The number of tasks being trained on."""
return len(self.task_names) if self.task_names is not None else 0

@property
def quantiles(self) -> List[float]:
"""A list of quantiles to be being trained on."""
return self._quantiles

@quantiles.setter
def quantiles(self, quantiles: List[float]) -> None:
self._quantiles = quantiles

@property
def features_size(self) -> int:
"""The dimensionality of the additional molecule-level features."""
Expand Down Expand Up @@ -701,16 +713,18 @@ def process_args(self) -> None:

# Process and validate metric and loss function
if self.metric is None:
if self.dataset_type == 'classification':
self.metric = 'auc'
elif self.dataset_type == 'multiclass':
self.metric = 'cross_entropy'
elif self.dataset_type == 'spectra':
self.metric = 'sid'
elif self.dataset_type == 'regression' and self.loss_function == 'bounded_mse':
self.metric = 'bounded_mse'
elif self.dataset_type == 'regression':
self.metric = 'rmse'
if self.dataset_type == "classification":
self.metric = "auc"
elif self.dataset_type == "multiclass":
self.metric = "cross_entropy"
elif self.dataset_type == "spectra":
self.metric = "sid"
elif self.dataset_type == "regression" and self.loss_function == "bounded_mse":
self.metric = "bounded_mse"
elif self.dataset_type == "regression" and self.loss_function == "quantile_interval":
self.metric = "quantile"
elif self.dataset_type == "regression":
self.metric = "rmse"
else:
raise ValueError(f'Dataset type {self.dataset_type} is not supported.')

Expand All @@ -720,11 +734,14 @@ def process_args(self) -> None:

for metric in self.metrics:
if not any([(self.dataset_type == 'classification' and metric in ['auc', 'prc-auc', 'accuracy', 'binary_cross_entropy', 'f1', 'mcc', 'recall', 'precision', 'balanced_accuracy', 'confusion_matrix']),
(self.dataset_type == 'regression' and metric in ['rmse', 'mae', 'mse', 'r2', 'bounded_rmse', 'bounded_mae', 'bounded_mse']),
(self.dataset_type == 'regression' and metric in ['rmse', 'mae', 'mse', 'r2', 'bounded_rmse', 'bounded_mae', 'bounded_mse', 'quantile']),
(self.dataset_type == 'multiclass' and metric in ['cross_entropy', 'accuracy', 'f1', 'mcc']),
(self.dataset_type == 'spectra' and metric in ['sid', 'wasserstein'])]):
raise ValueError(f'Metric "{metric}" invalid for dataset type "{self.dataset_type}".')

if metric == "quantile" and self.loss_function != "quantile_interval":
raise ValueError(f'Metric quantile is only compatible with quantile_interval loss.')

if self.loss_function is None:
if self.dataset_type == 'classification':
self.loss_function = 'binary_cross_entropy'
Expand Down Expand Up @@ -871,7 +888,14 @@ def process_args(self) -> None:

# check if key molecule index is outside of the number of molecules
if self.split_key_molecule >= self.number_of_molecules:
raise ValueError('The index provided with the argument `--split_key_molecule` must be less than the number of molecules. Note that this index begins with 0 for the first molecule. ')
raise ValueError(
"The index provided with the argument `--split_key_molecule` must be less than the number of molecules. Note that this index begins with 0 for the first molecule. "
)

if not 0 <= self.quantile_loss_alpha <= 0.5:
raise ValueError(
"quantile_loss_alpha should be in the range [0, 0.5]"
)


class PredictArgs(CommonArgs):
Expand Down Expand Up @@ -900,7 +924,18 @@ class PredictArgs(CommonArgs):
'dirichlet',
] = None
"""The method of calculating uncertainty."""
calibration_method: Literal['zscaling', 'tscaling', 'zelikman_interval', 'mve_weighting', 'platt', 'isotonic'] = None
calibration_method: Literal[
"zscaling",
"tscaling",
"zelikman_interval",
"mve_weighting",
"platt",
"isotonic",
"conformal",
"conformal_adaptive",
"conformal_regression",
"conformal_quantile_regression",
] = None
"""Methods used for calibrating the uncertainty calculated with uncertainty method."""
evaluation_methods: List[str] = None
"""The methods used for evaluating the uncertainty performance if the test data provided includes targets.
Expand All @@ -909,6 +944,8 @@ class PredictArgs(CommonArgs):
"""Location to save the results of uncertainty evaluations."""
uncertainty_dropout_p: float = 0.1
"""The probability to use for Monte Carlo dropout uncertainty estimation."""
conformal_alpha: float = 0.1
"""Target error rate for conformal prediction."""
dropout_sampling_size: int = 10
"""The number of samples to use for Monte Carlo dropout uncertainty estimation. Distinct from the dropout used during training."""
calibration_interval_percentile: float = 95
Expand Down Expand Up @@ -937,6 +974,8 @@ def process_args(self) -> None:
if self.regression_calibrator_metric is None:
if self.calibration_method == 'zelikman_interval':
self.regression_calibrator_metric = 'interval'
elif self.calibration_method in ['conformal_regression', 'conformal_quantile_regression']:
self.regression_calibrator_metric = None
else:
self.regression_calibrator_metric = 'stdev'

Expand Down Expand Up @@ -985,8 +1024,19 @@ def process_args(self) -> None:
('`--atom_descriptors_path`', self.atom_descriptors_path, self.calibration_atom_descriptors_path),
('`--bond_descriptors_path`', self.bond_descriptors_path, self.calibration_bond_descriptors_path)
]:
if base_features_path is not None and self.calibration_path is not None and cal_features_path is None:
raise ValueError(f'Additional features were provided using the argument {features_argument}. The same kinds of features must be provided for the calibration dataset.')
if (
base_features_path is not None
and self.calibration_path is not None
and cal_features_path is None
):
raise ValueError(
f"Additional features were provided using the argument {features_argument}. The same kinds of features must be provided for the calibration dataset."
)

if not 0 <= self.conformal_alpha <= 1:
raise ValueError(
"conformal_alpha should be in the range [0,1]"
)


class InterpretArgs(CommonArgs):
Expand Down

0 comments on commit b6b4d93

Please sign in to comment.