Skip to content

Commit

Permalink
Merge pull request #337 from shihchengli/save_features_npy
Browse files Browse the repository at this point in the history
Save loaded molecular features into .npy files
  • Loading branch information
oscarwumit committed Feb 4, 2023
2 parents ad97341 + 22ea9c9 commit a952312
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 16 deletions.
23 changes: 14 additions & 9 deletions chemprop/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -603,11 +603,13 @@ def save_smiles_splits(

features_header = []
if features_path is not None:
for feat_path in features_path:
with open(feat_path, "r") as f:
reader = csv.reader(f)
feat_header = next(reader)
features_header.extend(feat_header)
extension_sets = set([os.path.splitext(feat_path)[1] for feat_path in features_path])
if extension_sets == {'.csv'}:
for feat_path in features_path:
with open(feat_path, "r") as f:
reader = csv.reader(f)
feat_header = next(reader)
features_header.extend(feat_header)

all_split_indices = []
for dataset, name in [(train_data, "train"), (val_data, "val"), (test_data, "test")]:
Expand All @@ -632,10 +634,13 @@ def save_smiles_splits(

if features_path is not None:
dataset_features = dataset.features()
with open(os.path.join(save_dir, f"{name}_features.csv"), "w") as f:
writer = csv.writer(f)
writer.writerow(features_header)
writer.writerows(dataset_features)
if extension_sets == {'.csv'}:
with open(os.path.join(save_dir, f"{name}_features.csv"), "w") as f:
writer = csv.writer(f)
writer.writerow(features_header)
writer.writerows(dataset_features)
else:
np.save(os.path.join(save_dir, f"{name}_features.npy"), dataset_features)

if save_split_indices:
split_indices = []
Expand Down
14 changes: 7 additions & 7 deletions tests/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,7 @@ def fingerprint(self,
'chemprop',
'rmse',
2.14015989,
['--features_path', os.path.join(TEST_DATA_DIR, 'regression.npz'), '--no_features_scaling']
['--features_path', os.path.join(TEST_DATA_DIR, 'regression.npz'), '--no_features_scaling', '--save_smiles_splits']
),
(
'chemprop_features_generator_features_path',
Expand Down Expand Up @@ -322,7 +322,7 @@ def test_train_single_task_regression(self,
'chemprop',
'auc',
0.466828424,
['--features_path', os.path.join(TEST_DATA_DIR, 'classification.npz'), '--no_features_scaling', '--class_balance', '--split_sizes', '0.4', '0.3', '0.3']
['--features_path', os.path.join(TEST_DATA_DIR, 'classification.npz'), '--no_features_scaling', '--class_balance', '--split_sizes', '0.4', '0.3', '0.3', '--save_smiles_splits']
),
(
'chemprop_features_generator_features_path',
Expand Down Expand Up @@ -416,7 +416,7 @@ def test_train_multi_task_classification(self,
'chemprop_rdkit_features_path',
'chemprop',
1.51978455,
['--features_path', os.path.join(TEST_DATA_DIR, 'regression.npz'), '--no_features_scaling'],
['--features_path', os.path.join(TEST_DATA_DIR, 'regression.npz'), '--no_features_scaling', '--save_smiles_splits'],
['--features_path', os.path.join(TEST_DATA_DIR, 'regression_test.npz'), '--no_features_scaling']
),
(
Expand Down Expand Up @@ -508,7 +508,7 @@ def test_predict_individual_ensemble(self):
'chemprop_rdkit_features_path',
'chemprop',
0.307159229,
['--features_path', os.path.join(TEST_DATA_DIR, 'classification.npz'), '--no_features_scaling', '--class_balance', '--split_sizes', '0.4', '0.3', '0.3'],
['--features_path', os.path.join(TEST_DATA_DIR, 'classification.npz'), '--no_features_scaling', '--class_balance', '--split_sizes', '0.4', '0.3', '0.3', '--save_smiles_splits'],
['--features_path', os.path.join(TEST_DATA_DIR, 'classification_test.npz'), '--no_features_scaling']
),
(
Expand Down Expand Up @@ -681,7 +681,7 @@ def test_chemprop_web(self):
[
'--data_path', os.path.join(TEST_DATA_DIR, 'spectra.csv'),
'--features_path', os.path.join(TEST_DATA_DIR, 'spectra_features.csv'),
'--split_type', 'random_with_repeated_smiles'
'--split_type', 'random_with_repeated_smiles', '--save_smiles_splits'
]
),
(
Expand All @@ -691,7 +691,7 @@ def test_chemprop_web(self):
[
'--data_path', os.path.join(TEST_DATA_DIR, 'spectra_exclusions.csv'),
'--features_path', os.path.join(TEST_DATA_DIR, 'spectra_features.csv'),
'--split_type', 'random_with_repeated_smiles'
'--split_type', 'random_with_repeated_smiles', '--save_smiles_splits'
]
),
(
Expand Down Expand Up @@ -739,7 +739,7 @@ def test_train_spectra(self,
[
'--data_path', os.path.join(TEST_DATA_DIR, 'spectra.csv'),
'--features_path', os.path.join(TEST_DATA_DIR, 'spectra_features.csv'),
'--split_type', 'random_with_repeated_smiles'
'--split_type', 'random_with_repeated_smiles', '--save_smiles_splits'
],
[
'--features_path', os.path.join(TEST_DATA_DIR, 'spectra_features.csv'),
Expand Down

0 comments on commit a952312

Please sign in to comment.