|
|
@@ -12,6 +12,8 @@ |
|
|
import matplotlib.pyplot as plt
|
|
|
import matplotlib as mpl
|
|
|
|
|
|
+from mpl_toolkits.mplot3d import Axes3D
|
|
|
+
|
|
|
figure_counter = 0
|
|
|
cmap = plt.get_cmap('hsv')
|
|
|
run_label = 'Run number'
|
|
|
@@ -570,6 +572,7 @@ def create_neural_net_learner_visualizations(filename, |
|
|
visualization = NeuralNetVisualizer(filename, file_type=file_type)
|
|
|
if plot_cross_sections:
|
|
|
visualization.plot_cross_sections()
|
|
|
+ visualization.plot_surface()
|
|
|
|
|
|
|
|
|
class NeuralNetVisualizer(mll.NeuralNetLearner):
|
|
|
@@ -659,7 +662,7 @@ def return_cross_sections(self, points=100, cross_section_center=None): |
|
|
cross_parameter_arrays = np.array(cross_parameter_arrays)
|
|
|
cost_arrays = self.cost_scaler.inverse_transform(np.array(cost_arrays))
|
|
|
return (cross_parameter_arrays,cost_arrays)
|
|
|
-
|
|
|
+
|
|
|
def plot_cross_sections(self):
|
|
|
'''
|
|
|
Produce a figure of the cross section about best cost and parameters
|
|
|
@@ -686,3 +689,26 @@ def plot_cross_sections(self): |
|
|
for ind in range(self.num_params):
|
|
|
artists.append(plt.Line2D((0,1),(0,0), color=self.param_colors[ind], linestyle='-'))
|
|
|
plt.legend(artists,[str(x) for x in range(1,self.num_params+1)],loc=legend_loc)
|
|
|
+
|
|
|
+ def plot_surface(self):
|
|
|
+ '''
|
|
|
+ Produce a figure of the cost surface (only works when there are 2 parameters)
|
|
|
+ '''
|
|
|
+ if self.num_params != 2:
|
|
|
+ return
|
|
|
+ global figure_counter
|
|
|
+ figure_counter += 1
|
|
|
+ fig = plt.figure(figure_counter)
|
|
|
+ ax = fig.add_subplot(111, projection='3d')
|
|
|
+
|
|
|
+ points = 50
|
|
|
+ param_set = [ np.linspace(min_p, max_p, points) for (min_p,max_p) in zip(self.min_boundary,self.max_boundary)]
|
|
|
+ params = [(x,y) for x in param_set[0] for y in param_set[1]]
|
|
|
+ costs = self.predict_costs_from_param_array(params)
|
|
|
+ ax.scatter([param[0] for param in params], [param[1] for param in params], costs)
|
|
|
+ ax.set_zlim(top=100)
|
|
|
+ ax.set_xlabel('x')
|
|
|
+ ax.set_ylabel('y')
|
|
|
+ ax.set_zlabel('cost')
|
|
|
+
|
|
|
+ ax.scatter(self.all_params[:,0], self.all_params[:,1], self.all_costs, c='r')
|
0 comments on commit
1900587