From e6e83e815ace6ed5543e5bde92a9b5b5531b0508 Mon Sep 17 00:00:00 2001 From: Harry Slatyer Date: Fri, 9 Dec 2016 16:54:03 +1100 Subject: [PATCH] Use TF gradient when minimizing NN cost function estimate --- mloop/learners.py | 30 +++++++++++++++++++++++++----- mloop/nnlearner.py | 12 ++++++++++++ 2 files changed, 37 insertions(+), 5 deletions(-) diff --git a/mloop/learners.py b/mloop/learners.py index 8d80d77..ad6e5b1 100644 --- a/mloop/learners.py +++ b/mloop/learners.py @@ -1703,6 +1703,16 @@ def predict_cost(self,params): ''' return self.neural_net_impl.predict_cost(params) + def predict_cost_gradient(self,params): + ''' + Produces a prediction of the gradient of the cost function at params. + + Returns: + float : Predicted gradient at paramters + ''' + # scipy.optimize.minimize doesn't seem to like a 32-bit Jacobian, so we convert to 64 + return np.float64(self.neural_net_impl.predict_cost_gradient(params)) + def predict_costs_from_param_array(self,params): ''' @@ -1871,7 +1881,11 @@ def find_next_parameters(self): next_params = None next_cost = float('inf') for start_params in self.search_params: - result = so.minimize(self.predict_cost, start_params, bounds = self.search_region, tol=self.search_precision) + result = so.minimize(fun = self.predict_cost, + x0 = start_params, + jac = self.predict_cost_gradient, + bounds = self.search_region, + tol = self.search_precision) if result.fun < next_cost: next_params = result.x next_cost = result.fun @@ -1942,8 +1956,11 @@ def find_global_minima(self): search_bounds = list(zip(self.min_boundary, self.max_boundary)) for start_params in search_params: - # TODO: Take advantage of the fact that we get the gradient for free, so can use that to speed up the minimizer. - result = so.minimize(self.predict_cost, start_params, bounds = search_bounds, tol=self.search_precision) + result = so.minimize(fun = self.predict_cost, + x0 = start_params, + jac = self.predict_cost_gradient, + bounds = search_bounds, + tol = self.search_precision) curr_best_params = result.x curr_best_cost = result.fun if curr_best_cost