Practicing jax through implementing Iterative Amortized Inference by Marino et al., ICML 2018.
It seems like feeding the input image into refinemenet network helps an enormous amount. It was difficult for us to get it to work without conditioning on the input image.
To run: python main.py
.
To visualize: python plot_gif.py