A PyTorch implementation of Parameter-sharing Capsule Network based on the paper Evaluating Generalization Ability of Convolutional Neural Networks and Capsule Networks for Image Classification via Top-2 Classification.
- Anaconda
- PyTorch
conda install pytorch torchvision -c pytorch
- PyTorchNet
pip install git+https://github.com/pytorch/tnt.git@master
- capsule-layer
pip install git+https://github.com/leftthomas/CapsuleLayer.git@master
- opencv
conda install opencv
We have uploaded the datasets into BaiduYun(access code:tr4u). You needn't download the datasets by yourself, the code will download them automatically.
python -m visdom.server -logging_level WARNING & python main.py --data_type CIFAR10 --net_mode CNN --num_epochs 300
optional arguments:
--data_type dataset type [default value is 'MNIST'](choices:['MNIST', 'FashionMNIST', 'SVHN', 'CIFAR10', 'STL10'])
--net_mode network mode [default value is 'Capsule'](choices:['Capsule', 'CNN'])
--capsule_type capsule network type [default value is 'ps'](choices:['ps', 'fc'])
--routing_type routing type [default value is 'k_means'](choices:['k_means', 'dynamic'])
--num_iterations routing iterations number [default value is 3]
--batch_size train batch size [default value is 64]
--num_epochs train epochs number [default value is 100]
--use_da use data augmentation or not [default value is False]
Visdom now can be accessed by going to 127.0.0.1:8097/env/$data_type
in your browser,
$data_type
means the dataset type which you are training. If you want to interrupt
this process, just type ps aux | grep visdom
to find the PID
, then kill PID
.
python vis.py --data_type CIFAR10
optional arguments:
--data_type dataset type [default value is 'STL10'](choices:['MNIST', 'FashionMNIST', 'SVHN', 'CIFAR10', 'STL10'])
--data_mode visualized data mode [default value is 'test_single'](choices:['test_single', 'test_multi'])
--num_iterations routing iterations number [default value is 3]
Generated results are on the same directory with README.md
.
python utils.py
Generated figures are on the same directory with README.md
.
The train loss、accuracy, test loss、accuracy and confusion matrix are showed on visdom. The pretrained models and more results can be found in BaiduYun(access code:kvmj).
- PS-Capsule
- FC-Capsule
- CNN
- PS-Capsule
- FC-Capsule
- CNN
- PS-Capsule
- FC-Capsule
- CNN
- PS-Capsule
- FC-Capsule
- CNN
- PS-Capsule
- FC-Capsule
- CNN