Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] support pose tracking in webcam API #1393

Merged
merged 3 commits into from
May 25, 2022
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
13 changes: 12 additions & 1 deletion mmpose/apis/inference.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
import copy
import os
import warnings

Expand Down Expand Up @@ -109,7 +110,17 @@ def _inference_single_pose_model(model,
cfg.data.test.data_cfg.frame_weight_test)

# build the data pipeline
test_pipeline = Compose(cfg.test_pipeline)
_test_pipeline = copy.deepcopy(cfg.test_pipeline)

has_bbox_xywh2cs = False
for transform in _test_pipeline:
if transform['type'] == 'TopDownGetBboxCenterScale':
has_bbox_xywh2cs = True
break
if not has_bbox_xywh2cs:
_test_pipeline.insert(
0, dict(type='TopDownGetBboxCenterScale', padding=1.25))
test_pipeline = Compose(_test_pipeline)
_pipeline_gpu_speedup(test_pipeline, next(model.parameters()).device)

assert len(bboxes[0]) in [4, 5]
Expand Down
97 changes: 97 additions & 0 deletions tools/webcam/configs/examples/pose_tracking.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
# Copyright (c) OpenMMLab. All rights reserved.
runner = dict(
# Basic configurations of the runner
name='Pose Estimation',
camera_id=0,
camera_fps=20,
synchronous=False,
# Define nodes.
# The configuration of a node usually includes:
# 1. 'type': Node class name
# 2. 'name': Node name
# 3. I/O buffers (e.g. 'input_buffer', 'output_buffer'): specify the
# input and output buffer names. This may depend on the node class.
# 4. 'enable_key': assign a hot-key to toggle enable/disable this node.
# This may depend on the node class.
# 5. Other class-specific arguments
nodes=[
dict(
type='PoseTrackerNode',
name='PoseTracker',
det_model_config='demo/mmdetection_cfg/'
'ssdlite_mobilenetv2_scratch_600e_coco.py',
det_model_checkpoint='https://download.openmmlab.com'
'/mmdetection/v2.0/ssd/'
'ssdlite_mobilenetv2_scratch_600e_coco/ssdlite_mobilenetv2_'
'scratch_600e_coco_20210629_110627-974d9307.pth',
pose_model_config='configs/wholebody/2d_kpt_sview_rgb_img/'
'topdown_heatmap/coco-wholebody/'
'vipnas_mbv3_coco_wholebody_256x192_dark.py',
pose_model_checkpoint='https://download.openmmlab.com/mmpose/'
'top_down/vipnas/vipnas_mbv3_coco_wholebody_256x192_dark'
'-e2158108_20211205.pth',
det_interval=10,
cls_names=['person'],
smooth=True,
device='cuda:0',
input_buffer='_input_', # `_input_` is a runner-reserved buffer
output_buffer='human_pose'),
# 'ModelResultBindingNode':
# This node binds the latest model inference result with the current
# frame. (This means the frame image and inference result may be
# asynchronous).
dict(
type='ModelResultBindingNode',
name='ResultBinder',
frame_buffer='_frame_', # `_frame_` is a runner-reserved buffer
result_buffer='human_pose',
output_buffer='frame'),
# 'PoseVisualizerNode':
# This node draw the pose visualization result in the frame image.
# Pose results is needed.
dict(
type='PoseVisualizerNode',
name='Visualizer',
enable_key='v',
frame_buffer='frame',
output_buffer='vis'),
# 'NoticeBoardNode':
# This node show a notice board with given content, e.g. help
# information.
dict(
type='NoticeBoardNode',
name='Helper',
enable_key='h',
enable=True,
frame_buffer='vis',
output_buffer='vis_notice',
content_lines=[
'This is a demo for pose visualization and simple image '
'effects. Have fun!', '', 'Hot-keys:',
'"v": Pose estimation result visualization',
'"s": Sunglasses effect B-)', '"b": Bug-eye effect 0_0',
'"h": Show help information',
'"m": Show diagnostic information', '"q": Exit'
],
),
# 'MonitorNode':
# This node show diagnostic information in the frame image. It can
# be used for debugging or monitoring system resource status.
dict(
type='MonitorNode',
name='Monitor',
enable_key='m',
enable=False,
frame_buffer='vis_notice',
output_buffer='display'),
# 'RecorderNode':
# This node save the output video into a file.
dict(
type='RecorderNode',
name='Recorder',
out_video_file='record.mp4',
frame_buffer='display',
output_buffer='_display_'
# `_display_` is a runner-reserved buffer
)
])
4 changes: 3 additions & 1 deletion tools/webcam/webcam_apis/nodes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,14 @@
from .helper_node import ModelResultBindingNode, MonitorNode, RecorderNode
from .mmdet_node import DetectorNode
from .mmpose_node import TopDownPoseEstimatorNode
from .pose_tracker_node import PoseTrackerNode
from .valentinemagic_node import ValentineMagicNode
from .xdwendwen_node import XDwenDwenNode

