Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 18 additions & 8 deletions demo/image_demo.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
# Copyright (c) OpenMMLab. All rights reserved.
from argparse import ArgumentParser

from mmengine.utils import revert_sync_batchnorm

from mmseg.apis import inference_model, init_model, show_result_pyplot
from mmseg.utils import get_palette
from mmseg.utils import register_all_modules


def main():
Expand All @@ -13,27 +15,35 @@ def main():
parser.add_argument(
'--device', default='cuda:0', help='Device used for inference')
parser.add_argument(
'--palette',
default='cityscapes',
help='Color palette used for segmentation map')
'--save-dir',
default=None,
help='Save file dir for all storage backends.')
parser.add_argument(
'--opacity',
type=float,
default=0.5,
help='Opacity of painted segmentation map. In (0, 1] range.')
parser.add_argument(
'--title', default='result', help='The image identifier.')
args = parser.parse_args()

register_all_modules()

# build the model from a config file and a checkpoint file
model = init_model(args.config, args.checkpoint, device=args.device)
if args.device == 'cpu':
model = revert_sync_batchnorm(model)
# test a single image
result = inference_model(model, args.img)
# show the results
show_result_pyplot(
model,
args.img,
result,
get_palette(args.palette),
opacity=args.opacity)
args.img, [result],
title=args.title,
opacity=args.opacity,
draw_gt=False,
show=False if args.save_dir is not None else True,
save_dir=args.save_dir)


if __name__ == '__main__':
Expand Down
18 changes: 14 additions & 4 deletions demo/inference_demo.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,11 @@
},
"outputs": [],
"source": [
"import torch\n",
"from mmengine import revert_sync_batchnorm\n",
"from mmseg.apis import init_model, inference_model, show_result_pyplot\n",
"from mmseg.utils import get_palette"
"from mmseg.utils import register_all_modules\n",
"register_all_modules()"
]
},
{
Expand Down Expand Up @@ -56,6 +59,8 @@
"source": [
"# test a single image\n",
"img = 'demo.png'\n",
"if not torch.cuda.is_available():\n",
" model = revert_sync_batchnorm(model)\n",
"result = inference_model(model, img)"
]
},
Expand All @@ -66,7 +71,7 @@
"outputs": [],
"source": [
"# show the results\n",
"show_result_pyplot(model, img, result, get_palette('cityscapes'))"
"show_result_pyplot(model, img, result)"
]
},
{
Expand All @@ -79,7 +84,7 @@
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"display_name": "Python 3.10.4 ('pt1.11-v2')",
"language": "python",
"name": "python3"
},
Expand All @@ -93,7 +98,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.0"
"version": "3.10.4"
},
"pycharm": {
"stem_cell": {
Expand All @@ -103,6 +108,11 @@
},
"source": []
}
},
"vscode": {
"interpreter": {
"hash": "fdab7187f8cbd4ce42bbf864ddb4c4693e7329271a15a7fa96e4bdb82b9302c9"
}
}
},
"nbformat": 4,
Expand Down
17 changes: 10 additions & 7 deletions demo/video_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,11 @@
from argparse import ArgumentParser

import cv2
from mmengine.utils import revert_sync_batchnorm

from mmseg.apis import inference_model, init_model
from mmseg.utils import get_palette
from mmseg.apis.inference import show_result_pyplot
from mmseg.utils import register_all_modules


def main():
Expand Down Expand Up @@ -51,10 +53,16 @@ def main():
assert args.show or args.output_file, \
'At least one output should be enabled.'

register_all_modules()

# build the model from a config file and a checkpoint file
model = init_model(args.config, args.checkpoint, device=args.device)
if args.device == 'cpu':
model = revert_sync_batchnorm(model)

# build input video
if args.video.isdigit():
args.video = int(args.video)
cap = cv2.VideoCapture(args.video)
assert (cap.isOpened())
input_height = cap.get(cv2.CAP_PROP_FRAME_HEIGHT)
Expand Down Expand Up @@ -86,12 +94,7 @@ def main():
result = inference_model(model, frame)

