Permalink
Browse files

Plot NN surface when there are 2 params

  • Loading branch information...
1 parent e6e83e8 commit 1900587707bd4818672ab9bbc97e63c8639a7ba2 @charmasaur charmasaur committed Dec 9, 2016
Showing with 27 additions and 1 deletion.
  1. +27 −1 mloop/visualizations.py
View
@@ -12,6 +12,8 @@
import matplotlib.pyplot as plt
import matplotlib as mpl
+from mpl_toolkits.mplot3d import Axes3D
+
figure_counter = 0
cmap = plt.get_cmap('hsv')
run_label = 'Run number'
@@ -570,6 +572,7 @@ def create_neural_net_learner_visualizations(filename,
visualization = NeuralNetVisualizer(filename, file_type=file_type)
if plot_cross_sections:
visualization.plot_cross_sections()
+ visualization.plot_surface()
class NeuralNetVisualizer(mll.NeuralNetLearner):
@@ -659,7 +662,7 @@ def return_cross_sections(self, points=100, cross_section_center=None):
cross_parameter_arrays = np.array(cross_parameter_arrays)
cost_arrays = self.cost_scaler.inverse_transform(np.array(cost_arrays))
return (cross_parameter_arrays,cost_arrays)
-
+
def plot_cross_sections(self):
'''
Produce a figure of the cross section about best cost and parameters
@@ -686,3 +689,26 @@ def plot_cross_sections(self):
for ind in range(self.num_params):
artists.append(plt.Line2D((0,1),(0,0), color=self.param_colors[ind], linestyle='-'))
plt.legend(artists,[str(x) for x in range(1,self.num_params+1)],loc=legend_loc)
+
+ def plot_surface(self):
+ '''
+ Produce a figure of the cost surface (only works when there are 2 parameters)
+ '''
+ if self.num_params != 2:
+ return
+ global figure_counter
+ figure_counter += 1
+ fig = plt.figure(figure_counter)
+ ax = fig.add_subplot(111, projection='3d')
+
+ points = 50
+ param_set = [ np.linspace(min_p, max_p, points) for (min_p,max_p) in zip(self.min_boundary,self.max_boundary)]
+ params = [(x,y) for x in param_set[0] for y in param_set[1]]
+ costs = self.predict_costs_from_param_array(params)
+ ax.scatter([param[0] for param in params], [param[1] for param in params], costs)
+ ax.set_zlim(top=100)
+ ax.set_xlabel('x')
+ ax.set_ylabel('y')
+ ax.set_zlabel('cost')
+
+ ax.scatter(self.all_params[:,0], self.all_params[:,1], self.all_costs, c='r')

0 comments on commit 1900587

Please sign in to comment.