Skip to content

Commit

Permalink
Fixing test data path
Browse files Browse the repository at this point in the history
  • Loading branch information
swansonk14 committed Jul 25, 2020
1 parent d006680 commit 90e471d
Showing 1 changed file with 5 additions and 4 deletions.
9 changes: 5 additions & 4 deletions tests/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from chemprop.train import chemprop_train, chemprop_predict


TEST_DATA_DIR = 'tests/data'
SEED = 0
EPOCHS = 10
NUM_FOLDS = 3
Expand All @@ -24,7 +25,7 @@ def create_raw_train_args(dataset_type: str, metric: str, save_dir: str) -> List
"""Creates a list of raw command line arguments for training."""
return [
'chemprop_train', # Note: not actually used, just a placeholder
'--data_path', f'data/{dataset_type}.csv',
'--data_path', os.path.join(TEST_DATA_DIR, f'{dataset_type}.csv'),
'--dataset_type', dataset_type,
'--epochs', str(EPOCHS),
'--num_folds', str(NUM_FOLDS),
Expand All @@ -39,7 +40,7 @@ def create_raw_predict_args(dataset_type: str, preds_path: str, checkpoint_dir:
"""Creates a list of raw command line arguments for predicting."""
return [
'chemprop_predict', # Note: not actually used, just a placeholder
'--test_path', f'data/{dataset_type}_test_smiles.csv',
'--test_path', os.path.join(TEST_DATA_DIR, f'{dataset_type}_test_smiles.csv'),
'--preds_path', preds_path,
'--checkpoint_dir', checkpoint_dir
]
Expand Down Expand Up @@ -108,7 +109,7 @@ def test_chemprop_predict_single_task_regression(self):

# Check results
pred = pd.read_csv(preds_path)
true = pd.read_csv(f'data/{dataset_type}_test_true.csv')
true = pd.read_csv(os.path.join(TEST_DATA_DIR, f'{dataset_type}_test_true.csv'))
self.assertEqual(list(pred.keys()), list(true.keys()))
self.assertEqual(list(pred['smiles']), list(true['smiles']))

Expand All @@ -129,7 +130,7 @@ def test_chemprop_predict_multi_task_classification(self):

# Check results
pred = pd.read_csv(preds_path)
true = pd.read_csv(f'data/{dataset_type}_test_true.csv')
true = pd.read_csv(os.path.join(TEST_DATA_DIR, f'{dataset_type}_test_true.csv'))
self.assertEqual(list(pred.keys()), list(true.keys()))
self.assertEqual(list(pred['smiles']), list(true['smiles']))

Expand Down

0 comments on commit 90e471d

Please sign in to comment.