Skip to content

Commit

Permalink
generalize visualizers to 4D chips
Browse files Browse the repository at this point in the history
  • Loading branch information
AdeelH committed Jun 21, 2023
1 parent db24598 commit 7f98eed
Show file tree
Hide file tree
Showing 11 changed files with 322 additions and 34 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,17 @@ def plot_xyz(self,
axs: Sequence,
x: torch.Tensor,
y: int,
z: Optional[int] = None) -> None:
z: Optional[int] = None,
plot_title: bool = True) -> None:
channel_groups = self.get_channel_display_groups(x.shape[1])

img_axes = axs[:-1]
label_ax = axs[-1]

# plot image
imgs = channel_groups_to_imgs(x, channel_groups)
plot_channel_groups(img_axes, imgs, channel_groups)
plot_channel_groups(
img_axes, imgs, channel_groups, plot_title=plot_title)

# plot label
class_names = self.class_names
Expand Down Expand Up @@ -66,7 +68,8 @@ def plot_xyz(self,
label_ax.set_xlim((0, 1))
label_ax.xaxis.grid(linestyle='--', alpha=1)
label_ax.set_xlabel('Probability')
label_ax.set_title('Prediction')
if plot_title:
label_ax.set_title('Prediction')

def get_plot_ncols(self, **kwargs) -> int:
x = kwargs['x']
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@ def plot_xyz(self,
axs: Sequence,
x: torch.Tensor,
y: BoxList,
z: Optional[BoxList] = None) -> None:
z: Optional[BoxList] = None,
plot_title: bool = True) -> None:
y = y if z is None else z
channel_groups = self.get_channel_display_groups(x.shape[1])

Expand All @@ -28,4 +29,4 @@ def plot_xyz(self,

imgs = channel_groups_to_imgs(x, channel_groups)
imgs = [draw_boxes(img, y, class_names, class_colors) for img in imgs]
plot_channel_groups(axs, imgs, channel_groups)
plot_channel_groups(axs, imgs, channel_groups, plot_title=plot_title)
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,17 @@ def plot_xyz(self,
axs: Sequence,
x: torch.Tensor,
y: int,
z: Optional[int] = None) -> None:
z: Optional[int] = None,
plot_title: bool = True) -> None:
channel_groups = self.get_channel_display_groups(x.shape[1])

img_axes = axs[:-1]
label_ax = axs[-1]

# plot image
imgs = channel_groups_to_imgs(x, channel_groups)
plot_channel_groups(img_axes, imgs, channel_groups)
plot_channel_groups(
img_axes, imgs, channel_groups, plot_title=plot_title)

# plot label
class_names = self.class_names
Expand All @@ -36,8 +38,8 @@ def plot_xyz(self,
y=class_names, width=y, color='lightgray', edgecolor='black')
# show values on the end of bars
label_ax.bar_label(bars_gt, fmt='%.3f', padding=3)

label_ax.set_title('Ground truth')
if plot_title:
label_ax.set_title('Ground truth')
else:
# display targets and predictions as a grouped horizontal bar plot
bar_thickness = 0.35
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,17 @@ def plot_xyz(self,
axs: Sequence,
x: torch.Tensor,
y: Union[torch.Tensor, np.ndarray],
z: Optional[torch.Tensor] = None) -> None:
z: Optional[torch.Tensor] = None,
plot_title: bool = True) -> None:
channel_groups = self.get_channel_display_groups(x.shape[1])

img_axes = axs[:len(channel_groups)]
label_ax = axs[len(channel_groups)]

# plot image
imgs = channel_groups_to_imgs(x, channel_groups)
plot_channel_groups(img_axes, imgs, channel_groups)
plot_channel_groups(
img_axes, imgs, channel_groups, plot_title=plot_title)

