Skip to content

Commit

Permalink
Merge pull request #456 from soulios/mol_weight_split
Browse files Browse the repository at this point in the history
new split per molecular weight
  • Loading branch information
kevingreenman committed Sep 20, 2023
2 parents f3d1bff + 4794245 commit b9fc805
Show file tree
Hide file tree
Showing 4 changed files with 136 additions and 75 deletions.
2 changes: 1 addition & 1 deletion chemprop/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,7 +272,7 @@ class TrainArgs(CommonArgs):
"""Path to weights for each molecule in the training data, affecting the relative weight of molecules in the loss function"""
target_weights: List[float] = None
"""Weights associated with each target, affecting the relative weight of targets in the loss function. Must match the number of target columns."""
split_type: Literal['random', 'scaffold_balanced', 'predetermined', 'crossval', 'cv', 'cv-no-test', 'index_predetermined', 'random_with_repeated_smiles'] = 'random'
split_type: Literal['random', 'scaffold_balanced', 'predetermined', 'crossval', 'cv', 'cv-no-test', 'index_predetermined', 'random_with_repeated_smiles', 'molecular_weight'] = 'random'
"""Method of splitting the data into train/val/test."""
split_sizes: List[float] = None
"""Split proportions for train/validation/test sets."""
Expand Down
10 changes: 9 additions & 1 deletion chemprop/data/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ def __init__(self,
# for H2
elif m is not None and m.GetNumHeavyAtoms() == 0:
# not all features are equally long, so use methane as dummy molecule to determine length
self.features.extend(np.zeros(len(features_generator(Chem.MolFromSmiles('C')))))
self.features.extend(np.zeros(len(features_generator(Chem.MolFromSmiles('C')))))
else:
if m[0] is not None and m[1] is not None and m[0].GetNumHeavyAtoms() > 0:
self.features.extend(features_generator(m[0]))
Expand Down Expand Up @@ -221,6 +221,14 @@ def bond_types(self) -> List[List[float]]:
:return: A list of bond types for each molecule.
"""
return [[b.GetBondTypeAsDouble() for b in self.mol[i].GetBonds()] for i in range(self.number_of_molecules)]
@property
def max_molwt(self) -> float:
"""
Gets the maximum molecular weight among all the molecules in the :class:`MoleculeDatapoint`.
:return: The maximum molecular weight.
"""
return max(Chem.rdMolDescriptors.CalcExactMolWt(mol) for mol in self.mol)

def set_features(self, features: np.ndarray) -> None:
"""
Expand Down
18 changes: 17 additions & 1 deletion chemprop/data/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -657,7 +657,6 @@ def get_inequality_targets(path: str, target_columns: List[str] = None) -> List[

return gt_targets, lt_targets


def split_data(data: MoleculeDataset,
split_type: str = 'random',
sizes: Tuple[float, float, float] = (0.8, 0.1, 0.1),
Expand Down Expand Up @@ -819,7 +818,24 @@ def split_data(data: MoleculeDataset,
test = [data[i] for i in indices[train_val_size:]]

return MoleculeDataset(train), MoleculeDataset(val), MoleculeDataset(test)
elif split_type == 'molecular_weight':
train_size, val_size, test_size = [int(size * len(data)) for size in sizes]

sorted_data = sorted(data._data, key=lambda x: x.max_molwt, reverse=False)
indices = list(range(len(sorted_data)))

train_end_idx = int(train_size)
val_end_idx = int(train_size + val_size)
train_indices = indices[:train_end_idx]
val_indices = indices[train_end_idx:val_end_idx]
test_indices = indices[val_end_idx:]

# Create MoleculeDataset for each split
train = MoleculeDataset([sorted_data[i] for i in train_indices])
val = MoleculeDataset([sorted_data[i] for i in val_indices])
test = MoleculeDataset([sorted_data[i] for i in test_indices])

return train, val, test
else:
raise ValueError(f'split_type "{split_type}" not supported.')

Expand Down

0 comments on commit b9fc805

Please sign in to comment.