# blend raw image and prediction
draw_img = model.show_result(
frame,
result,
palette=get_palette(args.palette),
show=False,
opacity=args.opacity)
draw_img = show_result_pyplot(model, frame, [result])

if args.show:
cv2.imshow('video_demo', draw_img)
Expand Down
160 changes: 91 additions & 69 deletions mmseg/apis/inference.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,18 @@
# Copyright (c) OpenMMLab. All rights reserved.
import matplotlib.pyplot as plt
from typing import Sequence, Union

import mmcv
import numpy as np
import torch
from mmcv.parallel import collate, scatter
from mmcv.runner import load_checkpoint
from mmengine import Config
from mmengine.dataset import Compose

from mmseg.datasets.transforms import Compose
from mmseg.models import build_segmentor
from mmseg.data import SegDataSample
from mmseg.models import BaseSegmentor
from mmseg.registry import MODELS
from mmseg.utils import SampleList
from mmseg.visualization import SegLocalVisualizer


def init_model(config, checkpoint=None, device='cuda:0'):
Expand All @@ -23,13 +29,13 @@ def init_model(config, checkpoint=None, device='cuda:0'):
nn.Module: The constructed segmentor.
"""
if isinstance(config, str):
config = mmcv.Config.fromfile(config)
config = Config.fromfile(config)
elif not isinstance(config, mmcv.Config):
raise TypeError('config must be a filename or Config object, '
'but got {}'.format(type(config)))
config.model.pretrained = None
config.model.train_cfg = None
model = build_segmentor(config.model, test_cfg=config.get('test_cfg'))
model = MODELS.build(config.model)
if checkpoint is not None:
checkpoint = load_checkpoint(model, checkpoint, map_location='cpu')
model.CLASSES = checkpoint['meta']['CLASSES']
Expand All @@ -40,34 +46,41 @@ def init_model(config, checkpoint=None, device='cuda:0'):
return model


class LoadImage:
"""A simple pipeline to load image."""
ImageType = Union[str, np.ndarray, Sequence[str], Sequence[np.ndarray]]


def _preprare_data(imgs: ImageType, model: BaseSegmentor):

def __call__(self, results):
"""Call function to load images into results.
cfg = model.cfg
if dict(type='LoadAnnotations') in cfg.test_pipeline:
cfg.test_pipeline.remove(dict(type='LoadAnnotations'))

is_batch = True
if not isinstance(imgs, (list, tuple)):
imgs = [imgs]
is_batch = False

Args:
results (dict): A result dict contains the file name
of the image to be read.
if isinstance(imgs[0], np.ndarray):
cfg.test_pipeline[0].type = 'LoadImageFromNDArray'

Returns:
dict: ``results`` will be returned containing loaded image.
"""
# TODO: Consider using the singleton pattern to avoid building
# a pipeline for each inference
pipeline = Compose(cfg.test_pipeline)

if isinstance(results['img'], str):
results['filename'] = results['img']
results['ori_filename'] = results['img']
data = []
for img in imgs:
if isinstance(img, np.ndarray):
data_ = dict(img=img)
else:
results['filename'] = None
results['ori_filename'] = None
img = mmcv.imread(results['img'])
results['img'] = img
results['img_shape'] = img.shape
results['ori_shape'] = img.shape
return results
data_ = dict(img_path=img)
data_ = pipeline(data_)
data.append(data_)

return data, is_batch


def inference_model(model, img):
def inference_model(model: BaseSegmentor,
img: ImageType) -> Union[SegDataSample, SampleList]:
"""Inference image(s) with the segmentor.

Args:
Expand All @@ -76,61 +89,70 @@ def inference_model(model, img):
images.

