Skip to content
This repository has been archived by the owner on Oct 31, 2023. It is now read-only.

Commit

Permalink
Add Detectron2 support for inference in the wild
Browse files Browse the repository at this point in the history
  • Loading branch information
dariopavllo committed Jul 31, 2020
1 parent af776fb commit c4675a1
Show file tree
Hide file tree
Showing 2 changed files with 178 additions and 3 deletions.
29 changes: 26 additions & 3 deletions INFERENCE.md
@@ -1,14 +1,17 @@
# Inference in the wild

**Update:** we have added support for Detectron2.

In this short tutorial, we show how to run our model on arbitrary videos and visualize the predictions. Note that this feature is only provided for experimentation/research purposes and presents some limitations, as this repository is meant to provide a reference implementation of the approach described in the paper (not production-ready code for inference in the wild).

Our script assumes that a video depicts *exactly* one person. In case of multiple people visible at once, the script will select the person corresponding to the bounding box with the highest confidence, which may cause glitches.

The instructions below show how to use Detectron to infer 2D keypoints from videos, convert them to a custom dataset for our code, and infer 3D poses. For now, we do not have instructions for CPN. In the last section of this tutorial, we also provide some tips.

## Step 1: setup
Set up [Detectron](https://github.com/facebookresearch/Detectron) and copy the script `inference/infer_video.py` from this repo to the `tools` directory of the Detectron repo. This script, which requires `ffmpeg` in your system, provides a convenient interface to generate 2D keypoint predictions from videos without manually extracting individual frames.
The inference script requires `ffmpeg`, which you can easily install via conda, pip, or manually.

Next, download the [pretrained model](https://dl.fbaipublicfiles.com/video-pose-3d/pretrained_h36m_detectron_coco.bin) for generating 3D predictions. This model is different than the pretrained ones listed in the main README, as it expects input keypoints in COCO format (generated by the pretrained Detectron model) and outputs 3D joint positions in Human3.6M format. Put this model in the `checkpoint` directory of this repo.
Download the [pretrained model](https://dl.fbaipublicfiles.com/video-pose-3d/pretrained_h36m_detectron_coco.bin) for generating 3D predictions. This model is different than the pretrained ones listed in the main README, as it expects input keypoints in COCO format (generated by the pretrained Detectron model) and outputs 3D joint positions in Human3.6M format. Put this model in the `checkpoint` directory of this repo.

**Note:** if you had downloaded `d-pt-243.bin`, you should download the new pretrained model using the link above. `d-pt-243.bin` takes the keypoint probabilities as input (in addition to the x, y coordinates), which causes problems on videos with a different resolution than that of Human3.6M. The new model is only trained on 2D coordinates and works with any resolution/aspect ratio.

Expand All @@ -25,6 +28,26 @@ ffmpeg -i input.mp4 -filter "minterpolate='fps=50'" -crf 0 output.mp4
```

## Step 3: inferring 2D keypoints with Detectron

### Using Detectron2 (new)
Set up [Detectron2](https://github.com/facebookresearch/detectron2) and use the script `inference/infer_video_d2.py` (no need to copy this, as it directly uses the Detectron2 API). This script provides a convenient interface to generate 2D keypoint predictions from videos without manually extracting individual frames.

To infer keypoints from all the mp4 videos in `input_directory`, run
```
cd inference
python infer_video_d2.py \
--cfg COCO-Keypoints/keypoint_rcnn_R_101_FPN_3x.yaml \
--output-dir output_directory \
--image-ext mp4 \
input_directory
```
The results will be exported to `output_directory` as custom NumPy archives (`.npz` files). You can change the video extension in `--image-ext` (ffmpeg supports a wide range of formats).

**Note:** although the architecture is the same (ResNet-101), the weights used by the Detectron2 model are not the same as those used by Detectron1. Since our pretrained model was trained on Detectron1 poses, the result might be slightly different (but it should still be pretty close).

### Using Detectron1 (old instructions)
Set up [Detectron](https://github.com/facebookresearch/Detectron) and copy the script `inference/infer_video.py` from this repo to the `tools` directory of the Detectron repo. This script provides a convenient interface to generate 2D keypoint predictions from videos without manually extracting individual frames.

Our Detectron script `infer_video.py` is a simple adaptation of `infer_simple.py` (which works on images) and has a similar command-line syntax.

To infer keypoints from all the mp4 videos in `input_directory`, run
Expand Down Expand Up @@ -57,6 +80,6 @@ You can also export the 3D joint positions (in camera space) to a NumPy archive.

## Limitations and tips
- The model was trained on Human3.6M cameras (which are relatively undistorted), and the results may be bad if the intrinsic parameters of the cameras of your videos differ much from those of Human3.6M. This may be particularly noticeable with fisheye cameras, which present a high degree of non-linear lens distortion. If the camera parameters are known, consider preprocessing your videos to match those of Human3.6M as closely as possible.
- If you want multi-person tracking, you should implement a bounding box matching strategy. An example would be to use bipartite matching on the bounding box overlap (IoU) between subsequent frames, but there many other approaches.
- If you want multi-person tracking, you should implement a bounding box matching strategy. An example would be to use bipartite matching on the bounding box overlap (IoU) between subsequent frames, but there are many other approaches.
- Predictions are relative to the root joint, i.e. the global trajectory is not regressed. If you need it, you may want to use another model to regress it, such as the one we use for semi-supervision.
- Predictions are always in *camera space* (regardless of whether the trajectory is available). For our visualization script, we simply take a random camera from Human3.6M, which fits decently most videos where the camera viewport is parallel to the ground.
152 changes: 152 additions & 0 deletions inference/infer_video_d2.py
@@ -0,0 +1,152 @@
# Copyright (c) 2018-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
#

"""Perform inference on a single video or all videos with a certain extension
(e.g., .mp4) in a folder.
"""

import detectron2
from detectron2.utils.logger import setup_logger
from detectron2.config import get_cfg
from detectron2 import model_zoo
from detectron2.engine import DefaultPredictor

import subprocess as sp
import numpy as np
import time
import argparse
import sys
import os
import glob

def parse_args():
parser = argparse.ArgumentParser(description='End-to-end inference')
parser.add_argument(
'--cfg',
dest='cfg',
help='cfg model file (/path/to/model_config.yaml)',
default=None,
type=str
)
parser.add_argument(
'--output-dir',
dest='output_dir',
help='directory for visualization pdfs (default: /tmp/infer_simple)',
default='/tmp/infer_simple',
type=str
)
parser.add_argument(
'--image-ext',
dest='image_ext',
help='image file name extension (default: mp4)',
default='mp4',
type=str
)
parser.add_argument(
'im_or_folder', help='image or folder of images', default=None
)
if len(sys.argv) == 1:
parser.print_help()
sys.exit(1)
return parser.parse_args()

def get_resolution(filename):
command = ['ffprobe', '-v', 'error', '-select_streams', 'v:0',
'-show_entries', 'stream=width,height', '-of', 'csv=p=0', filename]
pipe = sp.Popen(command, stdout=sp.PIPE, bufsize=-1)
for line in pipe.stdout:
w, h = line.decode().strip().split(',')
return int(w), int(h)

def read_video(filename):
w, h = get_resolution(filename)

command = ['ffmpeg',
'-i', filename,
'-f', 'image2pipe',
'-pix_fmt', 'bgr24',
'-vsync', '0',
'-vcodec', 'rawvideo', '-']

pipe = sp.Popen(command, stdout=sp.PIPE, bufsize=-1)
while True:
data = pipe.stdout.read(w*h*3)
if not data:
break
yield np.frombuffer(data, dtype='uint8').reshape((h, w, 3))


def main(args):

cfg = get_cfg()
cfg.merge_from_file(model_zoo.get_config_file(args.cfg))
cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.7
cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url(args.cfg)
predictor = DefaultPredictor(cfg)


if os.path.isdir(args.im_or_folder):
im_list = glob.iglob(args.im_or_folder + '/*.' + args.image_ext)
else:
im_list = [args.im_or_folder]

for video_name in im_list:
out_name = os.path.join(
args.output_dir, os.path.basename(video_name)
)
print('Processing {}'.format(video_name))

boxes = []
segments = []
keypoints = []

for frame_i, im in enumerate(read_video(video_name)):
t = time.time()
outputs = predictor(im)['instances'].to('cpu')

print('Frame {} processed in {:.3f}s'.format(frame_i, time.time() - t))

has_bbox = False
if outputs.has('pred_boxes'):
bbox_tensor = outputs.pred_boxes.tensor.numpy()
if len(bbox_tensor) > 0:
has_bbox = True
scores = outputs.scores.numpy()[:, None]
bbox_tensor = np.concatenate((bbox_tensor, scores), axis=1)
if has_bbox:
kps = outputs.pred_keypoints.numpy()
kps_xy = kps[:, :, :2]
kps_prob = kps[:, :, 2:3]
kps_logit = np.zeros_like(kps_prob) # Dummy
kps = np.concatenate((kps_xy, kps_logit, kps_prob), axis=2)
kps = kps.transpose(0, 2, 1)
else:
kps = []
bbox_tensor = []

# Mimic Detectron1 format
cls_boxes = [[], bbox_tensor]
cls_keyps = [[], kps]

boxes.append(cls_boxes)
segments.append(None)
keypoints.append(cls_keyps)


# Video resolution
metadata = {
'w': im.shape[1],
'h': im.shape[0],
}

np.savez_compressed(out_name, boxes=boxes, segments=segments, keypoints=keypoints, metadata=metadata)


if __name__ == '__main__':
setup_logger()
args = parse_args()
main(args)

0 comments on commit c4675a1

Please sign in to comment.