Skip to content

notchia/jax-for-hamiltonian

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

5 Commits
 
 
 
 

Repository files navigation

jax-for-hamiltonian

This notebook contains a brief introduction to JAX and a demonstration of how it can be used to define and solve the equations of motion of a simple mass-spring system. I wrote this up in the process of learning to create wave propagation simulations from scratch.

Relevant features of JAX, as described in the JAX repository:

  • "automatically differentiate native Python and NumPy code" with the grad function (I used this to automatically generate the equations of motion without having to compute them by hand)
  • "automatic vectorization" with vmap function (I used this to simplify the function definitions and overall code structure)

I also looked into using the jit ("just-in-time") decorator to speed up function calls, but jit is not compatible with control flow operations like if-else statements, which I wanted to use for the boundary conditions.

View the Jupyter notebook on nbviewer

About

Brief demo using JAX to solve equations of motion

Resources

Stars

Watchers

Forks