Skip to content

Commit

Permalink
Visualize results on image demo (open-mmlab#58)
Browse files Browse the repository at this point in the history
* visualize results on image demo

* add matplotlib in requirements
  • Loading branch information
yl-1993 committed Oct 10, 2020
1 parent 9282d3a commit 9b425aa
Show file tree
Hide file tree
Showing 5 changed files with 87 additions and 7 deletions.
6 changes: 3 additions & 3 deletions demo/image_demo.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from argparse import ArgumentParser

from mmcls.apis import inference_model, init_model
from mmcls.apis import inference_model, init_model, show_result_pyplot


def main():
Expand All @@ -16,8 +16,8 @@ def main():
model = init_model(args.config, args.checkpoint, device=args.device)
# test a single image
result = inference_model(model, args.img)
# print result on terminal
print(result)
# show the results
show_result_pyplot(model, args.img, result)


if __name__ == '__main__':
Expand Down
4 changes: 2 additions & 2 deletions mmcls/apis/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from .inference import inference_model, init_model
from .inference import inference_model, init_model, show_result_pyplot
from .test import multi_gpu_test, single_gpu_test
from .train import set_random_seed, train_model

__all__ = [
'set_random_seed', 'train_model', 'init_model', 'inference_model',
'multi_gpu_test', 'single_gpu_test'
'multi_gpu_test', 'single_gpu_test', 'show_result_pyplot'
]
22 changes: 20 additions & 2 deletions mmcls/apis/inference.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import warnings

import matplotlib.pyplot as plt
import mmcv
import numpy as np
import torch
Expand Down Expand Up @@ -74,6 +75,23 @@ def inference_model(model, img):
scores = model(return_loss=False, **data)
pred_score = np.max(scores, axis=1)[0]
pred_label = np.argmax(scores, axis=1)[0]
result = {'pred_label': pred_label, 'pred_score': pred_score}
result['class_name'] = model.CLASSES[result['pred_label']]
result = {'pred_label': pred_label, 'pred_score': float(pred_score)}
result['pred_class'] = model.CLASSES[result['pred_label']]
return result


def show_result_pyplot(model, img, result, fig_size=(15, 10)):
"""Visualize the classification results on the image.
Args:
model (nn.Module): The loaded classifier.
img (str or np.ndarray): Image filename or loaded image.
result (list): The classification result.
fig_size (tuple): Figure size of the pyplot figure.
"""
if hasattr(model, 'module'):
model = model.module
img = model.show_result(img, result, show=False)
plt.figure(figsize=fig_size)
plt.imshow(mmcv.bgr2rgb(img))
plt.show()
61 changes: 61 additions & 0 deletions mmcls/models/classifiers/base.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
import warnings
from abc import ABCMeta, abstractmethod
from collections import OrderedDict

import cv2
import mmcv
import torch
import torch.distributed as dist
import torch.nn as nn
from mmcv import color_val
from mmcv.utils import print_log


Expand Down Expand Up @@ -155,3 +159,60 @@ def val_step(self, data, optimizer):
loss=loss, log_vars=log_vars, num_samples=len(data['img'].data))

return outputs

def show_result(self,
img,
result,
text_color='green',
font_scale=0.5,
row_width=20,
show=False,
win_name='',
wait_time=0,
out_file=None):
"""Draw `result` over `img`.
Args:
img (str or Tensor): The image to be displayed.
result (Tensor): The classification results to draw over `img`.
text_color (str or tuple or :obj:`Color`): Color of texts.
font_scale (float): Font scales of texts.
row_width (int): width between each row of results on the image.
show (bool): Whether to show the image.
Default: False.
win_name (str): The window name.
wait_time (int): Value of waitKey param.
Default: 0.
out_file (str or None): The filename to write the image.
Default: None.
Returns:
img (Tensor): Only if not `show` or `out_file`
"""
img = mmcv.imread(img)
img = img.copy()

# write results on left-top of the image
x, y = 0, row_width
text_color = color_val(text_color)
for k, v in result.items():
if isinstance(v, float):
v = f'{v:.2f}'
label_text = f'{k}: {v}'
cv2.putText(img, label_text, (x, y), cv2.FONT_HERSHEY_COMPLEX,
font_scale, text_color)
y += row_width

# if out_file specified, do not show image in window
if out_file is not None:
show = False

if show:
mmcv.imshow(img, win_name, wait_time)
if out_file is not None:
mmcv.imwrite(img, out_file)

if not (show or out_file):
warnings.warn('show==False and out_file is not specified, only '
'result image will be returned')
return img
1 change: 1 addition & 0 deletions requirements/runtime.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
mmcv
numpy
matplotlib

0 comments on commit 9b425aa

Please sign in to comment.