Skip to content

Commit

Permalink
add random_state
Browse files Browse the repository at this point in the history
  • Loading branch information
Carlos Hernandez committed Aug 5, 2016
1 parent 1a11835 commit 6cba6c7
Show file tree
Hide file tree
Showing 7 changed files with 27 additions and 7 deletions.
7 changes: 7 additions & 0 deletions osprey/config.py
Expand Up @@ -13,6 +13,7 @@
serialized to a database specified in this section.
- cv: specification for cross-validation.
- scoring: the score function used in cross-validation. (optional)
- random_seed: random seed to be used. (optional)
"""

import sys
Expand Down Expand Up @@ -50,6 +51,7 @@
'strategy': ['name', 'params'],
'cv': (int, dict),
'scoring': (str, type(None)),
'random_seed': (int, type(None)),
}


Expand Down Expand Up @@ -309,6 +311,11 @@ def scoring(self):
assert isinstance(scoring, (str, type(None)))
return scoring

def random_seed(self):
random_seed = self.get_section('random_seed')
assert isinstance(random_seed, (int, type(None)))
return random_seed

def cv(self, X, y=None):
cv = self.get_section('cv')
if isinstance(cv, int):
Expand Down
4 changes: 3 additions & 1 deletion osprey/data/default_config.yaml
@@ -1,4 +1,6 @@
trials:
project_name: default

scoring: !!null
scoring: !!null

random_seed: !!null
2 changes: 2 additions & 0 deletions osprey/data/grid_example.yaml
Expand Up @@ -29,3 +29,5 @@ dataset_loader:

trials:
uri: sqlite:///osprey-trials.db

random_seed: 42
2 changes: 2 additions & 0 deletions osprey/data/random_example.yaml
Expand Up @@ -25,3 +25,5 @@ dataset_loader:

trials:
uri: sqlite:///osprey-trials.db

random_seed: 42
2 changes: 2 additions & 0 deletions osprey/data/sklearn_skeleton_config.yaml
Expand Up @@ -27,3 +27,5 @@ dataset_loader:

trials:
uri: sqlite:///osprey-trials.db

random_seed: 42
10 changes: 6 additions & 4 deletions osprey/execute_worker.py
Expand Up @@ -35,6 +35,7 @@ def execute(args, parser):
strategy = config.strategy()
config_sha1 = config.sha1()
scoring = config.scoring()
random_seed = config.random_seed()
project_name = config.project_name()

if is_msmbuilder_estimator(estimator):
Expand Down Expand Up @@ -78,7 +79,7 @@ def signal_hander(signum, frame):

s = run_single_trial(
estimator=estimator, params=params, trial_id=trial_id,
scoring=scoring, X=X, y=y, cv=cv,
scoring=scoring, random_seed=random_seed, X=X, y=y, cv=cv,
sessionbuilder=config.trialscontext)

statuses[i] = s
Expand Down Expand Up @@ -122,14 +123,15 @@ def initialize_trial(strategy, searchspace, estimator, config_sha1,
return trial_id, params


def run_single_trial(estimator, params, trial_id, scoring, X, y, cv,
sessionbuilder):
def run_single_trial(estimator, params, trial_id, scoring, random_seed,
X, y, cv, sessionbuilder):

status = None

try:
score = fit_and_score_estimator(
estimator, params, cv=cv, scoring=scoring, X=X, y=y, verbose=1)
estimator, params, cv=cv, scoring=scoring, random_seed=random_seed,
X=X, y=y, verbose=1)
with sessionbuilder() as session:
trial = session.query(Trial).get(trial_id)
trial.mean_test_score = score['mean_test_score']
Expand Down
7 changes: 5 additions & 2 deletions osprey/fit_estimator.py
Expand Up @@ -19,8 +19,8 @@


def fit_and_score_estimator(estimator, parameters, cv, X, y=None, scoring=None,
iid=True, n_jobs=1, verbose=1,
pre_dispatch='2*n_jobs'):
random_seed=None, iid=True, n_jobs=1,
verbose=1, pre_dispatch='2*n_jobs'):
"""Fit and score an estimator with cross-validation
This function is basically a copy of sklearn's
Expand All @@ -39,6 +39,9 @@ def fit_and_score_estimator(estimator, parameters, cv, X, y=None, scoring=None,
The scores on the training and test sets, as well as the mean test set
score.
"""

np.random.seed(random_seed)

scorer = check_scoring(estimator, scoring=scoring)
n_samples = num_samples(X)
X, y = check_arrays(X, y, allow_lists=True, sparse_format='csr',
Expand Down

0 comments on commit 6cba6c7

Please sign in to comment.