Skip to content

Official repo for the ICLR22 paper "Towards Empirical Sandwich Bounds on the Rate-Distortion Function"

Notifications You must be signed in to change notification settings

mandt-lab/RD-sandwich

Repository files navigation

About

This repo contains the source code and data in the ICLR 2022 paper, Towards Empirical Sandwich Bounds on the Rate-Distortion Function. For an introduction to this topic, see this blog post: part 1, part 2.

In this work, we make a first attempt at a numerical method for estimating sandwich bounds on the rate-distortion (R-D) function of a general data source using i.i.d. samples. Unlike the classic Blahut-Arimoto (BA) algorithm, which only works when the source and reproduction alphabets are finite and the source probabilities known, here we consider general (e.g., continuous) alphabets, without knowledge of the source distribution other than access to its samples, a common setting in applications like image compression.

We estimate an upper bound on the R-D function using an SGD version of the BA algorithm that is scalable to high-resolution image datasets, based on a close theoretical connection to beta-VAEs and neural compression autoencoders. To verify the tightness of our upper bounds, we also develop a lower bound algorithm based on Csiszár's dual characterization of the R-D function. Due to the associated computational challenges, we restrict to a squared error distortion and a continuous reproduction alphabet, and obtain non-trivial lower bound estimates on data with relatively low intrinsic dimensions, ranging from particle physics measurements to realistic images generated by a BigGAN. Our estimated R-D upper bound on natural images indicates that there is theoretical room for improving state-of-the-art image compression methods by at least 1 dB in PSNR, at various bitrates.

If you find this work helpful, please consider citing the paper

@inproceedings{yang2022towards,
  title={Towards Empirical Sandwich Bounds on the Rate-Distortion Function},
  author={Yang, Yibo and Mandt, Stephan},
  booktitle={International Conference on Learning Representations},
  year={2022}
}

Overview

To get started, you may want to check out the demo notebook that estimates the proposed R-D sandwich bounds on a 2D Gaussian source, and compares with the BA algorithm and the analytical R-D function.

The main scripts implementing the various algorithms are:

  1. ba.py: this contains a vectorized implementation of the Blahut-Arimoto algorithm. It takes in samples from a continuous source, discretizes/estimates the source probabilities by a histogram, and then runs the BA algorithm on the discretized source.
  2. rdub_mlp.py: this implements the proposed upper bound algorithm, by training a beta-VAE with an MLP encoder and (optionally) decoder. It also implements the non-linear transform coding method from Ballé et al., 2021, by placing shape restrictions on the variational posterior and prior distributions to simulate quantization.
  3. rdlb.py: this implements the proposed lower bound algorithm, by training a neural-network function (referred to as "log u" in the paper) to maximize an unconstrained formulation of Csiszár's dual formulation of R-D.
  4. resnet_vae.py: this implements the proposed upper bound algorithm specialized to images, using a hierarchical convolutional VAE inspired by the ResNet-VAE.
  5. mbt2018.py: this implements the mean-scale hyperprior neural compression method from Minnen et al., 2018, based on the tfc implementation of the scale-only hyperprior model.
  6. msms2020.py: this implements the channelwise autoregressive neural compression method from Minnen et al., 2020, adapted from the tfc implementation.

During training, all these scripts will write to disk a log file in .jsonl format, containing the train and val/test objectives evaluated after each training epoch.

Software

The code was developed in a python 3.7 environment on Linux. The main dependices are tensorflow 2.5 and tensorflow-compression 2.2, which can be installed by pip install tensorflow-compression==2.2

I also use the syntax of the GNU parallel tool to conveniently capture experiment commands that are repeated with different hyper-parameters (most commonly, experiments are rerun with different lambda values to sweep out an R-D bound), and are embarrassingly parallelizable. The parallel tool is not required; I simply use its syntax, e.g., {1} {2} ::: a b ::: c d e, to specify all the possible tuples in the Cartesian product a c, a d, a e, b c, b d, and b e.

