Tutorial: Mixture Density Networks with JAX
Tutorial Notebook: mixture_density_networks_jax.ipynb
Reference paper: Mixture Density Networks (Bishop, 1994)
This tutorial is based on the recent PyTorch notebook with many improvements added by kylemcdonald.
Note: This notebook describes a slightly different loss formulation compared to the previous tutorials that is much more numerically stable, and is used in most of my other recent projects that needed MDNs.
JAX is a minimal framework to automatically calculate the gradients of native Python and NumPy / SciPy functions. It is a nice tool in the machine learning research toolbox.