Skip to content

Commit

Permalink
Monitor.clusters_heatmap takes mean tensor only
Browse files Browse the repository at this point in the history
  • Loading branch information
dizcza committed Jul 5, 2021
1 parent 7bfccd9 commit cf59af7
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 9 deletions.
14 changes: 6 additions & 8 deletions mighty/monitor/monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -765,6 +765,9 @@ def update_pairwise_dist(self, mean, std):
"""
if mean is None:
return
if mean.shape != std.shape:
raise ValueError("The mean and std must have the same shape and "
"come from VarianceOnline.get_mean_std().")

l1_norm = mean.norm(p=1, dim=1).mean()
pdist = torch.pdist(mean, p=1).mean() / l1_norm
Expand All @@ -777,23 +780,18 @@ def update_pairwise_dist(self, mean, std):
))
return pdist / std

def clusters_heatmap(self, mean, std, save=False):
def clusters_heatmap(self, mean, save=False):
"""
Cluster centers distribution heatmap.
Parameters
----------
mean, std : torch.Tensor
Tensors of shape `(C, V)`.
The mean and standard deviation of `C` clusters (vectors of size
`V`).
mean : (C, V) torch.Tensor
The mean of `C` clusters (vectors of size `V`).
"""
if mean is None:
return
if mean.shape != std.shape:
raise ValueError("The mean and std must have the same shape and "
"come from VarianceOnline.get_mean_std().")

n_classes = mean.shape[0]
win = "Embedding activations heatmap"
Expand Down
2 changes: 1 addition & 1 deletion mighty/trainer/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def _epoch_finished(self, loss):
self.monitor.update_l1_neuron_norm(self.online['l1_norm'].get_mean())
# mean and std can be Nones
mean, std = self.online['clusters'].get_mean_std()
self.monitor.clusters_heatmap(mean, std)
self.monitor.clusters_heatmap(mean)
self.monitor.update_pairwise_dist(mean, std)
self.monitor.embedding_hist(activations=mean)
super()._epoch_finished(loss)

0 comments on commit cf59af7

Please sign in to comment.