Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Visualize results on image demo #58

Merged
merged 2 commits into from
Oct 10, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions demo/image_demo.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from argparse import ArgumentParser

from mmcls.apis import inference_model, init_model
from mmcls.apis import inference_model, init_model, show_result_pyplot


def main():
Expand All @@ -16,8 +16,8 @@ def main():
model = init_model(args.config, args.checkpoint, device=args.device)
# test a single image
result = inference_model(model, args.img)
# print result on terminal
print(result)
# show the results
show_result_pyplot(model, args.img, result)


if __name__ == '__main__':
Expand Down
4 changes: 2 additions & 2 deletions mmcls/apis/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from .inference import inference_model, init_model
from .inference import inference_model, init_model, show_result_pyplot
from .test import multi_gpu_test, single_gpu_test
from .train import set_random_seed, train_model

__all__ = [
'set_random_seed', 'train_model', 'init_model', 'inference_model',
'multi_gpu_test', 'single_gpu_test'
'multi_gpu_test', 'single_gpu_test', 'show_result_pyplot'
]
22 changes: 20 additions & 2 deletions mmcls/apis/inference.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import warnings

import matplotlib.pyplot as plt
import mmcv
import numpy as np
import torch
Expand Down Expand Up @@ -74,6 +75,23 @@ def inference_model(model, img):
scores = model(return_loss=False, **data)
pred_score = np.max(scores, axis=1)[0]
pred_label = np.argmax(scores, axis=1)[0]
result = {'pred_label': pred_label, 'pred_score': pred_score}
result['class_name'] = model.CLASSES[result['pred_label']]
result = {'pred_label': pred_label, 'pred_score': float(pred_score)}
result['pred_class'] = model.CLASSES[result['pred_label']]
return result


def show_result_pyplot(model, img, result, fig_size=(15, 10)):
"""Visualize the classification results on the image.

Args:
model (nn.Module): The loaded classifier.
img (str or np.ndarray): Image filename or loaded image.
result (list): The classification result.
fig_size (tuple): Figure size of the pyplot figure.
"""
if hasattr(model, 'module'):
model = model.module
img = model.show_result(img, result, show=False)
plt.figure(figsize=fig_size)
plt.imshow(mmcv.bgr2rgb(img))
plt.show()
61 changes: 61 additions & 0 deletions mmcls/models/classifiers/base.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
import warnings
from abc import ABCMeta, abstractmethod
from collections import OrderedDict

import cv2
import mmcv
import torch
import torch.distributed as dist
import torch.nn as nn
from mmcv import color_val
from mmcv.utils import print_log


Expand Down Expand Up @@ -155,3 +159,60 @@ def val_step(self, data, optimizer):
loss=loss, log_vars=log_vars, num_samples=len(data['img'].data))

return outputs

def show_result(self,
img,
result,
text_color='green',
font_scale=0.5,
row_width=20,
show=False,
win_name='',
wait_time=0,
out_file=None):
"""Draw `result` over `img`.

Args:
img (str or Tensor): The image to be displayed.
result (Tensor): The classification results to draw over `img`.
text_color (str or tuple or :obj:`Color`): Color of texts.
font_scale (float): Font scales of texts.
row_width (int): width between each row of results on the image.
show (bool): Whether to show the image.
Default: False.
win_name (str): The window name.
wait_time (int): Value of waitKey param.
Default: 0.
out_file (str or None): The filename to write the image.
Default: None.

Returns:
img (Tensor): Only if not `show` or `out_file`
"""
img = mmcv.imread(img)
img = img.copy()

# write results on left-top of the image
x, y = 0, row_width
text_color = color_val(text_color)
for k, v in result.items():
if isinstance(v, float):
v = f'{v:.2f}'
label_text = f'{k}: {v}'
cv2.putText(img, label_text, (x, y), cv2.FONT_HERSHEY_COMPLEX,
font_scale, text_color)
y += row_width

# if out_file specified, do not show image in window
if out_file is not None:
show = False

if show:
mmcv.imshow(img, win_name, wait_time)
if out_file is not None:
mmcv.imwrite(img, out_file)

if not (show or out_file):
warnings.warn('show==False and out_file is not specified, only '
'result image will be returned')
return img
1 change: 1 addition & 0 deletions requirements/runtime.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
mmcv
numpy
matplotlib