Permalink
Browse files

Tweaks to NN learner shell

  • Loading branch information...
1 parent 326f98b commit 635a5f787f4abb815a375035291cbc9c324669b8 @charmasaur charmasaur committed Nov 25, 2016
Showing with 5 additions and 3 deletions.
  1. +5 −3 mloop/learners.py
View
@@ -1675,7 +1675,7 @@ def predict_cost(self,params):
'''
return self.gaussian_process.predict(params[np.newaxis,:])
- #--- FAKE NN CONSTRUCTOR END ---#
+ #--- FAKE NN METHODS END ---#
def wait_for_new_params_event(self):
@@ -1853,7 +1853,7 @@ def run(self):
self.wait_for_new_params_event()
#self.log.debug('Gaussian process learner reading costs')
self.get_params_and_costs()
- self.fit_gaussian_process()
+ self.fit_neural_net()
for _ in range(self.generation_num):
self.log.debug('Gaussian process learner generating parameter:'+ str(self.params_count+1))
next_params = self.find_next_parameters()
@@ -1864,7 +1864,7 @@ def run(self):
pass
if self.predict_global_minima_at_end or self.predict_local_minima_at_end:
self.get_params_and_costs()
- self.fit_gaussian_process()
+ self.fit_neural_net()
end_dict = {}
if self.predict_global_minima_at_end:
self.find_global_minima()
@@ -1904,6 +1904,7 @@ def find_global_minima(self):
for start_params in search_params:
result = so.minimize(self.predict_cost, start_params, bounds = search_bounds, tol=self.search_precision)
curr_best_params = result.x
+ # TODO: Doesn't apply to NN
(curr_best_cost,curr_best_uncer) = self.gaussian_process.predict(curr_best_params[np.newaxis,:],return_std=True)
if curr_best_cost<self.predicted_best_scaled_cost:
self.predicted_best_parameters = curr_best_params
@@ -1945,6 +1946,7 @@ def find_local_minima(self):
for start_params in self.all_params:
result = so.minimize(self.predict_cost, start_params, bounds = search_bounds, tol=self.search_precision)
curr_minima_params = result.x
+ # TODO: Doesn't apply to NN.
(curr_minima_cost,curr_minima_uncer) = self.gaussian_process.predict(curr_minima_params[np.newaxis,:],return_std=True)
if all( not np.all( np.abs(params - curr_minima_params) < self.minima_tolerance ) for params in self.local_minima_parameters):
#Non duplicate point so add to the list

0 comments on commit 635a5f7

Please sign in to comment.