From 1900587707bd4818672ab9bbc97e63c8639a7ba2 Mon Sep 17 00:00:00 2001 From: Harry Slatyer Date: Fri, 9 Dec 2016 17:05:38 +1100 Subject: [PATCH] Plot NN surface when there are 2 params --- mloop/visualizations.py | 28 +++++++++++++++++++++++++++- 1 file changed, 27 insertions(+), 1 deletion(-) diff --git a/mloop/visualizations.py b/mloop/visualizations.py index 494a458..8219fbc 100644 --- a/mloop/visualizations.py +++ b/mloop/visualizations.py @@ -12,6 +12,8 @@ import matplotlib.pyplot as plt import matplotlib as mpl +from mpl_toolkits.mplot3d import Axes3D + figure_counter = 0 cmap = plt.get_cmap('hsv') run_label = 'Run number' @@ -570,6 +572,7 @@ def create_neural_net_learner_visualizations(filename, visualization = NeuralNetVisualizer(filename, file_type=file_type) if plot_cross_sections: visualization.plot_cross_sections() + visualization.plot_surface() class NeuralNetVisualizer(mll.NeuralNetLearner): @@ -659,7 +662,7 @@ def return_cross_sections(self, points=100, cross_section_center=None): cross_parameter_arrays = np.array(cross_parameter_arrays) cost_arrays = self.cost_scaler.inverse_transform(np.array(cost_arrays)) return (cross_parameter_arrays,cost_arrays) - + def plot_cross_sections(self): ''' Produce a figure of the cross section about best cost and parameters @@ -686,3 +689,26 @@ def plot_cross_sections(self): for ind in range(self.num_params): artists.append(plt.Line2D((0,1),(0,0), color=self.param_colors[ind], linestyle='-')) plt.legend(artists,[str(x) for x in range(1,self.num_params+1)],loc=legend_loc) + + def plot_surface(self): + ''' + Produce a figure of the cost surface (only works when there are 2 parameters) + ''' + if self.num_params != 2: + return + global figure_counter + figure_counter += 1 + fig = plt.figure(figure_counter) + ax = fig.add_subplot(111, projection='3d') + + points = 50 + param_set = [ np.linspace(min_p, max_p, points) for (min_p,max_p) in zip(self.min_boundary,self.max_boundary)] + params = [(x,y) for x in param_set[0] for y in param_set[1]] + costs = self.predict_costs_from_param_array(params) + ax.scatter([param[0] for param in params], [param[1] for param in params], costs) + ax.set_zlim(top=100) + ax.set_xlabel('x') + ax.set_ylabel('y') + ax.set_zlabel('cost') + + ax.scatter(self.all_params[:,0], self.all_params[:,1], self.all_costs, c='r')