In fact, depending on the compute resources available, it may not make sense to directly run the parallel commands as written, or you may want to run them sequentially one job at a time with an extra -j 1 flag (e.g., parallel does not take care of allocating jobs to different GPUs, so running more than one job on a GPU may result in OOM crashes).

Data preparation

  • Gaussian:

    The parameters of the randomly generated n=1000 dimensional Gaussian source is already saved in ./data for your convenience. This was generated with the command python gen_gaussian_params.py --save_dir ./data --dim 1000. You can change the --dim argument (this is n in the paper) to experiment with other dimensions.

  • Speech:

    1. Clone the FSDD repo, to say /path/to/free-spoken-digit-dataset
    2. cd into data/speech, update the data_dir variable in create_data.py to point to /path/to/free-spoken-digit-dataset, then run create_data.py.
  • High-resolution images for training the hierarchical VAE:

    1. Download the COCO training data:

      wget http://images.cocodataset.org/zips/train2017.zip

    2. Unzip the images to, say /tmp/train2017, then run

      python prepare_imgs.py '/tmp/train2017/*.jpg' data/my_coco_train2017/

      to perform preprocessing. This selects images larger than 512x512 and randomly downsamples them to remove potential compression artifacts.

  • Kodak test set: Download them from http://r0k.us/graphics/kodak, and put them under ./data/kodak.

  • Tecnick test set: Download from https://sourceforge.net/projects/testimages/files/OLD/OLD_SAMPLING/testimages.zip, and unzip to data/Tecnick_TESTIMAGES/.

All other data is self-contained.

Running the experiments

1. Gaussian

UB experiments

  • The upper bound beta-VAE models with varying latent dimensions can be trained by running

    n=1000; parallel python rdub_mlp.py -V --dataset gaussian --gparams_path data/gaussian_params-dim=$n.npz --data_dim $n --latent_dim {1} --checkpoint_dir checkpoints/gaussian --prior_type gmm_1 --posterior_type gaussian --encoder_activation none --decoder_activation leaky_relu --decoder_units $n --lambda {2} --rpd train --epochs 80 --steps_per_epoch 1000 --lr 5e-4 --batchsize 64 --max_validation_steps 1 ::: 400 500 600 800 ::: 0.3 1 3 10 30 100 300

  • The upper bound beta-VAE models without a decoder (latent space equals reproduction space) can be trained by

    n=1000; parallel python rdub_mlp.py -V --dataset gaussian --gparams_path data/gaussian_params-dim=$n.npz --data_dim $n --latent_dim $n --checkpoint_dir checkpoints/gaussian --prior_type gmm_1 --posterior_type gaussian --encoder_activation none --decoder_activation none --decoder_units 0 --lambda {1} --rpd train --epochs 80 --steps_per_epoch 1000 --lr 5e-4 --batchsize 64 --max_validation_steps 1 ::: 0.3 1 3 10 30 100 300

Since we have access to the source and compute the training loss on true i.i.d. samples generated on the fly, there's no need for separate validation/test data (--max_validation_steps is set to a small dummy value of 1 here), and the training objectives (loss/rate/mse in the jsonl log) are unbiased estimates of the true Lagrangian/rate/distortion.

LB experiments

The main training command is

nn_size_to_n_ratio=100;
nn_size=$(( n * nn_size_to_n_ratio ));
parallel python rdlb.py  --dataset gaussian --data_dim $n --checkpoint_dir checkpoints/gaussian --model mlp --units $nn_size,$nn_size,$nn_size --lamb {1} train --num_Ck_samples 2 --batchsize 1024 --last_step 3000 --checkpoint_interval 3000 --y_init quick --y_quick_topn 10 --lr 5e-4 --lr_schedule ::: $lambs_to_use

which trains a MLP log u model with 3 hidden layers, with the number of units set based on the source dimension n.

The bash variables $n and $lambs_to_use for each run are set as follows:

######################
n=2
lambs_to_use="1.0 3.0 10.0 30.0 100.0"

######################
n=4
lambs_to_use="2.0 6.0 20.0 60.0 200.0"

