Skip to content

Commit

Permalink
Adding atom_features_size property to CommonArgs
Browse files Browse the repository at this point in the history
  • Loading branch information
swansonk14 committed Sep 12, 2020
1 parent 41bad1c commit 7e09f31
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 2 deletions.
10 changes: 10 additions & 0 deletions chemprop/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ class CommonArgs(Tap):

def __init__(self, *args, **kwargs):
super(CommonArgs, self).__init__(*args, **kwargs)
self._atom_features_size = 0
self._atom_descriptors_size = 0

@property
Expand Down Expand Up @@ -129,6 +130,15 @@ def features_scaling(self) -> bool:
"""Whether to apply normalization with a :class:`~chemprop.data.scaler.StandardScaler` to the additional molecule-level features."""
return not self.no_features_scaling

@property
def atom_features_size(self) -> int:
"""The size of the atom features."""
return self._atom_features_size

@atom_features_size.setter
def atom_features_size(self, atom_features_size: int) -> None:
self._atom_features_size = atom_features_size

@property
def atom_descriptors_size(self) -> int:
"""The size of the atom descriptors."""
Expand Down
4 changes: 2 additions & 2 deletions chemprop/train/cross_validate.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,13 +68,13 @@ def cross_validate(args: TrainArgs,
)
validate_dataset_type(data, dataset_type=args.dataset_type)
args.features_size = data.features_size()
args.atom_features_size = 0

if args.atom_descriptors == 'descriptor':
args.atom_descriptors_size = data.atom_descriptors_size()
args.ffn_hidden_size += args.atom_descriptors_size
elif args.atom_descriptors == 'feature':
set_extra_atom_fdim(data.atom_features_size())
args.atom_features_size = data.atom_features_size()
set_extra_atom_fdim(args.atom_features_size)

debug(f'Number of tasks = {args.num_tasks}')

Expand Down

0 comments on commit 7e09f31

Please sign in to comment.