| title | JAX Implementation of Hindsight Experience Replay (HER) | ||||
|---|---|---|---|---|---|
| categories |
|
||||
| tags |
|
||||
| github | jeertmans/HER-with-JAX | ||||
| website | https://colab.research.google.com/github/jeertmans/HER-with-JAX/blob/main/bit_flipping.ipynb | ||||
| image |
|
||||
| permalink | /posts/her-with-jax/ | ||||
| description | Implementation of the Hindsight Experience Replay (HER) method in JAX. |
I recently discovered the Hindsight Experience Replay (HER) paper and noticed that the official implementation is based on PyTorch and is not very well-structured. I also couldn't find a non-PyTorch implementation. Since I primarily work with JAX, I decided to reimplement the classic bit-flipping experiment to better understand HER.
This implementation uses Equinox for model definitions and Optax for optimization. The repository provides:
- A minimal and clean implementation of HER in JAX;
- Reproducible scripts and results;
- A Colab Notebook for direct experimentation.
Don't hesitate to check the code: https://github.com/jeertmans/HER-with-JAX.
Let me know if you have any questions, feedback, or recommendations!