From 9224be50eee3cef679c4776d4268dff7f748a278 Mon Sep 17 00:00:00 2001 From: Harry Slatyer Date: Fri, 2 Dec 2016 16:25:54 +1100 Subject: [PATCH] Fix importing/creation of NN impl We need to specify nnlearner as a package. More subtly, because of TF we can only run NNI in the same process in which it's created. This means we need to wait until the run() method of the learner is called before constructing the impl. --- mloop/__init__.py | 2 +- mloop/learners.py | 5 ++++- mloop/nnlearner.py | 2 ++ 3 files changed, 7 insertions(+), 2 deletions(-) diff --git a/mloop/__init__.py b/mloop/__init__.py index 9e53155..dd553f6 100644 --- a/mloop/__init__.py +++ b/mloop/__init__.py @@ -13,4 +13,4 @@ import os __version__= "2.1.1" -__all__ = ['controllers','interfaces','launchers','learners','testing','utilities','visualizations','cmd'] \ No newline at end of file +__all__ = ['controllers','interfaces','launchers','learners','nnlearner','testing','utilities','visualizations','cmd'] diff --git a/mloop/learners.py b/mloop/learners.py index 8f8038d..12f6ed5 100644 --- a/mloop/learners.py +++ b/mloop/learners.py @@ -1587,7 +1587,6 @@ def __init__(self, self.cost_has_noise = True self.noise_level = 1 - self.neural_net_impl = NeuralNetImpl(self.num_params) # TODO: What are these? self.generation_num = 4 if (self.default_bad_cost is None) and (self.default_bad_uncertainty is None): @@ -1799,6 +1798,10 @@ def run(self): #current solution is to only log to the console for warning and above from a process self.log = mp.log_to_stderr(logging.WARNING) + # The network needs to be created in the same process in which it runs + import mloop.nnlearner as mlnn + self.neural_net_impl = mlnn.NeuralNetImpl(self.num_params) + try: while not self.end_event.is_set(): self.log.debug('Learner waiting for new params event') diff --git a/mloop/nnlearner.py b/mloop/nnlearner.py index 56debd2..14b612e 100644 --- a/mloop/nnlearner.py +++ b/mloop/nnlearner.py @@ -7,6 +7,8 @@ class NeuralNetImpl(): ''' Neural network implementation. + This must run in the same process in which it's created. + Args: num_params (int): The number of params. '''