In [1]:
# git clone 

In [2]:
import os
import warnings

import cv2
import mmcv

from mmpose.apis import (
    collect_multi_frames,
    inference_top_down_pose_model,
    init_pose_model,
    process_mmdet_results,
    vis_pose_result,
)
from mmpose.datasets import DatasetInfo

try:
    from mmdet.apis import inference_detector, init_detector

    has_mmdet = True
except (ImportError, ModuleNotFoundError):
    has_mmdet = False

  from .autonotebook import tqdm as notebook_tqdm


## Config

In [3]:
# set detection config
det_config = "../mmdetection_cfg/faster_rcnn_r50_fpn_coco.py"
det_checkpoint = "https://download.openmmlab.com/mmdetection/v2.0/faster_rcnn/faster_rcnn_r50_fpn_2x_coco/faster_rcnn_r50_fpn_2x_coco_bbox_mAP-0.384_20200504_210434-a5d8aa15.pth"

# set pose config
pose_config = "../../configs/animal/2d_kpt_sview_rgb_img/topdown_heatmap/horse10/res50_horse10_256x256-split1.py"
pose_checkpoint = "https://download.openmmlab.com/mmpose/animal/resnet/res50_horse10_256x256_split1-3a3dc37e_20210405.pth"

In [4]:
# set data path
!mkdir data output
# video_path = "./data/clip.mp4"
video_path = "/home/yata/Videos/ooi/IMG_2437.MOV"
out_video_root = "./output/"

mkdir: cannot create directory ‘data’: File exists
mkdir: cannot create directory ‘output’: File exists


In [5]:
# set params
# bbox_thr = 0.1 # box thr
# kpt_thr = 0.4 # keypint thr
bbox_thr = 0.4 # box thr
kpt_thr = 0.4 # keypint thr
det_cat_id = 18
radius = 6
thickness = 1

show = False
use_multi_frames = False
online = False

## Build model

In [6]:
print("Initializing model...")
# build the detection model from a config file and a checkpoint file
det_model = init_detector(det_config, det_checkpoint, device="cuda:0")

Initializing model...
load checkpoint from http path: https://download.openmmlab.com/mmdetection/v2.0/faster_rcnn/faster_rcnn_r50_fpn_2x_coco/faster_rcnn_r50_fpn_2x_coco_bbox_mAP-0.384_20200504_210434-a5d8aa15.pth


In [7]:
# build the pose model from a config file and a checkpoint file
pose_model = init_pose_model(pose_config, pose_checkpoint, device="cuda:0")

load checkpoint from http path: https://download.openmmlab.com/mmpose/animal/resnet/res50_horse10_256x256_split1-3a3dc37e_20210405.pth


## Load data

In [8]:
dataset = pose_model.cfg.data["test"]["type"]
print(dataset)

AnimalHorse10Dataset


In [9]:
# get datasetinfo
dataset_info = pose_model.cfg.data["test"].get("dataset_info", None)
if dataset_info is None:
    warnings.warn(
        "Please set `dataset_info` in the config."
        "Check https://github.com/open-mmlab/mmpose/pull/663 for details.",
        DeprecationWarning,
    )
else:
    dataset_info = DatasetInfo(dataset_info)

In [10]:
# read video
video = mmcv.VideoReader(video_path)
assert video.opened, f"Faild to load video file {video_path}"

if out_video_root == "":
    save_out_video = False
else:
    os.makedirs(out_video_root, exist_ok=True)
    save_out_video = True

if save_out_video:
    fps = video.fps
    size = (video.width, video.height)
    fourcc = cv2.VideoWriter_fourcc(*"mp4v")
    videoWriter = cv2.VideoWriter(
        os.path.join(out_video_root, f"vis_{os.path.basename(video_path)}"),
        fourcc,
        fps,
        size,
    )

In [11]:
# マルチフレームとは？
# frame index offsets for inference, used in multi-frame inference setting
if use_multi_frames:
    assert "frame_indices_test" in pose_model.cfg.data.test.data_cfg
    indices = pose_model.cfg.data.test.data_cfg["frame_indices_test"]

In [12]:
# heatmapとは？
# whether to return heatmap, optional
# return_heatmap = False
return_heatmap = True

# return the output of some desired layers,
# e.g. use ('backbone', ) to return backbone feature
output_layer_names = None

In [13]:
# obtain basic information
print(len(video))
print(video.width, video.height, video.resolution, video.fps)

# shape: 1フレーム width x height x 3

175
3840 2160 (3840, 2160) 30.0


In [14]:
key_points = []
print("Running inference...")
for frame_id, cur_frame in enumerate(mmcv.track_iter_progress(video)):
    print(frame_id)
    # get the detection results of current frame
    # the resulting box is (x1, y1, x2, y2)
    mmdet_results = inference_detector(det_model, cur_frame)

    # keep the person class bounding boxes.
    person_results = process_mmdet_results(mmdet_results, det_cat_id)

    if use_multi_frames:
        frames = collect_multi_frames(video, frame_id, indices, online)

    # test a single image, with a list of bboxes.
    pose_results, returned_outputs = inference_top_down_pose_model(
        pose_model,
        frames if use_multi_frames else cur_frame,
        person_results,
        bbox_thr=bbox_thr,
        format="xyxy",
        dataset=dataset,
        dataset_info=dataset_info,
        return_heatmap=return_heatmap,
        outputs=output_layer_names,
    )
    key_points.append(pose_results)

    # show the results
    vis_frame = vis_pose_result(
        pose_model,
        cur_frame,
        pose_results,
        dataset=dataset,
        dataset_info=dataset_info,
        kpt_score_thr=kpt_thr,
        radius=radius,
        thickness=thickness,
        show=False,
    )

    if show:
        cv2.imshow("Frame", vis_frame)

    if save_out_video:
        videoWriter.write(vis_frame)

    if show and cv2.waitKey(1) & 0xFF == ord("q"):
        break

if save_out_video:
    videoWriter.release()
if show:
    cv2.destroyAllWindows()

Running inference...
[                                                  ] 0/175, elapsed: 0s, ETA:0
[                                 ] 1/175, 0.8 task/s, elapsed: 1s, ETA:   206s1
[                                 ] 2/175, 1.5 task/s, elapsed: 1s, ETA:   119s2
[                                 ] 3/175, 1.9 task/s, elapsed: 2s, ETA:    90s3
[                                 ] 4/175, 2.3 task/s, elapsed: 2s, ETA:    75s4
[                                 ] 5/175, 2.5 task/s, elapsed: 2s, ETA:    67s5
[>                                ] 6/175, 2.8 task/s, elapsed: 2s, ETA:    61s6
[>                                ] 7/175, 2.9 task/s, elapsed: 2s, ETA:    57s7
[>                                ] 8/175, 3.1 task/s, elapsed: 3s, ETA:    55s8
[>                                ] 9/175, 3.2 task/s, elapsed: 3s, ETA:    52s9
[>                               ] 10/175, 3.3 task/s, elapsed: 3s, ETA:    50s10
[>>                              ] 11/175, 3.4 task/s, elapsed: 3s, ETA:    48s11
[>>    

In [17]:
import pickle

with open("./output/keypoint.pickle", mode='wb') as f:
    pickle.dump(key_points, f)
