From 2fd48035c831489ee8fb7cd09b6126045976e883 Mon Sep 17 00:00:00 2001 From: MengzhangLI Date: Mon, 18 Jul 2022 12:55:24 +0800 Subject: [PATCH 01/12] [Feature] Add SegVisualizer --- mmseg/engine/hooks/__init__.py | 4 + mmseg/engine/hooks/visualization_hook.py | 101 ++++++++++ mmseg/visualization/__init__.py | 4 + mmseg/visualization/local_visualizer.py | 172 ++++++++++++++++ tests/test_engine/test_visualization_hook.py | 62 ++++++ .../test_local_visualizer.py | 187 ++++++++++++++++++ tools/browse_dataset.py | 166 +++------------- tools/test.py | 32 +++ 8 files changed, 591 insertions(+), 137 deletions(-) create mode 100644 mmseg/engine/hooks/__init__.py create mode 100644 mmseg/engine/hooks/visualization_hook.py create mode 100644 mmseg/visualization/__init__.py create mode 100644 mmseg/visualization/local_visualizer.py create mode 100644 tests/test_engine/test_visualization_hook.py create mode 100644 tests/test_visualization/test_local_visualizer.py diff --git a/mmseg/engine/hooks/__init__.py b/mmseg/engine/hooks/__init__.py new file mode 100644 index 0000000000..c6048088a7 --- /dev/null +++ b/mmseg/engine/hooks/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .visualization_hook import SegVisualizationHook + +__all__ = ['SegVisualizationHook'] diff --git a/mmseg/engine/hooks/visualization_hook.py b/mmseg/engine/hooks/visualization_hook.py new file mode 100644 index 0000000000..40725cda68 --- /dev/null +++ b/mmseg/engine/hooks/visualization_hook.py @@ -0,0 +1,101 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import os.path as osp +import warnings +from typing import Sequence + +import mmcv +from mmengine.hooks import Hook +from mmengine.runner import Runner + +from mmseg.data import SegDataSample +from mmseg.registry import HOOKS +from mmseg.visualization import SegLocalVisualizer + + +@HOOKS.register_module() +class SegVisualizationHook(Hook): + """Segmentation Visualization Hook. Used to visualize validation and + testing process prediction results. + + In the testing phase: + + 1. If ``show`` is True, it means that only the prediction results are + visualized without storing data, so ``vis_backends`` needs to + be excluded. + + Args: + draw (bool): whether to draw prediction results. If it is False, + it means that no drawing will be done. Defaults to False. + interval (int): The interval of visualization. Defaults to 50. + show (bool): Whether to display the drawn image. Default to False. + wait_time (float): The interval of show (s). Defaults to 0. + file_client_args (dict): Arguments to instantiate a FileClient. + See :class:`mmcv.fileio.FileClient` for details. + Defaults to ``dict(backend='disk')``. + """ + + def __init__(self, + draw: bool = False, + interval: int = 50, + show: bool = False, + wait_time: float = 0., + file_client_args: dict = dict(backend='disk')): + self._visualizer: SegLocalVisualizer = \ + SegLocalVisualizer.get_current_instance() + self.interval = interval + self.show = show + if self.show: + # No need to think about vis backends. + self._visualizer._vis_backends = {} + warnings.warn('The show is True, it means that only ' + 'the prediction results are visualized ' + 'without storing data, so vis_backends ' + 'needs to be excluded.') + + self.wait_time = wait_time + self.file_client_args = file_client_args.copy() + self.file_client = None + self.draw = draw + if not self.draw: + warnings.warn('The draw is False, it means that the ' + 'hook for visualization will not take ' + 'effect. The results will NOT be ' + 'visualized or stored.') + + def after_iter(self, + runner: Runner, + batch_idx: int, + data_batch: Sequence[dict], + outputs: Sequence[SegDataSample], + mode: str = 'val') -> None: + """Run after every ``self.interval`` validation iterations. + + Args: + runner (:obj:`Runner`): The runner of the validation process. + batch_idx (int): The index of the current batch in the val loop. + data_batch (Sequence[dict]): Data from dataloader. + outputs (Sequence[:obj:`SegDataSample`]): Outputs from model. + mode (str): mode (str): Current mode of runner. Defaults to 'val'. + """ + if self.draw is False or mode == 'train': + return + + if self.file_client is None: + self.file_client = mmcv.FileClient(**self.file_client_args) + + if self.every_n_inner_iters(batch_idx, self.interval): + for input_data, output in zip(data_batch, outputs): + img_path = input_data['data_sample'].img_path + img_bytes = self.file_client.get(img_path) + img = mmcv.imfrombytes(img_bytes, channel_order='rgb') + window_name = f'{mode}_{osp.basename(img_path)}' + + gt_sample = input_data['data_sample'] + self._visualizer.add_datasample( + window_name, + img, + gt_sample=gt_sample, + pred_sample=output, + show=self.show, + wait_time=self.wait_time, + step=runner.iter) diff --git a/mmseg/visualization/__init__.py b/mmseg/visualization/__init__.py new file mode 100644 index 0000000000..8cbb211e52 --- /dev/null +++ b/mmseg/visualization/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .local_visualizer import SegLocalVisualizer + +__all__ = ['SegLocalVisualizer'] diff --git a/mmseg/visualization/local_visualizer.py b/mmseg/visualization/local_visualizer.py new file mode 100644 index 0000000000..66f2fafd44 --- /dev/null +++ b/mmseg/visualization/local_visualizer.py @@ -0,0 +1,172 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Dict, List, Optional, Tuple + +import numpy as np +from mmengine import Visualizer +from mmengine.data import PixelData +from mmengine.dist import master_only + +from mmseg.data import SegDataSample +from mmseg.registry import VISUALIZERS + + +@VISUALIZERS.register_module() +class SegLocalVisualizer(Visualizer): + """MMSegmentation Local Visualizer. + + Args: + name (str): Name of the instance. Defaults to 'visualizer'. + image (np.ndarray, optional): the origin image to draw. The format + should be RGB. Defaults to None. + vis_backends (list, optional): Visual backend config list. + Defaults to None. + save_dir (str, optional): Save file dir for all storage backends. + If it is None, the backend storage will not save any data. + alpha (int, float): The transparency of segmentation mask. + Defaults to 0.8. + + Examples: + >>> import numpy as np + >>> import torch + >>> from mmengine.data import PixelData + >>> from mmseg.data import SegDataSample + >>> from mmseg.visualization import SegLocalVisualizer + + >>> seg_local_visualizer = SegLocalVisualizer() + >>> image = np.random.randint(0, 256, + ... size=(10, 12, 3)).astype('uint8') + >>> gt_sem_seg_data = dict(data=torch.randint(0, 2, (1, 10, 12))) + >>> gt_sem_seg = PixelData(**gt_sem_seg_data) + >>> gt_seg_data_sample = SegDataSample() + >>> gt_seg_data_sample.gt_sem_seg = gt_sem_seg + >>> seg_local_visualizer.dataset_meta = dict( + >>> classes=('background', 'foreground'), + >>> palette=[[120, 120, 120], [6, 230, 230]]) + >>> seg_local_visualizer.add_datasample('out_file_name', + ... image, gt_seg_data_sample) + >>> seg_local_visualizer.add_datasample( + ... 'out_file_name', image, gt_seg_data_sample, + ... show=True) + """ + + def __init__(self, + name: str = 'visualizer', + image: Optional[np.ndarray] = None, + vis_backends: Optional[Dict] = None, + save_dir: Optional[str] = None, + alpha: float = 0.8, + **kwargs): + super().__init__(name, image, vis_backends, save_dir, **kwargs) + self.alpha = alpha + # Set default value. When calling + # `SegLocalVisualizer().dataset_meta=xxx`, + # it will override the default value. + self.dataset_meta = {} + + def _draw_sem_seg(self, image: np.ndarray, sem_seg: PixelData, + classes: Optional[Tuple[str]], + palette: Optional[List[List[int]]]) -> np.ndarray: + """Draw semantic seg of GT or prediction. + + Args: + image (np.ndarray): The image to draw. + sem_seg (:obj:`PixelData`): Data structure for + pixel-level annotations or predictions. + classes (Tuple[str], optional): Category information. + palette (List[List[int]], optional): The palette of + segmentation map. + + Returns: + np.ndarray: the drawn image which channel is RGB. + """ + num_classes = len(classes) + + sem_seg = sem_seg.data + ids = np.unique(sem_seg)[::-1] + legal_indices = ids < num_classes + ids = ids[legal_indices] + labels = np.array(ids, dtype=np.int64) + + colors = [palette[label] for label in labels] + + self.set_image(image) + + # draw semantic masks + for label, color in zip(labels, colors): + self.draw_binary_masks( + sem_seg == label, colors=[color], alphas=self.alpha) + + return self.get_image() + + @master_only + def add_datasample(self, + name: str, + image: np.ndarray, + gt_sample: Optional[SegDataSample] = None, + pred_sample: Optional[SegDataSample] = None, + draw_gt: bool = True, + draw_pred: bool = True, + show: bool = False, + wait_time: float = 0, + step: int = 0) -> None: + """Draw datasample and save to all backends. + + - If GT and prediction are plotted at the same time, they are + displayed in a stitched image where the left image is the + ground truth and the right image is the prediction. + - If ``show`` is True, all storage backends are ignored, and + the images will be displayed in a local window. + + Args: + name (str): The image identifier. + image (np.ndarray): The image to draw. + gt_sample (:obj:`SegDataSample`, optional): GT SegDataSample. + Defaults to None. + pred_sample (:obj:`SegDataSample`, optional): Prediction + SegDataSample. Defaults to None. + draw_gt (bool): Whether to draw GT SegDataSample. Default to True. + draw_pred (bool): Whether to draw Prediction SegDataSample. + Defaults to True. + show (bool): Whether to display the drawn image. Default to False. + wait_time (float): The interval of show (s). Defaults to 0. + step (int): Global step value to record. Defaults to 0. + """ + classes = self.dataset_meta.get('classes', None) + palette = self.dataset_meta.get('palette', None) + + gt_img_data = None + pred_img_data = None + + if draw_gt and gt_sample is not None: + gt_img_data = image + if 'gt_sem_seg' in gt_sample: + assert classes is not None, 'class information is ' \ + 'not provided when ' \ + 'visualizing semantic ' \ + 'segmentation results.' + gt_img_data = self._draw_sem_seg(gt_img_data, + gt_sample.gt_sem_seg, classes, + palette) + + if draw_pred and pred_sample is not None: + pred_img_data = image + if 'pred_sem_seg' in pred_sample: + assert classes is not None, 'class information is ' \ + 'not provided when ' \ + 'visualizing semantic ' \ + 'segmentation results.' + pred_img_data = self._draw_sem_seg(pred_img_data, + pred_sample.pred_sem_seg, + classes, palette) + + if gt_img_data is not None and pred_img_data is not None: + drawn_img = np.concatenate((gt_img_data, pred_img_data), axis=1) + elif gt_img_data is not None: + drawn_img = gt_img_data + else: + drawn_img = pred_img_data + + if show: + self.show(drawn_img, win_name=name, wait_time=wait_time) + else: + self.add_image(name, drawn_img, step) diff --git a/tests/test_engine/test_visualization_hook.py b/tests/test_engine/test_visualization_hook.py new file mode 100644 index 0000000000..f5b6ba352d --- /dev/null +++ b/tests/test_engine/test_visualization_hook.py @@ -0,0 +1,62 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from unittest import TestCase +from unittest.mock import Mock + +import torch +from mmengine.data import PixelData + +from mmseg.data import SegDataSample +from mmseg.engine.hooks import SegVisualizationHook +from mmseg.visualization import SegLocalVisualizer + + +class TestVisualizationHook(TestCase): + + def setUp(self) -> None: + + h = 288 + w = 512 + num_class = 2 + + SegLocalVisualizer.get_instance('visualizer') + SegLocalVisualizer.dataset_meta = dict( + classes=('background', 'foreground'), + palette=[[120, 120, 120], [6, 230, 230]]) + + data_sample = SegDataSample() + data_sample.set_metainfo({'img_path': 'tests/data/color.jpg'}) + self.data_batch = [{'data_sample': data_sample}] * 2 + + pred_sem_seg_data = dict(data=torch.randint(0, num_class, (1, h, w))) + pred_sem_seg = PixelData(**pred_sem_seg_data) + pred_seg_data_sample = SegDataSample() + pred_seg_data_sample.pred_sem_seg = pred_sem_seg + self.outputs = [pred_seg_data_sample] * 2 + + def test_after_iter(self): + runner = Mock() + runner.iter = 1 + hook = SegVisualizationHook(draw=True, interval=1) + hook._after_iter( + runner, 1, self.data_batch, self.outputs, mode='train') + hook._after_iter(runner, 1, self.data_batch, self.outputs, mode='val') + hook._after_iter(runner, 1, self.data_batch, self.outputs, mode='test') + + def test_after_val_iter(self): + runner = Mock() + runner.iter = 2 + hook = SegVisualizationHook(interval=1) + hook.after_val_iter(runner, 1, self.data_batch, self.outputs) + + hook = SegVisualizationHook(draw=True, interval=1) + hook.after_val_iter(runner, 1, self.data_batch, self.outputs) + + hook = SegVisualizationHook( + draw=True, interval=1, show=True, wait_time=1) + hook.after_val_iter(runner, 1, self.data_batch, self.outputs) + + def test_after_test_iter(self): + runner = Mock() + runner.iter = 3 + hook = SegVisualizationHook(draw=True, interval=1) + hook.after_iter(runner, 1, self.data_batch, self.outputs) diff --git a/tests/test_visualization/test_local_visualizer.py b/tests/test_visualization/test_local_visualizer.py new file mode 100644 index 0000000000..b1c20edfbd --- /dev/null +++ b/tests/test_visualization/test_local_visualizer.py @@ -0,0 +1,187 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import os +import os.path as osp +from unittest import TestCase + +import cv2 +import mmcv +import numpy as np +import torch +from mmengine.data import PixelData + +from mmseg.data import SegDataSample +from mmseg.visualization import SegLocalVisualizer + + +class TestSegLocalVisualizer(TestCase): + + def test_add_datasample(self): + h = 10 + w = 12 + num_class = 2 + out_file = 'out_file' + + image = np.random.randint(0, 256, size=(h, w, 3)).astype('uint8') + + # test gt_sem_seg + gt_sem_seg_data = dict(data=torch.randint(0, num_class, (1, h, w))) + gt_sem_seg = PixelData(**gt_sem_seg_data) + + gt_seg_data_sample = SegDataSample() + gt_seg_data_sample.gt_sem_seg = gt_sem_seg + + seg_local_visualizer = SegLocalVisualizer( + vis_backends=[dict(type='LocalVisBackend')], save_dir='temp_dir') + seg_local_visualizer.dataset_meta = dict( + classes=('background', 'foreground'), + palette=[[120, 120, 120], [6, 230, 230]]) + seg_local_visualizer.add_datasample(out_file, image, + gt_seg_data_sample) + + # test out_file + seg_local_visualizer.add_datasample(out_file, image, + gt_seg_data_sample) + + assert os.path.exists( + osp.join('temp_dir' + '/vis_data/vis_image', out_file + '_0.png')) + drawn_img = cv2.imread( + osp.join('temp_dir' + '/vis_data/vis_image', out_file + '_0.png')) + assert drawn_img.shape == (h, w, 3) + + os.remove( + osp.join('temp_dir' + '/vis_data/vis_image', out_file + '_0.png')) + os.rmdir('temp_dir' + '/vis_data/vis_image') + + # test gt_instances and pred_instances + pred_sem_seg_data = dict(data=torch.randint(0, num_class, (1, h, w))) + pred_sem_seg = PixelData(**pred_sem_seg_data) + + pred_seg_data_sample = SegDataSample() + pred_seg_data_sample.pred_sem_seg = pred_sem_seg + + seg_local_visualizer.add_datasample(out_file, image, + gt_seg_data_sample, + pred_seg_data_sample) + self._assert_image_and_shape( + osp.join('temp_dir' + '/vis_data/vis_image', out_file + '_0.png'), + (h, w * 2, 3)) + + seg_local_visualizer.add_datasample( + out_file, + image, + gt_seg_data_sample, + pred_seg_data_sample, + draw_gt=False) + self._assert_image_and_shape( + osp.join('temp_dir' + '/vis_data/vis_image', out_file + '_0.png'), + (h, w, 3)) + + seg_local_visualizer.add_datasample( + out_file, + image, + gt_seg_data_sample, + pred_seg_data_sample, + draw_pred=False) + self._assert_image_and_shape( + osp.join('temp_dir' + '/vis_data/vis_image', out_file + '_0.png'), + (h, w, 3)) + os.rmdir('temp_dir/vis_data') + os.rmdir('temp_dir') + + def test_cityscapes_add_datasample(self): + h = 128 + w = 256 + num_class = 19 + out_file = 'out_file_cityscapes' + + image = mmcv.imread( + osp.join( + osp.dirname(__file__), + '../data/pseudo_cityscapes_dataset/leftImg8bit/frankfurt_000000_000294_leftImg8bit.png' # noqa + ), + 'color') + sem_seg = mmcv.imread( + osp.join( + osp.dirname(__file__), + '../data/pseudo_cityscapes_dataset/gtFine/frankfurt_000000_000294_gtFine_labelTrainIds.png' # noqa + ), + 'unchanged') + sem_seg = torch.unsqueeze(torch.from_numpy(sem_seg), 0) + gt_sem_seg_data = dict(data=sem_seg) + gt_sem_seg = PixelData(**gt_sem_seg_data) + + gt_seg_data_sample = SegDataSample() + gt_seg_data_sample.gt_sem_seg = gt_sem_seg + + seg_local_visualizer = SegLocalVisualizer( + vis_backends=[dict(type='LocalVisBackend')], save_dir='temp_dir') + seg_local_visualizer.dataset_meta = dict( + classes=('road', 'sidewalk', 'building', 'wall', 'fence', 'pole', + 'traffic light', 'traffic sign', 'vegetation', 'terrain', + 'sky', 'person', 'rider', 'car', 'truck', 'bus', 'train', + 'motorcycle', 'bicycle'), + palette=[[128, 64, 128], [244, 35, 232], [70, 70, 70], + [102, 102, 156], [190, 153, 153], [153, 153, 153], + [250, 170, 30], [220, 220, 0], [107, 142, 35], + [152, 251, 152], [70, 130, 180], [220, 20, 60], + [255, 0, 0], [0, 0, 142], [0, 0, 70], [0, 60, 100], + [0, 80, 100], [0, 0, 230], [119, 11, 32]]) + seg_local_visualizer.add_datasample(out_file, image, + gt_seg_data_sample) + + # test out_file + seg_local_visualizer.add_datasample(out_file, image, + gt_seg_data_sample) + assert os.path.exists( + osp.join('temp_dir' + '/vis_data/vis_image', out_file + '_0.png')) + drawn_img = cv2.imread( + osp.join('temp_dir' + '/vis_data/vis_image', out_file + '_0.png')) + assert drawn_img.shape == (h, w, 3) + + os.remove( + osp.join('temp_dir' + '/vis_data/vis_image', out_file + '_0.png')) + os.rmdir('temp_dir/vis_data/vis_image') + + # test gt_instances and pred_instances + pred_sem_seg_data = dict(data=torch.randint(0, num_class, (1, h, w))) + pred_sem_seg = PixelData(**pred_sem_seg_data) + + pred_seg_data_sample = SegDataSample() + pred_seg_data_sample.pred_sem_seg = pred_sem_seg + + seg_local_visualizer.add_datasample(out_file, image, + gt_seg_data_sample, + pred_seg_data_sample) + self._assert_image_and_shape( + osp.join('temp_dir' + '/vis_data/vis_image', out_file + '_0.png'), + (h, w * 2, 3)) + + seg_local_visualizer.add_datasample( + out_file, + image, + gt_seg_data_sample, + pred_seg_data_sample, + draw_gt=False) + self._assert_image_and_shape( + osp.join('temp_dir' + '/vis_data/vis_image', out_file + '_0.png'), + (h, w, 3)) + + seg_local_visualizer.add_datasample( + out_file, + image, + gt_seg_data_sample, + pred_seg_data_sample, + draw_pred=False) + self._assert_image_and_shape( + osp.join('temp_dir' + '/vis_data/vis_image', out_file + '_0.png'), + (h, w, 3)) + + os.rmdir('temp_dir/vis_data') + os.rmdir('temp_dir') + + def _assert_image_and_shape(self, out_file, out_shape): + assert os.path.exists(out_file) + drawn_img = cv2.imread(out_file) + assert drawn_img.shape == out_shape + os.remove(out_file) + os.rmdir('temp_dir/vis_data/vis_image') diff --git a/tools/browse_dataset.py b/tools/browse_dataset.py index 64fe695859..b5e0dc978d 100644 --- a/tools/browse_dataset.py +++ b/tools/browse_dataset.py @@ -1,48 +1,29 @@ # Copyright (c) OpenMMLab. All rights reserved. import argparse -import os -import warnings -from pathlib import Path +import os.path as osp import mmcv -import numpy as np from mmcv import Config, DictAction from mmseg.datasets import DATASETS +from mmseg.registry import VISUALIZERS +from mmseg.utils import register_all_modules def parse_args(): parser = argparse.ArgumentParser(description='Browse a dataset') parser.add_argument('config', help='train config file path') - parser.add_argument( - '--show-origin', - default=False, - action='store_true', - help='if True, omit all augmentation in pipeline,' - ' show origin image and seg map') - parser.add_argument( - '--skip-type', - type=str, - nargs='+', - default=['DefaultFormatBundle', 'Normalize', 'Collect'], - help='skip some useless pipeline,if `show-origin` is true, ' - 'all pipeline except `Load` will be skipped') parser.add_argument( '--output-dir', - default='./output', + default=None, type=str, help='If there is no display interface, you can save it') - parser.add_argument('--show', default=False, action='store_true') + parser.add_argument('--not-show', default=False, action='store_true') parser.add_argument( '--show-interval', - type=int, - default=999, - help='the interval of show (ms)') - parser.add_argument( - '--opacity', type=float, - default=0.5, - help='the opacity of semantic map') + default=2, + help='the interval of show (s)') parser.add_argument( '--cfg-options', nargs='+', @@ -57,124 +38,35 @@ def parse_args(): return args -def imshow_semantic(img, - seg, - class_names, - palette=None, - win_name='', - show=False, - wait_time=0, - out_file=None, - opacity=0.5): - """Draw `result` over `img`. - - Args: - img (str or Tensor): The image to be displayed. - seg (Tensor): The semantic segmentation results to draw over - `img`. - class_names (list[str]): Names of each classes. - palette (list[list[int]]] | np.ndarray | None): The palette of - segmentation map. If None is given, random palette will be - generated. Default: None - win_name (str): The window name. - wait_time (int): Value of waitKey param. - Default: 0. - show (bool): Whether to show the image. - Default: False. - out_file (str or None): The filename to write the image. - Default: None. - opacity(float): Opacity of painted segmentation map. - Default 0.5. - Must be in (0, 1] range. - Returns: - img (Tensor): Only if not `show` or `out_file` - """ - img = mmcv.imread(img) - img = img.copy() - if palette is None: - palette = np.random.randint(0, 255, size=(len(class_names), 3)) - palette = np.array(palette) - assert palette.shape[0] == len(class_names) - assert palette.shape[1] == 3 - assert len(palette.shape) == 2 - assert 0 < opacity <= 1.0 - color_seg = np.zeros((seg.shape[0], seg.shape[1], 3), dtype=np.uint8) - for label, color in enumerate(palette): - color_seg[seg == label, :] = color - # convert to BGR - color_seg = color_seg[..., ::-1] - - img = img * (1 - opacity) + color_seg * opacity - img = img.astype(np.uint8) - # if out_file specified, do not show image in window - if out_file is not None: - show = False - - if show: - mmcv.imshow(img, win_name, wait_time) - if out_file is not None: - mmcv.imwrite(img, out_file) +def main(): + args = parse_args() + cfg = Config.fromfile(args.config) + if args.cfg_options is not None: + cfg.merge_from_dict(args.cfg_options) - if not (show or out_file): - warnings.warn('show==False and out_file is not specified, only ' - 'result image will be returned') - return img + # register all modules in mmseg into the registries + register_all_modules() + dataset = DATASETS.build(cfg.train_dataloader.dataset) -def _retrieve_data_cfg(_data_cfg, skip_type, show_origin): - if show_origin is True: - # only keep pipeline of Loading data and ann - _data_cfg['pipeline'] = [ - x for x in _data_cfg.pipeline if 'Load' in x['type'] - ] - else: - _data_cfg['pipeline'] = [ - x for x in _data_cfg.pipeline if x['type'] not in skip_type - ] + visualizer = VISUALIZERS.build(cfg.visualizer) + visualizer.dataset_meta = dataset.METAINFO + progress_bar = mmcv.ProgressBar(len(dataset)) + for item in dataset: + img = item['inputs'].permute(1, 2, 0).numpy() + data_sample = item['data_sample'].numpy() + img_path = osp.basename(item['data_sample'].img_path) -def retrieve_data_cfg(config_path, skip_type, cfg_options, show_origin=False): - cfg = Config.fromfile(config_path) - if cfg_options is not None: - cfg.merge_from_dict(cfg_options) - train_data_cfg = cfg.data.train - if isinstance(train_data_cfg, list): - for _data_cfg in train_data_cfg: - while 'dataset' in _data_cfg and _data_cfg[ - 'type'] != 'MultiImageMixDataset': - _data_cfg = _data_cfg['dataset'] - if 'pipeline' in _data_cfg: - _retrieve_data_cfg(_data_cfg, skip_type, show_origin) - else: - raise ValueError - else: - while 'dataset' in train_data_cfg and train_data_cfg[ - 'type'] != 'MultiImageMixDataset': - train_data_cfg = train_data_cfg['dataset'] - _retrieve_data_cfg(train_data_cfg, skip_type, show_origin) - return cfg + img = img[..., [2, 1, 0]] # bgr to rgb + visualizer.add_datasample( + osp.basename(img_path), + img, + data_sample, + show=not args.not_show, + wait_time=args.show_interval) -def main(): - args = parse_args() - cfg = retrieve_data_cfg(args.config, args.skip_type, args.cfg_options, - args.show_origin) - dataset = DATASETS.build(cfg.data.train) - progress_bar = mmcv.ProgressBar(len(dataset)) - for item in dataset: - filename = os.path.join(args.output_dir, - Path(item['filename']).name - ) if args.output_dir is not None else None - imshow_semantic( - item['img'], - item['gt_semantic_seg'], - dataset.CLASSES, - dataset.PALETTE, - show=args.show, - wait_time=args.show_interval, - out_file=filename, - opacity=args.opacity, - ) progress_bar.update() diff --git a/tools/test.py b/tools/test.py index e4e1b5d4d7..59f7c7095b 100644 --- a/tools/test.py +++ b/tools/test.py @@ -19,6 +19,15 @@ def parse_args(): '--work-dir', help=('if specified, the evaluation metric results will be dumped' 'into the directory as json')) + parser.add_argument( + '--show', action='store_true', help='show prediction results') + parser.add_argument( + '--show-dir', + help='directory where painted images will be saved. ' + 'If specified, it will be automatically saved ' + 'to the work_dir/timestamp/show_dir') + parser.add_argument( + '--wait-time', type=float, default=2, help='the interval of show (s)') parser.add_argument( '--cfg-options', nargs='+', @@ -42,6 +51,26 @@ def parse_args(): return args +def trigger_visualization_hook(cfg, args): + default_hooks = cfg.default_hooks + if 'visualization' in default_hooks: + visualization_hook = default_hooks['visualization'] + # Turn on visualization + visualization_hook['draw'] = True + if args.show: + visualization_hook['show'] = True + visualization_hook['wait_time'] = args.wait_time + if args.show_dir: + visualization_hook['test_out_dir'] = args.show_dir + else: + raise RuntimeError( + 'VisualizationHook must be included in default_hooks.' + 'refer to usage ' + '"visualization=dict(type=\'VisualizationHook\')"') + + return cfg + + def main(): args = parse_args() @@ -66,6 +95,9 @@ def main(): cfg.load_from = args.checkpoint + if args.show or args.show_dir: + cfg = trigger_visualization_hook(cfg, args) + # build the runner from config runner = Runner.from_cfg(cfg) From 761eccd8964220e27efdc820e71a777b956e52b3 Mon Sep 17 00:00:00 2001 From: MengzhangLI Date: Mon, 18 Jul 2022 14:47:40 +0800 Subject: [PATCH 02/12] change name to visualizer_example --- mmseg/visualization/local_visualizer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mmseg/visualization/local_visualizer.py b/mmseg/visualization/local_visualizer.py index 66f2fafd44..58d1d2b866 100644 --- a/mmseg/visualization/local_visualizer.py +++ b/mmseg/visualization/local_visualizer.py @@ -42,7 +42,7 @@ class SegLocalVisualizer(Visualizer): >>> seg_local_visualizer.dataset_meta = dict( >>> classes=('background', 'foreground'), >>> palette=[[120, 120, 120], [6, 230, 230]]) - >>> seg_local_visualizer.add_datasample('out_file_name', + >>> seg_local_visualizer.add_datasample('visualizer_example', ... image, gt_seg_data_sample) >>> seg_local_visualizer.add_datasample( ... 'out_file_name', image, gt_seg_data_sample, From 769c23ce6f5df19d308a7b26847e1735687b993f Mon Sep 17 00:00:00 2001 From: xiexinch Date: Mon, 18 Jul 2022 19:42:27 +0800 Subject: [PATCH 03/12] fix inference api --- demo/image_demo.py | 4 +++- mmseg/apis/inference.py | 53 ++++++++--------------------------------- 2 files changed, 13 insertions(+), 44 deletions(-) diff --git a/demo/image_demo.py b/demo/image_demo.py index 5cde1ac9cd..fc5496b7cf 100644 --- a/demo/image_demo.py +++ b/demo/image_demo.py @@ -2,7 +2,7 @@ from argparse import ArgumentParser from mmseg.apis import inference_model, init_model, show_result_pyplot -from mmseg.utils import get_palette +from mmseg.utils import get_palette, register_all_modules def main(): @@ -23,6 +23,8 @@ def main(): help='Opacity of painted segmentation map. In (0, 1] range.') args = parser.parse_args() + register_all_modules() + # build the model from a config file and a checkpoint file model = init_model(args.config, args.checkpoint, device=args.device) # test a single image diff --git a/mmseg/apis/inference.py b/mmseg/apis/inference.py index bdbae1d0cc..f42d5c9070 100644 --- a/mmseg/apis/inference.py +++ b/mmseg/apis/inference.py @@ -2,11 +2,10 @@ import matplotlib.pyplot as plt import mmcv import torch -from mmcv.parallel import collate, scatter from mmcv.runner import load_checkpoint from mmseg.datasets.transforms import Compose -from mmseg.models import build_segmentor +from mmseg.registry import MODELS def init_model(config, checkpoint=None, device='cuda:0'): @@ -29,7 +28,7 @@ def init_model(config, checkpoint=None, device='cuda:0'): 'but got {}'.format(type(config))) config.model.pretrained = None config.model.train_cfg = None - model = build_segmentor(config.model, test_cfg=config.get('test_cfg')) + model = MODELS.build(config.model) if checkpoint is not None: checkpoint = load_checkpoint(model, checkpoint, map_location='cpu') model.CLASSES = checkpoint['meta']['CLASSES'] @@ -40,33 +39,6 @@ def init_model(config, checkpoint=None, device='cuda:0'): return model -class LoadImage: - """A simple pipeline to load image.""" - - def __call__(self, results): - """Call function to load images into results. - - Args: - results (dict): A result dict contains the file name - of the image to be read. - - Returns: - dict: ``results`` will be returned containing loaded image. - """ - - if isinstance(results['img'], str): - results['filename'] = results['img'] - results['ori_filename'] = results['img'] - else: - results['filename'] = None - results['ori_filename'] = None - img = mmcv.imread(results['img']) - results['img'] = img - results['img_shape'] = img.shape - results['ori_shape'] = img.shape - return results - - def inference_model(model, img): """Inference image(s) with the segmentor. @@ -79,23 +51,18 @@ def inference_model(model, img): (list[Tensor]): The segmentation result. """ cfg = model.cfg - device = next(model.parameters()).device # model device - # build the data pipeline - test_pipeline = [LoadImage()] + cfg.data.test.pipeline[1:] - test_pipeline = Compose(test_pipeline) + if dict(type='LoadAnnotations') in cfg.test_pipeline: + cfg.test_pipeline.remove(dict(type='LoadAnnotations')) + pipeline = Compose(cfg.test_pipeline) + # prepare data - data = dict(img=img) - data = test_pipeline(data) - data = collate([data], samples_per_gpu=1) - if next(model.parameters()).is_cuda: - # scatter to specified GPU - data = scatter(data, [device])[0] - else: - data['img_metas'] = [i.data[0] for i in data['img_metas']] + data, data_samples = model.data_preprocessor( + [pipeline(dict(img_path=img))], False) # forward the model with torch.no_grad(): - result = model(return_loss=False, rescale=True, **data) + result = model.forward(data, data_samples, mode='predict') + return result From e86b5ee50dc35c2e5e4b66850bc9fe3b34bdbb9e Mon Sep 17 00:00:00 2001 From: xiexinch Date: Thu, 21 Jul 2022 17:30:08 +0800 Subject: [PATCH 04/12] fix video demo and refine inference api --- demo/image_demo.py | 21 +++++-- demo/inference_demo.ipynb | 20 +++++-- demo/video_demo.py | 17 +++--- mmseg/apis/inference.py | 117 ++++++++++++++++++++++++++++---------- 4 files changed, 126 insertions(+), 49 deletions(-) diff --git a/demo/image_demo.py b/demo/image_demo.py index fc5496b7cf..a52e37c83e 100644 --- a/demo/image_demo.py +++ b/demo/image_demo.py @@ -1,8 +1,10 @@ # Copyright (c) OpenMMLab. All rights reserved. from argparse import ArgumentParser +from mmengine.utils import revert_sync_batchnorm + from mmseg.apis import inference_model, init_model, show_result_pyplot -from mmseg.utils import get_palette, register_all_modules +from mmseg.utils import register_all_modules def main(): @@ -13,20 +15,24 @@ def main(): parser.add_argument( '--device', default='cuda:0', help='Device used for inference') parser.add_argument( - '--palette', - default='cityscapes', - help='Color palette used for segmentation map') + '--save-dir', + default=None, + help='Save file dir for all storage backends.') parser.add_argument( '--opacity', type=float, default=0.5, help='Opacity of painted segmentation map. In (0, 1] range.') + parser.add_argument( + '--title', default='result', help='The image identifier.') args = parser.parse_args() register_all_modules() # build the model from a config file and a checkpoint file model = init_model(args.config, args.checkpoint, device=args.device) + if args.device == 'cpu': + model = revert_sync_batchnorm(model) # test a single image result = inference_model(model, args.img) # show the results @@ -34,8 +40,11 @@ def main(): model, args.img, result, - get_palette(args.palette), - opacity=args.opacity) + title=args.title, + opacity=args.opacity, + draw_gt=False, + show=False if args.save_dir is not None else True, + save_dir=args.save_dir) if __name__ == '__main__': diff --git a/demo/inference_demo.ipynb b/demo/inference_demo.ipynb index e54d509ff7..8a90bd9507 100644 --- a/demo/inference_demo.ipynb +++ b/demo/inference_demo.ipynb @@ -20,8 +20,11 @@ }, "outputs": [], "source": [ + "import torch\n", + "from mmengine import revert_sync_batchnorm\n", "from mmseg.apis import init_model, inference_model, show_result_pyplot\n", - "from mmseg.utils import get_palette" + "from mmseg.utils import register_all_modules\n", + "register_all_modules()" ] }, { @@ -45,7 +48,7 @@ "outputs": [], "source": [ "# build the model from a config file and a checkpoint file\n", - "model = init_model(config_file, checkpoint_file, device='cuda:0')" + "model = init_model(config_file, checkpoint_file, device='cpu')" ] }, { @@ -56,6 +59,8 @@ "source": [ "# test a single image\n", "img = 'demo.png'\n", + "if not torch.cuda.is_available():\n", + " model = revert_sync_batchnorm(model)\n", "result = inference_model(model, img)" ] }, @@ -66,7 +71,7 @@ "outputs": [], "source": [ "# show the results\n", - "show_result_pyplot(model, img, result, get_palette('cityscapes'))" + "show_result_pyplot(model, img, result)" ] }, { @@ -79,7 +84,7 @@ ], "metadata": { "kernelspec": { - "display_name": "Python 3", + "display_name": "Python 3.10.4 ('pt1.11-v2')", "language": "python", "name": "python3" }, @@ -93,7 +98,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.7.0" + "version": "3.10.4" }, "pycharm": { "stem_cell": { @@ -103,6 +108,11 @@ }, "source": [] } + }, + "vscode": { + "interpreter": { + "hash": "fdab7187f8cbd4ce42bbf864ddb4c4693e7329271a15a7fa96e4bdb82b9302c9" + } } }, "nbformat": 4, diff --git a/demo/video_demo.py b/demo/video_demo.py index 5b844f1617..cc3853ff46 100644 --- a/demo/video_demo.py +++ b/demo/video_demo.py @@ -2,9 +2,11 @@ from argparse import ArgumentParser import cv2 +from mmengine.utils import revert_sync_batchnorm from mmseg.apis import inference_model, init_model -from mmseg.utils import get_palette +from mmseg.apis.inference import show_result_pyplot +from mmseg.utils import register_all_modules def main(): @@ -51,10 +53,16 @@ def main(): assert args.show or args.output_file, \ 'At least one output should be enabled.' + register_all_modules() + # build the model from a config file and a checkpoint file model = init_model(args.config, args.checkpoint, device=args.device) + if args.device == 'cpu': + model = revert_sync_batchnorm(model) # build input video + if args.video.isdigit(): + args.video = int(args.video) cap = cv2.VideoCapture(args.video) assert (cap.isOpened()) input_height = cap.get(cv2.CAP_PROP_FRAME_HEIGHT) @@ -86,12 +94,7 @@ def main(): result = inference_model(model, frame) # blend raw image and prediction - draw_img = model.show_result( - frame, - result, - palette=get_palette(args.palette), - show=False, - opacity=args.opacity) + draw_img = show_result_pyplot(model, frame, result) if args.show: cv2.imshow('video_demo', draw_img) diff --git a/mmseg/apis/inference.py b/mmseg/apis/inference.py index f42d5c9070..b3bc9516a6 100644 --- a/mmseg/apis/inference.py +++ b/mmseg/apis/inference.py @@ -1,11 +1,16 @@ # Copyright (c) OpenMMLab. All rights reserved. -import matplotlib.pyplot as plt +from typing import Union + import mmcv +import numpy as np import torch from mmcv.runner import load_checkpoint from mmseg.datasets.transforms import Compose +from mmseg.models import BaseSegmentor from mmseg.registry import MODELS +from mmseg.utils import SampleList +from mmseg.visualization import SegLocalVisualizer def init_model(config, checkpoint=None, device='cuda:0'): @@ -39,7 +44,50 @@ def init_model(config, checkpoint=None, device='cuda:0'): return model -def inference_model(model, img): +def _preprare_data(img, model: BaseSegmentor): + + cfg = model.cfg + if dict(type='LoadAnnotations') in cfg.test_pipeline: + cfg.test_pipeline.remove(dict(type='LoadAnnotations')) + # TODO: Consider using the singleton pattern to avoid building + # a pipeline for each inference + pipeline = Compose(cfg.test_pipeline) + + if isinstance(img, str): + data, data_samples = model.data_preprocessor( + [pipeline(dict(img_path=img))], False) + elif isinstance(img, np.ndarray): + pipeline.transforms.pop(0) + data, data_samples = model.data_preprocessor([ + pipeline( + dict( + img=img, img_shape=img.shape[:2], ori_shape=img.shape[:2])) + ], False) + elif isinstance(img, list) and len(img) >= 0: + if isinstance(img[0], str): + data, data_samples = model.data_preprocessor( + [pipeline(dict(img_path=_img)) for _img in img], False) + elif isinstance(img[0], np.ndarray): + pipeline.transforms.pop(0) + data, data_samples = model.data_preprocessor([ + pipeline( + dict( + img=img, + img_shape=img.shape[:2], + ori_shape=img.shape[:2])) for _img in img + ], False) + else: + raise f'Unexpected img type, support str/ndarray\ + but got {type(img)}' + + else: + raise f'Unexpected img type, support str/ndarray or list[str/ndarray]\ + but got {type(img)}' + + return data, data_samples + + +def inference_model(model: BaseSegmentor, img): """Inference image(s) with the segmentor. Args: @@ -50,14 +98,8 @@ def inference_model(model, img): Returns: (list[Tensor]): The segmentation result. """ - cfg = model.cfg - if dict(type='LoadAnnotations') in cfg.test_pipeline: - cfg.test_pipeline.remove(dict(type='LoadAnnotations')) - pipeline = Compose(cfg.test_pipeline) - # prepare data - data, data_samples = model.data_preprocessor( - [pipeline(dict(img_path=img))], False) + data, data_samples = _preprare_data(img, model) # forward the model with torch.no_grad(): @@ -66,38 +108,51 @@ def inference_model(model, img): return result -def show_result_pyplot(model, - img, - result, - palette=None, - fig_size=(15, 10), - opacity=0.5, - title='', - block=True): +def show_result_pyplot(model: BaseSegmentor, + img: Union[str, np.ndarray], + result: SampleList, + opacity: float = 0.5, + title: str = '', + draw_gt: bool = True, + draw_pred: bool = True, + wait_time: float = 0, + show: bool = True, + save_dir=None): """Visualize the segmentation results on the image. Args: model (nn.Module): The loaded segmentor. img (str or np.ndarray): Image filename or loaded image. - result (list): The segmentation result. - palette (list[list[int]]] | None): The palette of segmentation - map. If None is given, random palette will be generated. - Default: None - fig_size (tuple): Figure size of the pyplot figure. + result (list): The prediction SegDataSample result. opacity(float): Opacity of painted segmentation map. Default 0.5. Must be in (0, 1] range. title (str): The title of pyplot figure. Default is ''. - block (bool): Whether to block the pyplot figure. - Default is True. + show (bool): Whether to display the drawn image. + Default to True. """ if hasattr(model, 'module'): model = model.module - img = model.show_result( - img, result, palette=palette, show=False, opacity=opacity) - plt.figure(figsize=fig_size) - plt.imshow(mmcv.bgr2rgb(img)) - plt.title(title) - plt.tight_layout() - plt.show(block=block) + if isinstance(img, str): + image = mmcv.imread(img) + else: + image = img + if save_dir is not None: + mmcv.mkdir_or_exist(save_dir) + # init visualizer + visualizer = SegLocalVisualizer( + vis_backends=[dict(type='LocalVisBackend')], + save_dir=save_dir, + alpha=opacity) + visualizer.dataset_meta = dict( + classes=model.CLASSES, palette=model.PALETTE) + visualizer.add_datasample( + name=title, + image=image, + pred_sample=result[0], + draw_gt=draw_gt, + draw_pred=draw_pred, + wait_time=wait_time, + show=show) + return visualizer.get_image() From ea2fcd318f7b7856da4c445ba17380b442e03318 Mon Sep 17 00:00:00 2001 From: xiexinch Date: Mon, 25 Jul 2022 14:42:40 +0800 Subject: [PATCH 05/12] fix --- mmseg/apis/inference.py | 53 +++++++++++++++-------------------------- 1 file changed, 19 insertions(+), 34 deletions(-) diff --git a/mmseg/apis/inference.py b/mmseg/apis/inference.py index b3bc9516a6..5746d65e64 100644 --- a/mmseg/apis/inference.py +++ b/mmseg/apis/inference.py @@ -5,8 +5,8 @@ import numpy as np import torch from mmcv.runner import load_checkpoint +from mmcv.transforms import Compose -from mmseg.datasets.transforms import Compose from mmseg.models import BaseSegmentor from mmseg.registry import MODELS from mmseg.utils import SampleList @@ -44,47 +44,32 @@ def init_model(config, checkpoint=None, device='cuda:0'): return model -def _preprare_data(img, model: BaseSegmentor): +def _preprare_data(imgs, model: BaseSegmentor): cfg = model.cfg if dict(type='LoadAnnotations') in cfg.test_pipeline: cfg.test_pipeline.remove(dict(type='LoadAnnotations')) + + if not isinstance(imgs, (list, tuple)): + imgs = [imgs] + + if isinstance(imgs[0], np.ndarray): + cfg.test_pipeline[0].type = 'LoadImageFromNDArray' + # TODO: Consider using the singleton pattern to avoid building # a pipeline for each inference pipeline = Compose(cfg.test_pipeline) - if isinstance(img, str): - data, data_samples = model.data_preprocessor( - [pipeline(dict(img_path=img))], False) - elif isinstance(img, np.ndarray): - pipeline.transforms.pop(0) - data, data_samples = model.data_preprocessor([ - pipeline( - dict( - img=img, img_shape=img.shape[:2], ori_shape=img.shape[:2])) - ], False) - elif isinstance(img, list) and len(img) >= 0: - if isinstance(img[0], str): - data, data_samples = model.data_preprocessor( - [pipeline(dict(img_path=_img)) for _img in img], False) - elif isinstance(img[0], np.ndarray): - pipeline.transforms.pop(0) - data, data_samples = model.data_preprocessor([ - pipeline( - dict( - img=img, - img_shape=img.shape[:2], - ori_shape=img.shape[:2])) for _img in img - ], False) + data = [] + for img in imgs: + if isinstance(img, np.ndarray): + data_ = dict(img=img) else: - raise f'Unexpected img type, support str/ndarray\ - but got {type(img)}' - - else: - raise f'Unexpected img type, support str/ndarray or list[str/ndarray]\ - but got {type(img)}' + data_ = dict(img_path=img) + data_ = pipeline(data_) + data.append(data_) - return data, data_samples + return data def inference_model(model: BaseSegmentor, img): @@ -99,11 +84,11 @@ def inference_model(model: BaseSegmentor, img): (list[Tensor]): The segmentation result. """ # prepare data - data, data_samples = _preprare_data(img, model) + data = _preprare_data(img, model) # forward the model with torch.no_grad(): - result = model.forward(data, data_samples, mode='predict') + result = model.test_step(data) return result From dda7f4b0cec6046b8dc9fb7d3f54fde8f51186ff Mon Sep 17 00:00:00 2001 From: xiexinch Date: Mon, 25 Jul 2022 14:45:03 +0800 Subject: [PATCH 06/12] mmseg compose --- mmseg/apis/inference.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mmseg/apis/inference.py b/mmseg/apis/inference.py index 5746d65e64..2e2cfe2c4e 100644 --- a/mmseg/apis/inference.py +++ b/mmseg/apis/inference.py @@ -5,8 +5,8 @@ import numpy as np import torch from mmcv.runner import load_checkpoint -from mmcv.transforms import Compose +from mmseg.datasets.transforms.compose import Compose from mmseg.models import BaseSegmentor from mmseg.registry import MODELS from mmseg.utils import SampleList From 5d701f912e98a6fc8d8d8e8998acdd3887ec0333 Mon Sep 17 00:00:00 2001 From: xiexinch Date: Wed, 27 Jul 2022 19:24:13 +0800 Subject: [PATCH 07/12] set default device to cuda:0 --- demo/inference_demo.ipynb | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/demo/inference_demo.ipynb b/demo/inference_demo.ipynb index 8a90bd9507..97a6dc923a 100644 --- a/demo/inference_demo.ipynb +++ b/demo/inference_demo.ipynb @@ -48,7 +48,7 @@ "outputs": [], "source": [ "# build the model from a config file and a checkpoint file\n", - "model = init_model(config_file, checkpoint_file, device='cpu')" + "model = init_model(config_file, checkpoint_file, device='cuda:0')" ] }, { From 4cdea1c04b03640d7c9d8bb1beef1bf369f97dfe Mon Sep 17 00:00:00 2001 From: xiexinch Date: Thu, 28 Jul 2022 10:45:17 +0800 Subject: [PATCH 08/12] fix import --- mmseg/apis/inference.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/mmseg/apis/inference.py b/mmseg/apis/inference.py index 2e2cfe2c4e..c0e0d05253 100644 --- a/mmseg/apis/inference.py +++ b/mmseg/apis/inference.py @@ -5,12 +5,13 @@ import numpy as np import torch from mmcv.runner import load_checkpoint +from mmengine import Config +from mmengine.dataset import Compose -from mmseg.datasets.transforms.compose import Compose +from mmseg.engine import SegLocalVisualizer from mmseg.models import BaseSegmentor from mmseg.registry import MODELS from mmseg.utils import SampleList -from mmseg.visualization import SegLocalVisualizer def init_model(config, checkpoint=None, device='cuda:0'): @@ -27,7 +28,7 @@ def init_model(config, checkpoint=None, device='cuda:0'): nn.Module: The constructed segmentor. """ if isinstance(config, str): - config = mmcv.Config.fromfile(config) + config = Config.fromfile(config) elif not isinstance(config, mmcv.Config): raise TypeError('config must be a filename or Config object, ' 'but got {}'.format(type(config))) From fcb3a86aeab644d598382e5372da6a8d15da8aab Mon Sep 17 00:00:00 2001 From: xiexinch Date: Thu, 28 Jul 2022 17:48:03 +0800 Subject: [PATCH 09/12] update dir --- mmseg/apis/inference.py | 2 +- mmseg/engine/__init__.py | 3 +- mmseg/engine/hooks/visualization_hook.py | 2 +- mmseg/engine/visualization/__init__.py | 4 - .../engine/visualization/local_visualizer.py | 172 ------------------ mmseg/visualization/local_visualizer.py | 6 +- tests/test_engine/test_local_visualizer.py | 2 +- tests/test_engine/test_visualization_hook.py | 2 +- 8 files changed, 8 insertions(+), 185 deletions(-) delete mode 100644 mmseg/engine/visualization/__init__.py delete mode 100644 mmseg/engine/visualization/local_visualizer.py diff --git a/mmseg/apis/inference.py b/mmseg/apis/inference.py index c0e0d05253..39bac7758e 100644 --- a/mmseg/apis/inference.py +++ b/mmseg/apis/inference.py @@ -8,10 +8,10 @@ from mmengine import Config from mmengine.dataset import Compose -from mmseg.engine import SegLocalVisualizer from mmseg.models import BaseSegmentor from mmseg.registry import MODELS from mmseg.utils import SampleList +from mmseg.visualization import SegLocalVisualizer def init_model(config, checkpoint=None, device='cuda:0'): diff --git a/mmseg/engine/__init__.py b/mmseg/engine/__init__.py index 517f811d5d..ada4057012 100644 --- a/mmseg/engine/__init__.py +++ b/mmseg/engine/__init__.py @@ -2,9 +2,8 @@ from .hooks import SegVisualizationHook from .optimizers import (LayerDecayOptimizerConstructor, LearningRateDecayOptimizerConstructor) -from .visualization import SegLocalVisualizer __all__ = [ 'LearningRateDecayOptimizerConstructor', 'LayerDecayOptimizerConstructor', - 'SegVisualizationHook', 'SegLocalVisualizer' + 'SegVisualizationHook' ] diff --git a/mmseg/engine/hooks/visualization_hook.py b/mmseg/engine/hooks/visualization_hook.py index bd9bc2f3a4..40725cda68 100644 --- a/mmseg/engine/hooks/visualization_hook.py +++ b/mmseg/engine/hooks/visualization_hook.py @@ -8,8 +8,8 @@ from mmengine.runner import Runner from mmseg.data import SegDataSample -from mmseg.engine.visualization import SegLocalVisualizer from mmseg.registry import HOOKS +from mmseg.visualization import SegLocalVisualizer @HOOKS.register_module() diff --git a/mmseg/engine/visualization/__init__.py b/mmseg/engine/visualization/__init__.py deleted file mode 100644 index 8cbb211e52..0000000000 --- a/mmseg/engine/visualization/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -from .local_visualizer import SegLocalVisualizer - -__all__ = ['SegLocalVisualizer'] diff --git a/mmseg/engine/visualization/local_visualizer.py b/mmseg/engine/visualization/local_visualizer.py deleted file mode 100644 index ea966fa5b2..0000000000 --- a/mmseg/engine/visualization/local_visualizer.py +++ /dev/null @@ -1,172 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -from typing import Dict, List, Optional, Tuple - -import numpy as np -from mmengine import Visualizer -from mmengine.data import PixelData -from mmengine.dist import master_only - -from mmseg.data import SegDataSample -from mmseg.registry import VISUALIZERS - - -@VISUALIZERS.register_module() -class SegLocalVisualizer(Visualizer): - """MMSegmentation Local Visualizer. - - Args: - name (str): Name of the instance. Defaults to 'visualizer'. - image (np.ndarray, optional): the origin image to draw. The format - should be RGB. Defaults to None. - vis_backends (list, optional): Visual backend config list. - Defaults to None. - save_dir (str, optional): Save file dir for all storage backends. - If it is None, the backend storage will not save any data. - alpha (int, float): The transparency of segmentation mask. - Defaults to 0.8. - - Examples: - >>> import numpy as np - >>> import torch - >>> from mmengine.data import PixelData - >>> from mmseg.data import SegDataSample - >>> from mmseg.engine.visualization import SegLocalVisualizer - - >>> seg_local_visualizer = SegLocalVisualizer() - >>> image = np.random.randint(0, 256, - ... size=(10, 12, 3)).astype('uint8') - >>> gt_sem_seg_data = dict(data=torch.randint(0, 2, (1, 10, 12))) - >>> gt_sem_seg = PixelData(**gt_sem_seg_data) - >>> gt_seg_data_sample = SegDataSample() - >>> gt_seg_data_sample.gt_sem_seg = gt_sem_seg - >>> seg_local_visualizer.dataset_meta = dict( - >>> classes=('background', 'foreground'), - >>> palette=[[120, 120, 120], [6, 230, 230]]) - >>> seg_local_visualizer.add_datasample('visualizer_example', - ... image, gt_seg_data_sample) - >>> seg_local_visualizer.add_datasample( - ... 'visualizer_example', image, - ... gt_seg_data_sample, show=True) - """ - - def __init__(self, - name: str = 'visualizer', - image: Optional[np.ndarray] = None, - vis_backends: Optional[Dict] = None, - save_dir: Optional[str] = None, - alpha: float = 0.8, - **kwargs): - super().__init__(name, image, vis_backends, save_dir, **kwargs) - self.alpha = alpha - # Set default value. When calling - # `SegLocalVisualizer().dataset_meta=xxx`, - # it will override the default value. - self.dataset_meta = {} - - def _draw_sem_seg(self, image: np.ndarray, sem_seg: PixelData, - classes: Optional[Tuple[str]], - palette: Optional[List[List[int]]]) -> np.ndarray: - """Draw semantic seg of GT or prediction. - - Args: - image (np.ndarray): The image to draw. - sem_seg (:obj:`PixelData`): Data structure for - pixel-level annotations or predictions. - classes (Tuple[str], optional): Category information. - palette (List[List[int]], optional): The palette of - segmentation map. - - Returns: - np.ndarray: the drawn image which channel is RGB. - """ - num_classes = len(classes) - - sem_seg = sem_seg.data - ids = np.unique(sem_seg)[::-1] - legal_indices = ids < num_classes - ids = ids[legal_indices] - labels = np.array(ids, dtype=np.int64) - - colors = [palette[label] for label in labels] - - self.set_image(image) - - # draw semantic masks - for label, color in zip(labels, colors): - self.draw_binary_masks( - sem_seg == label, colors=[color], alphas=self.alpha) - - return self.get_image() - - @master_only - def add_datasample(self, - name: str, - image: np.ndarray, - gt_sample: Optional[SegDataSample] = None, - pred_sample: Optional[SegDataSample] = None, - draw_gt: bool = True, - draw_pred: bool = True, - show: bool = False, - wait_time: float = 0, - step: int = 0) -> None: - """Draw datasample and save to all backends. - - - If GT and prediction are plotted at the same time, they are - displayed in a stitched image where the left image is the - ground truth and the right image is the prediction. - - If ``show`` is True, all storage backends are ignored, and - the images will be displayed in a local window. - - Args: - name (str): The image identifier. - image (np.ndarray): The image to draw. - gt_sample (:obj:`SegDataSample`, optional): GT SegDataSample. - Defaults to None. - pred_sample (:obj:`SegDataSample`, optional): Prediction - SegDataSample. Defaults to None. - draw_gt (bool): Whether to draw GT SegDataSample. Default to True. - draw_pred (bool): Whether to draw Prediction SegDataSample. - Defaults to True. - show (bool): Whether to display the drawn image. Default to False. - wait_time (float): The interval of show (s). Defaults to 0. - step (int): Global step value to record. Defaults to 0. - """ - classes = self.dataset_meta.get('classes', None) - palette = self.dataset_meta.get('palette', None) - - gt_img_data = None - pred_img_data = None - - if draw_gt and gt_sample is not None: - gt_img_data = image - if 'gt_sem_seg' in gt_sample: - assert classes is not None, 'class information is ' \ - 'not provided when ' \ - 'visualizing semantic ' \ - 'segmentation results.' - gt_img_data = self._draw_sem_seg(gt_img_data, - gt_sample.gt_sem_seg, classes, - palette) - - if draw_pred and pred_sample is not None: - pred_img_data = image - if 'pred_sem_seg' in pred_sample: - assert classes is not None, 'class information is ' \ - 'not provided when ' \ - 'visualizing semantic ' \ - 'segmentation results.' - pred_img_data = self._draw_sem_seg(pred_img_data, - pred_sample.pred_sem_seg, - classes, palette) - - if gt_img_data is not None and pred_img_data is not None: - drawn_img = np.concatenate((gt_img_data, pred_img_data), axis=1) - elif gt_img_data is not None: - drawn_img = gt_img_data - else: - drawn_img = pred_img_data - - if show: - self.show(drawn_img, win_name=name, wait_time=wait_time) - else: - self.add_image(name, drawn_img, step) diff --git a/mmseg/visualization/local_visualizer.py b/mmseg/visualization/local_visualizer.py index 58d1d2b866..ea966fa5b2 100644 --- a/mmseg/visualization/local_visualizer.py +++ b/mmseg/visualization/local_visualizer.py @@ -30,7 +30,7 @@ class SegLocalVisualizer(Visualizer): >>> import torch >>> from mmengine.data import PixelData >>> from mmseg.data import SegDataSample - >>> from mmseg.visualization import SegLocalVisualizer + >>> from mmseg.engine.visualization import SegLocalVisualizer >>> seg_local_visualizer = SegLocalVisualizer() >>> image = np.random.randint(0, 256, @@ -45,8 +45,8 @@ class SegLocalVisualizer(Visualizer): >>> seg_local_visualizer.add_datasample('visualizer_example', ... image, gt_seg_data_sample) >>> seg_local_visualizer.add_datasample( - ... 'out_file_name', image, gt_seg_data_sample, - ... show=True) + ... 'visualizer_example', image, + ... gt_seg_data_sample, show=True) """ def __init__(self, diff --git a/tests/test_engine/test_local_visualizer.py b/tests/test_engine/test_local_visualizer.py index 6100fe8561..b1c20edfbd 100644 --- a/tests/test_engine/test_local_visualizer.py +++ b/tests/test_engine/test_local_visualizer.py @@ -10,7 +10,7 @@ from mmengine.data import PixelData from mmseg.data import SegDataSample -from mmseg.engine.visualization import SegLocalVisualizer +from mmseg.visualization import SegLocalVisualizer class TestSegLocalVisualizer(TestCase): diff --git a/tests/test_engine/test_visualization_hook.py b/tests/test_engine/test_visualization_hook.py index a70fb612e4..f5b6ba352d 100644 --- a/tests/test_engine/test_visualization_hook.py +++ b/tests/test_engine/test_visualization_hook.py @@ -7,7 +7,7 @@ from mmseg.data import SegDataSample from mmseg.engine.hooks import SegVisualizationHook -from mmseg.engine.visualization import SegLocalVisualizer +from mmseg.visualization import SegLocalVisualizer class TestVisualizationHook(TestCase): From dd9b3caa95fcc425a191b4c45959fcc4633c05b2 Mon Sep 17 00:00:00 2001 From: xiexinch Date: Fri, 29 Jul 2022 11:18:27 +0800 Subject: [PATCH 10/12] rm engine/visualizer ut --- tests/test_engine/test_local_visualizer.py | 187 --------------------- 1 file changed, 187 deletions(-) delete mode 100644 tests/test_engine/test_local_visualizer.py diff --git a/tests/test_engine/test_local_visualizer.py b/tests/test_engine/test_local_visualizer.py deleted file mode 100644 index b1c20edfbd..0000000000 --- a/tests/test_engine/test_local_visualizer.py +++ /dev/null @@ -1,187 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -import os -import os.path as osp -from unittest import TestCase - -import cv2 -import mmcv -import numpy as np -import torch -from mmengine.data import PixelData - -from mmseg.data import SegDataSample -from mmseg.visualization import SegLocalVisualizer - - -class TestSegLocalVisualizer(TestCase): - - def test_add_datasample(self): - h = 10 - w = 12 - num_class = 2 - out_file = 'out_file' - - image = np.random.randint(0, 256, size=(h, w, 3)).astype('uint8') - - # test gt_sem_seg - gt_sem_seg_data = dict(data=torch.randint(0, num_class, (1, h, w))) - gt_sem_seg = PixelData(**gt_sem_seg_data) - - gt_seg_data_sample = SegDataSample() - gt_seg_data_sample.gt_sem_seg = gt_sem_seg - - seg_local_visualizer = SegLocalVisualizer( - vis_backends=[dict(type='LocalVisBackend')], save_dir='temp_dir') - seg_local_visualizer.dataset_meta = dict( - classes=('background', 'foreground'), - palette=[[120, 120, 120], [6, 230, 230]]) - seg_local_visualizer.add_datasample(out_file, image, - gt_seg_data_sample) - - # test out_file - seg_local_visualizer.add_datasample(out_file, image, - gt_seg_data_sample) - - assert os.path.exists( - osp.join('temp_dir' + '/vis_data/vis_image', out_file + '_0.png')) - drawn_img = cv2.imread( - osp.join('temp_dir' + '/vis_data/vis_image', out_file + '_0.png')) - assert drawn_img.shape == (h, w, 3) - - os.remove( - osp.join('temp_dir' + '/vis_data/vis_image', out_file + '_0.png')) - os.rmdir('temp_dir' + '/vis_data/vis_image') - - # test gt_instances and pred_instances - pred_sem_seg_data = dict(data=torch.randint(0, num_class, (1, h, w))) - pred_sem_seg = PixelData(**pred_sem_seg_data) - - pred_seg_data_sample = SegDataSample() - pred_seg_data_sample.pred_sem_seg = pred_sem_seg - - seg_local_visualizer.add_datasample(out_file, image, - gt_seg_data_sample, - pred_seg_data_sample) - self._assert_image_and_shape( - osp.join('temp_dir' + '/vis_data/vis_image', out_file + '_0.png'), - (h, w * 2, 3)) - - seg_local_visualizer.add_datasample( - out_file, - image, - gt_seg_data_sample, - pred_seg_data_sample, - draw_gt=False) - self._assert_image_and_shape( - osp.join('temp_dir' + '/vis_data/vis_image', out_file + '_0.png'), - (h, w, 3)) - - seg_local_visualizer.add_datasample( - out_file, - image, - gt_seg_data_sample, - pred_seg_data_sample, - draw_pred=False) - self._assert_image_and_shape( - osp.join('temp_dir' + '/vis_data/vis_image', out_file + '_0.png'), - (h, w, 3)) - os.rmdir('temp_dir/vis_data') - os.rmdir('temp_dir') - - def test_cityscapes_add_datasample(self): - h = 128 - w = 256 - num_class = 19 - out_file = 'out_file_cityscapes' - - image = mmcv.imread( - osp.join( - osp.dirname(__file__), - '../data/pseudo_cityscapes_dataset/leftImg8bit/frankfurt_000000_000294_leftImg8bit.png' # noqa - ), - 'color') - sem_seg = mmcv.imread( - osp.join( - osp.dirname(__file__), - '../data/pseudo_cityscapes_dataset/gtFine/frankfurt_000000_000294_gtFine_labelTrainIds.png' # noqa - ), - 'unchanged') - sem_seg = torch.unsqueeze(torch.from_numpy(sem_seg), 0) - gt_sem_seg_data = dict(data=sem_seg) - gt_sem_seg = PixelData(**gt_sem_seg_data) - - gt_seg_data_sample = SegDataSample() - gt_seg_data_sample.gt_sem_seg = gt_sem_seg - - seg_local_visualizer = SegLocalVisualizer( - vis_backends=[dict(type='LocalVisBackend')], save_dir='temp_dir') - seg_local_visualizer.dataset_meta = dict( - classes=('road', 'sidewalk', 'building', 'wall', 'fence', 'pole', - 'traffic light', 'traffic sign', 'vegetation', 'terrain', - 'sky', 'person', 'rider', 'car', 'truck', 'bus', 'train', - 'motorcycle', 'bicycle'), - palette=[[128, 64, 128], [244, 35, 232], [70, 70, 70], - [102, 102, 156], [190, 153, 153], [153, 153, 153], - [250, 170, 30], [220, 220, 0], [107, 142, 35], - [152, 251, 152], [70, 130, 180], [220, 20, 60], - [255, 0, 0], [0, 0, 142], [0, 0, 70], [0, 60, 100], - [0, 80, 100], [0, 0, 230], [119, 11, 32]]) - seg_local_visualizer.add_datasample(out_file, image, - gt_seg_data_sample) - - # test out_file - seg_local_visualizer.add_datasample(out_file, image, - gt_seg_data_sample) - assert os.path.exists( - osp.join('temp_dir' + '/vis_data/vis_image', out_file + '_0.png')) - drawn_img = cv2.imread( - osp.join('temp_dir' + '/vis_data/vis_image', out_file + '_0.png')) - assert drawn_img.shape == (h, w, 3) - - os.remove( - osp.join('temp_dir' + '/vis_data/vis_image', out_file + '_0.png')) - os.rmdir('temp_dir/vis_data/vis_image') - - # test gt_instances and pred_instances - pred_sem_seg_data = dict(data=torch.randint(0, num_class, (1, h, w))) - pred_sem_seg = PixelData(**pred_sem_seg_data) - - pred_seg_data_sample = SegDataSample() - pred_seg_data_sample.pred_sem_seg = pred_sem_seg - - seg_local_visualizer.add_datasample(out_file, image, - gt_seg_data_sample, - pred_seg_data_sample) - self._assert_image_and_shape( - osp.join('temp_dir' + '/vis_data/vis_image', out_file + '_0.png'), - (h, w * 2, 3)) - - seg_local_visualizer.add_datasample( - out_file, - image, - gt_seg_data_sample, - pred_seg_data_sample, - draw_gt=False) - self._assert_image_and_shape( - osp.join('temp_dir' + '/vis_data/vis_image', out_file + '_0.png'), - (h, w, 3)) - - seg_local_visualizer.add_datasample( - out_file, - image, - gt_seg_data_sample, - pred_seg_data_sample, - draw_pred=False) - self._assert_image_and_shape( - osp.join('temp_dir' + '/vis_data/vis_image', out_file + '_0.png'), - (h, w, 3)) - - os.rmdir('temp_dir/vis_data') - os.rmdir('temp_dir') - - def _assert_image_and_shape(self, out_file, out_shape): - assert os.path.exists(out_file) - drawn_img = cv2.imread(out_file) - assert drawn_img.shape == out_shape - os.remove(out_file) - os.rmdir('temp_dir/vis_data/vis_image') From 4b5fe84b157d14296520a74abb9034f49a9f4858 Mon Sep 17 00:00:00 2001 From: xiexinch Date: Fri, 29 Jul 2022 18:28:11 +0800 Subject: [PATCH 11/12] refine inference api and docs --- demo/image_demo.py | 3 +-- demo/video_demo.py | 2 +- mmseg/apis/inference.py | 34 ++++++++++++++++++++++++---------- 3 files changed, 26 insertions(+), 13 deletions(-) diff --git a/demo/image_demo.py b/demo/image_demo.py index a52e37c83e..3791ec9416 100644 --- a/demo/image_demo.py +++ b/demo/image_demo.py @@ -38,8 +38,7 @@ def main(): # show the results show_result_pyplot( model, - args.img, - result, + args.img, [result], title=args.title, opacity=args.opacity, draw_gt=False, diff --git a/demo/video_demo.py b/demo/video_demo.py index cc3853ff46..285a3c9215 100644 --- a/demo/video_demo.py +++ b/demo/video_demo.py @@ -94,7 +94,7 @@ def main(): result = inference_model(model, frame) # blend raw image and prediction - draw_img = show_result_pyplot(model, frame, result) + draw_img = show_result_pyplot(model, frame, [result]) if args.show: cv2.imshow('video_demo', draw_img) diff --git a/mmseg/apis/inference.py b/mmseg/apis/inference.py index 39bac7758e..0f5f602aa0 100644 --- a/mmseg/apis/inference.py +++ b/mmseg/apis/inference.py @@ -1,5 +1,5 @@ # Copyright (c) OpenMMLab. All rights reserved. -from typing import Union +from typing import Sequence, Union import mmcv import numpy as np @@ -8,6 +8,7 @@ from mmengine import Config from mmengine.dataset import Compose +from mmseg.data import SegDataSample from mmseg.models import BaseSegmentor from mmseg.registry import MODELS from mmseg.utils import SampleList @@ -45,14 +46,19 @@ def init_model(config, checkpoint=None, device='cuda:0'): return model -def _preprare_data(imgs, model: BaseSegmentor): +ImageType = Union[str, np.ndarray, Sequence[str], Sequence[np.ndarray]] + + +def _preprare_data(imgs: ImageType, model: BaseSegmentor): cfg = model.cfg if dict(type='LoadAnnotations') in cfg.test_pipeline: cfg.test_pipeline.remove(dict(type='LoadAnnotations')) + is_batch = True if not isinstance(imgs, (list, tuple)): imgs = [imgs] + is_batch = False if isinstance(imgs[0], np.ndarray): cfg.test_pipeline[0].type = 'LoadImageFromNDArray' @@ -70,10 +76,11 @@ def _preprare_data(imgs, model: BaseSegmentor): data_ = pipeline(data_) data.append(data_) - return data + return data, is_batch -def inference_model(model: BaseSegmentor, img): +def inference_model(model: BaseSegmentor, + img: ImageType) -> Union[SegDataSample, SampleList]: """Inference image(s) with the segmentor. Args: @@ -82,16 +89,18 @@ def inference_model(model: BaseSegmentor, img): images. Returns: - (list[Tensor]): The segmentation result. + :obj:`SegDataSample` or list[:obj:`SegDataSample`]: + If imgs is a list or tuple, the same length list type results + will be returned, otherwise return the detection results directly. """ # prepare data - data = _preprare_data(img, model) + data, is_batch = _preprare_data(img, model) # forward the model with torch.no_grad(): - result = model.test_step(data) + results = model.test_step(data) - return result + return results if is_batch else results[0] def show_result_pyplot(model: BaseSegmentor, @@ -111,12 +120,17 @@ def show_result_pyplot(model: BaseSegmentor, img (str or np.ndarray): Image filename or loaded image. result (list): The prediction SegDataSample result. opacity(float): Opacity of painted segmentation map. - Default 0.5. - Must be in (0, 1] range. + Default 0.5. Must be in (0, 1] range. title (str): The title of pyplot figure. Default is ''. + draw_gt (bool): Whether to draw GT SegDataSample. Default to True. + draw_pred (bool): Whether to draw Prediction SegDataSample. + Defaults to True. + wait_time (float): The interval of show (s). Defaults to 0. show (bool): Whether to display the drawn image. Default to True. + save_dir (str, optional): Save file dir for all storage backends. + If it is None, the backend storage will not save any data. """ if hasattr(model, 'module'): model = model.module From 5da2eb98f6533c1a7dd4e179dff4dc98b8a00361 Mon Sep 17 00:00:00 2001 From: xiexinch Date: Fri, 29 Jul 2022 18:30:38 +0800 Subject: [PATCH 12/12] rename --- mmseg/apis/inference.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mmseg/apis/inference.py b/mmseg/apis/inference.py index 0f5f602aa0..d26b1c87a2 100644 --- a/mmseg/apis/inference.py +++ b/mmseg/apis/inference.py @@ -91,7 +91,7 @@ def inference_model(model: BaseSegmentor, Returns: :obj:`SegDataSample` or list[:obj:`SegDataSample`]: If imgs is a list or tuple, the same length list type results - will be returned, otherwise return the detection results directly. + will be returned, otherwise return the segmentation results directly. """ # prepare data data, is_batch = _preprare_data(img, model)