Skip to content

Commit

Permalink
Feature/accuracy movie (#160)
Browse files Browse the repository at this point in the history
* Add accuracy subplot to informationplane movie

* Change to test data evaluation

* Include reveiw changes
  • Loading branch information
felixmzd committed Sep 5, 2018
1 parent dae09ae commit 8376ba7
Show file tree
Hide file tree
Showing 4 changed files with 92 additions and 30 deletions.
7 changes: 7 additions & 0 deletions deep_bottleneck/callbacks/loggingreporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,3 +125,10 @@ def on_epoch_end(self, epoch, logs=None):
str(i),
data=self.layerfuncs[i]([self.dataset.train.examples])[0]
)

self.file_all_activations[str(epoch)].create_group('accuracy')
self.file_all_activations[f'{epoch}/accuracy']['training'] = float(logs['acc'])
try:
self.file_all_activations[f'{epoch}/accuracy']['validation'] = float(logs['val_acc'])
except KeyError:
print('Validation not enabled. Validation metrics cannot be logged')
2 changes: 1 addition & 1 deletion deep_bottleneck/configs/basic.json
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
"architecture": [10, 7, 5, 4, 3],
"optimizer": "adam",
"learning_rate": 0.0004,
"calculate_mi_for": "full_dataset",
"calculate_mi_for": "test",
"activation_fn": "relu",
"model": "models.feedforward",
"dataset": "datasets.harmonics",
Expand Down
4 changes: 0 additions & 4 deletions deep_bottleneck/mi_estimator/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,6 @@ def __init__(self, discretization_range, data, architecture, calculate_mi_for):
def compute_mi(self, file_all_activations) -> pd.DataFrame:
print(f'*** Start running {self.__class__.__name__}. ***')

print(file_all_activations["2"])
print(f'len of file activations: {len(file_all_activations)}')
for i in file_all_activations: print(file_all_activations[str(i)])

labels, one_hot_labels = self._construct_dataset()
# Proportion of instances that have a certain label.
label_weights = np.mean(one_hot_labels, axis=0)
Expand Down
109 changes: 84 additions & 25 deletions deep_bottleneck/plotter/informationplane_movie.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,11 @@ class InformationPlaneMoviePlotter(BasePlotter):
plotname = 'infoplane_movie'
filename = f'plots/{plotname}.mp4'

num_layers = None
total_number_of_epochs = None
epoch_indexes = None
layers_colors = None

def __init__(self, run, dataset):
self.dataset = dataset
self.run = run
Expand All @@ -24,43 +29,97 @@ def generate(self, measures_summary):
self.plot(measures_summary)
self.run.add_artifact(self.filename, name=self.plotname)

def plot(self, measures_summary):
def setup_infoplane_subplot(self, ax_infoplane):
if self.dataset == 'datasets.mnist' or self.dataset == 'datasets.fashion_mnist':
ax_infoplane.set(xlim=[0, 14], ylim=[0, 3.5])
else:
ax_infoplane.set(xlim=[0, 12], ylim=[0, 1])

measures = measures_summary['measures_all_runs']
ax_infoplane.set(xlabel='I(X;M)', ylabel='I(Y;M)')

os.makedirs('plots/', exist_ok=True)
scatter = ax_infoplane.scatter([], [], s=20, edgecolor='none')
text = ax_infoplane.text(0, 1.05, "", fontsize=12)

plt.set_cmap("hsv")
fig, ax = plt.subplots()
if self.dataset == 'datasets.mnist' or self.dataset == 'datasets.fashion_mnist':
ax.set(xlim=[0, 14], ylim=[0, 3.5])
else:
ax.set(xlim=[0, 12], ylim=[0, 1])
return scatter, text

def fill_infoplane_subplot(self, ax_infoplane, mi_epoch):
xmvals = mi_epoch['MI_XM']
ymvals = mi_epoch['MI_YM']

