Skip to content


Switch branches/tags

Name already in use

A tag already exists with the provided branch name. Many Git commands accept both tag and branch names, so creating this branch may cause unexpected behavior. Are you sure you want to create this branch?

Latest commit


Git stats


Failed to load latest commit information.
Latest commit message
Commit time

Improving Inference for Neural Image Compression

Example SGA optimization landscape

This repository contains implementation of various methods considered in Improving Inference for Neural Image Compression, accepted at NeurIPS 2020:

  title={Improving Inference for Neural Image Compression},
  author={Yang, Yibo and Bamler, Robert and Mandt, Stephan},
  journal={arXiv preprint arXiv:2006.04240},


We propose various methods to improve the compression performance of a popular and competitive neural image compression baseline model (mean-scale hyperprior model proposed by Minnen et al., 2018), at inference/compression time, based on ideas related to iterative variational inference, stochastic discrete optimization, and bits-back coding, aiming to close the approximation gaps that lie between current neural compression methods and rate-distortion optimality.

The scripts (non-bits-back version) and (bits-back version) train the baseline models, with encoder networks learned through amortized inference. Given trained models, the following scripts run various iterative inference methods considered in the paper and evaluate the resulting compression performance:,,,, and (requiring a model pre-trained with, and,, and (requiring a model pre-trained with


To install requirements:

pip install -r requirements.txt

The important dependencies are python 3.6 (tested on 3.6.9), tensorflow-compression 1.3 (which requires tensorflow 1.15), and tensorflow-probability 0.7.0 for the Gumbel-softmax trick (used in Stochastic Gumbel Annealing).


The following command can be used to train models in the paper:

python <train_script> --checkpoint_dir <checkpoint_dir> --num_filters <num_filters> train --train_glob=<train_glob> --batchsize 8 --patchsize 256 --save_summary_secs 600 --lambda <lambda> --last_step <last_step>

  • <train_script> is for the Base Hyperprior model in paper, and for the version modified for lossy bits-back coding;
  • <checkpoint_dir> is the overall folder of model checkpoints (this is /.checkpoints by default);
  • <num_filters> is the number of (de)convolutional filters; in the paper we set this to 192 for most of our models following Minnen et al., 2018, except we found it necessary to increase this to 256 to match the published performance of mean-scale model at higher rate (when lambda=0.04 and 0.08), following Ballé et al., 2018;
  • <train_glob> is a string of glob pattern like "imgs/*.png" or "imgs/*.npy" (we support float32 .npy format to reduce CPU load when training); in our experiments we used CLIC-2018 images (specifically, we combined all the images from professional_train, professional_valid, professional_test, mobile_valid, and mobile_test, with no pre-processing);
  • <lambda> is the penalty coefficient in front of the reconstruction loss (we trained with MSE loss in all experiments) and controls the rate-distortion tradeoff; see below section on pre-trained models;
  • <last_step> is the total number of training steps; we typically used 2 million steps to reproduce the mean-scale (base hyperprior) model results from Minnen et al., 2018;
  • batchsize and patchsize are set following Ballé et al., 2018; for miscellaneous other options see


Given a pretrained model and some image input, the following command runs some form of (improved) inference method for compression and evaluates the reconstruction results (BPP, PSNR, MS-SSIM, etc.):

python <script> --num_filters <num_filters> --verbose --checkpoint_dir <checkpoint_dir> compress <run_name> <eval_imgs> --results_dir <results_dir>

where <script> is one of the compression/evaluation scripts listed below, <num_filters> (e.g., 192) and <run_name> (e.g., mbt2018-num_filters=192-lmbda=0.001) come from the pre-trained model (whose checkpoint folder should belong to <checkpoint_dir>), and <eval_image> can be either a single input image, or a numpy array of a batch of images with shape (num_imgs, H, W, 3) and type uint8.

Below we list the script used for all inference methods evaluated in the paper:

script method entry in Table 1 of paper Base Hyperprior M3 SGA M1 SGA + BB M2 MAP A1 STE A2 Uniform Noise A3 Deterministic Annealing A4 BB without SGA A5 BB without any iterative inference A6

Rate-distortion results on Kodak and Tecnick (averaged over all images for each lambda setting) can be found in the results folder.

Pre-trained Models

Our pre-trained models can be found here. Download and untar them into <checkpoint_dir>, with each sub-folder corresponding to a model <runname>.

The lmbda=0.001 models were trained for 1 million steps, lmbda=0.08 models were trained for 3 million steps, and all the other models were trained for 2 million steps.


Official code repo for NeurIPS 2020 paper "Improving Inference for Neural Image Compression"







No releases published


No packages published