Skip to content

Commit

Permalink
Initial code release
Browse files Browse the repository at this point in the history
  • Loading branch information
dariopavllo committed Oct 13, 2020
1 parent 9790d5b commit cc83cf5
Show file tree
Hide file tree
Showing 41 changed files with 17,039 additions and 0 deletions.
11 changes: 11 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
checkpoints_gan/
checkpoints_recon/
tensorboard_gan/
tensorboard_recon/
cache/
datasets/
output/
__pycache__/
.ipynb_checkpoints/
*.zip
*.py[cod]
106 changes: 106 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
# Convolutional Generation of Textured 3D Meshes

This is the reference implementation of "Convolutional Generation of Textured 3D Meshes", accepted at **NeurIPS 2020** with oral presentation.

> Dario Pavllo, Graham Spinks, Thomas Hofmann, Marie-Francine Moens, and Aurelien Lucchi. [Convolutional Generation of Textured 3D Meshes](https://arxiv.org/abs/2006.07660). In Neural Information Processing Systems (NeurIPS), 2020.
The paper proposes a novel GAN framework for generating 3D triangle meshes and corresponding texture maps, leveraging recent advances in differentiable rendering. The model can be conditioned on a variety of inputs, including class labels, attributes, and text via an attention mechanism.

![](images/teaser.jpg)
<img src="images/animation.gif" width="768px" alt="" />

## Setup
Instructions on how to set up dependencies, datasets, and pretrained models can be found in [SETUP.md](SETUP.md)

## Evaluating pretrained models
In order to test our pretrained models, the minimal setup described in [SETUP.md](SETUP.md) is sufficient. No dataset setup is required.
We provide an interface for evaluating FID scores, as well as an interface for exporting a sample of generated 3D meshes (both as a grid of renderings and as .obj meshes).

### Exporting a sample
You can export a sample of generated meshes using `--export-sample`. For a quick demo on CUB, run:
```
python run_generation.py --name pretrained_cub_512x512_class --conditional_class --dataset cub --gpu_ids 0 --batch_size 16 --export_sample
```
This will generate a sample of 16 meshes, render them from random viewpoints, and export the final result to the `output` directory as a png image. In addition, the script will export the meshes as .obj files (along with material and texture). These can be imported into Blender or other modeling tools.

![](images/export.jpg)

### Evaluating FID scores on pretrained models
For CUB, we provide pretrained models where the generator is conditioned on class labels, captions, or nothing. You can evaluate them as follows:
```
python run_generation.py --name pretrained_cub_512x512_class --conditional_class --dataset cub --gpu_ids 0,1,2,3 --batch_size 64 --evaluate
python run_generation.py --name pretrained_cub_512x512_text --conditional_text --dataset cub --gpu_ids 0,1,2,3 --batch_size 64 --evaluate
python run_generation.py --name pretrained_cub_512x512_uncond --dataset cub --gpu_ids 0,1,2,3 --batch_size 64 --evaluate
```

For Pascal3D+, the pretrained models are conditioned on either class labels, class+color, or nothing.
```
python run_generation.py --name pretrained_p3d_512x512_class --conditional_class --dataset p3d --gpu_ids 0,1,2,3 --batch_size 64 --evaluate
python run_generation.py --name pretrained_p3d_512x512_class_color --conditional_class --conditional_color --dataset p3d --gpu_ids 0,1,2,3 --batch_size 64 --evaluate
python run_generation.py --name pretrained_p3d_512x512_uncond --dataset p3d --gpu_ids 0,1,2,3 --batch_size 64 --evaluate
```
You can of course adjust the number of GPUs and batch size to suit your computational resources. For evaluation, 16 elements per GPU is a sensible choice. You can also tune the number of data-loading threads using the `--num_workers` argument (default: 4 threads). If unspecified, the truncation sigma for evaluation is autodetected depending on the dataset. In case you want to specify it manually, you can do so through the `--truncation_sigma` argument (e.g. to disable truncation, specify a large value like 1000).

The table below summarizes the FID scores you can expect when evaluating the pretrained models:
| Dataset | Split | Texture resolution | Conditioning | FID (Full) | $$\sigma$$ |
|:-------|:-------:|:-------:|:-------:|:-------:|:-------:|
| CUB Birds | testval | 512 x 512 | Class | ~33 | 0.25 |
| CUB Birds | testval | 512 x 512 | Text | ~18 | 0.5 |
| CUB Birds | testval | 512 x 512 | None | ~41 | 1.0 |
| Pascal3D+ Cars | train | 512 x 512 | Class | ~27 | 0.75 |
| Pascal3D+ Cars | train | 512 x 512 | Class+Color | ~31 | 0.5 |
| Pascal3D+ Cars | train | 512 x 512 | None | ~43 | 1.0 |


To evaluate the Mesh FID and Texture FID (described in the paper), you need to set up the pseudo-ground-truth data as described in the next section.

## Generating pseudo-ground-truth data
This steps requires a trained mesh estimation model. You can use the pretrained one we provide or train it from scratch (as described in the next section).
The pseudo-ground-truth for CUB can be generated as follows:
```
python run_reconstruction.py --name pretrained_reconstruction_cub --dataset cub --batch_size 10 --generate_pseudogt
```
The command for Pascal3D+ is:
```
python run_reconstruction.py --name pretrained_reconstruction_p3d --dataset p3d --optimize_z0 --batch_size 10 --generate_pseudogt
```

This will create (or replace) a directory `cache`, which contains precomputed statistics for FID evaluation, pose/image metadata, and the pseudo-ground-truth for each image.


## GAN training
To train the mesh generator from scratch, you first need to set up the pseudo-ground-truth data as described in the section above. Then, you can train a new model as follows:
```
python run_generation.py --name cub_512x512_class --conditional_class --dataset cub --gpu_ids 0,1,2,3 --batch_size 32 --epochs 600 --tensorboard
```
This command will train a CUB model conditioned on class labels, for 600 epochs. By default, FID evaluations are carried out every 20 epochs, but you can change this value using the flag `--evaluate_freq`. If you specify `--tensorboard`, training curves, FID curves, and generated results will be exported in the Tensorboard log directory `tensorboard_gan`. Note that using a different batch size or number of GPUs might results in slightly different results than those reported in the paper.

Once the training process has finished, you can find the best checkpoint (in terms of FID score) by running:
```
python run_generation.py --name cub_512x512_class --conditional_class --dataset cub --gpu_ids 0,1,2,3 --batch_size 64 --evaluate --which_epoch best
```

For Pascal3D+ and the other settings, the commands to train new models are the same as the evaluation ones, but without the `--evaluate` flag (and with a batch size of 32).

## Mesh estimation model training
To train the mesh estimation model from scratch (the very first step of the pipeline), you can use the following two commands:
```
python run_reconstruction.py --name pretrained_reconstruction_cub --dataset cub --batch_size 50 --tensorboard
python run_reconstruction.py --name pretrained_reconstruction_p3d --dataset p3d --optimize_z0 --batch_size 50 --tensorboard
```
Tensorboard logs are saved in `tensorboard_recon`.

## Citation
If you use this work in your research, please consider citing our paper:
```
@inproceedings{pavllo2020convmesh,
title={Convolutional Generation of Textured 3D Meshes},
author={Pavllo, Dario and Spinks, Graham and Hofmann, Thomas and Moens, Marie-Francine and Lucchi, Aurelien},
booktitle={Neural Information Processing Systems (NeurIPS)},
year={2020}
}
```

## License and Acknowledgments
Our work is licensed under the MIT license. For more details, see [LICENSE](LICENSE).
This repository includes third-party libraries which may be subject to their respective licenses: [Synchronized-BatchNorm-PyTorch](https://github.com/vacancy/Synchronized-BatchNorm-PyTorch), the data loader from [CMR](https://github.com/akanazawa/cmr), some text processing scripts from [AttnGAN](https://github.com/taoxugit/AttnGAN), and FID evaluation code from [pytorch-fid](https://github.com/mseitzer/pytorch-fid).
38 changes: 38 additions & 0 deletions SETUP.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
# Code and data setup

## Requirements
- [Kaolin](https://github.com/NVIDIAGameWorks/kaolin) (tested on commit [e7e5131](https://github.com/NVIDIAGameWorks/kaolin/tree/e7e513173bd4159ae45be6b3e156a3ad156a3eb9))
- Python >= 3.6
- PyTorch >= 1.2
- CUDA >= 10.0 (you won't be able to build kaolin with CUDA 9)

To run the code, you also need to install the following packages (you can easily do so via pip): `packaging`, `nltk` (for models conditioned on captions), and `tensorboard` (if you want to use this feature).

Note that, although Kaolin only officially supports PyTorch versions between 1.2 and 1.4, our code dynamically patches some functions to make them work with newer PyTorch versions. Currently, inference code has been successfully tested with PyTorch 1.6.


## Minimal setup (evaluating pretrained models)
This step involves setting up the pretrained models, precomputed statistics for FID evaluation, and precomputed pose metadata. No dataset setup is involved.
With this setup, you will be able to evaluate our pretrained models (FID scores and mesh export), but you will not be able to train a new model from scratch.

You can download the pretrained models and cache directory from the [Releases](https://github.com/dariopavllo/convmesh/releases) section of this repository. It suffices to extract the archives to the root directory of this repo.

## Full dataset setup (training from scratch)
If you have not already done so, set up the precomputed metadata as described in the step above.

Then, to set up the CUB dataset, download [CUB images](http://www.vision.caltech.edu/visipedia-data/CUB-200-2011/CUB_200_2011.tgz) and [segmentations](http://www.vision.caltech.edu/visipedia-data/CUB-200-2011/segmentations.tgz) ([source](http://www.vision.caltech.edu/visipedia/CUB-200-2011.html)) and extract them so that your directory tree looks like this:
```
datasets/cub/CUB_200_2011/
datasets/cub/CUB_200_2011/segmentations/
datasets/cub/data/
datasets/cub/sfm/
```
Creating symbolic links is also a good idea if you have a copy of the dataset somewhere else.

For Pascal3D+ Cars, download [PASCAL3D+_release1.1.zip](ftp://cs.stanford.edu/cs/cvgl/PASCAL3D+_release1.1.zip) ([source](https://cvgl.stanford.edu/projects/pascal3d.html)) and set up your directory tree like this:
```
datasets/p3d/PASCAL3D+_release1.1/
datasets/p3d/data/
datasets/p3d/sfm/
datasets/p3d/p3d_labels.csv
```
210 changes: 210 additions & 0 deletions cmr_data/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,210 @@
# Adapted from:
# https://github.com/akanazawa/cmr/blob/c24cab6aececa1cb8416ccb2d3ee470915726937/data/base.py

"""
Base data loading class.
Should output:
- img: B X 3 X H X W
- kp: B X nKp X 2
- mask: B X H X W
- sfm_pose: B X 7 (s, tr, q)
(kp, sfm_pose) correspond to image coordinates in [-1, 1]
"""

import os.path as osp
import numpy as np
import copy

import scipy.linalg
import scipy.ndimage.interpolation
from skimage.io import imread

import torch
from torch.utils.data import Dataset

from . import image_utils
from . import transformations


# -------------- Dataset ------------- #
# ------------------------------------ #
class BaseDataset(Dataset):
'''
img, mask, kp, pose data loader
'''

def __init__(self, is_train, img_size):
# Child class should define/load:
# self.kp_perm
# self.img_dir
# self.anno
# self.anno_sfm
if not isinstance(img_size, list):
self.img_sizes = [img_size]
else:
self.img_sizes = img_size
self.jitter_frac = 0
self.padding_frac = 0.05
self.is_train = is_train

def get_paths(self):
paths = []
for index, data in enumerate(self.anno):
img_path_rel = str(data.rel_path).replace('\\', '/')
paths.append(img_path_rel)
return paths

def forward_img(self, index):
data = self.anno[index]
data_sfm = self.anno_sfm[index]

# sfm_pose = (sfm_c, sfm_t, sfm_r)
sfm_pose = [np.copy(data_sfm.scale), np.copy(data_sfm.trans), np.copy(data_sfm.rot)]

sfm_rot = np.pad(sfm_pose[2], (0,1), 'constant')
sfm_rot[3, 3] = 1
sfm_pose[2] = transformations.quaternion_from_matrix(sfm_rot, isprecise=True)

img_path = osp.join(self.img_dir, str(data.rel_path)).replace('\\', '/')
img_path_rel = str(data.rel_path).replace('\\', '/')
img = imread(img_path) / 255.0
# Some are grayscale:
if len(img.shape) == 2:
img = np.repeat(np.expand_dims(img, 2), 3, axis=2)
mask = np.expand_dims(data.mask, 2)

# Adjust to 0 indexing
bbox = np.array(
[data.bbox.x1, data.bbox.y1, data.bbox.x2, data.bbox.y2],
float) - 1

parts = data.parts.T.astype(float)
kp = np.copy(parts)
vis = kp[:, 2] > 0
kp[vis, :2] -= 1

# Peturb bbox
if self.is_train:
bbox = image_utils.peturb_bbox(
bbox, pf=self.padding_frac, jf=self.jitter_frac)
else:
bbox = image_utils.peturb_bbox(
bbox, pf=self.padding_frac, jf=0)
bbox = image_utils.square_bbox(bbox)
true_resolution = bbox[2] - bbox[0] + 1

# crop image around bbox, translate kps
img, mask, kp, sfm_pose = self.crop_image(img, mask, bbox, kp, vis, sfm_pose)

mirrored = self.is_train and (torch.randint(0, 2, size=(1,)).item() == 1)

# scale image, and mask. And scale kps.
img_ref, mask_ref, kp_ref, sfm_pose_ref = self.scale_image(img.copy(), mask.copy(),
kp.copy(), vis.copy(),
copy.deepcopy(sfm_pose),
self.img_sizes[0])
if mirrored:
img_ref, mask_ref, kp_ref, sfm_pose_ref = self.mirror_image(img_ref, mask_ref, kp_ref, sfm_pose_ref)

# Normalize kp to be [-1, 1]
img_h, img_w = img_ref.shape[:2]
kp_norm, sfm_pose_ref = self.normalize_kp(kp_ref, sfm_pose_ref, img_h, img_w)

# Finally transpose the image to 3xHxW
img_ref = np.transpose(img_ref, (2, 0, 1))

# Compute other resolutions (if requested)
extra_res = {}
for res in self.img_sizes[1:]:
img2, mask2, kp2, sfm_pose2 = self.scale_image(img.copy(), mask.copy(),
kp.copy(), vis.copy(),
copy.deepcopy(sfm_pose),
res)
if mirrored:
img2, mask2, kp2, sfm_pose2 = self.mirror_image(img2, mask2, kp2, sfm_pose2)

img2 = np.transpose(img2, (2, 0, 1))
extra_res[res] = (img2, mask2)

return img_ref, kp_norm, mask_ref, sfm_pose_ref, mirrored, img_path_rel, extra_res

def normalize_kp(self, kp, sfm_pose, img_h, img_w):
vis = kp[:, 2, None] > 0
new_kp = np.stack([2 * (kp[:, 0] / img_w) - 1,
2 * (kp[:, 1] / img_h) - 1,
kp[:, 2]]).T
sfm_pose[0] *= (1.0/img_w + 1.0/img_h)
sfm_pose[1][0] = 2.0 * (sfm_pose[1][0] / img_w) - 1
sfm_pose[1][1] = 2.0 * (sfm_pose[1][1] / img_h) - 1
new_kp = vis * new_kp

return new_kp, sfm_pose

def crop_image(self, img, mask, bbox, kp, vis, sfm_pose):
# crop image and mask and translate kps
img = image_utils.crop(img, bbox, bgval=1)
mask = image_utils.crop(mask, bbox, bgval=0)
kp[vis, 0] -= bbox[0]
kp[vis, 1] -= bbox[1]
sfm_pose[1][0] -= bbox[0]
sfm_pose[1][1] -= bbox[1]
return img, mask, kp, sfm_pose

def scale_image(self, img, mask, kp, vis, sfm_pose, img_size):
# Scale image so largest bbox size is img_size
bwidth = np.shape(img)[0]
bheight = np.shape(img)[1]
scale = img_size / float(max(bwidth, bheight))
img_scale, _ = image_utils.resize_img(img, scale)
# if img_scale.shape[0] != img_size:
# print('bad!')
# import ipdb; ipdb.set_trace()
mask_scale, _ = image_utils.resize_img(mask, scale)
kp[vis, :2] *= scale
sfm_pose[0] *= scale
sfm_pose[1] *= scale

return img_scale, mask_scale, kp, sfm_pose

def mirror_image(self, img, mask, kp, sfm_pose):
kp_perm = self.kp_perm

# Need copy bc torch collate doesnt like neg strides
img_flip = img[:, ::-1, :].copy()
mask_flip = mask[:, ::-1].copy()

# Flip kps.
new_x = img.shape[1] - kp[:, 0] - 1
kp_flip = np.hstack((new_x[:, None], kp[:, 1:]))
kp_flip = kp_flip[kp_perm, :]
# Flip sfm_pose Rot.
R = transformations.quaternion_matrix(sfm_pose[2])
flip_R = np.diag([-1, 1, 1, 1]).dot(R.dot(np.diag([-1, 1, 1, 1])))
sfm_pose[2] = transformations.quaternion_from_matrix(flip_R, isprecise=True)
# Flip tx
tx = img.shape[1] - sfm_pose[1][0] - 1
sfm_pose[1][0] = tx
return img_flip, mask_flip, kp_flip, sfm_pose

def __len__(self):
return self.num_imgs

def __getitem__(self, index):
img, kp, mask, sfm_pose, mirrored, path, extra_res = self.forward_img(index)
sfm_pose[0].shape = 1

elem = {
'img': img,
'kp': kp,
'mask': mask,
'sfm_pose': np.concatenate(sfm_pose),
'mirrored': mirrored,
'inds': index,
'path': path,
}

for res, img2 in extra_res.items():
elem[f'img_{res}'] = img2

return elem
Loading

0 comments on commit cc83cf5

Please sign in to comment.