Skip to content

Commit

Permalink
Rearrange ence evaluation (#302)
Browse files Browse the repository at this point in the history
* Rearrange ence evaluation

* Flake 8 formatting
  • Loading branch information
cjmcgill committed Jun 29, 2022
1 parent 2c171a1 commit 389bc12
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 27 deletions.
39 changes: 20 additions & 19 deletions chemprop/uncertainty/uncertainty_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,11 +40,14 @@ def raise_argument_errors(self):
raise NotImplementedError(
"No uncertainty evaluators implemented for spectra dataset type."
)
if self.uncertainty_method in ['ensemble', 'dropout'] and self.dataset_type in ['classification', 'multiclass']:
if self.uncertainty_method in ["ensemble", "dropout"] and self.dataset_type in [
"classification",
"multiclass",
]:
raise NotImplementedError(
'Though ensemble and dropout uncertainty methods are available for classification \
"Though ensemble and dropout uncertainty methods are available for classification \
multiclass dataset types, their outputs are not confidences and are not \
compatible with any implemented evaluation methods for classification.'
compatible with any implemented evaluation methods for classification."
)

@abstractmethod
Expand Down Expand Up @@ -158,7 +161,8 @@ def evaluate(
task_mask = mask[:, i]
task_unc = uncertainties[task_mask, i]
task_targets = targets[task_mask, i]
task_likelihood = task_unc * task_targets + (1 - task_unc) * (1 - task_targets)
task_likelihood = task_unc * task_targets \
+ (1 - task_unc) * (1 - task_targets)
task_nll = -1 * np.log(task_likelihood)
nll.append(task_nll.mean())
return nll
Expand Down Expand Up @@ -252,7 +256,7 @@ def evaluate(
bin_unc = task_unc / original_scaling[j] * bin_scaling[i][j]
bin_fraction = np.mean(bin_unc >= task_error)
fractions[j, i] = bin_fraction

# return calibration settings to original state
self.calibrator.regression_calibrator_metric = original_metric
self.calibrator.scaling = original_scaling
Expand Down Expand Up @@ -307,17 +311,18 @@ def evaluate(
error = np.abs(preds - targets) # shape(data, tasks)

# get stdev scaling then revert if interval
if self.calibrator is not None and self.calibration_method != "tscaling":
if self.calibrator.regression_calibrator_metric == "interval":
original_metric = self.calibrator.regression_calibrator_metric
original_scaling = self.calibrator.scaling
if self.calibrator is not None:
original_metric = self.calibrator.regression_calibrator_metric
original_scaling = self.calibrator.scaling
if (
self.calibration_method != "tscaling"
and self.calibrator.regression_calibrator_metric == "interval"
):
self.calibrator.regression_calibrator_metric = "stdev"
self.calibrator.calibrate()
stdev_scaling = self.calibrator.scaling
self.calibrator.regression_calibrator_metric = original_metric
self.calibrator.scaling = original_scaling
else: # stdev metric
stdev_scaling = self.calibrator.scaling

mean_vars = np.zeros([preds.shape[1], 100]) # shape(tasks, 100)
rmses = np.zeros_like(mean_vars)
Expand All @@ -344,9 +349,10 @@ def evaluate(
bin_var = t.var(df=self.calibrator.num_models - 1, scale=bin_unc)
mean_vars[i, j] = np.mean(bin_var)
rmses[i, j] = np.sqrt(np.mean(np.square(split_error[j])))
else: # stdev metric
else:
bin_unc = split_unc[j]
bin_unc = bin_unc / original_scaling[i] * stdev_scaling[i] # convert from interval to stdev as needed
if self.calibrator.regression_calibrator_metric == "interval":
bin_unc = bin_unc / original_scaling[i] * stdev_scaling[i] # convert from interval to stdev as needed
mean_vars[i, j] = np.mean(np.square(bin_unc))
rmses[i, j] = np.sqrt(np.mean(np.square(split_error[j])))

Expand Down Expand Up @@ -424,12 +430,7 @@ def build_uncertainty_evaluator(
"f1",
"mcc",
]
multiclass_metrics = [
"cross_entropy",
"accuracy",
"f1",
"mcc"
]
multiclass_metrics = ["cross_entropy", "accuracy", "f1", "mcc"]
if dataset_type == "classification" and evaluation_method in classification_metrics:
evaluator_class = MetricEvaluator
elif dataset_type == "multiclass" and evaluation_method in multiclass_metrics:
Expand Down
16 changes: 8 additions & 8 deletions tests/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -1046,14 +1046,14 @@ def test_train_single_task_regression_reaction_solvent(self,
['--loss_function', 'evidential'],
[],
),
# (
# 8.843267,
# 'dropout',
# None,
# 'nll',
# ['--num_folds', '1'],
# [],
# ),
(
20.50925,
'dropout',
'zscaling',
'ence',
['--num_folds', '1'],
[],
),
(
-1.9783182,
'ensemble',
Expand Down

0 comments on commit 389bc12

Please sign in to comment.