This repository represents a TensorFlow implementation of capsule network (CapsNet). For more details about CapsNet, you can check my blog post. you can also find here a detailed report about this topic.
dataset.py
: loads the MNIST dataset using keras APIcapsnet.py
: represents CapsNet architecture and contains function to build itcapsule_layer.py
: represents CapsNet layers which are mainly PrimaryCaps and DigitCaps layersengine.py
: it extracts parameters from the config, set up training and testing configuration, and implements themconfig.py
: represents a dict of parameters with a getter functionmain.py
: the main entry pointutils.py
: contians some helping functionsrun_kaggle.py
: a script to run digit recognizer competition from kaggle
For training, run python3 main.py --train
. In config.py
, you can specify your hyperparameters. checkpoint_path
is the location where models/checkpoints are saved. log
is the location where TensorFlow summaries are saved to be used later in Tensorboard for example.
For testing, you just need to run python3 main.py --test
. This will load the model corresponding to the latest saved checkpoint.
In addition, the code was tested on the test data provided in digit recognizer competition from Kaggle, which is also MNIST data. The score achieved was: 0.99500
which is 99.5%
accuracy.