# plot labels
class_colors = self.class_colors
Expand All @@ -38,7 +40,8 @@ def plot_xyz(self,

label_ax.imshow(
y, vmin=0, vmax=len(colors), cmap=cmap, interpolation='none')
label_ax.set_title(f'Ground truth')
if plot_title:
label_ax.set_title(f'Ground truth')
label_ax.set_xticks([])
label_ax.set_yticks([])

Expand All @@ -52,7 +55,8 @@ def plot_xyz(self,
vmax=len(colors),
cmap=cmap,
interpolation='none')
pred_ax.set_title(f'Predicted labels')
if plot_title:
pred_ax.set_title(f'Predicted labels')
pred_ax.set_xticks([])
pred_ax.set_yticks([])

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

if TYPE_CHECKING:
from torch.utils.data import Dataset
from matplotlib.figure import Figure


class Visualizer(ABC):
Expand Down Expand Up @@ -66,7 +67,12 @@ def __init__(self,
channel_display_groups)

@abstractmethod
def plot_xyz(self, axs, x: Tensor, y, z=None):
def plot_xyz(self,
axs,
x: Tensor,
y: Sequence,
z: Optional[Sequence] = None,
plot_title: bool = True):
"""Plot image, ground truth labels, and predicted labels.
Args:
Expand Down Expand Up @@ -95,11 +101,62 @@ def plot_batch(self,
"""
params = self.get_plot_params(
x=x, y=y, z=z, output_path=output_path, batch_limit=batch_limit)
if params['fig_args']['nrows'] == 0:
if params['subplot_args']['nrows'] == 0:
return

fig, axs = plt.subplots(**params['fig_args'])
if x.ndim == 4:
fig, axs = plt.subplots(**params['fig_args'],
**params['subplot_args'])
plot_xyz_args = params['plot_xyz_args']
self._plot_batch(fig, axs, plot_xyz_args, x, y=y, z=z)
elif x.ndim == 5:
# If a temporal dimension is present, we divide the figure into
# multiple subfigures--one for each batch. Then, in each subfigure,
# we plot all timesteps as if they were a single batch. To
# delineate the boundary b/w batch items, we adopt the convention
# of only displaying subplot titles once per batch (above the first
# row in each batch).
batch_sz, T, *_ = x.shape
params['fig_args']['figsize'][1] *= T
fig = plt.figure(**params['fig_args'])
subfigs = fig.subfigures(nrows=batch_sz, ncols=1, hspace=0.0)
subfig_axs = [
subfig.subplots(
nrows=T, ncols=params['subplot_args']['ncols'])
for subfig in subfigs.flat
]
for i, axs in enumerate(subfig_axs):
plot_xyz_args = [
dict(params['plot_xyz_args'][i]) for _ in range(T)
]
plot_xyz_args[0]['plot_title'] = True
for args in plot_xyz_args[1:]:
args['plot_title'] = False
_x = x[i]
_y = [y[i]] * T
_z = None if z is None else [z[i]] * T
self._plot_batch(fig, axs, plot_xyz_args, _x, y=_y, z=_z)
else:
raise ValueError('Expected x to have 4 or 5 dims, but found '
f'x.shape: {x.shape}')

if show:
plt.show()
if output_path is not None:
make_dir(output_path, use_dirname=True)
plt.savefig(output_path, bbox_inches='tight', pad_inches=0.2)

plt.close(fig)

def _plot_batch(
self,
fig: 'Figure',
axs: Sequence,
plot_xyz_args: List[dict],
x: Tensor,
y: Optional[Sequence] = None,
z: Optional[Sequence] = None,
):
# (N, c, h, w) --> (N, h, w, c)
x = x.permute(0, 2, 3, 1)

Expand All @@ -109,20 +166,9 @@ def plot_batch(self,
imgs = [tf(image=img)['image'] for img in x.numpy()]
x = torch.from_numpy(np.stack(imgs))

plot_xyz_args = params['plot_xyz_args']
for i, row_axs in enumerate(axs):
if z is None:
self.plot_xyz(row_axs, x[i], y[i], **plot_xyz_args)
else:
self.plot_xyz(row_axs, x[i], y[i], z=z[i], **plot_xyz_args)

if show:
plt.show()
if output_path is not None:
make_dir(output_path, use_dirname=True)
plt.savefig(output_path, bbox_inches='tight', pad_inches=0.2)

plt.close(fig)
_z = None if z is None else z[i]
self.plot_xyz(row_axs, x[i], y[i], z=_z, **plot_xyz_args[i])

def get_channel_display_groups(
self, nb_img_channels: int
Expand Down Expand Up @@ -179,12 +225,14 @@ def get_plot_params(self, **kwargs) -> dict:
ncols = self.get_plot_ncols(**kwargs)
params = {
'fig_args': {
'constrained_layout': True,
'figsize': np.array((self.scale * ncols, self.scale * nrows)),
},
'subplot_args': {
'nrows': nrows,
'ncols': ncols,
'constrained_layout': True,
'figsize': (self.scale * ncols, self.scale * nrows),
'squeeze': False
},
'plot_xyz_args': {}
'plot_xyz_args': [{} for _ in range(nrows)]
}
return params
Original file line number Diff line number Diff line change
Expand Up @@ -322,10 +322,12 @@ def adjust_conv_channels(old_conv: nn.Conv2d,

def plot_channel_groups(axs: Iterable,
imgs: Iterable[Union[np.array, torch.Tensor]],
channel_groups: dict) -> None:
channel_groups: dict,
plot_title: bool = True) -> None:
for title, ax, img in zip(channel_groups.keys(), axs, imgs):
ax.imshow(img)
ax.set_title(title)
if plot_title:
ax.set_title(title)
ax.set_xticks([])
ax.set_yticks([])

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
from typing import Callable
import unittest

import torch

from rastervision.pytorch_learner.dataset import ClassificationVisualizer


class TestClassificationVisualizer(unittest.TestCase):
def assertNoError(self, fn: Callable, msg: str = ''):
try:
fn()
except Exception:
self.fail(msg)

def test_plot_batch(self):
# w/o z
viz = ClassificationVisualizer(
class_names=['bg', 'fg'],
channel_display_groups=dict(RGB=[0, 1, 2], IR=[3]))
x = torch.randn(size=(2, 4, 256, 256))
y = torch.tensor([0, 1])
self.assertNoError(lambda: viz.plot_batch(x, y))

# w/ z
viz = ClassificationVisualizer(
class_names=['bg', 'fg'],
channel_display_groups=dict(RGB=[0, 1, 2], IR=[3]))
x = torch.randn(size=(2, 4, 256, 256))
y = torch.tensor([0, 1])
z = torch.tensor([[0.9, 0.1], [0.6, 0.4]])
self.assertNoError(lambda: viz.plot_batch(x, y, z=z))

def test_plot_batch_temporal(self):
# w/o z
viz = ClassificationVisualizer(
class_names=['bg', 'fg'],
channel_display_groups=dict(RGB=[0, 1, 2], IR=[3]))
x = torch.randn(size=(2, 3, 4, 256, 256))
y = torch.tensor([0, 1])
self.assertNoError(lambda: viz.plot_batch(x, y))

# w/ z
viz = ClassificationVisualizer(
class_names=['bg', 'fg'],
channel_display_groups=dict(RGB=[0, 1, 2], IR=[3]))
x = torch.randn(size=(2, 3, 4, 256, 256))
y = torch.tensor([0, 1])
z = torch.tensor([[0.9, 0.1], [0.6, 0.4]])
self.assertNoError(lambda: viz.plot_batch(x, y, z=z))
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
from typing import Callable
import unittest

import torch

from rastervision.core.box import Box
from rastervision.pytorch_learner.dataset import (ObjectDetectionVisualizer,
BoxList)


def random_boxlist(x, nboxes: int = 5) -> BoxList:
extent = Box(0, 0, *x.shape[-2:])
boxes = [extent.make_random_square(50) for _ in range(nboxes)]
npboxes = torch.from_numpy(Box.to_npboxes(boxes))
class_ids = torch.randint(0, 2, size=(nboxes, ))
scores = torch.rand(nboxes)
return BoxList(npboxes, class_ids=class_ids, scores=scores)


class TestClassificationVisualizer(unittest.TestCase):
def assertNoError(self, fn: Callable, msg: str = ''):
try:
fn()
except Exception:
self.fail(msg)

def test_plot_batch(self):
# w/o z
viz = ObjectDetectionVisualizer(
class_names=['bg', 'fg'],
channel_display_groups=dict(RGB=[0, 1, 2], IR=[3]))
x = torch.randn(size=(2, 4, 256, 256))
y = [random_boxlist(_x) for _x in x]
self.assertNoError(lambda: viz.plot_batch(x, y))

# w/ z
viz = ObjectDetectionVisualizer(
class_names=['bg', 'fg'],
channel_display_groups=dict(RGB=[0, 1, 2], IR=[3]))
x = torch.randn(size=(2, 4, 256, 256))
y = [random_boxlist(_x) for _x in x]
z = [random_boxlist(_x) for _x in x]
self.assertNoError(lambda: viz.plot_batch(x, y, z=z))

def test_plot_batch_temporal(self):
# w/o z
viz = ObjectDetectionVisualizer(
class_names=['bg', 'fg'],
channel_display_groups=dict(RGB=[0, 1, 2], IR=[3]))
x = torch.randn(size=(2, 3, 4, 256, 256))
y = [random_boxlist(_x) for _x in x]
self.assertNoError(lambda: viz.plot_batch(x, y))

# w/ z
viz = ObjectDetectionVisualizer(
class_names=['bg', 'fg'],
channel_display_groups=dict(RGB=[0, 1, 2], IR=[3]))
x = torch.randn(size=(2, 3, 4, 256, 256))
y = [random_boxlist(_x) for _x in x]
z = [random_boxlist(_x) for _x in x]
self.assertNoError(lambda: viz.plot_batch(x, y, z=z))

0 comments on commit 7f98eed

Please sign in to comment.