######################
n=8
lambs_to_use="4.0 12.0 40.0 120.0 400.0"

######################
n=16
lambs_to_use="8.0 24.0 80.0 240.0 800.0"

The evaluation command is

parallel python rdlb.py --dataset gaussian --data_dim $n --checkpoint_dir checkpoints/gaussian --model mlp --units $nn_size,$nn_size,$nn_size --lamb {1} eval --batchsize 1024 --y_init exhaustive --num_Ck_samples 30 ::: $lambs_to_use

where the bash variables are set the same way as before. The evaluation code loads the trained log u model, and draws Ck samples to form the LB estimator xi as defined in Appendix A.6, saving the resulting Ck and xi samples in a .npz file. The expected value of xi then gives a natural sample-mean estimator of the LB intercept (denoted R_ in code), and samples of xi can also be used to estimate confidence intervals.

2. Particle physics

UB experiments

  • Upper bound beta-VAE:

    parallel -j1 python rdub_mlp.py -V --checkpoint_dir checkpoints/physics --dataset data/physics/ppzee-split=train.npy --data_dim 16 --latent_dim 16 --prior_type maf --maf_stacks 3 --posterior_type gaussian --encoder_units 500,500 --decoder_units 500,500 --encoder_activation softplus --decoder_activation softplus --ar_hidden_units 50,50 --ar_activation softplus --nats --lambda {} train --epochs 100 --steps_per_epoch 1000 --lr 5e-4 --batchsize 512 --max_validation_steps 20 ::: 100 300 1000 3000 10000 30000

  • Nonlinear transform coding (Ballé et al., 2021):

    parallel -j1 python rdub_mlp.py -V --checkpoint_dir checkpoints/physics --dataset data/physics/ppzee-split=train.npy --data_dim 16 --latent_dim 16 --prior_type deep --posterior_type uniform --encoder_units 500,500 --decoder_units 500,500 --encoder_activation softplus --decoder_activation softplus --nats --lambda {} train --epochs 100 --steps_per_epoch 1000 --lr 5e-4 --batchsize 512 --max_validation_steps 20 ::: 100 300 1000 3000 10000 30000

  • On the 2D marginal, the experiment commands are identical to the above, except

    --dataset data/physics/ppzee-split=train.npy --data_dim 16 --latent_dim 16

    is replaced with

    --dataset data/physics/n=2-ppzee-split=train.npy --data_dim 2 --latent_dim 2

  • The code will automatically look for a test dataset in the same directory as the training .npy file, so the test performance is also continually evaluated (on a subset) during training (in the val_loss, val_rate, and val_mse fields of the jsonl log).

  • The BA algorithm:

    parallel python ba.py -V --samples data/physics/n=2-ppzee-split=test.npy --save_dir checkpoints/physics --bins 60 --steps 2000 --tol 1e-5 --lamb {} ::: 100 300 1000 3000 10000 30000

LB experiments

  • Train:

    parallel python rdlb.py -V --dataset data/physics/ppzee-split=train.npy --data_dim 16 --checkpoint_dir checkpoints/physics --model mlp --units 200,200,200 --lamb {} train --batchsize 2048 --num_Ck_samples 2 --last_step 4000 --checkpoint_interval 2000 --y_init quick --y_quick_topn 3 --lr 5e-4 --lr_schedule ::: 100 300 1000 3000 10000 30000

  • Eval (here 4*5=20 samples of Ck are drawn in total, parallelized over four different processes with different seeds for efficiency):

    parallel python rdlb.py -V --seed {1} --dataset data/physics/ppzee-split=test.npy --data_dim 16 --model mlp --units 200,200,200 eval --batchsize 2048 --num_Ck_samples 5 --y_init exhaustive --ckpt {2} ::: 0 1 2 3 ::: checkpoints/physics/rdlb-data_dim=16-model=mlp-units=200_200_200-lamb=*-batchsize=2048

    Here the --lamb arg is not set; instead the code will use the lambda value from the name of the checkpoint that is being loaded. The results from different lambdas and different seeds can be combined to produce the final xi samples, using the convenience method utils.aggregate_lb_results; e.g.,

    import glob
    res_files = glob.glob('checkpoints/physics/rdlb-data_dim=16-model=mlp-units=200_200_200-lamb=*/rd-*.npz') 
    results = aggregate_lb_results(res_files)  # will be a dict of (lamb, xi_samples)
    
  • The commands on the 2D marginal source are identical to the above, except ppzee-split={train/test}.npy needs to be replaced with n=2-ppzee-split={train/test}.npy, and --data_dim 16 with --data_dim 2.

