Skip to content

Commit

Permalink
build dataset from video
Browse files Browse the repository at this point in the history
  • Loading branch information
yysijie committed Oct 3, 2019
1 parent 34980a5 commit 670a454
Show file tree
Hide file tree
Showing 13 changed files with 236 additions and 14 deletions.
2 changes: 1 addition & 1 deletion configs/pose_estimation/pose_demo.yaml
Expand Up @@ -2,7 +2,7 @@ processor_cfg:
name: ".processor.pose_demo.inference"
gpus: 1
worker_per_gpu: 2
video_file: resource/media/skateboarding.mp4
video_file: resource/data_example/skateboarding.mp4
save_dir: "work_dir/pose_demo"

detection_cfg:
Expand Down
2 changes: 1 addition & 1 deletion configs/pose_estimation/pose_demo_HD.yaml
Expand Up @@ -2,7 +2,7 @@ processor_cfg:
name: ".processor.pose_demo.inference"
gpus: 1
worker_per_gpu: 1
video_file: resource/media/skateboarding.mp4
video_file: resource/data_example/skateboarding.mp4
save_dir: "work_dir/pose_demo_HD"

detection_cfg:
Expand Down
14 changes: 13 additions & 1 deletion configs/recognition/recognition_demo.yaml
Expand Up @@ -2,13 +2,14 @@ processor_cfg:
name: ".processor.recognition_demo.inference"
gpus: 1
worker_per_gpu: 2
video_file: resource/media/skateboarding.mp4
video_file: resource/data_example/skateboarding.mp4
save_dir: "work_dir/recognition_demo"

detection_cfg:
model_cfg: configs/mmdet/cascade_rcnn_r50_fpn_1x.py
checkpoint_file: mmskeleton://mmdet/cascade_rcnn_r50_fpn_20e
bbox_thre: 0.8

estimation_cfg:
model_cfg: configs/pose_estimation/hrnet/pose_hrnet_w32_256x192_test.yaml
checkpoint_file: mmskeleton://pose_estimation/pose_hrnet_w32_256x192
Expand All @@ -27,6 +28,17 @@ processor_cfg:
- 0.225
post_process: true

recognition_cfg:
checkpoint_file: mmskeleton://st_gcn/kinetics-skeleton
model_cfg:
name: ".models.backbones.ST_GCN"
in_channels: 3
num_class: 400
edge_importance_weighting: True
graph_cfg:
layout: "openpose"
strategy: "spatial"

argparse_cfg:
gpus:
bind_to: processor_cfg.gpus
Expand Down
48 changes: 48 additions & 0 deletions configs/utils/build_skeleton_dataset.yaml
@@ -0,0 +1,48 @@
processor_cfg:
name: ".processor.skeleton_dataset.build"
gpus: 1
worker_per_gpu: 2
video_dir: resource/data_example
out_dir: "data/dataset_example"
category_annotation: resource/category_annotation_example.json

detection_cfg:
model_cfg: configs/mmdet/cascade_rcnn_r50_fpn_1x.py
checkpoint_file: mmskeleton://mmdet/cascade_rcnn_r50_fpn_20e
bbox_thre: 0.8
estimation_cfg:
model_cfg: configs/pose_estimation/hrnet/pose_hrnet_w32_256x192_test.yaml
checkpoint_file: mmskeleton://pose_estimation/pose_hrnet_w32_256x192
data_cfg:
image_size:
- 192
- 256
pixel_std: 200
image_mean:
- 0.485
- 0.456
- 0.406
image_std:
- 0.229
- 0.224
- 0.225
post_process: true

