Skip to content

Commit

Permalink
Merge pull request #374 from shihchengli/fix_constraints
Browse files Browse the repository at this point in the history
Set atom and bond constraints when loading model
  • Loading branch information
oscarwumit committed Mar 22, 2023
2 parents 44af7b1 + 04c8590 commit 0c129b9
Showing 1 changed file with 13 additions and 9 deletions.
22 changes: 13 additions & 9 deletions chemprop/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,8 @@ def __init__(self, *args, **kwargs):
self._bond_features_size = 0
self._atom_descriptors_size = 0
self._bond_descriptors_size = 0
self._atom_constraints = []
self._bond_constraints = []

@property
def device(self) -> torch.device:
Expand Down Expand Up @@ -598,23 +600,25 @@ def atom_constraints(self) -> List[bool]:
A list of booleans indicating whether constraints applied to output of atomic properties.
"""
if self.is_atom_bond_targets and self.constraints_path:
header = chemprop.data.utils.get_header(self.constraints_path)
atom_constraints = [target in header for target in self.atom_targets]
if not self._atom_constraints:
header = chemprop.data.utils.get_header(self.constraints_path)
self._atom_constraints = [target in header for target in self.atom_targets]
else:
atom_constraints = [False] * len(self.atom_targets)
return atom_constraints
self._atom_constraints = [False] * len(self.atom_targets)
return self._atom_constraints

@property
def bond_constraints(self) -> List[bool]:
"""
A list of booleans indicating whether constraints applied to output of bond properties.
"""
if self.is_atom_bond_targets and self.constraints_path:
header = chemprop.data.utils.get_header(self.constraints_path)
bond_constraints = [target in header for target in self.bond_targets]
if not self._bond_constraints:
header = chemprop.data.utils.get_header(self.constraints_path)
self._bond_constraints = [target in header for target in self.bond_targets]
else:
bond_constraints = [False] * len(self.bond_targets)
return bond_constraints
self._bond_constraints = [False] * len(self.bond_targets)
return self._bond_constraints

def process_args(self) -> None:
super(TrainArgs, self).process_args()
Expand Down

0 comments on commit 0c129b9

Please sign in to comment.