Skip to content

Commit

Permalink
Merge branch 'master' of github.com:kentsommer/pytorch-value-iteratio…
Browse files Browse the repository at this point in the history
…n-networks
  • Loading branch information
Kent Sommer committed Apr 21, 2017
2 parents 2b2bc5f + 88dd40d commit ffe6c3c
Showing 1 changed file with 13 additions and 1 deletion.
14 changes: 13 additions & 1 deletion README.md
Expand Up @@ -46,7 +46,7 @@ python train.py --datafile dataset/gridworld_28x28.npz --imsize 28 --lr 0.002 --
- `l_q`: Number of channels in q layer (~actions) in VI-module. Default: 10, described in paper.
- `batch_size`: Batch size. Default: 128

## How to visualize / test paths (requires training first)
## How to test / visualize paths (requires training first)
#### 8x8 gridworld
```bash
python test.py --weights trained/vin_8x8.pth --imsize 8 --k 10
Expand All @@ -59,10 +59,15 @@ python test.py --weights trained/vin_16x16.pth --imsize 16 --k 20
```bash
python test.py --weights trained/vin_28x28.pth --imsize 28 --k 36
```
To visualize the optimal and predicted paths simply pass:
```bash
--plot
```

**Flags**:
- `weights`: Path to trained weights.
- `imsize`: The size of input images. One of: [8, 16, 28]
- `plot`: If supplied, the optimal and predicted paths will be plotted
- `k`: Number of Value Iterations. Recommended: [10 for 8x8, 20 for 16x16, 36 for 28x28]
- `l_i`: Number of channels in input layer. Default: 2, i.e. obstacles image and goal image.
- `l_h`: Number of channels in first convolutional layer. Default: 150, described in paper.
Expand All @@ -85,6 +90,13 @@ Test set | 13846 | 77203 | 251755

The datasets (8x8, 16x16, and 28x28) included in this repository can be reproduced using the ```dataset/make_training_data.py``` script. Note that this script is not optimized and runs rather slowly (also uses a lot of memory :D)

## Performance: Success Rate
This is the success rate from rollouts of the learned policy in the environment (taken over 5000 randomly generated domains).

Success Rate | 8x8 | 16x16 | 28x28
-- | -- | -- | --
PyTorch | 99.69% | 96.99% | 91.07%

## Performance: Test Accuracy

**NOTE**: This is the **accuracy on test set**. It is different from the table in the paper, which indicates the **success rate** from rollouts of the learned policy in the environment.
Expand Down

0 comments on commit ffe6c3c

Please sign in to comment.