diff --git a/mloop/learners.py b/mloop/learners.py index ae56fa9..f5169a4 100644 --- a/mloop/learners.py +++ b/mloop/learners.py @@ -1670,6 +1670,13 @@ def __init__(self, #Remove logger so gaussian process can be safely picked for multiprocessing on Windows self.log = None + def create_neural_net(self): + ''' + Creates the neural net. Must be called from the same process as fit_neural_net, predict_cost and predict_costs_from_param_array. + ''' + import mloop.nnlearner as mlnn + self.neural_net_impl = mlnn.NeuralNetImpl(self.num_params) + def fit_neural_net(self): ''' Determine the appropriate number of layers for the NN given the data. @@ -1871,8 +1878,7 @@ def run(self): self.log = mp.log_to_stderr(logging.WARNING) # The network needs to be created in the same process in which it runs - import mloop.nnlearner as mlnn - self.neural_net_impl = mlnn.NeuralNetImpl(self.num_params) + self.create_neural_net() try: while not self.end_event.is_set():