Skip to content

Commit

Permalink
Strip out dead code, improve docs
Browse files Browse the repository at this point in the history
  • Loading branch information
isarandi committed Mar 27, 2020
1 parent fc025ea commit 7c1fa96
Show file tree
Hide file tree
Showing 29 changed files with 451 additions and 707 deletions.
24 changes: 13 additions & 11 deletions README.md
@@ -1,26 +1,28 @@
## MeTRo 3D Human Pose Estimator
# MeTRo 3D Human Pose Estimator

#### What is this?
### What is this?

Code to train and evaluate the MeTRo method, proposed in our paper
"Metric-Scale Truncation-Robust Heatmaps for 3D Human Pose Estimation" (Sárándi et al., 2020).
A preprint of the paper is on arXiv: https://arxiv.org/abs/2003.02953

#### What does it do?
### What does it do?

It takes a single **RGB image of a person as input** and returns the **3D coordinates of 17 body joints** relative to the pelvis. The coordinates are estimated in millimeters directly. Also, it always returns a complete pose by guessing joint positions even outside of the image boundaries (truncation).

#### How do I run it?
### How do I run it?
There is a small, self-contained script `inference.py` with minimal dependencies (just TensorFlow + NumPy), which can perform inference with a pretrained, exported model. Use it as follows:

Stay tuned for inference instructions and pretrained models!

#### How do I train it?
See [DATASETS.md](docs/DATASETS.md) on how to download and prepare the training and test data.

Then see [TRAINING.md](docs/TRAINING.md) for instructions on running experiments.
```bash
wget https://omnomnom.vision.rwth-aachen.de/data/metro-pose3d/coco_mpii_h36m_3dhp_cmu_3dpw_resnet50_upperbodyaug_stride16.pb
./inference.py --model-path=coco_mpii_h36m_3dhp_cmu_3dpw_resnet50_upperbodyaug_stride16.pb
```

### How do I train it?
See [DEPENDENCIES.md](docs/DEPENDENCIES.md) for installing the dependencies. Then follow [DATASETS.md](docs/DATASETS.md) to download and prepare the training and test data.
Finally, see [TRAINING.md](docs/TRAINING.md) for instructions on running experiments.

#### How do I cite it?
### How do I cite it?
If you use this work, please cite it as:

