Skip to content

Commit

Permalink
Adding Unit Testing and Some Changes to Integration Testing (#232)
Browse files Browse the repository at this point in the history
* url test error

* save the test results

* New test values

* Starting a unit test file

* Add **kwargs to patching

* get_header and pre_process_smiles tests

* change assert error

* Moving unit tests to a separate directory

* Tests for get_smiles

* add filter_invalid_smiles function

* Add pytest -v setting

* Add get_data tests

* split_data tests

* Remove an empty test file

* Move key index error to process args

* Remove standin for cv splitting

* Update test values for shorter number of epochs

* Increasing delta margin to 1.5%

* Add tests for split size argument errors
  • Loading branch information
cjmcgill committed Mar 21, 2022
1 parent 7aa766b commit 7b2ceb1
Show file tree
Hide file tree
Showing 6 changed files with 594 additions and 98 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -54,4 +54,4 @@ jobs:
- name: Test with pytest
shell: bash -l {0}
run: |
pytest
pytest -v
6 changes: 5 additions & 1 deletion chemprop/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,8 @@ class TrainArgs(CommonArgs):
split_sizes: List[float] = None
"""Split proportions for train/validation/test sets."""
split_key_molecule: int = 0
"""The index of the key molecule used for splitting when multiple molecules are present and constrained split_type is used, like scaffold_balanced or random_with_repeated_smiles."""
"""The index of the key molecule used for splitting when multiple molecules are present and constrained split_type is used, like scaffold_balanced or random_with_repeated_smiles.
Note that this index begins with zero for the first molecule."""
num_folds: int = 1
"""Number of folds when performing cross validation."""
folds_file: str = None
Expand Down Expand Up @@ -722,6 +723,9 @@ def process_args(self) -> None:
if min(self.target_weights) < 0:
raise ValueError('Provided target weights must be non-negative.')

# check if key molecule index is outside of the number of molecules
if self.split_key_molecule >= self.number_of_molecules:
raise ValueError('The index provided with the argument `--split_key_molecule` must be less than the number of molecules. Note that this index begins with 0 for the first molecule. ')

class PredictArgs(CommonArgs):
""":class:`PredictArgs` includes :class:`CommonArgs` along with additional arguments used for predicting with a Chemprop model."""
Expand Down
1 change: 1 addition & 0 deletions chemprop/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
'filter_invalid_smiles',
'get_class_sizes',
'get_data',
'get_data_weights',
'get_data_from_smiles',
'get_data_weights',
'get_invalid_smiles_from_file',
Expand Down
46 changes: 23 additions & 23 deletions chemprop/data/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from logging import Logger
import pickle
from random import Random
from typing import List, Optional, Set, Tuple, Union
from typing import List, Set, Tuple, Union
import os

from rdkit import Chem
Expand All @@ -15,12 +15,25 @@
from chemprop.args import PredictArgs, TrainArgs
from chemprop.features import load_features, load_valid_atom_or_bond_features, is_mol

def get_header(path: str) -> List[str]:
"""
Returns the header of a data CSV file.
:param path: Path to a CSV file.
:return: A list of strings containing the strings in the comma-separated header.
"""
with open(path) as f:
header = next(csv.reader(f))

return header


def preprocess_smiles_columns(path: str,
smiles_columns: Optional[Union[str, List[Optional[str]]]],
number_of_molecules: int = 1) -> List[Optional[str]]:
smiles_columns: Union[str, List[str]] = None,
number_of_molecules: int = 1) -> List[str]:
"""
Preprocesses the :code:`smiles_columns` variable to ensure that it is a list of column
headings corresponding to the columns in the data file holding SMILES.
headings corresponding to the columns in the data file holding SMILES. Assumes file has a header.
:param path: Path to a CSV file.
:param smiles_columns: The names of the columns containing SMILES.
Expand Down Expand Up @@ -84,19 +97,6 @@ def get_task_names(path: str,
return target_names


def get_header(path: str) -> List[str]:
"""
Returns the header of a data CSV file.
:param path: Path to a CSV file.
:return: A list of strings containing the strings in the comma-separated header.
"""
with open(path) as f:
header = next(csv.reader(f))

return header


def get_data_weights(path: str) -> List[float]:
"""
Returns the list of data weights for the loss function as stored in a CSV file.
Expand All @@ -120,6 +120,7 @@ def get_data_weights(path: str) -> List[float]:

def get_smiles(path: str,
smiles_columns: Union[str, List[str]] = None,
number_of_molecules: int = 1,
header: bool = True,
flatten: bool = False
) -> Union[List[str], List[List[str]]]:
Expand All @@ -129,22 +130,24 @@ def get_smiles(path: str,
:param path: Path to a CSV file.
:param smiles_columns: A list of the names of the columns containing SMILES.
By default, uses the first :code:`number_of_molecules` columns.
:param number_of_molecules: The number of molecules for each data point. Not necessary if
the names of smiles columns are previously processed.
:param header: Whether the CSV file contains a header.
:param flatten: Whether to flatten the returned SMILES to a list instead of a list of lists.
:return: A list of SMILES or a list of lists of SMILES, depending on :code:`flatten`.
"""
if smiles_columns is not None and not header:
raise ValueError('If smiles_column is provided, the CSV file must have a header.')

if not isinstance(smiles_columns, list):
smiles_columns = preprocess_smiles_columns(path=path, smiles_columns=smiles_columns)
if not isinstance(smiles_columns, list) and header:
smiles_columns = preprocess_smiles_columns(path=path, smiles_columns=smiles_columns, number_of_molecules=number_of_molecules)

with open(path) as f:
if header:
reader = csv.DictReader(f)
else:
reader = csv.reader(f)
smiles_columns = 0
smiles_columns = list(range(number_of_molecules))

smiles = [[row[c] for c in smiles_columns] for row in reader]

Expand Down Expand Up @@ -517,9 +520,6 @@ def split_data(data: MoleculeDataset,
else:
folds_file = val_fold_index = test_fold_index = None

if key_molecule_index >= args.number_of_molecules:
raise ValueError('The index provided with the argument `--split_key_molecule` must be less than the number of molecules. Note that this index begins with 0 for the first molecule. ')

if split_type == 'crossval':
index_set = args.crossval_index_sets[args.seed]
data_split = []
Expand Down

0 comments on commit 7b2ceb1

Please sign in to comment.