Replace deprecated jax.tree_*
functions with jax.tree.*
#308
test_633773781% was force-pushed and no longer has any new commits.
Pushing new commits will allow the pull request to be re-opened.