Skip to content

Commit

Permalink
reset online measures before each full_forward_pass
Browse files Browse the repository at this point in the history
  • Loading branch information
dizcza committed Jul 19, 2021
1 parent 9090f54 commit 518a320
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 4 deletions.
6 changes: 2 additions & 4 deletions mighty/monitor/monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -814,11 +814,9 @@ def clusters_heatmap(self, mean, title=None, save=False):
)
if n_classes <= self.n_classes_format_ytickstep_1:
opts.update(ytickstep=1)
self.viz.heatmap(mean.cpu(), win=win, opts=opts)
if save:
self.viz.heatmap(mean.cpu(),
win=f"{win}. Epoch {self.timer.epoch}",
opts=opts)
win = f"{win}. Epoch {self.timer.epoch}"
self.viz.heatmap(mean.cpu(), win=win, opts=opts)

def update_l1_neuron_norm(self, l1_norm: torch.Tensor):
"""
Expand Down
2 changes: 2 additions & 0 deletions mighty/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,6 +384,8 @@ def full_forward_pass(self, train=True):
mode_saved = self.model.training
self.model.train(False)
self.accuracy_measure.reset_labels()
for online_measure in self.online.values():
online_measure.reset()
loss_online = MeanOnline()

if train:
Expand Down

0 comments on commit 518a320

Please sign in to comment.