Skip to content

mmz33/CapsNet

Repository files navigation

Capsule Network

This repository represents a TensorFlow implementation of capsule network (CapsNet). For more details about CapsNet, you can check my blog post. you can also find here a detailed report about this topic.

Files

  • dataset.py: loads the MNIST dataset using keras API
  • capsnet.py: represents CapsNet architecture and contains function to build it
  • capsule_layer.py: represents CapsNet layers which are mainly PrimaryCaps and DigitCaps layers
  • engine.py: it extracts parameters from the config, set up training and testing configuration, and implements them
  • config.py: represents a dict of parameters with a getter function
  • main.py: the main entry point
  • utils.py: contians some helping functions
  • run_kaggle.py: a script to run digit recognizer competition from kaggle

Training

For training, run python3 main.py --train. In config.py, you can specify your hyperparameters. checkpoint_path is the location where models/checkpoints are saved. log is the location where TensorFlow summaries are saved to be used later in Tensorboard for example.

Testing

For testing, you just need to run python3 main.py --test. This will load the model corresponding to the latest saved checkpoint.

Kaggle Digit Recognizer

In addition, the code was tested on the test data provided in digit recognizer competition from Kaggle, which is also MNIST data. The score achieved was: 0.99500 which is 99.5% accuracy.

Tensorboard

Train

Screenshot 2020-02-21 at 22 51 21

Reconstructed Images

During training

train_3 train_9 train_0 train_8

During validation

valid_1 valid_2 valid_8 valid_9

TF Graph

capsnet-tf-graph