PyTorch CapsNet: Capsule Network for PyTorch
A CUDA-enabled PyTorch implementation of CapsNet (Capsule Network) based on this paper: Sara Sabour, Nicholas Frosst, Geoffrey E Hinton. Dynamic Routing Between Capsules. NIPS 2017
test error is 0.21% and the
best test error is 0.20%. The current
test accuracy is 99.31% and the
best test accuracy is 99.32%.
What is a Capsule
A Capsule is a group of neurons whose activity vector represents the instantiation parameters of a specific type of entity such as an object or object part.
You can learn more about Capsule Networks here.
Why another CapsNet implementation?
I wanted a decent PyTorch implementation of CapsNet and I couldn't find one at the point when I started. The goal of this implementation is focus to help newcomers learn and understand the CapsNet architecture and the idea of Capsules. The implementation is NOT focus on rigorous correctness of the results. In addition, the codes are not optimized for speed. To help us read and understand the codes easier, the codes comes with ample comments and the Python classes and functions are documented with Python docstring.
I will try my best to check and fix issues reported. Contributions are highly welcomed. If you find any bugs or errors in the codes, please do not hesitate to open an issue or a pull request. Thank you.
Status and Latest Updates:
See the CHANGELOG
The model was trained on the standard MNIST data.
Note: you don't have to manually download, preprocess, and load the MNIST dataset as TorchVision will take care of this step for you.
I have tried using other datasets. See the Other Datasets section below for more details.
- Python 3
- Tested with version 3.6.4
- Tested with version 0.3.0.post4
- Migrate existing code to work in version 0.4.0. [Work-In-Progress]
- Code will not run with version 0.1.2 due to
keepdimnot available in this version.
- Code will not run with version 0.2.0 due to
softmaxfunction doesn't takes a dimension.
- CUDA 8 and above
- Tested with CUDA 8 and CUDA 9.
Training and Evaluation
Clone this repository with
git and install project dependencies.
$ git clone https://github.com/cedrickchee/capsule-net-pytorch.git $ cd capsule-net-pytorch $ pip install -r requirements.txt
Step 2. Start the CapsNet on MNIST training and evaluation:
- Training with default settings:
$ python main.py
- Training on 8 GPUs with 30 epochs and 1 routing iteration:
$ CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python main.py --epochs 30 --num-routing 1 --threads 16 --batch-size 128 --test-batch-size 128
Step 3. Test a pre-trained model:
If you have trained a model in Step 2 above, then the weights for the trained model will be saved to
results/trained_model/model_epoch_10.pth. [WIP] Now just run the following command to get test results.
$ python main.py --is-training 0 --weights results/trained_model/model_epoch_10.pth
You can download the weights for the pre-trained model from my Google Drive. We saved the weights (model state dict) and the optimizer state for the model at the end of every training epoch.
Uncompress and put the weights (.pth files) into
Note: the model was last trained on 2017-11-26 and the weights last updated on 2017-11-28.
The Default Hyper Parameters
|Training epochs||10||--epochs 10|
|Learning rate||0.01||--lr 0.01|
|Training batch size||128||--batch-size 128|
|Testing batch size||128||--test-batch-size 128|
|Log interval||10||--log-interval 10|
|Disables CUDA training||false||--no-cuda|
|Num. of channels produced by the convolution||256||--num-conv-out-channel 256|
|Num. of input channels to the convolution||1||--num-conv-in-channel 1|
|Num. of primary unit||8||--num-primary-unit 8|
|Primary unit size||1152||--primary-unit-size 1152|
|Num. of digit classes||10||--num-classes 10|
|Output unit size||16||--output-unit-size 16|
|Num. routing iteration||3||--num-routing 3|
|Use reconstruction loss||true||--use-reconstruction-loss|
|Regularization coefficient for reconstruction loss||0.0005||--regularization-scale 0.0005|
|Dataset name (mnist, cifar10)||mnist||--dataset mnist|
|Input image width to the convolution||28||--input-width 28|
|Input image height to the convolution||28||--input-height 28|
CapsNet classification test error on MNIST. The MNIST average and standard deviation results are reported from 3 trials.
The results can be reproduced by running the following commands.
python main.py --epochs 50 --num-routing 1 --use-reconstruction-loss no --regularization-scale 0.0 #CapsNet-v1 python main.py --epochs 50 --num-routing 1 --use-reconstruction-loss yes --regularization-scale 0.0005 #CapsNet-v2 python main.py --epochs 50 --num-routing 3 --use-reconstruction-loss no --regularization-scale 0.0 #CapsNet-v3 python main.py --epochs 50 --num-routing 3 --use-reconstruction-loss yes --regularization-scale 0.0005 #CapsNet-v4
Training Loss and Accuracy
The training losses and accuracies for CapsNet-v4 (50 epochs, 3 routing iteration, using reconstruction, regularization scale of 0.0005):
Training accuracy. Highest training accuracy: 100%
Training loss. Lowest training error: 0.1938%
Test Loss and Accuracy
The test losses and accuracies for CapsNet-v4 (50 epochs, 3 routing iteration, using reconstruction, regularization scale of 0.0005):
Test accuracy. Highest test accuracy: 99.32%
Test loss. Lowest test error: 0.2002%
5.97s / batchor
8min / epochon a single Tesla K80 GPU with batch size of 704.
3.25s / batchor
25min / epochon a single Tesla K80 GPUwith batch size of 128.
In my case, these are the hyperparameters I used for the training setup:
- batch size: 128
- Epochs: 50
- Num. of routing: 3
- Use reconstruction loss: yes
- Regularization scale for reconstruction loss: 0.0005
The results of CapsNet-v4.
Digits at left are reconstructed images.
[WIP] Ground truth image from dataset
Model architecture: ------------------ Net ( (conv1): ConvLayer ( (conv0): Conv2d(1, 256, kernel_size=(9, 9), stride=(1, 1)) (relu): ReLU (inplace) ) (primary): CapsuleLayer ( (conv_units): ModuleList ( (0): Conv2d(256, 32, kernel_size=(9, 9), stride=(2, 2)) (1): Conv2d(256, 32, kernel_size=(9, 9), stride=(2, 2)) (2): Conv2d(256, 32, kernel_size=(9, 9), stride=(2, 2)) (3): Conv2d(256, 32, kernel_size=(9, 9), stride=(2, 2)) (4): Conv2d(256, 32, kernel_size=(9, 9), stride=(2, 2)) (5): Conv2d(256, 32, kernel_size=(9, 9), stride=(2, 2)) (6): Conv2d(256, 32, kernel_size=(9, 9), stride=(2, 2)) (7): Conv2d(256, 32, kernel_size=(9, 9), stride=(2, 2)) ) ) (digits): CapsuleLayer ( ) (decoder): Decoder ( (fc1): Linear (160 -> 512) (fc2): Linear (512 -> 1024) (fc3): Linear (1024 -> 784) (relu): ReLU (inplace) (sigmoid): Sigmoid () ) ) Parameters and size: ------------------- conv1.conv0.weight: [256, 1, 9, 9] conv1.conv0.bias:  primary.conv_units.0.weight: [32, 256, 9, 9] primary.conv_units.0.bias:  primary.conv_units.1.weight: [32, 256, 9, 9] primary.conv_units.1.bias:  primary.conv_units.2.weight: [32, 256, 9, 9] primary.conv_units.2.bias:  primary.conv_units.3.weight: [32, 256, 9, 9] primary.conv_units.3.bias:  primary.conv_units.4.weight: [32, 256, 9, 9] primary.conv_units.4.bias:  primary.conv_units.5.weight: [32, 256, 9, 9] primary.conv_units.5.bias:  primary.conv_units.6.weight: [32, 256, 9, 9] primary.conv_units.6.bias:  primary.conv_units.7.weight: [32, 256, 9, 9] primary.conv_units.7.bias:  digits.weight: [1, 1152, 10, 16, 8] decoder.fc1.weight: [512, 160] decoder.fc1.bias:  decoder.fc2.weight: [1024, 512] decoder.fc2.bias:  decoder.fc3.weight: [784, 1024] decoder.fc3.bias:  Total number of parameters on (with reconstruction network): 8227088 (8 million)
We logged the training and test losses and accuracies using tensorboardX. TensorBoard helps us visualize how the machine learn over time. We can visualize statistics, such as how the objective function is changing or weights or accuracy varied during training.
TensorBoard operates by reading TensorFlow data (events files).
How to Use TensorBoard
- Download a copy of the events files for the latest run from my Google Drive.
- Uncompress the file and put it into
- Check to ensure you have installed tensorflow (CPU version). We need this for TensorBoard server and dashboard.
- Start TensorBoard.
$ tensorboard --logdir runs
- Open TensorBoard dashboard in your web browser using this URL: http://localhost:6006
In the spirit of experiment, I have tried using other datasets. I have updated the implementation so that it supports and works with CIFAR10. Need to note that I have not tested throughly our capsule model on CIFAR10.
Here's how we can train and test the model on CIFAR10 by running the following commands.
python main.py --dataset cifar10 --num-conv-in-channel 3 --input-width 32 --input-height 32 --primary-unit-size 2048 --epochs 80 --num-routing 1 --use-reconstruction-loss yes --regularization-scale 0.0005
Training Loss and Accuracy
The training losses and accuracies for CapsNet-v4 (80 epochs, 3 routing iteration, using reconstruction, regularization scale of 0.0005):
- Highest training accuracy: 100%
- Lowest training error: 0.3589%
Test Loss and Accuracy
The test losses and accuracies for CapsNet-v4 (80 epochs, 3 routing iteration, using reconstruction, regularization scale of 0.0005):
- Highest test accuracy: 71%
- Lowest test error: 0.5735%
- Publish results.
- More testing.
- Inference mode - command to test a pre-trained model.
- Jupyter Notebook version.
- Create a sample to show how we can apply CapsNet to real-world application.
- Experiment with CapsNet:
- Try using another dataset.
- Come out a more creative model structure.
- Pre-trained model and weights.
- Add visualization for training and evaluation metrics.
- Implement recontruction loss.
- Check algorithm for correctness.
- Update results from TensorBoard after making improvements and bug fixes.
- Publish updated pre-trained model weights.
- Log the original and reconstructed images using TensorBoard.
- Update results with reconstructed image and original image.
- Resume training by loading model checkpoint.
- Migrate existing code to work in PyTorch 0.4.0.
WIP is an acronym for Work-In-Progress
Referenced these implementations mainly for sanity check:
Here's some resources that we think will be helpful if you want to learn more about Capsule Networks:
- Articles and blog posts:
- The first author of the paper, Sara Sabour has released the code.
Real-world Application of CapsNet
The following is a few samples in the wild that show how we can apply CapsNet to real-world use cases.