Returns:
(list[Tensor]): The segmentation result.
:obj:`SegDataSample` or list[:obj:`SegDataSample`]:
If imgs is a list or tuple, the same length list type results
will be returned, otherwise return the segmentation results directly.
"""
cfg = model.cfg
device = next(model.parameters()).device # model device
# build the data pipeline
test_pipeline = [LoadImage()] + cfg.data.test.pipeline[1:]
test_pipeline = Compose(test_pipeline)
# prepare data
data = dict(img=img)
data = test_pipeline(data)
data = collate([data], samples_per_gpu=1)
if next(model.parameters()).is_cuda:
# scatter to specified GPU
data = scatter(data, [device])[0]
else:
data['img_metas'] = [i.data[0] for i in data['img_metas']]
data, is_batch = _preprare_data(img, model)

# forward the model
with torch.no_grad():
result = model(return_loss=False, rescale=True, **data)
return result


def show_result_pyplot(model,
img,
result,
palette=None,
fig_size=(15, 10),
opacity=0.5,
title='',
block=True):
results = model.test_step(data)

return results if is_batch else results[0]


def show_result_pyplot(model: BaseSegmentor,
img: Union[str, np.ndarray],
result: SampleList,
opacity: float = 0.5,
title: str = '',
draw_gt: bool = True,
draw_pred: bool = True,
wait_time: float = 0,
Comment on lines +111 to +113
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Might add draw_gt, draw_pred, and wait_time in docstring

show: bool = True,
save_dir=None):
"""Visualize the segmentation results on the image.

Args:
model (nn.Module): The loaded segmentor.
img (str or np.ndarray): Image filename or loaded image.
result (list): The segmentation result.
palette (list[list[int]]] | None): The palette of segmentation
map. If None is given, random palette will be generated.
Default: None
fig_size (tuple): Figure size of the pyplot figure.
result (list): The prediction SegDataSample result.
opacity(float): Opacity of painted segmentation map.
Default 0.5.
Must be in (0, 1] range.
Default 0.5. Must be in (0, 1] range.
title (str): The title of pyplot figure.
Default is ''.
block (bool): Whether to block the pyplot figure.
Default is True.
draw_gt (bool): Whether to draw GT SegDataSample. Default to True.
draw_pred (bool): Whether to draw Prediction SegDataSample.
Defaults to True.
wait_time (float): The interval of show (s). Defaults to 0.
show (bool): Whether to display the drawn image.
Default to True.
save_dir (str, optional): Save file dir for all storage backends.
If it is None, the backend storage will not save any data.
"""
if hasattr(model, 'module'):
model = model.module
img = model.show_result(
img, result, palette=palette, show=False, opacity=opacity)
plt.figure(figsize=fig_size)
plt.imshow(mmcv.bgr2rgb(img))
plt.title(title)
plt.tight_layout()
plt.show(block=block)
if isinstance(img, str):
image = mmcv.imread(img)
else:
image = img
if save_dir is not None:
mmcv.mkdir_or_exist(save_dir)
# init visualizer
visualizer = SegLocalVisualizer(
vis_backends=[dict(type='LocalVisBackend')],
save_dir=save_dir,
alpha=opacity)
visualizer.dataset_meta = dict(
classes=model.CLASSES, palette=model.PALETTE)
visualizer.add_datasample(
name=title,
image=image,
pred_sample=result[0],
draw_gt=draw_gt,
draw_pred=draw_pred,
wait_time=wait_time,
show=show)
return visualizer.get_image()
3 changes: 1 addition & 2 deletions mmseg/engine/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,8 @@
from .hooks import SegVisualizationHook
from .optimizers import (LayerDecayOptimizerConstructor,
LearningRateDecayOptimizerConstructor)
from .visualization import SegLocalVisualizer

__all__ = [
'LearningRateDecayOptimizerConstructor', 'LayerDecayOptimizerConstructor',
'SegVisualizationHook', 'SegLocalVisualizer'
'SegVisualizationHook'
]
2 changes: 1 addition & 1 deletion mmseg/engine/hooks/visualization_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@
from mmengine.runner import Runner

from mmseg.data import SegDataSample
from mmseg.engine.visualization import SegLocalVisualizer
from mmseg.registry import HOOKS
from mmseg.visualization import SegLocalVisualizer


@HOOKS.register_module()
Expand Down
File renamed without changes.
Loading