Skip to content

Commit

Permalink
Merge pull request #453 from chemprop/nan_metric_fix
Browse files Browse the repository at this point in the history
Switching np.mean to np.nanmean to handle NaN metrics
  • Loading branch information
kevingreenman committed Feb 2, 2024
2 parents 40e615c + 7c9e1f5 commit b08063b
Show file tree
Hide file tree
Showing 4 changed files with 37 additions and 10 deletions.
2 changes: 2 additions & 0 deletions chemprop/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,6 +306,8 @@ class TrainArgs(CommonArgs):
"""
extra_metrics: List[Metric] = []
"""Additional metrics to use to evaluate the model. Not used for early stopping."""
ignore_nan_metrics: bool = False
"""Ignore invalid task metrics (NaNs) when computing average metrics across tasks."""
save_dir: str = None
"""Directory where model checkpoints will be saved."""
checkpoint_frzn: str = None
Expand Down
16 changes: 13 additions & 3 deletions chemprop/train/cross_validate.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,8 @@ def cross_validate(args: TrainArgs,
contains_nan_scores = False
for fold_num in range(args.num_folds):
for metric, scores in all_scores.items():
info(f'\tSeed {init_seed + fold_num} ==> test {metric} = {multitask_mean(scores[fold_num], metric):.6f}')
info(f'\tSeed {init_seed + fold_num} ==> test {metric} = '
f'{multitask_mean(scores=scores[fold_num], metric=metric, ignore_nan_metrics=args.ignore_nan_metrics):.6f}')

if args.show_individual_scores:
for task_name, score in zip(args.task_names, scores[fold_num]):
Expand All @@ -143,7 +144,12 @@ def cross_validate(args: TrainArgs,

# Report scores across folds
for metric, scores in all_scores.items():
avg_scores = multitask_mean(scores, axis=1, metric=metric) # average score for each model across tasks
avg_scores = multitask_mean(
scores=scores,
axis=1,
metric=metric,
ignore_nan_metrics=args.ignore_nan_metrics
) # average score for each model across tasks
mean_score, std_score = np.mean(avg_scores), np.std(avg_scores)
info(f'Overall test {metric} = {mean_score:.6f} +/- {std_score:.6f}')

Expand Down Expand Up @@ -188,7 +194,11 @@ def cross_validate(args: TrainArgs,
writer.writerow(row)

# Determine mean and std score of main metric
avg_scores = multitask_mean(all_scores[args.metric], metric=args.metric, axis=1)
avg_scores = multitask_mean(
scores=all_scores[args.metric],
metric=args.metric, axis=1,
ignore_nan_metrics=args.ignore_nan_metrics
)
mean_score, std_score = np.mean(avg_scores), np.std(avg_scores)

# Optionally merge and save test preds
Expand Down
18 changes: 15 additions & 3 deletions chemprop/train/run_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,7 +319,11 @@ def run_training(args: TrainArgs,

for metric, scores in val_scores.items():
# Average validation score\
mean_val_score = multitask_mean(scores, metric=metric)
mean_val_score = multitask_mean(
scores=scores,
metric=metric,
ignore_nan_metrics=args.ignore_nan_metrics
)
debug(f'Validation {metric} = {mean_val_score:.6f}')
writer.add_scalar(f'validation_{metric}', mean_val_score, n_iter)

Expand All @@ -330,7 +334,11 @@ def run_training(args: TrainArgs,
writer.add_scalar(f'validation_{task_name}_{metric}', val_score, n_iter)

# Save model checkpoint if improved validation score
mean_val_score = multitask_mean(val_scores[args.metric], metric=args.metric)
mean_val_score = multitask_mean(
scores=val_scores[args.metric],
metric=args.metric,
ignore_nan_metrics=args.ignore_nan_metrics
)
if args.minimize_score and mean_val_score < best_score or \
not args.minimize_score and mean_val_score > best_score:
best_score, best_epoch = mean_val_score, epoch
Expand Down Expand Up @@ -403,7 +411,11 @@ def run_training(args: TrainArgs,

for metric, scores in ensemble_scores.items():
# Average ensemble score
mean_ensemble_test_score = multitask_mean(scores, metric=metric)
mean_ensemble_test_score = multitask_mean(
scores=scores,
metric=metric,
ignore_nan_metrics=args.ignore_nan_metrics
)
info(f'Ensemble test {metric} = {mean_ensemble_test_score:.6f}')

# Individual ensemble scores
Expand Down
11 changes: 7 additions & 4 deletions chemprop/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
from torch.optim import Adam, Optimizer
from torch.optim.lr_scheduler import _LRScheduler
from tqdm import tqdm
from scipy.stats.mstats import gmean

from chemprop.args import PredictArgs, TrainArgs, FingerprintArgs
from chemprop.data import StandardScaler, AtomBondScaler, MoleculeDataset, preprocess_smiles_columns, get_task_names
Expand Down Expand Up @@ -842,6 +841,7 @@ def multitask_mean(
scores: np.ndarray,
metric: str,
axis: int = None,
ignore_nan_metrics: bool = False,
) -> float:
"""
A function for combining the metric scores across different
Expand All @@ -853,7 +853,8 @@ def multitask_mean(
:param scores: The scores from different tasks for a single metric.
:param metric: The metric used to generate the scores.
:axis: The axis along which to take the mean.
:param axis: The axis along which to take the mean.
:param ignore_nan_metrics: Ignore invalid task metrics (NaNs) when computing average metrics across tasks.
:return: The combined score across the tasks.
"""
scale_dependent_metrics = ["rmse", "mae", "mse", "bounded_rmse", "bounded_mae", "bounded_mse"]
Expand All @@ -863,10 +864,12 @@ def multitask_mean(

]

mean_fn = np.nanmean if ignore_nan_metrics else np.mean

if metric in scale_dependent_metrics:
return gmean(scores, axis=axis)
return np.exp(mean_fn(np.log(scores), axis=axis))
elif metric in nonscale_dependent_metrics:
return np.mean(scores, axis=axis)
return mean_fn(scores, axis=axis)
else:
raise NotImplementedError(
f"The metric used, {metric}, has not been added to the list of\
Expand Down

0 comments on commit b08063b

Please sign in to comment.