Skip to content

MNIST benchmarks with Spatial Transformer Networks, Vision Transformers and SpinalNets, with model modifications including CoordConv layers.

Notifications You must be signed in to change notification settings


Repository files navigation


This repository implement models described in recent computer vision literature, with a focus on a simple classification task with a classical dataset (MNIST). Three base models are explored: Spatial transformer networks, vision transformers, and SpinalNets. We also implement new variations for two of three of these models, by replacing standard convolutional layers by CoordConv layers.

  • Spatial transformer networks (STN)
  • Spatial transformer networks + CoordConv layers
  • Vision transformers
  • SpinalNet
  • SpinalNet + STN + CoordConv layers

A complete run of the experiments together with results, comments and references are available in MNIST_benchmarks.ipynb. These can also be reproduced in the following Colab notebook:

Open In Colab

A standalone script is also provided to reproduce the experiments. A few dependencies are necessary and listed in requirements.txt.

usage: [-h] [--device {gpu,cpu}] [--workers WORKERS]
                         [--bs BS] [--maxepochs MAX_EPOCHS]
                         [--patience PATIENCE] [--mindelta MIN_DELTA]
                         [--model {stn,stncoordconv,vit,spinal,spinalstn}]
                         [--localization] [--lr LR] [--logs LOGPATH]


optional arguments:
  -h, --help            show this help message and exit
  --device {gpu,cpu}    Device on which to run the experiments. (default: cpu)
  --workers WORKERS     Number of workers for dataloaders. (default: 2)
  --bs BS               Batch size. (default: 64)
  --maxepochs MAX_EPOCHS
                        Maximum number of epochs to run the experiment for.
                        (default: 20)
  --patience PATIENCE   Number of epochs with no improvement before triggering
                        early stopping. (default: 5)
  --mindelta MIN_DELTA  Required improvement in the validation loss for early
                        stopping. (default: 0.005)
  --model {stn,stncoordconv,vit,spinal,spinalstn}
                        Type of model to train. (default: stn)
  --localization        Whether to use CoordConv in the localization network.
                        (default: False)
  --lr LR               Learning rate for SGD. (default: 0.01)
  --logs LOGPATH        Directory to store tensorboard logs. (default: logs/)

Tensorboard is used to save the training and validation logs and metrics. By default, the logs are saved in logs/. To launch tensorboard, use the following line. More details on tensorboard are found here:

tensorboard --logdir=logs/ --port <port> --host <host>


MNIST benchmarks with Spatial Transformer Networks, Vision Transformers and SpinalNets, with model modifications including CoordConv layers.







No releases published


No packages published