points = np.array([xmvals, ymvals]).transpose()
colors = self.layers_colors[mi_epoch.index]

ax_infoplane.set_offsets(points)
ax_infoplane.set_array(colors)

return ax_infoplane

def setup_accuracy_subplot(self, ax_accuracy):
[acc_line] = ax_accuracy.plot([], [], 'b', label="training accuracy")
[val_acc_line] = ax_accuracy.plot([], [], 'g', label="validation accuracy")

ax_accuracy.set_ylim(0, 1)
ax_accuracy.set_xlim(0, self.total_number_of_epochs)

xticks_positions = range(0, self.total_number_of_epochs, int(self.total_number_of_epochs / 20))
ax_accuracy.set_xticks(xticks_positions)
ax_accuracy.set_xticklabels(self.epoch_indexes[xticks_positions], rotation=90)

ax.set(xlabel='I(X;M)', ylabel='I(Y;M)')
handles, labels = ax_accuracy.get_legend_handles_labels()
ax_accuracy.legend(handles, labels, loc=4)

scatter = ax.scatter([], [], s=20, edgecolor='none')
text = ax.text(0, 1.05, "", fontsize=12)
ax_accuracy.set_xlabel('Epoch')
ax_accuracy.set_ylabel('Accuracy')

num_layers = measures.index.get_level_values(1).nunique()
layers_colors = np.linspace(0, 1, num_layers)
return acc_line, val_acc_line

writer = FFMpegWriter(fps=10)
def fill_accuracy_subplot(self, acc_line, val_acc_line, activations_summary, epoch_number, acc, val_acc):
epoch_accuracy = np.asarray(activations_summary[f'{epoch_number}/accuracy/']['training'])
epoch_val_accuracy = np.asarray(activations_summary[f'{epoch_number}/accuracy/']['validation'])

acc.append(epoch_accuracy)
val_acc.append(epoch_val_accuracy)

xs = range(len(acc))
acc_line.set_data(xs, acc)
val_acc_line.set_data(xs, val_acc)

return acc, val_acc, acc_line, val_acc_line

def get_specifications(self, measures):
self.num_layers = measures.index.get_level_values(1).nunique()
self.layers_colors = np.linspace(0, 1, self.num_layers)
self.epoch_indexes = measures.index.get_level_values('epoch').unique()
self.total_number_of_epochs = len(self.epoch_indexes)

def plot(self, measures_summary):
os.makedirs('plots/', exist_ok=True)

measures = measures_summary['measures_all_runs']
activations_summary = measures_summary['activations_summary']
self.get_specifications(measures)

plt.set_cmap("hsv")
fig, (ax_infoplane, ax_accuracy) = plt.subplots(2, 1, figsize=(6, 9),
gridspec_kw={'height_ratios': [2, 1]})

acc = []
val_acc = []

scatter, text = self.setup_infoplane_subplot(ax_infoplane)
acc_line, val_acc_line = self.setup_accuracy_subplot(ax_accuracy)

writer = FFMpegWriter(fps=7)
with writer.saving(fig, self.filename, 600):
for epoch_number, mi_epoch in measures.groupby(level=0):
# Drop outer index level corresponding to the epoch.
mi_epoch.index = mi_epoch.index.droplevel()

xmvals = mi_epoch['MI_XM']
ymvals = mi_epoch['MI_YM']

points = np.array([xmvals, ymvals]).transpose()
colors = layers_colors[mi_epoch.index]

scatter.set_offsets(points)
scatter.set_array(colors)
scatter = self.fill_infoplane_subplot(scatter, mi_epoch)
text.set_text(f'Epoch: {epoch_number}')

text.set_text(f"Epoch: {epoch_number}")
acc, val_acc, acc_line, val_acc_line = self.fill_accuracy_subplot(acc_line, val_acc_line,
activations_summary, epoch_number,
acc, val_acc)

writer.grab_frame()
writer.grab_frame()

0 comments on commit 8376ba7

Please sign in to comment.