Skip to content

Commit

Permalink
Save smiles splits in sklearn_train
Browse files Browse the repository at this point in the history
  • Loading branch information
swansonk14 committed Sep 10, 2020
1 parent f6da997 commit 99fd202
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 10 deletions.
13 changes: 11 additions & 2 deletions chemprop/sklearn_train.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from logging import Logger
import os
import pickle
from pprint import pformat
from typing import Dict, List, Union

import numpy as np
Expand All @@ -10,9 +9,10 @@
from tqdm import trange, tqdm

from chemprop.args import SklearnTrainArgs
from chemprop.data import get_data, get_task_names, MoleculeDataset, split_data
from chemprop.data import MoleculeDataset, split_data
from chemprop.features import get_features_generator
from chemprop.train import cross_validate, evaluate_predictions
from chemprop.utils import save_smiles_splits


def predict(model: Union[RandomForestRegressor, RandomForestClassifier, SVR, SVC],
Expand Down Expand Up @@ -192,6 +192,15 @@ def run_sklearn(args: SklearnTrainArgs,
args=args
)

if args.save_smiles_splits:
save_smiles_splits(
data_path=args.data_path,
save_dir=args.save_dir,
train_data=train_data,
test_data=test_data,
smiles_column=args.smiles_column
)

debug(f'Total size = {len(data):,} | train size = {len(train_data):,} | test size = {len(test_data):,}')

debug('Computing morgan fingerprints')
Expand Down
4 changes: 2 additions & 2 deletions chemprop/train/run_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,11 +67,11 @@ def run_training(args: TrainArgs,

if args.save_smiles_splits:
save_smiles_splits(
data_path=args.data_path,
save_dir=args.save_dir,
train_data=train_data,
val_data=val_data,
test_data=test_data,
data_path=args.data_path,
save_dir=args.save_dir,
smiles_column=args.smiles_column
)

Expand Down
15 changes: 9 additions & 6 deletions chemprop/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,21 +392,21 @@ def wrap(*args, **kwargs) -> Any:
return timeit_decorator


def save_smiles_splits(train_data: MoleculeDataset,
val_data: MoleculeDataset,
test_data: MoleculeDataset,
data_path: str,
def save_smiles_splits(data_path: str,
save_dir: str,
train_data: MoleculeDataset = None,
val_data: MoleculeDataset = None,
test_data: MoleculeDataset = None,
smiles_column: str = None) -> None:
"""
Saves indices of train/val/test split as a pickle file.
:param data_path: Path to data CSV file.
:param save_dir: Path where pickle files will be saved.
:param train_data: Train :class:`~chemprop.data.data.MoleculeDataset`.
:param val_data: Validation :class:`~chemprop.data.data.MoleculeDataset`.
:param test_data: Test :class:`~chemprop.data.data.MoleculeDataset`.
:param data_path: Path to data CSV file.
:param smiles_column: The name of the column containing SMILES. By default, uses the first column.
:param save_dir: Path where pickle files will be saved.
"""
makedirs(save_dir)

Expand All @@ -428,6 +428,9 @@ def save_smiles_splits(train_data: MoleculeDataset,

all_split_indices = []
for dataset, name in [(train_data, 'train'), (val_data, 'val'), (test_data, 'test')]:
if dataset is None:
continue

with open(os.path.join(save_dir, f'{name}_smiles.csv'), 'w') as f:
writer = csv.writer(f)
writer.writerow(['smiles'])
Expand Down

0 comments on commit 99fd202

Please sign in to comment.