Skip to content

Commit

Permalink
New loss functions and metrics (#238)
Browse files Browse the repository at this point in the history
* Split metrics and loss functions off of utils

* Initial add of f1 and mcc metrics

* Add metrics args

* Set f1 for multiclass to use micro averaging

* Change args input and of loss function and get_loss_func

* Remove data_weight as a default datapoint attribute

* Add f1, mcc, and bounded_mse loss functions

* Correct multiclass mcc loss

* Add get data functions for inequality targets

* Add bounded metrics

* Disable f1 loss function

* Update readme

* Move loss_functions and metrics into train directory

* Fix import error

* New loss function gpu support

* Correct a vector dimension error

* Add inequalities to get_data function

* More compact argument spacing in test_integration

* Overwrite the loss_function default None

* Correct errors in mcc loss implementation

* Add testing for mcc and bounded_mse losses

* Make bounded_mse default metric when used as loss function

* Fix variable argument name in get_data

* Remove f1 as a loss function altogether

* Fix multiclass mcc dimensionality

* Description for loss functions and small dimensionality change for mcc
  • Loading branch information
cjmcgill committed Feb 10, 2022
1 parent 8392031 commit 3302509
Show file tree
Hide file tree
Showing 16 changed files with 1,910 additions and 467 deletions.
135 changes: 75 additions & 60 deletions README.md

Large diffs are not rendered by default.

37 changes: 28 additions & 9 deletions chemprop/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from chemprop.features import get_available_features_generators


Metric = Literal['auc', 'prc-auc', 'rmse', 'mae', 'mse', 'r2', 'accuracy', 'cross_entropy', 'binary_cross_entropy', 'sid', 'wasserstein']
Metric = Literal['auc', 'prc-auc', 'rmse', 'mae', 'mse', 'r2', 'accuracy', 'cross_entropy', 'binary_cross_entropy', 'sid', 'wasserstein', 'f1', 'mcc', 'bounded_rmse', 'bounded_mae', 'bounded_mse']


def get_checkpoint_paths(checkpoint_path: Optional[str] = None,
Expand Down Expand Up @@ -228,7 +228,9 @@ class TrainArgs(CommonArgs):
ignore_columns: List[str] = None
"""Name of the columns to ignore when :code:`target_columns` is not provided."""
dataset_type: Literal['regression', 'classification', 'multiclass', 'spectra']
"""Type of dataset. This determines the loss function used during training."""
"""Type of dataset. This determines the default loss function used during training."""
loss_function: Literal['mse', 'bounded_mse', 'binary_cross_entropy','cross_entropy', 'mcc', 'sid', 'wasserstein'] = None
"""Choice of loss function. Loss functions are limited to compatible dataset types."""
multiclass_num_classes: int = 3
"""Number of classes when running multiclass classification."""
separate_val_path: str = None
Expand Down Expand Up @@ -398,8 +400,6 @@ class TrainArgs(CommonArgs):
"""Indicates which function to use in dataset_type spectra training to constrain outputs to be positive."""
spectra_target_floor: float = 1e-8
"""Values in targets for dataset type spectra are replaced with this value, intended to be a small positive number used to enforce positive values."""
alternative_loss_function: Literal['wasserstein'] = None
"""Option to replace the default loss function, with an alternative. Only currently applied for spectra data type and wasserstein loss."""
overwrite_default_atom_features: bool = False
"""
Overwrites the default atom descriptors with the new ones instead of concatenating them.
Expand Down Expand Up @@ -441,7 +441,7 @@ def metrics(self) -> List[str]:
@property
def minimize_score(self) -> bool:
"""Whether the model should try to minimize the score metric or maximize it."""
return self.metric in {'rmse', 'mae', 'mse', 'cross_entropy', 'binary_cross_entropy', 'sid', 'wasserstein'}
return self.metric in {'rmse', 'mae', 'mse', 'cross_entropy', 'binary_cross_entropy', 'sid', 'wasserstein', 'bounded_mse', 'bounded_mae', 'bounded_rmse'}

@property
def use_input_features(self) -> bool:
Expand Down Expand Up @@ -542,19 +542,38 @@ def process_args(self) -> None:
self.metric = 'cross_entropy'
elif self.dataset_type == 'spectra':
self.metric = 'sid'
else:
elif self.dataset_type == 'regression' and self.loss_function == 'bounded_mse':
self.metric = 'bounded_mse'
elif self.dataset_type == 'regression':
self.metric = 'rmse'
else:
raise ValueError(f'Dataset type {self.dataset_type} is not supported.')

if self.metric in self.extra_metrics:
raise ValueError(f'Metric {self.metric} is both the metric and is in extra_metrics. '
f'Please only include it once.')

for metric in self.metrics:
if not any([(self.dataset_type == 'classification' and metric in ['auc', 'prc-auc', 'accuracy', 'binary_cross_entropy']),
(self.dataset_type == 'regression' and metric in ['rmse', 'mae', 'mse', 'r2']),
(self.dataset_type == 'multiclass' and metric in ['cross_entropy', 'accuracy']),
if not any([(self.dataset_type == 'classification' and metric in ['auc', 'prc-auc', 'accuracy', 'binary_cross_entropy', 'f1', 'mcc']),
(self.dataset_type == 'regression' and metric in ['rmse', 'mae', 'mse', 'r2', 'bounded_rmse', 'bounded_mae', 'bounded_mse']),
(self.dataset_type == 'multiclass' and metric in ['cross_entropy', 'accuracy', 'f1', 'mcc']),
(self.dataset_type == 'spectra' and metric in ['sid','wasserstein'])]):
raise ValueError(f'Metric "{metric}" invalid for dataset type "{self.dataset_type}".')

if self.loss_function is None:
if self.dataset_type == 'classification':
self.loss_function = 'binary_cross_entropy'
elif self.dataset_type == 'multiclass':
self.loss_function = 'cross_entropy'
elif self.dataset_type == 'spectra':
self.loss_function = 'sid'
elif self.dataset_type == 'regression':
self.loss_function = 'mse'
else:
raise ValueError(f'Default loss function not configured for dataset type {self.dataset_type}.')

if self.loss_function != 'bounded_mse' and any(metric in ['bounded_mse', 'bounded_rmse', 'bounded_mae'] for metric in self.metrics):
raise ValueError('Bounded metrics can only be used in conjunction with the regression loss function bounded_mse.')

# Validate class balance
if self.class_balance and self.dataset_type != 'classification':
Expand Down
69 changes: 65 additions & 4 deletions chemprop/data/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,9 @@ def __init__(self,
smiles: List[str],
targets: List[Optional[float]] = None,
row: OrderedDict = None,
data_weight: float = 1,
data_weight: float = None,
gt_targets: List[bool] = None,
lt_targets: List[bool] = None,
features: np.ndarray = None,
features_generator: List[str] = None,
phase_features: List[float] = None,
Expand All @@ -72,6 +74,8 @@ def __init__(self,
:param targets: A list of targets for the molecule (contains None for unknown target values).
:param row: The raw CSV row containing the information for this molecule.
:param data_weight: Weighting of the datapoint for the loss function.
:param gt_targets: Indicates whether the targets are an inequality regression target of the form ">x".
:param lt_targets: Indicates whether the targets are an inequality regression target of the form "<x".
:param features: A numpy array containing additional features (e.g., Morgan fingerprint).
:param features_generator: A list of features generators to use.
:param phase_features: A one-hot vector indicating the phase of the data, as used in spectra data.
Expand All @@ -87,7 +91,6 @@ def __init__(self,
self.smiles = smiles
self.targets = targets
self.row = row
self.data_weight = data_weight
self.features = features
self.features_generator = features_generator
self.phase_features = phase_features
Expand All @@ -99,7 +102,13 @@ def __init__(self,
self.is_reaction = is_reaction()
self.is_explicit_h = is_explicit_h()
self.is_adding_hs = is_adding_hs()


if data_weight is not None:
self.data_weight = data_weight
if gt_targets is not None:
self.gt_targets = gt_targets
if lt_targets is not None:
self.lt_targets = lt_targets

# Generate additional features if given a generator
if self.features_generator is not None:
Expand Down Expand Up @@ -372,8 +381,11 @@ def bond_features(self) -> List[np.ndarray]:

def data_weights(self) -> List[float]:
"""
Returns the loss weighting associated with each molecule
Returns the loss weighting associated with each datapoint.
"""
if not hasattr(self._data[0], 'data_weight'):
return [1. for d in self._data]

return [d.data_weight for d in self._data]

def targets(self) -> List[List[Optional[float]]]:
Expand All @@ -384,6 +396,24 @@ def targets(self) -> List[List[Optional[float]]]:
"""
return [d.targets for d in self._data]

def gt_targets(self) -> List[np.ndarray]:
"""
"""
if not hasattr(self._data[0], 'gt_targets'):
return None

return [d.gt_targets for d in self._data]

def lt_targets(self) -> List[np.ndarray]:
"""
"""
if not hasattr(self._data[0], 'lt_targets'):
return None

return [d.lt_targets for d in self._data]

def num_tasks(self) -> int:
"""
Returns the number of prediction tasks.
Expand Down Expand Up @@ -669,6 +699,37 @@ def targets(self) -> List[List[Optional[float]]]:

return [self._dataset[index].targets for index in self._sampler]

@property
def gt_targets(self) -> List[List[Optional[bool]]]:
"""
Returns booleans for whether each target is an inequality rather than a value target, associated with each molecule.
:return: A list of lists of booleans (or None) containing the targets.
"""
if self._class_balance or self._shuffle:
raise ValueError('Cannot safely extract targets when class balance or shuffle are enabled.')

if not hasattr(self._dataset[0],'gt_targets'):
return None

return [self._dataset[index].gt_targets for index in self._sampler]

@property
def lt_targets(self) -> List[List[Optional[bool]]]:
"""
Returns booleans for whether each target is an inequality rather than a value target, associated with each molecule.
:return: A list of lists of booleans (or None) containing the targets.
"""
if self._class_balance or self._shuffle:
raise ValueError('Cannot safely extract targets when class balance or shuffle are enabled.')

if not hasattr(self._dataset[0],'lt_targets'):
return None

return [self._dataset[index].lt_targets for index in self._sampler]


@property
def iter_size(self) -> int:
"""Returns the number of data points included in each full iteration through the :class:`MoleculeDataLoader`."""
Expand Down
71 changes: 59 additions & 12 deletions chemprop/data/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,7 @@ def get_data(path: str,
max_data_size: int = None,
store_row: bool = False,
logger: Logger = None,
loss_function: str = None,
skip_none_targets: bool = False) -> MoleculeDataset:
"""
Gets SMILES and target values from a CSV file.
Expand All @@ -252,6 +253,7 @@ def get_data(path: str,
:param store_row: Whether to store the raw CSV row in each :class:`~chemprop.data.data.MoleculeDatapoint`.
:param skip_none_targets: Whether to skip targets that are all 'None'. This is mostly relevant when --target_columns
are passed in, so only a subset of tasks are examined.
:param loss_function: The loss function to be used in training.
:return: A :class:`~chemprop.data.MoleculeDataset` containing SMILES and target values along
with other info such as additional features when desired.
"""
Expand All @@ -270,6 +272,7 @@ def get_data(path: str,
bond_features_path = bond_features_path if bond_features_path is not None \
else args.bond_features_path
max_data_size = max_data_size if max_data_size is not None else args.max_data_size
loss_function = loss_function if loss_function is not None else args.loss_function

if not isinstance(smiles_columns, list):
smiles_columns = preprocess_smiles_columns(path=path, smiles_columns=smiles_columns)
Expand Down Expand Up @@ -303,24 +306,41 @@ def get_data(path: str,
else:
data_weights = None

# By default, the targets columns are all the columns except the SMILES column
if target_columns is None:
target_columns = get_task_names(
path=path,
smiles_columns=smiles_columns,
target_columns=target_columns,
ignore_columns=ignore_columns,
)

# Find targets provided as inequalities
if loss_function == 'bounded_mse':
gt_targets, lt_targets = get_inequality_targets(path=path, target_columns=target_columns)
else:
gt_targets, lt_targets = None, None

# Load data
with open(path) as f:
reader = csv.DictReader(f)

# By default, the targets columns are all the columns except the SMILES column
if target_columns is None:
target_columns = get_task_names(
path=path,
smiles_columns=smiles_columns,
target_columns=target_columns,
ignore_columns=ignore_columns,
)

all_smiles, all_targets, all_rows, all_features, all_phase_features, all_weights = [], [], [], [], [], []
all_smiles, all_targets, all_rows, all_features, all_phase_features, all_weights, all_gt, all_lt = [], [], [], [], [], [], [], []
for i, row in enumerate(tqdm(reader)):
smiles = [row[c] for c in smiles_columns]

targets = [float(row[column]) if row[column] not in ['','nan'] else None for column in target_columns]
targets = []
for column in target_columns:
value = row[column]
if value in ['','nan']:
targets.append(None)
elif '>' in value or '<' in value:
if loss_function == 'bounded_mse':
targets.append(float(value.strip('<>')))
else:
raise ValueError('Inequality found in target data. To use inequality targets (> or <), the regression loss function bounded_mse must be used.')
else:
targets.append(float(value))

# Check whether all targets are None and skip if so
if skip_none_targets and all(x is None for x in targets):
Expand All @@ -338,6 +358,12 @@ def get_data(path: str,
if data_weights is not None:
all_weights.append(data_weights[i])

if gt_targets is not None:
all_gt.append(gt_targets[i])

if lt_targets is not None:
all_lt.append(lt_targets[i])

if store_row:
all_rows.append(row)

Expand Down Expand Up @@ -369,7 +395,9 @@ def get_data(path: str,
smiles=smiles,
targets=targets,
row=all_rows[i] if store_row else None,
data_weight=all_weights[i] if data_weights is not None else 1.,
data_weight=all_weights[i] if data_weights is not None else None,
gt_targets=all_gt[i] if gt_targets is not None else None,
lt_targets=all_lt[i] if lt_targets is not None else None,
features_generator=features_generator,
features=all_features[i] if features_data is not None else None,
phase_features=all_phase_features[i] if phase_features is not None else None,
Expand Down Expand Up @@ -427,6 +455,25 @@ def get_data_from_smiles(smiles: List[List[str]],
return data


def get_inequality_targets(path: str, target_columns: List[str] = None) -> List[str]:
"""
"""
gt_targets = []
lt_targets = []

with open(path) as f:
reader = csv.DictReader(f)
for line in reader:
values = [line[col] for col in target_columns]
gt_targets.append(['>' in val for val in values])
lt_targets.append(['<' in val for val in values])
if any(['<' in val and '>' in val for val in values]):
raise ValueError(f'A target value in csv file {path} contains both ">" and "<" symbols. Inequality targets must be on one edge and not express a range.')

return gt_targets, lt_targets


def split_data(data: MoleculeDataset,
split_type: str = 'random',
sizes: Tuple[float, float, float] = (0.8, 0.1, 0.1),
Expand Down
12 changes: 8 additions & 4 deletions chemprop/models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,10 @@ def __init__(self, args: TrainArgs):

self.classification = args.dataset_type == 'classification'
self.multiclass = args.dataset_type == 'multiclass'

# when using cross entropy losses, no sigmoid or softmax during training. But they are needed for mcc loss.
if self.classification or self.multiclass:
self.no_training_normalization = args.loss_function in ['cross_entropy', 'binary_cross_entropy']

self.output_size = args.num_tasks
if self.multiclass:
Expand Down Expand Up @@ -172,12 +176,12 @@ def forward(self,
output = self.ffn(self.encoder(batch, features_batch, atom_descriptors_batch,
atom_features_batch, bond_features_batch))

# Don't apply sigmoid during training b/c using BCEWithLogitsLoss
if self.classification and not self.training:
# Don't apply sigmoid during training when using BCEWithLogitsLoss
if self.classification and not (self.training and self.no_training_normalization):
output = self.sigmoid(output)
if self.multiclass:
output = output.reshape((output.size(0), -1, self.num_classes)) # batch size x num targets x num classes per target
if not self.training:
output = self.multiclass_softmax(output) # to get probabilities during evaluation, but not during training as we're using CrossEntropyLoss
if not (self.training and self.no_training_normalization):
output = self.multiclass_softmax(output) # to get probabilities during evaluation, but not during training when using CrossEntropyLoss

return output

0 comments on commit 3302509

Please sign in to comment.