3. Speech

The commands are almost identical to the previous ones on physics data.

UB experiments

  • UB beta-VAE:

    parallel -j1 python rdub_mlp.py -V --checkpoint_dir checkpoints/speech --dataset data/speech/stft-split=train.npy --data_dim 33 --latent_dim 33 --prior_type maf --maf_stacks 3 --posterior_type gaussian --encoder_units 500,500 --decoder_units 500,500 --encoder_activation softplus --decoder_activation softplus --ar_hidden_units 50,50 --ar_activation softplus --nats --lambda {} train --epochs 100 --steps_per_epoch 1000 --lr 5e-4 --batchsize 512 ::: 0.3 1 3 10 30 50 100

  • NTC:

    parallel -j1 python rdub_mlp.py -V --checkpoint_dir checkpoints/speech --dataset data/speech/stft-split=train.npy --data_dim 33 --latent_dim 33 --prior_type deep --posterior_type uniform --encoder_units 500,500 --decoder_units 500,500 --encoder_activation softplus --decoder_activation softplus --nats --lambda {} train --epochs 100 --steps_per_epoch 1000 --lr 5e-4 --batchsize 512 ::: 0.3 1 3 10 30 50 100

  • BA:

    parallel -j1 python ba.py -V --samples data/speech/frame_length=63/n=2-dig=all-speaker=theo-split=test-stft.npy --lamb {} --save_dir checkpoints/speech/BA --bins 50 --steps 2000 --tol 1e-5 ::: 0.1 0.3 1 3 10 30 100

LB experiments

  • Train:

    parallel python rdlb.py -V --dataset data/speech/stft-split=train.npy --data_dim 33 --checkpoint_dir checkpoints/speech --model mlp --units 200,200,200 --lamb {} train --batchsize 2048 --num_Ck_samples 2 --last_step 4000 --checkpoint_interval 2000 --y_init quick --y_quick_topn 3 --lr 5e-4 --lr_schedule ::: 0.3 1 3 10 30 50

  • Eval:

    parallel python rdlb.py -V --seed {1} --dataset data/speech/stft-split=test.npy --data_dim 33 --model mlp --units 200,200,200 eval --batchsize 2048 --num_Ck_samples 5 --y_init exhaustive --ckpt {2} ::: 0 1 2 3 ::: checkpoints/speech/rdlb-data_dim=33-model=mlp-units=200_200_200-lamb=*-batchsize=2048

  • The commands on the 2D marginal source are identical to the above, except stft-split={train/test}.npy needs to be replaced with n=2-stft-split={train/test}.npy, and --data_dim 33 with --data_dim 2.

4. Banana-shaped source