```bibtex
Expand Down
24 changes: 14 additions & 10 deletions datasets/all.sh
@@ -1,20 +1,24 @@
#!/usr/bin/env bash
set -euo pipefail

ask(){
ask() {
while true; do
read -rp "$1" yn
case $yn in
[Yy]* ) echo y; break;;
[Nn]* ) echo n;;
* ) echo "Please answer yes or no.";;
esac
read -rp "$1" yn
case $yn in
[Yy]*)
echo y
break
;;
[Nn]*) echo n ;;
*) echo "Please answer yes or no." ;;
esac
done
}

if [[ $(ask "The Human3.6M, MPII, Pascal VOC, MPI-INF-3DHP and INRIA Holidays datasets are each from third parties. Have you read and do you agree with their respective licenses? [y/n]") != 'y' ]]; then
echo "Then no cookies for you! Go read all the licenses!"
exit 1
if [[ $(ask "The Human3.6M, MPII, Pascal VOC, MPI-INF-3DHP and INRIA Holidays datasets are each from third parties.
Have you read and agreed to their respective licenses? [y/n] ") != 'y' ]]; then
echo "Then no cookies for you! Go read all the licenses!"
exit 1
fi

# Find out the location of this script and cd into it
Expand Down
2 changes: 1 addition & 1 deletion datasets/get_3dhp.sh
Expand Up @@ -20,4 +20,4 @@ mv mpi_inf_3dhp_test_set/mpi_inf_3dhp_test_set/TS* ./
mv mpi_inf_3dhp_test_set/mpi_inf_3dhp_test_set/test_util ./
mv mpi_inf_3dhp_test_set/mpi_inf_3dhp_test_set/README.txt ./README_testset.txt
rmdir mpi_inf_3dhp_test_set/mpi_inf_3dhp_test_set
rmdir mpi_inf_3dhp_test_set
rmdir mpi_inf_3dhp_test_set
75 changes: 46 additions & 29 deletions docs/DEPENDENCIES.md
@@ -1,52 +1,69 @@
## Dependencies

***Note: All instructions were tested on Ubuntu 18.04.3.***
*Note: this was tested on Ubuntu 18.04.3.*

Anaconda is the easiest way to install the dependencies. If you don't have it installed yet, open a Bash shell and install Miniconda as follows:
### All-in-one script

On a freshly installed Ubuntu 18.04, just run:

```bash
./install_dependencies.sh
```

This should take care of everything.

----
### Step-by-step explanation

[Anaconda](https://anaconda.com) is the simplest way to install most of the dependencies. If you don't have it installed yet, open a Bash shell and install Miniconda as follows:

```bash
wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh
bash ./Miniconda3-latest-Linux-x86_64.sh
eval "$($HOME/miniconda3/bin/conda shell.bash hook)"
$ wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh
$ bash ./Miniconda3-latest-Linux-x86_64.sh -b
$ eval "$($HOME/miniconda3/bin/conda shell.bash hook)"
```

Create a new environment and install the dependencies:

```bash
conda create --name metro-pose3d python=3.6 Cython matplotlib pillow imageio ffmpeg scikit-image scikit-learn tqdm numba
conda activate metro-pose3d
conda install opencv3 -c menpo
pip install tensorflow-gpu==1.13.1 attrdict jpeg4py transforms3d more_itertools spacepy
$ conda create --yes --name metro-pose3d python=3.6 Cython matplotlib pillow imageio ffmpeg scikit-image scikit-learn tqdm numba
$ conda activate metro-pose3d
$ conda install --yes opencv3 -c menpo
$ pip install tensorflow-gpu==1.13.1 attrdict jpeg4py transforms3d more_itertools spacepy
```
### COCO tools
#### COCO tools

Install the [COCO tools](https://github.com/pdollar/coco) (used for managing runlength-encoded masks):
Install the [COCO tools](https://github.com/cocodataset/cocoapi) (used for managing runlength-encoded masks):

```
git clone https://github.com/cocodataset/cocoapi
pushd coco/PythonAPI
make
python setup.py install
popd
rm -rf cocoapi
$ git clone https://github.com/cocodataset/cocoapi
$ cd cocoapi/PythonAPI
$ make
$ python setup.py install
$ cd ../..
$ rm -rf cocoapi
```

### CDF
If you also want to train the model, you'll need to install the CDF library because
Human3.6M supplies the annotations as cdf files. We read them using the [SpacePy](https://spacepy.github.io/) Python library,
which in turn depends on the CDF library.
#### CDF
We need to install the [CDF library](https://cdf.gsfc.nasa.gov/) because Human3.6M supplies the annotations as cdf files.
We read them using the [SpacePy](https://spacepy.github.io/) Python library, which in turn depends on the CDF library.

```bash
wget https://spdf.sci.gsfc.nasa.gov/pub/software/cdf/dist/cdf37_0/linux/cdf37_1-dist-cdf.tar.gz
tar xf cdf37_1-dist-cdf.tar.gz
rm cdf37_1-dist-cdf.tar.gz
cd cdf37_1-dist
make OS=linux ENV=gnu CURSES=yes FORTRAN=no UCOPTIONS=-O2 SHARED=yes -j4 all
$ wget https://spdf.sci.gsfc.nasa.gov/pub/software/cdf/dist/cdf37_1/linux/cdf37_1-dist-cdf.tar.gz
$ tar xf cdf37_1-dist-cdf.tar.gz
$ rm cdf37_1-dist-cdf.tar.gz
$ cd cdf37_1-dist
$ make OS=linux ENV=gnu CURSES=yes FORTRAN=no UCOPTIONS=-O2 SHARED=yes -j4 all
```

If you have sudo rights, simply run `sudo make install`. If you have no `sudo` rights, make sure to add the
`cdf37_1-dist/src/lib` to the `LD_LIBRARY_PATH` environment variable (add to ~/.bashrc for permanent effect), or use GNU Stow.

### libjpeg-turbo

Install libjpeg-turbo to make JPEG loading faster. TODO
#### libjpeg-turbo (optional)
Install libjpeg-turbo to make JPEG loading faster.

```bash
$ git clone https://github.com/libjpeg-turbo/libjpeg-turbo.git
$ cmake -G"Unix Makefiles" .
$ make
```
5 changes: 0 additions & 5 deletions docs/INFERENCE.md
@@ -1,5 +0,0 @@
## Inference

This guide is about how to run a pretrained model on new images.

*--Coming soon--*
72 changes: 72 additions & 0 deletions inference.py
@@ -0,0 +1,72 @@
#!/usr/bin/env python3

import argparse

import numpy as np
import skimage.data
import skimage.transform
import tensorflow as tf


def main():
parser = argparse.ArgumentParser(description='MeTRo-Pose3D', allow_abbrev=False)
parser.add_argument('--model-path', type=str, required=True)
opts = parser.parse_args()

images_numpy = np.stack([skimage.transform.resize(skimage.data.astronaut(), (256, 256))])
images_tensor = tf.convert_to_tensor(images_numpy)
poses_tensor = estimate_pose(images_tensor, opts.model_path)

with tf.Session() as sess:
poses_arr = sess.run(poses_tensor)
edges = [(1, 0), (0, 18), (0, 2), (2, 3), (3, 4), (0, 8), (8, 9), (9, 10), (18, 5), (5, 6),
(6, 7), (18, 11), (11, 12), (12, 13), (15, 14), (14, 1), (17, 16), (16, 1)]
visualize_pose(image=images_numpy[0], coords=poses_arr[0], edges=edges)


def estimate_pose(im, model_path):
graph_def = tf.GraphDef()
with tf.gfile.GFile(model_path, 'rb') as f:
graph_def.ParseFromString(f.read())

im_t = tf.cast(im, tf.float32) # / 255 *2 -1
im_t = tf.transpose(im_t, [0, 3, 1, 2])
return tf.import_graph_def(
graph_def, input_map={'input:0': im_t}, return_elements=['pred'])[0].outputs[0]


def visualize_pose(image, coords, edges):
import matplotlib.pyplot as plt
plt.switch_backend('TkAgg')
# noinspection PyUnresolvedReferences
from mpl_toolkits.mplot3d import Axes3D

# Matplotlib interprets the Z axis as vertical, but our pose
# has Y as the vertical axis.
# Therefore we do a 90 degree rotation around the horizontal (X) axis
coords2 = coords.copy()
coords[:, 1], coords[:, 2] = coords2[:, 2], -coords2[:, 1]

fig = plt.figure(figsize=(10, 5))
image_ax = fig.add_subplot(1, 2, 1)
image_ax.set_title('Input')
image_ax.imshow(image)

pose_ax = fig.add_subplot(1, 2, 2, projection='3d')
pose_ax.set_title('Prediction')
range_ = 800
pose_ax.set_xlim3d(-range_, range_)
pose_ax.set_ylim3d(-range_, range_)
pose_ax.set_zlim3d(-range_, range_)

for i_start, i_end in edges:
pose_ax.plot(*zip(coords[i_start], coords[i_end]), marker='o', markersize=2)

pose_ax.scatter(coords[:, 0], coords[:, 1], coords[:, 2], s=2)

fig.tight_layout()
plt.show()


if __name__ == '__main__':
main()
31 changes: 31 additions & 0 deletions install_dependencies.sh
@@ -0,0 +1,31 @@
#!/usr/bin/env bash
set -euo pipefail

sudo apt install build-essential --yes wget curl gfortran git ncurses-dev unzip tar

wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh
bash ./Miniconda3-latest-Linux-x86_64.sh -b
eval "$("$HOME/miniconda3/bin/conda" shell.bash hook)"

conda create --yes --name metro-pose3d python=3.6 Cython matplotlib pillow imageio ffmpeg scikit-image scikit-learn tqdm numba
conda activate metro-pose3d
conda install --yes opencv3 -c menpo
pip install tensorflow-gpu==1.13.1 attrdict jpeg4py transforms3d more_itertools spacepy

git clone https://github.com/cocodataset/cocoapi
cd cocoapi/PythonAPI
make
python setup.py install
cd ../..
rm -rf cocoapi

wget https://spdf.sci.gsfc.nasa.gov/pub/software/cdf/dist/cdf37_1/linux/cdf37_1-dist-cdf.tar.gz
tar xf cdf37_1-dist-cdf.tar.gz
rm cdf37_1-dist-cdf.tar.gz
cd cdf37_1-dist
make OS=linux ENV=gnu CURSES=yes FORTRAN=no UCOPTIONS=-O2 SHARED=yes -j4 all

export LD_LIBRARY_PATH=$PWD/src/lib${LD_LIBRARY_PATH:+:${LD_LIBRARY_PATH}}

# Optional:
# wget https://sourceforge.net/projects/libjpeg-turbo/files/2.0.4/libjpeg-turbo-2.0.4.tar.gz
19 changes: 12 additions & 7 deletions src/augmentation/background.py
@@ -1,6 +1,8 @@
import functools

import geom
import numpy as np

import cameralib
import improc
import paths
import util
Expand All @@ -15,11 +17,14 @@ def get_inria_holiday_background_paths():

def augment_background(im, fgmask, rng):
path = util.choice(get_inria_holiday_background_paths(), rng)
background_im = improc.imread_jpeg_fast(path)
background_im = improc.imread_jpeg(path)

cam = cameralib.Camera.create2D(background_im.shape)
cam_new = cam.copy()

tr = geom.SimTransform()
imside = im.shape[0]
tr = (tr.center_fill(background_im.shape[:2], im.shape[:2], factor=rng.uniform(1.2, 1.5)).
translate(rng.uniform(-imside * 0.1, imside * 0.1, size=2)))
warped_background_im = tr.transform_image(background_im, dst_shape=im.shape[:2])
zoom_aug_factor = rng.uniform(1.2, 1.5)
cam_new.zoom(zoom_aug_factor * np.max(im.shape[:2] / np.asarray(background_im.shape[:2])))
cam_new.center_principal_point(im.shape)
cam_new.shift_image(util.random_uniform_disc(rng) * im.shape[0] * 0.1)
warped_background_im = cameralib.reproject_image(background_im, cam, cam_new, im.shape)
return improc.blend_image(warped_background_im, im, fgmask)
19 changes: 9 additions & 10 deletions src/augmentation/voc_loader.py
Expand Up @@ -10,15 +10,6 @@
import paths
import util

morph_elem = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (8, 8))


def soften_mask(mask):
eroded = cv2.erode(mask, morph_elem)
result = mask.astype(np.float32)
result[eroded < result] = 0.75
return result


@functools.lru_cache()
@util.cache_result_on_disk(
Expand Down Expand Up @@ -53,7 +44,7 @@ def load_occluders():
path = f'{pascal_root}/JPEGImages/{image_filename}'
seg_path = f'{pascal_root}/SegmentationObject/{segmentation_filename}'

im = improc.imread_jpeg_fast(path)
im = improc.imread_jpeg(path)
labels = np.asarray(PIL.Image.open(seg_path))

for i_obj, (xmin, ymin, xmax, ymax) in boxes:
Expand All @@ -71,3 +62,11 @@ def load_occluders():
image_paths.append(path)

return image_mask_pairs


def soften_mask(mask):
morph_elem = improc.get_structuring_element(cv2.MORPH_ELLIPSE, (8, 8))
eroded = cv2.erode(mask, morph_elem)
result = mask.astype(np.float32)
result[eroded < result] = 0.75
return result

0 comments on commit 7c1fa96

Please sign in to comment.