Skip to content

Commit

Permalink
Merge pull request #693 from chemprop/conformal_fix
Browse files Browse the repository at this point in the history
conformal quantile prediction bug fix
  • Loading branch information
kevingreenman committed Mar 4, 2024
2 parents 0f0da06 + 274bfe6 commit d5a84d1
Show file tree
Hide file tree
Showing 6 changed files with 79 additions and 22 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -440,7 +440,7 @@ Calibrated regression outputs can be in the form of a standard deviation or an i
* `zelikman_interval` Assumes that the error distribution is the same for each prediction but scaled by the uncalibrated standard deviation for each. Multiplies the uncalibrated standard deviation by a factor necessary to cover the specified interval of the calibration set. Does not assume a Gaussian distribution. Intended for use with intervals but can return a stdev as well. (https://arxiv.org/abs/2005.12496)
* `mve_weighting` For use with ensembles of models trained with mve or evidential loss function. Uses a weighted average of the predicted variances to achieve a minimum negative log likelihood of predictions. (https://doi.org/10.1186/s13321-021-00551-x)
* `conformal_regression` Generates a symmetric interval of fixed size for each prediction such that the actual value has probability $1-\alpha$ of falling in the interval. The desired error rate is controlled using the parameter `--conformal_alpha <float>` which is set by default to 0.1. (https://arxiv.org/abs/2107.07511)
* `conformal_quantile_regression` Similar to `conformal_regression` but generates an interval of variable size for each prediction based on quantile predictions of the data. The model should be trained with parameters `--loss_function quantile_interval` and `--quantile_loss_alpha <float>` where $\alpha$ is the desired error rate of the quantile interval. The trained model will output the center of the $\alpha/2$ and $1-\alpha/2$ quantiles according to pinball loss as the predicted value and return the half range of the interval as the interval value in the uncertainty. The parameter `--conformal_alpha <float>` should be included to specify the desired error rate of the conformal method during inference. (https://arxiv.org/abs/2107.07511)
* `conformal_quantile_regression` Similar to `conformal_regression` but generates an interval of variable size for each prediction based on quantile predictions of the data. The model should be trained with parameters `--loss_function quantile_interval` and `--quantile_loss_alpha <float>` where $\alpha$ is the desired error rate of the quantile interval. The trained model will output the center of the $\alpha/2$ and $1-\alpha/2$ quantiles according to pinball loss as the predicted value and return the half range of the interval as the uncertainty quantification. The parameter `--conformal_alpha <float>` should be included to specify the desired error rate of the conformal method during inference. (https://arxiv.org/abs/2107.07511)

**Classification**
* `platt` Uses a linear scaling before the sigmoid function in prediction to minimize the negative log likelihood of the predictions. If the model checkpoint was generated after Chemprop v1.5.0, then a Bayesian correction is applied to account for the class balance in the training set during prediction. Implemented for classification but not multiclass datasets. (https://arxiv.org/abs/1706.04599)
Expand Down
22 changes: 19 additions & 3 deletions chemprop/train/cross_validate.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,15 @@ def cross_validate(args: TrainArgs,
f'{multitask_mean(scores=scores[fold_num], metric=metric, ignore_nan_metrics=args.ignore_nan_metrics):.6f}')

if args.show_individual_scores:
for task_name, score in zip(args.task_names, scores[fold_num]):
if args.loss_function == "quantile_interval" and metric == "quantile":
num_tasks = len(args.task_names) // 2
task_names = args.task_names[:num_tasks]
task_names = [f"{task_name} lower" for task_name in task_names] + [
f"{task_name} upper" for task_name in task_names]
else:
task_names = args.task_names

for task_name, score in zip(task_names, scores[fold_num]):
info(f'\t\tSeed {init_seed + fold_num} ==> test {task_name} {metric} = {score:.6f}')
if np.isnan(score):
contains_nan_scores = True
Expand All @@ -163,7 +171,7 @@ def cross_validate(args: TrainArgs,
info(f'Overall test {metric} = {mean_score:.6f} +/- {std_score:.6f}')

if args.show_individual_scores:
for task_num, task_name in enumerate(args.task_names):
for task_num, task_name in enumerate(task_names):
info(f'\tOverall test {task_name} {metric} = '
f'{np.mean(scores[:, task_num]):.6f} +/- {np.std(scores[:, task_num]):.6f}')

Expand Down Expand Up @@ -194,7 +202,15 @@ def cross_validate(args: TrainArgs,
row += [mean, std] + task_scores.tolist()
writer.writerow(row)
else: # all other data types, separate scores by task
for task_num, task_name in enumerate(args.task_names):
if args.loss_function == "quantile_interval" and metric == "quantile":
num_tasks = len(args.task_names) // 2
task_names = args.task_names[:num_tasks]
task_names = [f"{task_name} (lower quantile)" for task_name in task_names] + [
f"{task_name} (upper quantile)" for task_name in task_names]
else:
task_names = args.task_names

for task_num, task_name in enumerate(task_names):
row = [task_name]
for metric, scores in all_scores.items():
task_scores = scores[:, task_num]
Expand Down
21 changes: 14 additions & 7 deletions chemprop/train/make_predictions.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,9 @@ def predict_and_save(
calibrator=calibrator
) # preds and unc are lists of shape(data,tasks)

if args.loss_function == "quantile_interval":
task_names = task_names[:len(task_names) // 2]

if calibrator is not None and args.is_atom_bond_targets and args.calibration_method == "isotonic":
unc = reshape_values(unc, test_data, len(args.atom_targets), len(args.bond_targets))

Expand Down Expand Up @@ -218,8 +221,6 @@ def predict_and_save(

if evaluators is not None:
evaluations = []
if args.loss_function == "quantile_interval":
task_names = task_names[:len(task_names) // 2]
print(f"Evaluating uncertainty for tasks {task_names}")
for evaluator in evaluators:
evaluation = evaluator.evaluate(
Expand Down Expand Up @@ -297,12 +298,10 @@ def predict_and_save(
# Add predictions columns
if args.uncertainty_method == "spectra_roundrobin":
unc_names = [estimator.label]
elif args.calibration_method == "conformal_regression":
unc_names = [f"{name}_{estimator.label}_interval" for name in task_names]
elif args.calibration_method == "conformal_quantile_regression":
unc_names = [f"{name}_{estimator.label}_quantile_interval" for name in task_names]
elif args.uncertainty_method == "conformal_quantile_regression" and args.calibration_method is None:
unc_names = [f"{name}_conformal_regression_{args.conformal_alpha}_interval" for name in task_names]
unc_names = [f"{name}_{args.conformal_alpha}_half_interval" for name in task_names]
elif args.calibration_method == "conformal_regression" and args.calibration_path is None:
unc_names = []
elif args.calibration_method == "conformal" and args.dataset_type == "classification":
unc_names = [f"{name}_{estimator.label}_in_set" for name in task_names] + [
f"{name}_{estimator.label}_out_set" for name in task_names
Expand Down Expand Up @@ -423,8 +422,16 @@ def make_predictions(
if args.dataset_type in ["classification", "multiclass"]:
args.uncertainty_method = "classification"
elif args.calibration_method == "conformal_regression":
if args.loss_function == "quantile_interval":
raise ValueError(
"For a model trained on the `quantile_interval` loss function, the calibration method should be assigned as `conformal_quantile_regression` instead of `conformal_regression`."
)
args.uncertainty_method = "conformal_regression"
elif args.calibration_method == "conformal_quantile_regression":
if args.loss_function != "quantile_interval":
raise ValueError(
"The calibration method `conformal_quantile_regression` only supports regression models trained on the `quantile_interval` loss function."
)
args.uncertainty_method = "conformal_quantile_regression"
else:
raise ValueError(
Expand Down
32 changes: 27 additions & 5 deletions chemprop/train/run_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ def run_training(args: TrainArgs,
val_data=val_data,
test_data=test_data,
smiles_columns=args.smiles_columns,
loss_function=args.loss_function,
logger=logger,
)

Expand Down Expand Up @@ -329,8 +330,15 @@ def run_training(args: TrainArgs,
writer.add_scalar(f'validation_{metric}', mean_val_score, n_iter)

if args.show_individual_scores:
if args.loss_function == "quantile_interval" and metric == "quantile":
num_tasks = len(args.task_names) // 2
task_names = args.task_names[:num_tasks]
task_names = [f"{task_name} lower" for task_name in task_names] + [
f"{task_name} upper" for task_name in task_names]
else:
task_names = args.task_names
# Individual validation scores
for task_name, val_score in zip(args.task_names, scores):
for task_name, val_score in zip(task_names, scores):
debug(f'Validation {task_name} {metric} = {val_score:.6f}')
writer.add_scalar(f'validation_{task_name}_{metric}', val_score, n_iter)

Expand Down Expand Up @@ -386,7 +394,7 @@ def run_training(args: TrainArgs,

if args.show_individual_scores and args.dataset_type != 'spectra':
# Individual test scores
for task_name, test_score in zip(args.task_names, scores):
for task_name, test_score in zip(task_names, scores):
info(f'Model {model_idx} test {task_name} {metric} = {test_score:.6f}')
writer.add_scalar(f'test_{task_name}_{metric}', test_score, n_iter)
writer.close()
Expand Down Expand Up @@ -423,7 +431,7 @@ def run_training(args: TrainArgs,

# Individual ensemble scores
if args.show_individual_scores:
for task_name, ensemble_score in zip(args.task_names, scores):
for task_name, ensemble_score in zip(task_names, scores):
info(f'Ensemble test {task_name} {metric} = {ensemble_score:.6f}')

# Save scores
Expand All @@ -446,8 +454,22 @@ def run_training(args: TrainArgs,
values = [list(v) for v in values]
test_preds_dataframe[bond_target] = values
else:
for i, task_name in enumerate(args.task_names):
test_preds_dataframe[task_name] = [pred[i] for pred in avg_test_preds]
if args.loss_function == "quantile_interval" and metric == "quantile":
num_tasks = len(args.task_names) // 2
task_names = args.task_names[:num_tasks]
avg_test_preds = np.array(avg_test_preds)
num_data = avg_test_preds.shape[0]
preds = avg_test_preds.reshape(num_data, 2, num_tasks).mean(axis=1)
intervals = abs(np.diff(avg_test_preds.reshape(num_data, 2, num_tasks), axis=1) / 2)
intervals = intervals.reshape(num_data, num_tasks)
for i, task_name in enumerate(task_names):
test_preds_dataframe[task_name] = [pred[i] for pred in preds]
for i, task_name in enumerate(task_names):
task_name = f"{task_name}_{args.quantile_loss_alpha}_half_interval"
test_preds_dataframe[task_name] = [interval[i] for interval in intervals]
else:
for i, task_name in enumerate(task_names):
test_preds_dataframe[task_name] = [pred[i] for pred in avg_test_preds]

test_preds_dataframe.to_csv(os.path.join(args.save_dir, 'test_preds.csv'), index=False)

Expand Down
8 changes: 4 additions & 4 deletions chemprop/uncertainty/uncertainty_calibrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -903,7 +903,7 @@ class ConformalMulticlassCalibrator(UncertaintyCalibrator):

@property
def label(self):
return "conformal"
return f"conformal_{self.conformal_alpha}"

def raise_argument_errors(self):
super().raise_argument_errors()
Expand Down Expand Up @@ -973,7 +973,7 @@ class ConformalAdaptiveMulticlassCalibrator(ConformalMulticlassCalibrator):

@property
def label(self):
return "conformal_adaptive"
return f"conformal_adaptive_{self.conformal_alpha}"

def raise_argument_errors(self):
super().raise_argument_errors()
Expand Down Expand Up @@ -1011,7 +1011,7 @@ class ConformalMultilabelCalibrator(UncertaintyCalibrator):

@property
def label(self):
return f"conformal_{self.conformal_alpha}"
return f"conformal_multilabel_{self.conformal_alpha}"

def raise_argument_errors(self):
super().raise_argument_errors()
Expand Down Expand Up @@ -1099,7 +1099,7 @@ class ConformalRegressionCalibrator(UncertaintyCalibrator):

@property
def label(self):
return f"conformal_regression_{self.conformal_alpha}"
return f"conformal_regression_{self.conformal_alpha}_half_interval"

def raise_argument_errors(self):
super().raise_argument_errors()
Expand Down
16 changes: 14 additions & 2 deletions chemprop/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -608,8 +608,9 @@ def save_smiles_splits(
train_data: MoleculeDataset = None,
val_data: MoleculeDataset = None,
test_data: MoleculeDataset = None,
logger: logging.Logger = None,
smiles_columns: List[str] = None,
loss_function: str = None,
logger: logging.Logger = None,
) -> None:
"""
Saves a csv file with train/val/test splits of target data and additional features.
Expand All @@ -626,6 +627,7 @@ def save_smiles_splits(
:param val_data: Validation :class:`~chemprop.data.data.MoleculeDataset`.
:param test_data: Test :class:`~chemprop.data.data.MoleculeDataset`.
:param smiles_columns: The name of the column containing SMILES. By default, uses the first column.
:param loss_function: The loss function to be used in training.
:param logger: A logger for recording output.
"""
makedirs(save_dir)
Expand All @@ -652,7 +654,15 @@ def save_smiles_splits(
indices_by_smiles[smiles] = i

if task_names is None:
task_names = get_task_names(path=data_path, smiles_columns=smiles_columns)
task_names = get_task_names(
path=data_path,
smiles_columns=smiles_columns,
loss_function=loss_function,
)

if loss_function == "quantile_interval":
num_tasks = len(task_names) // 2
task_names = task_names[:num_tasks]

features_header = []
if features_path is not None:
Expand Down Expand Up @@ -689,6 +699,8 @@ def save_smiles_splits(
dataset_targets = dataset.targets()
for i, smiles in enumerate(dataset.smiles()):
targets = [x.tolist() if isinstance(x, np.ndarray) else x for x in dataset_targets[i]]
# correct the number of targets when running quantile regression
targets = targets[:len(task_names)]
writer.writerow(smiles + targets)

if features_path is not None:
Expand Down

0 comments on commit d5a84d1

Please sign in to comment.