Skip to content

Commit

Permalink
fix bug in visualizer when plotting temporal data w/ batch size 1
Browse files Browse the repository at this point in the history
  • Loading branch information
AdeelH committed Oct 12, 2023
1 parent e4e10ad commit 43e3164
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,8 @@ def plot_batch(self,
batch_sz, T, *_ = x.shape
params['fig_args']['figsize'][1] *= T
fig = plt.figure(**params['fig_args'])
subfigs = fig.subfigures(nrows=batch_sz, ncols=1, hspace=0.0)
subfigs = fig.subfigures(
nrows=batch_sz, ncols=1, hspace=0.0, squeeze=False)
subfig_axs = [
subfig.subplots(
nrows=T, ncols=params['subplot_args']['ncols'])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ def test_plot_batch_temporal(self):
x = torch.randn(size=(2, 3, 4, 256, 256))
y = (torch.randn(size=(2, 256, 256)) > 0).long()
self.assertNoError(lambda: viz.plot_batch(x, y))
# w/o z, batch size = 1
self.assertNoError(lambda: viz.plot_batch(x[[0]], y[[0]]))

# w/ z
viz = SemanticSegmentationVisualizer(
Expand All @@ -50,3 +52,5 @@ def test_plot_batch_temporal(self):
y = (torch.randn(size=(2, 256, 256)) > 0).long()
z = torch.randn(size=(2, num_classes, 256, 256)).softmax(dim=-3)
self.assertNoError(lambda: viz.plot_batch(x, y, z=z))
# w/ z, batch size = 1
self.assertNoError(lambda: viz.plot_batch(x[[0]], y[[0]]))

0 comments on commit 43e3164

Please sign in to comment.