Skip to content

Commit

Permalink
Reset features and targets after each fold
Browse files Browse the repository at this point in the history
  • Loading branch information
swansonk14 committed Sep 4, 2020
1 parent f152dbb commit 4c8f5e7
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 9 deletions.
15 changes: 8 additions & 7 deletions chemprop/data/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,10 @@ def set_targets(self, targets: List[Optional[float]]):
"""
self.targets = targets

def reset_features_and_targets(self) -> None:
"""Resets the features and targets to their raw values."""
self.features, self.targets = self.raw_features, self.raw_targets


class MoleculeDataset(Dataset):
r"""A :class:`MoleculeDataset` contains a list of :class:`MoleculeDatapoint`\ s with access to their attributes."""
Expand Down Expand Up @@ -250,13 +254,10 @@ def set_targets(self, targets: List[List[Optional[float]]]) -> None:
for i in range(len(self._data)):
self._data[i].set_targets(targets[i])

def sort(self, key: Callable) -> None:
"""
Sorts the dataset using the provided key.
:param key: A function on a :class:`MoleculeDatapoint` to determine the sorting order.
"""
self._data.sort(key=key)
def reset_features_and_targets(self) -> None:
"""Resets the features and targets to their raw values."""
for d in self._data:
d.reset_features_and_targets()

def __len__(self) -> int:
"""
Expand Down
2 changes: 1 addition & 1 deletion chemprop/train/cross_validate.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from collections import defaultdict
from copy import deepcopy
import csv
from logging import Logger
import os
Expand Down Expand Up @@ -72,6 +71,7 @@ def cross_validate(args: TrainArgs,
args.seed = init_seed + fold_num
args.save_dir = os.path.join(save_dir, f'fold_{fold_num}')
makedirs(args.save_dir)
data.reset_features_and_targets()

This comment has been minimized.

Copy link
@fhvermei

fhvermei Sep 10, 2020

Contributor

RDKit features are removed and not generated again?
The following test fails:
python train.py --data_path data/freesolv.csv --dataset_type regression --epochs 5 --features_generator rdkit_2d_normalized --no_features_scaling

This comment has been minimized.

Copy link
@swansonk14

swansonk14 Sep 12, 2020

Author Contributor

Good catch! I just made a fix here: 3472968

model_scores = train_func(args, data, logger)
for metric, scores in model_scores.items():
all_scores[metric].append(scores)
Expand Down
1 change: 0 additions & 1 deletion chemprop/train/run_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,6 @@ def run_training(args: TrainArgs,
if args.dataset_type == 'regression':
debug('Fitting scaler')
scaler = train_data.normalize_targets()
print(scaler.means, scaler.stds)
else:
scaler = None

Expand Down

0 comments on commit 4c8f5e7

Please sign in to comment.