Skip to content

Upgrade to JAX 0.7.0 #31

@leakec

Description

@leakec

JAX 0.7.0 requires hashable types for the keyword args to H. This means we can no longer use numpy arrays. These could be refactored to be tuples.

To do this, we need to make modifications to Python and C++. I'm curious if this is the time where we should bite off switching from swig to pybind11. pybind11 has better typing support, and since we are only binding from C++ to python (not C++ to other languages) it would make sense.

To Dos:

  • Cap version at 0.6 for JAX for current release so something at least works when you install via PyPI.
  • Switch from swig to pybind11.
  • Change to using tuples rather than numpy arrays for H.
  • Make sure docs build. Can update these to build TFC too, since cmake should be supported.
  • Check that the full test suite (not just the normal one) passes after this change.
  • Release.

Metadata

Metadata

Assignees

Labels

bugSomething isn't working

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions