Skip to content

Commit

Permalink
Fixing data modification and torch seed issues
Browse files Browse the repository at this point in the history
  • Loading branch information
swansonk14 committed Aug 28, 2020
1 parent 74fba47 commit 56bf598
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 5 deletions.
6 changes: 2 additions & 4 deletions chemprop/train/cross_validate.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from collections import defaultdict
from copy import deepcopy
import csv
from logging import Logger
import os
Expand Down Expand Up @@ -58,9 +59,6 @@ def cross_validate(args: TrainArgs,
# Save args
args.save(os.path.join(args.save_dir, 'args.json'))

# Set pytorch seed for random initial weights
torch.manual_seed(args.pytorch_seed)

# Get data
debug('Loading data')
data = get_data(path=args.data_path, args=args, logger=logger, skip_none_targets=True)
Expand All @@ -75,7 +73,7 @@ def cross_validate(args: TrainArgs,
args.seed = init_seed + fold_num
args.save_dir = os.path.join(save_dir, f'fold_{fold_num}')
makedirs(args.save_dir)
model_scores = train_func(args, data, logger)
model_scores = train_func(args, deepcopy(data), logger) # deepcopy since data may be modified
for metric, scores in model_scores.items():
all_scores[metric].append(scores)
all_scores = dict(all_scores)
Expand Down
4 changes: 3 additions & 1 deletion chemprop/train/run_training.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from logging import Logger
import os
import sys
from typing import Dict, List

import numpy as np
Expand Down Expand Up @@ -40,6 +39,9 @@ def run_training(args: TrainArgs,
else:
debug = info = print

# Set pytorch seed for random initial weights
torch.manual_seed(args.pytorch_seed)

# Split data
debug(f'Splitting data with seed {args.seed}')
if args.separate_test_path:
Expand Down

0 comments on commit 56bf598

Please sign in to comment.