Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Decentralized Distributed PPO (#245)
Add DD-PPO to habitat-baselines
- Loading branch information
1 parent
44c8be1
commit 85b7907
Showing
24 changed files
with
1,671 additions
and
40 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,60 @@ | ||
BASE_TASK_CONFIG_PATH: "configs/tasks/pointnav_gibson.yaml" | ||
TRAINER_NAME: "ddppo" | ||
ENV_NAME: "NavRLEnv" | ||
SIMULATOR_GPU_ID: 0 | ||
TORCH_GPU_ID: 0 | ||
VIDEO_OPTION: [] | ||
TENSORBOARD_DIR: "tb" | ||
VIDEO_DIR: "video_dir" | ||
TEST_EPISODE_COUNT: 994 | ||
EVAL_CKPT_PATH_DIR: "data/new_checkpoints" | ||
NUM_PROCESSES: 4 | ||
SENSORS: ["DEPTH_SENSOR"] | ||
CHECKPOINT_FOLDER: "data/new_checkpoints" | ||
NUM_UPDATES: 10000 | ||
LOG_INTERVAL: 10 | ||
CHECKPOINT_INTERVAL: 50 | ||
|
||
RL: | ||
SUCCESS_REWARD: 2.5 | ||
PPO: | ||
# ppo params | ||
clip_param: 0.2 | ||
ppo_epoch: 2 | ||
num_mini_batch: 2 | ||
value_loss_coef: 0.5 | ||
entropy_coef: 0.01 | ||
lr: 2.5e-4 | ||
eps: 1e-5 | ||
max_grad_norm: 0.2 | ||
num_steps: 128 | ||
use_gae: True | ||
gamma: 0.99 | ||
tau: 0.95 | ||
use_linear_clip_decay: False | ||
use_linear_lr_decay: False | ||
reward_window_size: 50 | ||
|
||
use_normalized_advantage: False | ||
|
||
hidden_size: 512 | ||
|
||
DDPPO: | ||
sync_frac: 0.6 | ||
# The PyTorch distributed backend to use | ||
distrib_backend: GLOO | ||
# Visual encoder backbone | ||
pretrained_weights: data/ddppo-models/gibson-2plus-resnet50.pth | ||
# Initialize with pretrained weights | ||
pretrained: False | ||
# Initialize just the visual encoder backbone with pretrained weights | ||
pretrained_encoder: False | ||
# Whether or not the visual encoder backbone will be trained. | ||
train_encoder: True | ||
# Whether or not to reset the critic linear layer | ||
reset_critic: True | ||
|
||
# Model parameters | ||
backbone: resnet50 | ||
rnn_type: LSTM | ||
num_recurrent_layers: 2 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
BASE_TASK_CONFIG_PATH: "configs/tasks/pointnav.yaml" | ||
TRAINER_NAME: "ddppo" | ||
ENV_NAME: "NavRLEnv" | ||
SIMULATOR_GPU_ID: 0 | ||
TORCH_GPU_ID: 0 | ||
VIDEO_OPTION: [] | ||
TENSORBOARD_DIR: "" | ||
EVAL_CKPT_PATH_DIR: "data/test_checkpoints/ddppo/pointnav/ckpt.0.pth" | ||
NUM_PROCESSES: 1 | ||
CHECKPOINT_FOLDER: "data/test_checkpoints/ddppo/pointnav/" | ||
NUM_UPDATES: 2 | ||
LOG_INTERVAL: 100 | ||
CHECKPOINT_INTERVAL: 1 | ||
|
||
RL: | ||
SUCCESS_REWARD: 2.5 | ||
PPO: | ||
# ppo params | ||
clip_param: 0.2 | ||
ppo_epoch: 2 | ||
num_mini_batch: 1 | ||
value_loss_coef: 0.5 | ||
entropy_coef: 0.01 | ||
lr: 2.5e-4 | ||
eps: 1e-5 | ||
max_grad_norm: 0.2 | ||
num_steps: 16 | ||
use_gae: True | ||
gamma: 0.99 | ||
tau: 0.95 | ||
use_linear_clip_decay: False | ||
use_linear_lr_decay: False | ||
reward_window_size: 50 | ||
|
||
use_normalized_advantage: False | ||
|
||
hidden_size: 512 | ||
|
||
DDPPO: | ||
sync_frac: 0.6 | ||
# The PyTorch distributed backend to use | ||
distrib_backend: GLOO | ||
# Visual encoder backbone | ||
pretrained_weights: data/ddppo-models/gibson-2plus-resnet50.pth | ||
# Initialize with pretrained weights | ||
pretrained: False | ||
# Initialize just the visual encoder backbone with pretrained weights | ||
pretrained_encoder: False | ||
# Whether or not the visual encoder backbone will be trained. | ||
train_encoder: True | ||
# Whether or not to reset the critic linear layer | ||
reset_critic: True | ||
|
||
# Model parameters | ||
backbone: resnet50 | ||
rnn_type: LSTM | ||
num_recurrent_layers: 2 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
# Decentralized Distributed PPO | ||
|
||
Provides changes to the core baseline ppo algorithm and training script to implemented Decentralized Distributed PPO (DD-PPO). | ||
DD-PPO leverages distributed data parallelism to seamlessly scale PPO to hundreds of GPUs with no centralized server. | ||
|
||
See the [paper](https://arxiv.org/abs/1911.00357) for more detail. | ||
|
||
## Running | ||
|
||
There are two example scripts to run provided. A single node script that leverages `torch.distributed.launch` to create multiple workers: | ||
`single_node.sh`, and a multi-node script that leverages [SLURM](https://slurm.schedmd.com/documentation.html) to create all the works on multiple nodes: `multi_node_slurm.sh`. | ||
|
||
The two recommended backends are GLOO and NCCL. Use NCCL if your system has it, and GLOO if otherwise. | ||
|
||
See [pytorch's distributed docs](https://pytorch.org/docs/stable/distributed.html#backends-that-come-with-pytorch) | ||
and [pytorch's distributed tutorial](https://pytorch.org/tutorials/intermediate/dist_tuto.html) for more information. | ||
|
||
## Pretrained Models (PointGoal Navigation with GPS+Compass) | ||
|
||
|
||
All weights available as a zip [here](https://drive.google.com/open?id=1ueXuIqP2HZ0oxhpDytpc3hpciXSd8H16). | ||
|
||
### Depth models | ||
|
||
| Architecture | Training Data | Val SPL | Test SPL | URL | | ||
| ------------ | ------------- | ------- | -------- | --- | | ||
| ResNet50 + LSTM512 | Gibson 4+ | 0.922 | 0.917 | | | ||
| ResNet50 + LSTM512 | Gibson 4+ and MP3D(train/val/test)<br/> **Caution:** Trained on MP3D val and test | 0.956 | 0.941 | | ||
| ResNet50 + LSTM512 | Gibson 2+ | 0.956 | 0.944 | | | ||
| SE-ResNeXt50 + LSTM512 | Gibson 2+ | 0.959 | 0.943 | | | ||
| SE-ResNeXt101 + LSTM1024 | Gibson 2+ | 0.969 | 0.948 | | | ||
|
||
### RGB models | ||
|
||
| Architecture | Training Data | Val SPL | Test SPL | URL | | ||
| ------------ | ------------- | ------- | -------- | --- | | ||
| ResNet50 + LSTM512 | Gibson 2+ and MP3D(train/val/test)<br/> **Caution:** Trained on MP3D val and test | | | | ||
| SE-ResNeXt50 + LSTM512 | Gibson 2+ and MP3D(train/val/test)<br/> **Caution:** Trained on MP3D val and test | 0.933 | 0.920 | | ||
|
||
|
||
### Blind Models | ||
|
||
| Architecture | Training Data | Val SPL | Test SPL | URL | | ||
| ------------ | ------------- | ------- | -------- | --- | | ||
| LSTM512 | Gibson 0+ and MP3D(train/val/test)<br/> **Caution:** Trained on MP3D val and test | 0.729 | 0.676 | | ||
|
||
|
||
|
||
|
||
**Note:** Evaluation was done with *sampled* actions. | ||
|
||
All model weights are subject to [Matterport3D Terms-of-Use](http://dovahkiin.stanford.edu/matterport/public/MP_TOS.pdf). | ||
|
||
|
||
## Citing | ||
|
||
If you use DD-PPO or the model-weights in your research, please cite the following [paper](https://arxiv.org/abs/1911.00357): | ||
|
||
@article{wijmans2020ddppo, | ||
title = {{D}ecentralized {D}istributed {PPO}: {S}olving {P}oint{G}oal {N}avigation}, | ||
author = {Erik Wijmans and Abhishek Kadian and Ari Morcos and Stefan Lee and Irfan Essa and Devi Parikh and Manolis Savva and Dhruv Batra}, | ||
journal = {International Conference on Learning Representations (ICLR)}, | ||
year = {2020} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
#!/usr/bin/env python3 | ||
|
||
# Copyright (c) Facebook, Inc. and its affiliates. | ||
# This source code is licensed under the MIT license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
from habitat_baselines.rl.ddppo.algo import DDPPOTrainer |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
#!/usr/bin/env python3 | ||
|
||
# Copyright (c) Facebook, Inc. and its affiliates. | ||
# This source code is licensed under the MIT license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
from habitat_baselines.rl.ddppo.algo.ddppo_trainer import DDPPOTrainer |
Oops, something went wrong.