From 7f98eed284b25e76eb3e57ed615b1b362ad3f217 Mon Sep 17 00:00:00 2001 From: Adeel Hassan Date: Tue, 20 Jun 2023 14:30:47 -0400 Subject: [PATCH] generalize visualizers to 4D chips --- .../visualizer/classification_visualizer.py | 9 +- .../visualizer/object_detection_visualizer.py | 5 +- .../visualizer/regression_visualizer.py | 10 ++- .../semantic_segmentation_visualizer.py | 12 ++- .../dataset/visualizer/visualizer.py | 86 +++++++++++++++---- .../pytorch_learner/utils/utils.py | 6 +- .../test_classification_visualizer.py | 50 +++++++++++ .../test_object_detection_visualizer.py | 61 +++++++++++++ .../visualizer/test_regression_visualizer.py | 50 +++++++++++ .../test_semantic_segmentation_visualizer.py | 52 +++++++++++ .../dataset/visualizer/test_visualizer.py | 15 ++++ 11 files changed, 322 insertions(+), 34 deletions(-) create mode 100644 tests/pytorch_learner/dataset/visualizer/test_classification_visualizer.py create mode 100644 tests/pytorch_learner/dataset/visualizer/test_object_detection_visualizer.py create mode 100644 tests/pytorch_learner/dataset/visualizer/test_regression_visualizer.py create mode 100644 tests/pytorch_learner/dataset/visualizer/test_semantic_segmentation_visualizer.py diff --git a/rastervision_pytorch_learner/rastervision/pytorch_learner/dataset/visualizer/classification_visualizer.py b/rastervision_pytorch_learner/rastervision/pytorch_learner/dataset/visualizer/classification_visualizer.py index e182127a48..3c03934dd6 100644 --- a/rastervision_pytorch_learner/rastervision/pytorch_learner/dataset/visualizer/classification_visualizer.py +++ b/rastervision_pytorch_learner/rastervision/pytorch_learner/dataset/visualizer/classification_visualizer.py @@ -15,7 +15,8 @@ def plot_xyz(self, axs: Sequence, x: torch.Tensor, y: int, - z: Optional[int] = None) -> None: + z: Optional[int] = None, + plot_title: bool = True) -> None: channel_groups = self.get_channel_display_groups(x.shape[1]) img_axes = axs[:-1] @@ -23,7 +24,8 @@ def plot_xyz(self, # plot image imgs = channel_groups_to_imgs(x, channel_groups) - plot_channel_groups(img_axes, imgs, channel_groups) + plot_channel_groups( + img_axes, imgs, channel_groups, plot_title=plot_title) # plot label class_names = self.class_names @@ -66,7 +68,8 @@ def plot_xyz(self, label_ax.set_xlim((0, 1)) label_ax.xaxis.grid(linestyle='--', alpha=1) label_ax.set_xlabel('Probability') - label_ax.set_title('Prediction') + if plot_title: + label_ax.set_title('Prediction') def get_plot_ncols(self, **kwargs) -> int: x = kwargs['x'] diff --git a/rastervision_pytorch_learner/rastervision/pytorch_learner/dataset/visualizer/object_detection_visualizer.py b/rastervision_pytorch_learner/rastervision/pytorch_learner/dataset/visualizer/object_detection_visualizer.py index df0d06f099..e6f9fbe0b1 100644 --- a/rastervision_pytorch_learner/rastervision/pytorch_learner/dataset/visualizer/object_detection_visualizer.py +++ b/rastervision_pytorch_learner/rastervision/pytorch_learner/dataset/visualizer/object_detection_visualizer.py @@ -19,7 +19,8 @@ def plot_xyz(self, axs: Sequence, x: torch.Tensor, y: BoxList, - z: Optional[BoxList] = None) -> None: + z: Optional[BoxList] = None, + plot_title: bool = True) -> None: y = y if z is None else z channel_groups = self.get_channel_display_groups(x.shape[1]) @@ -28,4 +29,4 @@ def plot_xyz(self, imgs = channel_groups_to_imgs(x, channel_groups) imgs = [draw_boxes(img, y, class_names, class_colors) for img in imgs] - plot_channel_groups(axs, imgs, channel_groups) + plot_channel_groups(axs, imgs, channel_groups, plot_title=plot_title) diff --git a/rastervision_pytorch_learner/rastervision/pytorch_learner/dataset/visualizer/regression_visualizer.py b/rastervision_pytorch_learner/rastervision/pytorch_learner/dataset/visualizer/regression_visualizer.py index 1b4af2fe36..3dc076f493 100644 --- a/rastervision_pytorch_learner/rastervision/pytorch_learner/dataset/visualizer/regression_visualizer.py +++ b/rastervision_pytorch_learner/rastervision/pytorch_learner/dataset/visualizer/regression_visualizer.py @@ -17,7 +17,8 @@ def plot_xyz(self, axs: Sequence, x: torch.Tensor, y: int, - z: Optional[int] = None) -> None: + z: Optional[int] = None, + plot_title: bool = True) -> None: channel_groups = self.get_channel_display_groups(x.shape[1]) img_axes = axs[:-1] @@ -25,7 +26,8 @@ def plot_xyz(self, # plot image imgs = channel_groups_to_imgs(x, channel_groups) - plot_channel_groups(img_axes, imgs, channel_groups) + plot_channel_groups( + img_axes, imgs, channel_groups, plot_title=plot_title) # plot label class_names = self.class_names @@ -36,8 +38,8 @@ def plot_xyz(self, y=class_names, width=y, color='lightgray', edgecolor='black') # show values on the end of bars label_ax.bar_label(bars_gt, fmt='%.3f', padding=3) - - label_ax.set_title('Ground truth') + if plot_title: + label_ax.set_title('Ground truth') else: # display targets and predictions as a grouped horizontal bar plot bar_thickness = 0.35 diff --git a/rastervision_pytorch_learner/rastervision/pytorch_learner/dataset/visualizer/semantic_segmentation_visualizer.py b/rastervision_pytorch_learner/rastervision/pytorch_learner/dataset/visualizer/semantic_segmentation_visualizer.py index a5666488d7..9febf75292 100644 --- a/rastervision_pytorch_learner/rastervision/pytorch_learner/dataset/visualizer/semantic_segmentation_visualizer.py +++ b/rastervision_pytorch_learner/rastervision/pytorch_learner/dataset/visualizer/semantic_segmentation_visualizer.py @@ -17,7 +17,8 @@ def plot_xyz(self, axs: Sequence, x: torch.Tensor, y: Union[torch.Tensor, np.ndarray], - z: Optional[torch.Tensor] = None) -> None: + z: Optional[torch.Tensor] = None, + plot_title: bool = True) -> None: channel_groups = self.get_channel_display_groups(x.shape[1]) img_axes = axs[:len(channel_groups)] @@ -25,7 +26,8 @@ def plot_xyz(self, # plot image imgs = channel_groups_to_imgs(x, channel_groups) - plot_channel_groups(img_axes, imgs, channel_groups) + plot_channel_groups( + img_axes, imgs, channel_groups, plot_title=plot_title) # plot labels class_colors = self.class_colors @@ -38,7 +40,8 @@ def plot_xyz(self, label_ax.imshow( y, vmin=0, vmax=len(colors), cmap=cmap, interpolation='none') - label_ax.set_title(f'Ground truth') + if plot_title: + label_ax.set_title(f'Ground truth') label_ax.set_xticks([]) label_ax.set_yticks([]) @@ -52,7 +55,8 @@ def plot_xyz(self, vmax=len(colors), cmap=cmap, interpolation='none') - pred_ax.set_title(f'Predicted labels') + if plot_title: + pred_ax.set_title(f'Predicted labels') pred_ax.set_xticks([]) pred_ax.set_yticks([]) diff --git a/rastervision_pytorch_learner/rastervision/pytorch_learner/dataset/visualizer/visualizer.py b/rastervision_pytorch_learner/rastervision/pytorch_learner/dataset/visualizer/visualizer.py index d2bc46cf45..7bff864e2a 100644 --- a/rastervision_pytorch_learner/rastervision/pytorch_learner/dataset/visualizer/visualizer.py +++ b/rastervision_pytorch_learner/rastervision/pytorch_learner/dataset/visualizer/visualizer.py @@ -23,6 +23,7 @@ if TYPE_CHECKING: from torch.utils.data import Dataset + from matplotlib.figure import Figure class Visualizer(ABC): @@ -66,7 +67,12 @@ def __init__(self, channel_display_groups) @abstractmethod - def plot_xyz(self, axs, x: Tensor, y, z=None): + def plot_xyz(self, + axs, + x: Tensor, + y: Sequence, + z: Optional[Sequence] = None, + plot_title: bool = True): """Plot image, ground truth labels, and predicted labels. Args: @@ -95,11 +101,62 @@ def plot_batch(self, """ params = self.get_plot_params( x=x, y=y, z=z, output_path=output_path, batch_limit=batch_limit) - if params['fig_args']['nrows'] == 0: + if params['subplot_args']['nrows'] == 0: return - fig, axs = plt.subplots(**params['fig_args']) + if x.ndim == 4: + fig, axs = plt.subplots(**params['fig_args'], + **params['subplot_args']) + plot_xyz_args = params['plot_xyz_args'] + self._plot_batch(fig, axs, plot_xyz_args, x, y=y, z=z) + elif x.ndim == 5: + # If a temporal dimension is present, we divide the figure into + # multiple subfigures--one for each batch. Then, in each subfigure, + # we plot all timesteps as if they were a single batch. To + # delineate the boundary b/w batch items, we adopt the convention + # of only displaying subplot titles once per batch (above the first + # row in each batch). + batch_sz, T, *_ = x.shape + params['fig_args']['figsize'][1] *= T + fig = plt.figure(**params['fig_args']) + subfigs = fig.subfigures(nrows=batch_sz, ncols=1, hspace=0.0) + subfig_axs = [ + subfig.subplots( + nrows=T, ncols=params['subplot_args']['ncols']) + for subfig in subfigs.flat + ] + for i, axs in enumerate(subfig_axs): + plot_xyz_args = [ + dict(params['plot_xyz_args'][i]) for _ in range(T) + ] + plot_xyz_args[0]['plot_title'] = True + for args in plot_xyz_args[1:]: + args['plot_title'] = False + _x = x[i] + _y = [y[i]] * T + _z = None if z is None else [z[i]] * T + self._plot_batch(fig, axs, plot_xyz_args, _x, y=_y, z=_z) + else: + raise ValueError('Expected x to have 4 or 5 dims, but found ' + f'x.shape: {x.shape}') + if show: + plt.show() + if output_path is not None: + make_dir(output_path, use_dirname=True) + plt.savefig(output_path, bbox_inches='tight', pad_inches=0.2) + + plt.close(fig) + + def _plot_batch( + self, + fig: 'Figure', + axs: Sequence, + plot_xyz_args: List[dict], + x: Tensor, + y: Optional[Sequence] = None, + z: Optional[Sequence] = None, + ): # (N, c, h, w) --> (N, h, w, c) x = x.permute(0, 2, 3, 1) @@ -109,20 +166,9 @@ def plot_batch(self, imgs = [tf(image=img)['image'] for img in x.numpy()] x = torch.from_numpy(np.stack(imgs)) - plot_xyz_args = params['plot_xyz_args'] for i, row_axs in enumerate(axs): - if z is None: - self.plot_xyz(row_axs, x[i], y[i], **plot_xyz_args) - else: - self.plot_xyz(row_axs, x[i], y[i], z=z[i], **plot_xyz_args) - - if show: - plt.show() - if output_path is not None: - make_dir(output_path, use_dirname=True) - plt.savefig(output_path, bbox_inches='tight', pad_inches=0.2) - - plt.close(fig) + _z = None if z is None else z[i] + self.plot_xyz(row_axs, x[i], y[i], z=_z, **plot_xyz_args[i]) def get_channel_display_groups( self, nb_img_channels: int @@ -179,12 +225,14 @@ def get_plot_params(self, **kwargs) -> dict: ncols = self.get_plot_ncols(**kwargs) params = { 'fig_args': { + 'constrained_layout': True, + 'figsize': np.array((self.scale * ncols, self.scale * nrows)), + }, + 'subplot_args': { 'nrows': nrows, 'ncols': ncols, - 'constrained_layout': True, - 'figsize': (self.scale * ncols, self.scale * nrows), 'squeeze': False }, - 'plot_xyz_args': {} + 'plot_xyz_args': [{} for _ in range(nrows)] } return params diff --git a/rastervision_pytorch_learner/rastervision/pytorch_learner/utils/utils.py b/rastervision_pytorch_learner/rastervision/pytorch_learner/utils/utils.py index 3417be8c24..001e59e4dd 100644 --- a/rastervision_pytorch_learner/rastervision/pytorch_learner/utils/utils.py +++ b/rastervision_pytorch_learner/rastervision/pytorch_learner/utils/utils.py @@ -322,10 +322,12 @@ def adjust_conv_channels(old_conv: nn.Conv2d, def plot_channel_groups(axs: Iterable, imgs: Iterable[Union[np.array, torch.Tensor]], - channel_groups: dict) -> None: + channel_groups: dict, + plot_title: bool = True) -> None: for title, ax, img in zip(channel_groups.keys(), axs, imgs): ax.imshow(img) - ax.set_title(title) + if plot_title: + ax.set_title(title) ax.set_xticks([]) ax.set_yticks([]) diff --git a/tests/pytorch_learner/dataset/visualizer/test_classification_visualizer.py b/tests/pytorch_learner/dataset/visualizer/test_classification_visualizer.py new file mode 100644 index 0000000000..cb49cf406f --- /dev/null +++ b/tests/pytorch_learner/dataset/visualizer/test_classification_visualizer.py @@ -0,0 +1,50 @@ +from typing import Callable +import unittest + +import torch + +from rastervision.pytorch_learner.dataset import ClassificationVisualizer + + +class TestClassificationVisualizer(unittest.TestCase): + def assertNoError(self, fn: Callable, msg: str = ''): + try: + fn() + except Exception: + self.fail(msg) + + def test_plot_batch(self): + # w/o z + viz = ClassificationVisualizer( + class_names=['bg', 'fg'], + channel_display_groups=dict(RGB=[0, 1, 2], IR=[3])) + x = torch.randn(size=(2, 4, 256, 256)) + y = torch.tensor([0, 1]) + self.assertNoError(lambda: viz.plot_batch(x, y)) + + # w/ z + viz = ClassificationVisualizer( + class_names=['bg', 'fg'], + channel_display_groups=dict(RGB=[0, 1, 2], IR=[3])) + x = torch.randn(size=(2, 4, 256, 256)) + y = torch.tensor([0, 1]) + z = torch.tensor([[0.9, 0.1], [0.6, 0.4]]) + self.assertNoError(lambda: viz.plot_batch(x, y, z=z)) + + def test_plot_batch_temporal(self): + # w/o z + viz = ClassificationVisualizer( + class_names=['bg', 'fg'], + channel_display_groups=dict(RGB=[0, 1, 2], IR=[3])) + x = torch.randn(size=(2, 3, 4, 256, 256)) + y = torch.tensor([0, 1]) + self.assertNoError(lambda: viz.plot_batch(x, y)) + + # w/ z + viz = ClassificationVisualizer( + class_names=['bg', 'fg'], + channel_display_groups=dict(RGB=[0, 1, 2], IR=[3])) + x = torch.randn(size=(2, 3, 4, 256, 256)) + y = torch.tensor([0, 1]) + z = torch.tensor([[0.9, 0.1], [0.6, 0.4]]) + self.assertNoError(lambda: viz.plot_batch(x, y, z=z)) diff --git a/tests/pytorch_learner/dataset/visualizer/test_object_detection_visualizer.py b/tests/pytorch_learner/dataset/visualizer/test_object_detection_visualizer.py new file mode 100644 index 0000000000..5ec6a362b4 --- /dev/null +++ b/tests/pytorch_learner/dataset/visualizer/test_object_detection_visualizer.py @@ -0,0 +1,61 @@ +from typing import Callable +import unittest + +import torch + +from rastervision.core.box import Box +from rastervision.pytorch_learner.dataset import (ObjectDetectionVisualizer, + BoxList) + + +def random_boxlist(x, nboxes: int = 5) -> BoxList: + extent = Box(0, 0, *x.shape[-2:]) + boxes = [extent.make_random_square(50) for _ in range(nboxes)] + npboxes = torch.from_numpy(Box.to_npboxes(boxes)) + class_ids = torch.randint(0, 2, size=(nboxes, )) + scores = torch.rand(nboxes) + return BoxList(npboxes, class_ids=class_ids, scores=scores) + + +class TestClassificationVisualizer(unittest.TestCase): + def assertNoError(self, fn: Callable, msg: str = ''): + try: + fn() + except Exception: + self.fail(msg) + + def test_plot_batch(self): + # w/o z + viz = ObjectDetectionVisualizer( + class_names=['bg', 'fg'], + channel_display_groups=dict(RGB=[0, 1, 2], IR=[3])) + x = torch.randn(size=(2, 4, 256, 256)) + y = [random_boxlist(_x) for _x in x] + self.assertNoError(lambda: viz.plot_batch(x, y)) + + # w/ z + viz = ObjectDetectionVisualizer( + class_names=['bg', 'fg'], + channel_display_groups=dict(RGB=[0, 1, 2], IR=[3])) + x = torch.randn(size=(2, 4, 256, 256)) + y = [random_boxlist(_x) for _x in x] + z = [random_boxlist(_x) for _x in x] + self.assertNoError(lambda: viz.plot_batch(x, y, z=z)) + + def test_plot_batch_temporal(self): + # w/o z + viz = ObjectDetectionVisualizer( + class_names=['bg', 'fg'], + channel_display_groups=dict(RGB=[0, 1, 2], IR=[3])) + x = torch.randn(size=(2, 3, 4, 256, 256)) + y = [random_boxlist(_x) for _x in x] + self.assertNoError(lambda: viz.plot_batch(x, y)) + + # w/ z + viz = ObjectDetectionVisualizer( + class_names=['bg', 'fg'], + channel_display_groups=dict(RGB=[0, 1, 2], IR=[3])) + x = torch.randn(size=(2, 3, 4, 256, 256)) + y = [random_boxlist(_x) for _x in x] + z = [random_boxlist(_x) for _x in x] + self.assertNoError(lambda: viz.plot_batch(x, y, z=z)) diff --git a/tests/pytorch_learner/dataset/visualizer/test_regression_visualizer.py b/tests/pytorch_learner/dataset/visualizer/test_regression_visualizer.py new file mode 100644 index 0000000000..80ed357156 --- /dev/null +++ b/tests/pytorch_learner/dataset/visualizer/test_regression_visualizer.py @@ -0,0 +1,50 @@ +from typing import Callable +import unittest + +import torch + +from rastervision.pytorch_learner.dataset import RegressionVisualizer + + +class TestClassificationVisualizer(unittest.TestCase): + def assertNoError(self, fn: Callable, msg: str = ''): + try: + fn() + except Exception: + self.fail(msg) + + def test_plot_batch(self): + # w/o z + viz = RegressionVisualizer( + class_names=['bg', 'fg'], + channel_display_groups=dict(RGB=[0, 1, 2], IR=[3])) + x = torch.randn(size=(2, 4, 256, 256)) + y = torch.tensor([0.2, 1.3]) + self.assertNoError(lambda: viz.plot_batch(x, y)) + + # w/ z + viz = RegressionVisualizer( + class_names=['bg', 'fg'], + channel_display_groups=dict(RGB=[0, 1, 2], IR=[3])) + x = torch.randn(size=(2, 4, 256, 256)) + y = torch.tensor([0.2, 1.3]) + z = torch.tensor([0.1, 2]) + self.assertNoError(lambda: viz.plot_batch(x, y, z=z)) + + def test_plot_batch_temporal(self): + # w/o z + viz = RegressionVisualizer( + class_names=['bg', 'fg'], + channel_display_groups=dict(RGB=[0, 1, 2], IR=[3])) + x = torch.randn(size=(2, 3, 4, 256, 256)) + y = torch.tensor([0.2, 1.3]) + self.assertNoError(lambda: viz.plot_batch(x, y)) + + # w/ z + viz = RegressionVisualizer( + class_names=['bg', 'fg'], + channel_display_groups=dict(RGB=[0, 1, 2], IR=[3])) + x = torch.randn(size=(2, 3, 4, 256, 256)) + y = torch.tensor([0.2, 1.3]) + z = torch.tensor([0.1, 2]) + self.assertNoError(lambda: viz.plot_batch(x, y, z=z)) diff --git a/tests/pytorch_learner/dataset/visualizer/test_semantic_segmentation_visualizer.py b/tests/pytorch_learner/dataset/visualizer/test_semantic_segmentation_visualizer.py new file mode 100644 index 0000000000..e7e44272eb --- /dev/null +++ b/tests/pytorch_learner/dataset/visualizer/test_semantic_segmentation_visualizer.py @@ -0,0 +1,52 @@ +from typing import Callable +import unittest + +import torch + +from rastervision.pytorch_learner.dataset import SemanticSegmentationVisualizer + + +class TestClassificationVisualizer(unittest.TestCase): + def assertNoError(self, fn: Callable, msg: str = ''): + try: + fn() + except Exception: + self.fail(msg) + + def test_plot_batch(self): + # w/o z + viz = SemanticSegmentationVisualizer( + class_names=['bg', 'fg'], + channel_display_groups=dict(RGB=[0, 1, 2], IR=[3])) + x = torch.randn(size=(2, 4, 256, 256)) + y = (torch.randn(size=(2, 256, 256)) > 0).long() + self.assertNoError(lambda: viz.plot_batch(x, y)) + + # w/ z + viz = SemanticSegmentationVisualizer( + class_names=['bg', 'fg'], + channel_display_groups=dict(RGB=[0, 1, 2], IR=[3])) + num_classes = 2 + x = torch.randn(size=(2, 4, 256, 256)) + y = (torch.randn(size=(2, 256, 256)) > 0).long() + z = torch.randn(size=(2, num_classes, 256, 256)).softmax(dim=-3) + self.assertNoError(lambda: viz.plot_batch(x, y, z=z)) + + def test_plot_batch_temporal(self): + # w/o z + viz = SemanticSegmentationVisualizer( + class_names=['bg', 'fg'], + channel_display_groups=dict(RGB=[0, 1, 2], IR=[3])) + x = torch.randn(size=(2, 3, 4, 256, 256)) + y = (torch.randn(size=(2, 256, 256)) > 0).long() + self.assertNoError(lambda: viz.plot_batch(x, y)) + + # w/ z + viz = SemanticSegmentationVisualizer( + class_names=['bg', 'fg'], + channel_display_groups=dict(RGB=[0, 1, 2], IR=[3])) + num_classes = 2 + x = torch.randn(size=(2, 3, 4, 256, 256)) + y = (torch.randn(size=(2, 256, 256)) > 0).long() + z = torch.randn(size=(2, num_classes, 256, 256)).softmax(dim=-3) + self.assertNoError(lambda: viz.plot_batch(x, y, z=z)) diff --git a/tests/pytorch_learner/dataset/visualizer/test_visualizer.py b/tests/pytorch_learner/dataset/visualizer/test_visualizer.py index faf7b12dd0..6c4cff9f69 100644 --- a/tests/pytorch_learner/dataset/visualizer/test_visualizer.py +++ b/tests/pytorch_learner/dataset/visualizer/test_visualizer.py @@ -1,5 +1,7 @@ import unittest +import torch + from rastervision.pytorch_learner.dataset.visualizer import SemanticSegmentationVisualizer @@ -12,3 +14,16 @@ def test_get_batch_from_empty_dataset(self): with self.assertRaises(ValueError): viz.get_batch(ds) + + def test_plot_batch_invalid_x_shape(self): + viz = SemanticSegmentationVisualizer(class_names=['bg', 'fg']) + + y = (torch.randn(size=(2, 256, 256)) > 0).long() + + x = torch.randn(size=(2, 1, 3, 4, 256, 256)) + with self.assertRaises(ValueError): + viz.plot_batch(x, y) + + x = torch.randn(size=(4, 256, 256)) + with self.assertRaises(ValueError): + viz.plot_batch(x, y)