argparse_cfg:
gpus:
bind_to: processor_cfg.gpus
help: number of gpus
video_dir:
bind_to: processor_cfg.video_dir
help: folder for videos
worker_per_gpu:
bind_to: processor_cfg.worker_per_gpu
help: number of workers for each gpu
skeleton_model:
bind_to: processor_cfg.estimation_cfg.model_cfg
skeleton_checkpoint:
bind_to: processor_cfg.estimation_cfg.checkpoint_file
detection_model:
bind_to: processor_cfg.detection_cfg.model_cfg
detection_checkpoint:
bind_to: processor_cfg.detection_cfg.checkpoint_file
4 changes: 2 additions & 2 deletions mmskeleton/apis/estimation.py
Expand Up @@ -50,8 +50,8 @@ def inference_pose_estimator(pose_estimator, image):
has_return = False
preds, maxvals, meta = None, None, None

result = dict(position_preds=preds,
position_maxvals=maxvals,
result = dict(joint_preds=preds,
joint_scores=maxvals,
meta=meta,
has_return=has_return,
person_bbox=person_bbox)
Expand Down
5 changes: 3 additions & 2 deletions mmskeleton/processor/pose_demo.py
Expand Up @@ -42,6 +42,7 @@ def render(image, pred, person_bbox, bbox_thre):

def worker(inputs, results, gpu, detection_cfg, estimation_cfg, render_image):
worker_id = current_process()._identity[0] - 1
global pose_estimators
if worker_id not in pose_estimators:
pose_estimators[worker_id] = init_pose_estimator(detection_cfg,
estimation_cfg,
Expand All @@ -56,7 +57,7 @@ def worker(inputs, results, gpu, detection_cfg, estimation_cfg, render_image):
res['frame_index'] = idx

if render_image:
res['render_image'] = render(image, res['position_preds'],
res['render_image'] = render(image, res['joint_preds'],
res['person_bbox'],
detection_cfg.bbox_thre)
results.put(res)
Expand All @@ -81,7 +82,7 @@ def inference(detection_cfg,
res = inference_pose_estimator(model, image)
res['frame_index'] = i
if save_dir is not None:
res['render_image'] = render(image, res['position_preds'],
res['render_image'] = render(image, res['joint_preds'],
res['person_bbox'],
detection_cfg.bbox_thre)
all_result.append(res)
Expand Down
4 changes: 0 additions & 4 deletions mmskeleton/processor/recognition.py
Expand Up @@ -136,7 +136,3 @@ def weights_init(model):
elif classname.find('BatchNorm') != -1:
model.weight.data.normal_(1.0, 0.02)
model.bias.data.fill_(0)


def demo(pose_estimation_cfg):
pass
30 changes: 27 additions & 3 deletions mmskeleton/processor/recognition_demo.py
Expand Up @@ -7,16 +7,40 @@
from time import time
from mmcv.utils import ProgressBar
from .pose_demo import inference as pose_inference
from mmskeleton.utils import call_obj, load_checkpoint


def init_recognizer(recognition_cfg, device):
model = call_obj(**(recognition_cfg.model_cfg))
load_checkpoint(model,
recognition_cfg.checkpoint_file,
map_location=device)
return model


def inference(detection_cfg,
estimation_cfg,
recognition_cfg,
video_file,
gpus=1,
worker_per_gpu=1,
save_dir=None):

pose = pose_inference(detection_cfg, estimation_cfg, video_file, gpus,
worker_per_gpu)
recognizer = init_recognizer(recognition_cfg, 0)
# import IPython
# IPython.embed()
resolution = mmcv.VideoReader(video_file).resolution
results = pose_inference(detection_cfg, estimation_cfg, video_file, gpus,
worker_per_gpu)

seq = np.zeros((1, 3, len(results), 17, 1))
for i, r in enumerate(results):
if r['joint_preds'] is not None:
seq[0, 0, i, :, 0] = r['joint_preds'][0, :, 0] / resolution[0]
seq[0, 1, i, :, 0] = r['joint_preds'][0, :, 1] / resolution[1]
seq[0, 2, i, :, 0] = r['joint_scores'][0, :, 0]

import IPython
IPython.embed()

return pose
return results
123 changes: 123 additions & 0 deletions mmskeleton/processor/skeleton_dataset.py
@@ -0,0 +1,123 @@
import os
import json
import mmcv
import numpy as np
import ntpath
from mmskeleton.apis.estimation import init_pose_estimator, inference_pose_estimator
from multiprocessing import current_process, Process, Manager
from mmskeleton.utils import cache_checkpoint
from mmskeleton.processor.apis import save_batch_image_with_joints
from mmcv.utils import ProgressBar

pose_estimators = dict()


def worker(inputs, results, gpu, detection_cfg, estimation_cfg):
worker_id = current_process()._identity[0] - 1
global pose_estimators
if worker_id not in pose_estimators:
pose_estimators[worker_id] = init_pose_estimator(detection_cfg,
estimation_cfg,
device=gpu)
while True:
idx, image = inputs.get()

# end signal
if image is None:
return

res = inference_pose_estimator(pose_estimators[worker_id], image)
res['frame_index'] = idx
results.put(res)


def build(detection_cfg,
estimation_cfg,
video_dir,
out_dir,
gpus=1,
worker_per_gpu=1,
video_max_length=10000,
category_annotation=None):

cache_checkpoint(detection_cfg.checkpoint_file)
cache_checkpoint(estimation_cfg.checkpoint_file)
if not os.path.isdir(out_dir):
os.makedirs(out_dir)

if category_annotation is None:
video_categories = dict()
else:
with open(category_annotation) as f:
video_categories = json.load(f)['annotations']

inputs = Manager().Queue(video_max_length)
results = Manager().Queue(video_max_length)

num_worker = gpus * worker_per_gpu
procs = []
for i in range(num_worker):
p = Process(target=worker,
args=(inputs, results, i % gpus, detection_cfg,
estimation_cfg))
procs.append(p)
p.start()

video_file_list = os.listdir(video_dir)
prog_bar = ProgressBar(len(video_file_list))
for video_file in video_file_list:

reader = mmcv.VideoReader(os.path.join(video_dir, video_file))
video_frames = reader[:video_max_length]
annotations = []

for i, image in enumerate(video_frames):
inputs.put((i, image))

for i in range(len(video_frames)):
t = results.get()
if not t['has_return']:
continue

num_person = len(t['joint_preds'])
assert len(t['person_bbox']) == num_person

for j in range(num_person):
keypoints = np.concatenate(
(t['joint_preds'][j], t['joint_scores'][j]), 1)
keypoints = keypoints.reshape(-1).round().astype(int).tolist()

person_info = dict(person_bbox=t['person_bbox']
[j].round().astype(int).tolist(),
frame_index=t['frame_index'],
id=j,
person_id=None,
keypoints=keypoints,
num_keypoints=t['joint_scores'].shape[1])

annotations.append(person_info)

# output results
annotations = sorted(annotations, key=lambda x: x['frame_index'])
category_id = video_categories[video_file][
'category_id'] if video_file in video_categories else -1
info = dict(video_name=video_file,
resolution=reader.resolution,
version='1.0')
video_info = dict(info=info,
category_id=category_id,
annotations=annotations)
with open(os.path.join(out_dir, video_file + '.json'), 'w') as f:
json.dump(video_info, f)

prog_bar.update()

# send end signals
for p in procs:
inputs.put((-1, None))
# wait to finish
for p in procs:
p.join()

print('\nBuild skeleton dataset to {}.'.format(out_dir))
return video_info
18 changes: 18 additions & 0 deletions resource/category_annotation_example.json
@@ -0,0 +1,18 @@
{
"categories": [
"skateboarding",
"clean_and_jerk",
"ta_chi"
],
"annotations": {
"clean_and_jerk.mp4": {
"category_id": 1
},
"skateboarding.mp4": {
"category_id": 0
},
"ta_chi.mp4": {
"category_id": 2
}
}
}
File renamed without changes.
File renamed without changes.
File renamed without changes.

0 comments on commit 670a454

Please sign in to comment.