Skip to content

Commit

Permalink
make dill dependency optional
Browse files Browse the repository at this point in the history
  • Loading branch information
jakebian committed Feb 8, 2018
1 parent b31b7d8 commit 2990722
Show file tree
Hide file tree
Showing 5 changed files with 40 additions and 15 deletions.
12 changes: 9 additions & 3 deletions hyperopt/fmin.py
Expand Up @@ -3,7 +3,6 @@
from future import standard_library
from builtins import str
from builtins import object
import dill

import functools
import logging
Expand All @@ -21,6 +20,13 @@
logger = logging.getLogger(__name__)


try:
import dill as pickler
except Exception as e:
logger.info('Failed to load dill, try installing dill via "pip install dill" for enhanced pickling support.')
import six.moves.cPickle as pickler


def fmin_pass_expr_memo_ctrl(f):
"""
Mark a function as expecting kwargs 'expr', 'memo' and 'ctrl' from
Expand Down Expand Up @@ -73,9 +79,9 @@ def __init__(self, algo, domain, trials, rstate, async=None,
if self.async:
if 'FMinIter_Domain' in trials.attachments:
logger.warn('over-writing old domain trials attachment')
msg = dill.dumps(domain)
msg = pickler.dumps(domain)
# -- sanity check for unpickling
dill.loads(msg)
pickler.loads(msg)
trials.attachments['FMinIter_Domain'] = msg

def serial_evaluate(self, N=-1):
Expand Down
12 changes: 9 additions & 3 deletions hyperopt/main.py
Expand Up @@ -10,11 +10,17 @@
from . import utils
from .base import SerialExperiment
import sys
import dill

standard_library.install_aliases()
logger = logging.getLogger(__name__)


try:
import dill as pickler
except Exception as e:
logger.info('Failed to load dill, try installing dill via "pip install dill" for enhanced pickling support.')
import six.moves.cPickle as pickler

__authors__ = "James Bergstra"
__license__ = "3-clause BSD License"
__contact__ = "github.com/hyperopt/hyperopt"
Expand Down Expand Up @@ -71,7 +77,7 @@ def main_search():
if not options.load:
raise IOError()
handle = open(options.load, 'rb')
self = dill.load(handle)
self = pickler.load(handle)
handle.close()
except IOError:
bandit = utils.get_obj(bandit_json, argfile=options.bandit_argfile)
Expand All @@ -84,7 +90,7 @@ def main_search():
self.run(int(options.steps))
finally:
if options.save:
dill.dump(self, open(options.save, 'wb'))
pickler.dump(self, open(options.save, 'wb'))


def main(cmd, fn_pos=1):
Expand Down
17 changes: 11 additions & 6 deletions hyperopt/mongoexp.py
Expand Up @@ -127,7 +127,6 @@
import six
from six.moves import map
from six.moves import range
import dill

__authors__ = ["James Bergstra", "Dan Yamins"]
__license__ = "3-clause BSD License"
Expand All @@ -136,6 +135,12 @@
standard_library.install_aliases()
logger = logging.getLogger(__name__)

try:
import dill as pickler
except Exception as e:
logger.info('Failed to load dill, try installing dill via "pip install dill" for enhanced pickling support.')
import six.moves.cPickle as pickler


class OperationFailure(Exception):
"""Proxy that could be factored out if we also want to use CouchDB and
Expand Down Expand Up @@ -741,7 +746,7 @@ def refresh_tids(self, tids):
str(numpy.random.randint(1e8)) + '.pkl')
logger.error('HYPEROPT REFRESH ERROR: writing error file to %s' % reportpath)
_file = open(reportpath, 'w')
dill.dump({'db_data': db_data,
pickler.dump({'db_data': db_data,
'existing_data': existing_data},
_file)
_file.close()
Expand Down Expand Up @@ -1040,9 +1045,9 @@ def run_one(self,
cmd_protocol = cmd[0]
try:
if cmd_protocol == 'cpickled fn':
worker_fn = dill.loads(cmd[1])
worker_fn = pickler.loads(cmd[1])
elif cmd_protocol == 'call evaluate':
bandit = dill.loads(cmd[1])
bandit = pickler.loads(cmd[1])
worker_fn = bandit.evaluate
elif cmd_protocol == 'token_load':
cmd_toks = cmd[1].split('.')
Expand All @@ -1054,14 +1059,14 @@ def run_one(self,
elif cmd_protocol == 'driver_attachment':
# name = 'driver_attachment_%s' % job['exp_key']
blob = ctrl.trials.attachments[cmd[1]]
bandit_name, bandit_args, bandit_kwargs = dill.loads(blob)
bandit_name, bandit_args, bandit_kwargs = pickler.loads(blob)
worker_fn = json_call(bandit_name,
args=bandit_args,
kwargs=bandit_kwargs).evaluate
elif cmd_protocol == 'domain_attachment':
blob = ctrl.trials.attachments[cmd[1]]
try:
domain = dill.loads(blob)
domain = pickler.loads(blob)
except BaseException as e:
logger.info(
'Error while unpickling.')
Expand Down
9 changes: 7 additions & 2 deletions hyperopt/utils.py
Expand Up @@ -13,11 +13,16 @@
import numpy
from . import pyll
from contextlib import contextmanager
import dill

standard_library.install_aliases()
logger = logging.getLogger(__name__)

try:
import dill as pickler
except Exception as e:
logger.info('Failed to load dill, try installing dill via "pip install dill" for enhanced pickling support.')
import six.moves.cPickle as pickler


def import_tokens(tokens):
# XXX Document me
Expand Down Expand Up @@ -85,7 +90,7 @@ def get_obj(f, argfile=None, argstr=None, args=(), kwargs=None):
if argfile is not None:
argstr = open(argfile).read()
if argstr is not None:
argd = dill.loads(argstr)
argd = pickler.loads(argstr)
else:
argd = {}
args = args + argd.get('args', ())
Expand Down
5 changes: 4 additions & 1 deletion setup.py
Expand Up @@ -144,6 +144,9 @@ def find_package_data(packages):
keywords='Bayesian optimization hyperparameter model selection',
package_data=package_data,
include_package_data=True,
install_requires=['numpy', 'scipy', 'nose', 'pymongo', 'six', 'dill', 'networkx', 'future'],
install_requires=['numpy', 'scipy', 'nose', 'pymongo', 'six', 'networkx', 'future'],
extras_require={
'dill': 'dill'
},
zip_safe=False
)

0 comments on commit 2990722

Please sign in to comment.