Skip to content
/ JAX-RL Public

JAX implementations of various deep reinforcement learning algorithms.

License

Notifications You must be signed in to change notification settings

hamishs/JAX-RL

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

59 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

JAX-RL

JAX implementations of various deep reinforcement learning algorithms.

Main libraries used:

  • JAX - main framework
  • Haiku - neural networks
  • Optax - gradient based optimisation

Algorithms implemented

Algorithms Paper
Proximal Policy Optimization (PPO) https://arxiv.org/abs/1707.06347
Deep Q-Network (DQN) https://www.cs.toronto.edu/~vmnih/docs/dqn.pdf
Double Deep Q-Network (DDQN) https://arxiv.org/abs/1509.06461
Deep Recurrent Q-Network (DRQN) https://arxiv.org/abs/1507.06527
Deep Deterministic Policy Gradient (DDPG) https://arxiv.org/abs/1509.02971

Tabular algorithms

  • Q-learning
  • Double Q-learning
  • SARSA
  • Expected SARSA

Installation

$ pip install git+https://github.com/hamishs/JAX-RL

About

JAX implementations of various deep reinforcement learning algorithms.

Topics

Resources

License

Stars

Watchers

Forks

Packages

No packages published

Languages