UB experiments

  • UB beta-VAE:

    parallel python rdub_mlp.py -V --checkpoint_dir checkpoints/banana --dataset banana --data_dim 2 --latent_dim 2 --prior_type maf --maf_stacks 3 --posterior_type gaussian --encoder_units 100,100 --decoder_units 100,100 --encoder_activation softplus --decoder_activation softplus --ar_hidden_units 10,10 --ar_activation softplus --nats --lambda {} train --epochs 100 --steps_per_epoch 1000 --lr 5e-4 --batchsize 1024 ::: 0.1 0.3 1 3 10 30

  • NTC (this reproduces Ballé et al.'s experiment):

    parallel python rdub_mlp.py -V --checkpoint_dir checkpoints/banana --dataset banana --data_dim 2 --latent_dim 2 --prior_type deep --posterior_type uniform --encoder_units 100,100 --decoder_units 100,100 --encoder_activation softplus --decoder_activation softplus --nats --lambda {} train --epochs 100 --steps_per_epoch 1000 --lr 5e-4 --batchsize 1024 ::: 0.1 0.3 1 3 10 30

  • BA:
    parallel -j1 python ba.py -V --samples data/banana-dim=2-samples=100000.npy --save_dir checkpoints/banana/BA --bins 80 --steps 2000 --tol 1e-5 --lamb {} ::: 0.1 0.3 1.0 3.0 10.0 30.0 100.0

  • UB beta-VAE on higher-dimension embeddings of the source:

    parallel python rdub_mlp.py -V --checkpoint_dir checkpoints/banana --dataset banana --data_dim $n --latent_dim $n --prior_type maf --maf_stacks 3 --posterior_type gaussian --encoder_units $nn_size,$nn_size --decoder_units $nn_size,$nn_size --encoder_activation softplus --decoder_activation softplus --ar_hidden_units 20,20 --ar_activation softplus --nats --lambda {} train --epochs 100 --steps_per_epoch 1000 --lr 5e-4 --batchsize 1024 ::: 0.1 0.3 1 3 10 30 100 300 1000

The bash variables $n and $nn_size are set as follows:

######################
n=4
nn_size=400

######################
n=16
nn_size=1600

######################
n=100
nn_size=2000

######################
n=500
nn_size=2000

LB experiments

  • Train:

    parallel python rdlb.py --dataset banana --data_dim $n --checkpoint_dir checkpoints/banana --model mlp --units $nn_size,$nn_size,$nn_size --lamb {1} train --batchsize 1024 --num_Ck_samples 2 --last_step 4000 --checkpoint_interval 4000 --y_init quick --y_quick_topn 20 --lr 1e-3 --lr_schedule ::: 0.1 0.3 1 3 10 30 100 300

  • Eval:

    parallel python rdlb.py --seed {1} --dataset banana --data_dim $n --checkpoint_dir checkpoints/banana --model mlp --units $nn_size,$nn_size,$nn_size --lamb {2} eval --batchsize 1024 --y_init exhaustive --num_Ck_samples 5 ::: 0 1 2 3 ::: 0.1 0.3 1 3 10 30 100 300

The bash variables $n and $nn_size are set as follows (somewhat arbitrarily, as I didn't do much architecture tuning):

######################
n=2
nn_size=200

######################
n=4
nn_size=400

######################
n=16
nn_size=1000

######################
n=100
nn_size=1000

######################
n=500
nn_size=1000

5. GAN-generated images

UB experiments

In the following commands, $d stands for the intrinsic dimensionality of the images, and is set to 2 or 4; $lamb is the coefficient in front of the distortion term in the Lagrangian, and takes value in 5e-4, 3e-4, 1e-4, 3e-5, 1e-5, 3e-6, 1e-6.

In the training commands, we set the amount of training based on how long it takes the model with the highest lambda to converge. The models with lower lambda values typically require less training.

In the evaluation commands, --batchsize sets the number of GAN image samples to evaluate on; we use --no_cast_xhat in order to treat the reconstruction as floating-point valued (rather than uint8) when computing the MSE, since we take the alphabets of the GAN source to be continuous (more precisely, [0, 1]^n). As the neural compression autoencoders assume input and output to be in [0, 255] (and similarly do resnet_vae.py and the evaluation code, for consistency), the resulting MSE needs to be divided by 255^2 to align with results in the paper.

  • Upper bound ResNet-VAE:

    Train:

    python resnet_vae.py -V --latent_channels 4,8,16,32,64,128 --num_filters 256 --lambda $lamb --flat_z0 --img_dim 128 --maf_units 32,16 --maf_stacks 3 --checkpoint_dir checkpoints/gan/dataset=basenji-data_dim=$d train --dataset basenji --data_dim $d --batchsize 8 --epochs 600 --steps_per_epoch 1000 --warmup 400 --patience 10

    Evaluation (for measurements of variability):

    python resnet_vae.py -V --latent_channels 4,8,16,32,64,128 --num_filters 256 --lambda $lamb --flat_z0 --img_dim 128 --maf_units 32,16 --maf_stacks 3 --checkpoint_dir checkpoints/gan/dataset=basenji-data_dim=$d eval --dataset basenji --data_dim $d --no_cast_xhat --batchsize 100 --results_dir results/gan/

  • Mean-scale hyperprior model (Minnen et al., 2018):

    Train (using uniform noise approximation to quantization):

    python mbt2018.py -V --checkpoint_dir checkpoints/gan/dataset=basenji-data_dim=$d --lambda $lamb train --dataset basenji --data_dim $d --batchsize 8 --epochs 600 --steps_per_epoch 1000 --warmup 400 --patience 10

    Evaluation (using actual quantization):

    python mbt2018.py -V --checkpoint_dir checkpoints/gan/dataset=basenji-data_dim=$d --lambda $lamb eval --dataset basenji --data_dim $d --no_cast_xhat --batchsize 100 --results_dir results/gan/

  • Channel-wise autoregressive model (Minnen et al., 2020):

    Identical to the above, except the script name mbt2018.py is replaced with ms2020.py.

LB experiments

  • Train:

    python rdlb.py -V --dataset basenji --data_dim $d --model cnn --units 4,8,16,1024 --kernel_dims 9,5,3 --lamb 4000 --checkpoint_dir checkpoints/gan/dataset=basenji-data_dim=$d --chunksize 16 train --batchsize 2048 --num_Ck_samples 2 --last_step 4000 --checkpoint_interval 1000 --y_init quick --y_quick_topn 3 --lr 1e-4 --anneal_lamb --target_lambs 50,100,200,300,500,1000,2000,3000,4000

Note that rdlb.py here trains the lower bound model with a gradually increasing lambda, and saves model checkpoints along the way when lambda reaches a value in --target_lambs. Thus, one training run suffices to sweep out an entire R-D lower bound curve, although the performance can likely still be improved by further training/fine-tuning the intermediate checkpoints saved along the way.

  • Eval:

    parallel -j1 python rdlb.py -V --seed {1} --dataset basenji --data_dim $d --model cnn --units 4,8,16,1024 --kernel_dims 9,5,3 --chunksize 16 eval --batchsize 2048 --num_Ck_samples 5 --y_init exhaustive --ckpt {2} ::: 0 1 2 3 4 5 6 7 8 9 ::: checkpoints/gan/basenji/rdlb-data_dim=$d-model=cnn-units=4_8_16_1024-lamb=4000.0-k=2048/step=*-kerasckpt.index

Again, the results can be collected using the convenience method utils.aggregate_lb_results.

6. Natural images

You can train the proposed hierarchical ResNet VAE with:

python resnet_vae.py -V --latent_channels 4,8,16,32,64,128 --ar_prior_levels 4 --ar_slices 8 --num_filters 256 --lambda $lamb --checkpoint_dir checkpoints/img_compression train --dataset cocotrain --patchsize 256 --batchsize 8 --epochs 600 --steps_per_epoch 10000 --warmup 400 --patience 20

and evaluate on Kodak and Tecnick with

parallel -j1 python resnet_vae.py -V --latent_channels 4,8,16,32,64,128 --ar_prior_levels 4 --ar_slices 8 --num_filters 256 --lambda $lamb --checkpoint_dir checkpoints/img_compression eval --dataset {1} --results_dir results/img_compression --no_sub_dir ::: 0.005 0.01 0.02 0.04 0.08 0.16 ::: kodak tecnick

Above $lamb takes value in 0.08, 0.04, 0.02, 0.01, 0.005, 0.0025, 0.001.

About

Official repo for the ICLR22 paper "Towards Empirical Sandwich Bounds on the Rate-Distortion Function"

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published