__all__ = [
'NODES', 'PoseVisualizerNode', 'DetectorNode', 'TopDownPoseEstimatorNode',
'MonitorNode', 'BugEyeNode', 'SunglassesNode', 'ModelResultBindingNode',
'NoticeBoardNode', 'RecorderNode', 'FaceSwapNode', 'MoustacheNode',
'SaiyanNode', 'BackgroundNode', 'XDwenDwenNode', 'ValentineMagicNode'
'SaiyanNode', 'BackgroundNode', 'XDwenDwenNode', 'ValentineMagicNode',
'PoseTrackerNode'
]
2 changes: 1 addition & 1 deletion tools/webcam/webcam_apis/nodes/frame_drawing_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def __init__(self,
super().__init__(name=name, enable_key=enable_key)

# Register buffers
self.register_input_buffer(frame_buffer, 'frame', essential=True)
self.register_input_buffer(frame_buffer, 'frame', trigger=True)
self.register_output_buffer(output_buffer)

self._enabled = enable
Expand Down
22 changes: 11 additions & 11 deletions tools/webcam/webcam_apis/nodes/helper_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,10 @@ def __init__(self, name: str, frame_buffer: str, result_buffer: str,
self.result_lag = RunningAverage(window=10)

# Register buffers
# Note that essential buffers will be set in set_runner() because
# it depends on the runner.synchronous attribute.
self.register_input_buffer(result_buffer, 'result', essential=False)
self.register_input_buffer(frame_buffer, 'frame', essential=False)
# The trigger buffer depends on the runner.synchronous attribute, thus
# it will be set later in ``set_runner``.
self.register_input_buffer(result_buffer, 'result', trigger=False)
self.register_input_buffer(frame_buffer, 'frame', trigger=False)
self.register_output_buffer(output_buffer)

def set_runner(self, runner):
Expand All @@ -51,15 +51,15 @@ def set_runner(self, runner):
# Set synchronous according to the runner
if runner.synchronous:
self.synchronous = True
essential_input = 'result'
trigger = 'result'
else:
self.synchronous = False
essential_input = 'frame'
trigger = 'frame'

# Set essential input buffer according to the synchronous setting
# Set trigger input buffer according to the synchronous setting
for buffer_info in self._input_buffers:
if buffer_info.input_name == essential_input:
buffer_info.essential = True
if buffer_info.input_name == trigger:
buffer_info.trigger = True

def process(self, input_msgs):
result_msg = input_msgs['result']
Expand Down Expand Up @@ -146,7 +146,7 @@ def __init__(self,
else:
self.ignore_items = ignore_items

self.register_input_buffer(frame_buffer, 'frame', essential=True)
self.register_input_buffer(frame_buffer, 'frame', trigger=True)
self.register_output_buffer(output_buffer)

def process(self, input_msgs):
Expand Down Expand Up @@ -234,7 +234,7 @@ def __init__(
self.vwriter = None

# Register buffers
self.register_input_buffer(frame_buffer, 'frame', essential=True)
self.register_input_buffer(frame_buffer, 'frame', trigger=True)
self.register_output_buffer(output_buffer)

# Start a new thread to write frame
Expand Down
13 changes: 7 additions & 6 deletions tools/webcam/webcam_apis/nodes/mmdet_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,23 +21,24 @@ def __init__(self,
input_buffer: str,
output_buffer: Union[str, List[str]],
enable_key: Optional[Union[str, int]] = None,
enable: bool = True,
device: str = 'cuda:0'):
# Check mmdetection is installed
assert has_mmdet, 'Please install mmdet to run the demo.'
super().__init__(name=name, enable_key=enable_key, enable=True)
assert has_mmdet, \
f'MMDetection is required for {self.__class__.__name__}.'

super().__init__(name=name, enable_key=enable_key, enable=enable)

self.model_config = model_config
self.model_checkpoint = model_checkpoint
self.device = device.lower()

# Init model
self.model = init_detector(
self.model_config,
self.model_checkpoint,
device=self.device.lower())
self.model_config, self.model_checkpoint, device=self.device)

# Register buffers
self.register_input_buffer(input_buffer, 'input', essential=True)
self.register_input_buffer(input_buffer, 'input', trigger=True)
self.register_output_buffer(output_buffer)

def bypass(self, input_msgs):
Expand Down
11 changes: 4 additions & 7 deletions tools/webcam/webcam_apis/nodes/mmpose_node.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
# Copyright (c) OpenMMLab. All rights reserved.
from dataclasses import dataclass
from typing import Dict, List, Optional, Union
from typing import List, Optional, Union

from mmpose.apis import (get_track_id, inference_top_down_pose_model,
init_pose_model)
from mmpose.core import Smoother
from ..utils import Message
from .builder import NODES
from .node import Node

Expand Down Expand Up @@ -51,21 +50,19 @@ def __init__(
self.smoother = None
# Init model
self.model = init_pose_model(
self.model_config,
self.model_checkpoint,
device=self.device.lower())
self.model_config, self.model_checkpoint, device=self.device)

# Store history for pose tracking
self.track_info = TrackInfo()

# Register buffers
self.register_input_buffer(input_buffer, 'input', essential=True)
self.register_input_buffer(input_buffer, 'input', trigger=True)
self.register_output_buffer(output_buffer)

def bypass(self, input_msgs):
return input_msgs['input']

def process(self, input_msgs: Dict[str, Message]) -> Message:
def process(self, input_msgs):

input_msg = input_msgs['input']
img = input_msg.get_image()
Expand Down
16 changes: 8 additions & 8 deletions tools/webcam/webcam_apis/nodes/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ class BufferInfo():
"""Dataclass for buffer information."""
buffer_name: str
input_name: Optional[str] = None
essential: bool = False
trigger: bool = False


@dataclass
Expand Down Expand Up @@ -119,7 +119,7 @@ def _toggle_enable(self):
def register_input_buffer(self,
buffer_name: str,
input_name: str,
essential: bool = False):
trigger: bool = False):
"""Register an input buffer, so that Node can automatically check if
data is ready, fetch data from the buffers and format the inputs to
feed into `process` method.
Expand All @@ -134,12 +134,12 @@ def register_input_buffer(self,
buffer_name (str): The name of the buffer
input_name (str): The name of the fetched message from the
corresponding buffer
essential (bool): An essential input means the node will wait
trigger (bool): An trigger input means the node will wait
until the input is ready before processing. Otherwise, an
inessential input will not block the processing, instead
a None will be fetched if the buffer is not ready.
"""
buffer_info = BufferInfo(buffer_name, input_name, essential)
buffer_info = BufferInfo(buffer_name, input_name, trigger)
self._input_buffers.append(buffer_info)

def register_output_buffer(self, buffer_name: Union[str, List[str]]):
Expand Down Expand Up @@ -213,9 +213,9 @@ def _get_input_from_buffer(self) -> Tuple[bool, Optional[Dict]]:
if buffer_manager is None:
raise ValueError(f'{self.name}: Runner not set!')

# Check that essential buffers are ready
# Check that trigger buffers are ready
for buffer_info in self._input_buffers:
if buffer_info.essential and buffer_manager.is_empty(
if buffer_info.trigger and buffer_manager.is_empty(
buffer_info.buffer_name):
return False, None

Expand All @@ -230,9 +230,9 @@ def _get_input_from_buffer(self) -> Tuple[bool, Optional[Dict]]:
result[buffer_info.input_name] = buffer_manager.get(
buffer_info.buffer_name, block=False)
except Empty:
if buffer_info.essential:
if buffer_info.trigger:
# Return unsuccessful flag if any
# essential input is unready
# trigger input is unready
return False, None

return True, result
Expand Down