Skip to content

hardmaru/mdn_jax_tutorial

master
Switch branches/tags
Code

Latest commit

 

Git stats

Files

Permalink
Failed to load latest commit information.
Type
Name
Latest commit message
Commit time
 
 

Tutorial: Mixture Density Networks with JAX

April 2020

Tutorial Notebook: mixture_density_networks_jax.ipynb

Reference paper: Mixture Density Networks (Bishop, 1994)

Related posts:

This tutorial is based on the recent PyTorch notebook with many improvements added by kylemcdonald.

Note: This notebook describes a slightly different loss formulation compared to the previous tutorials that is much more numerically stable, and is used in most of my other recent projects that needed MDNs.

JAX is a minimal framework to automatically calculate the gradients of native Python and NumPy / SciPy functions. It is a nice tool in the machine learning research toolbox.

Recommended JAX Tutorials: Getting started with JAX and You don't know JAX.

License

MIT

About

Mixture Density Networks (Bishop, 1994) tutorial in JAX

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published