Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Operators written purely in Jax #968

Closed
wants to merge 8 commits into from

Conversation

inailuig
Copy link
Collaborator

@inailuig inailuig commented Nov 3, 2021

Currently the operators are written using numba for the very valid reason that the array shapes (number of connected elements) are in general data-dependent which is something not (yet?) supported by jax.
These operators can only be evaluated on the cpu, and are called either outside of a jax jit block or via numba4jax which has some overhead.

This PR adds a Ising operator with padding written in jax and a branchless, jax-jittable MetropolisSampler (I wanted some way to test the operator...) that runs fully on the gpu.

For Paulistrings the same can also be done easily (apart from the constructor).

This rule only works with operators which are written in jax.
"""

def transition(self, _0, _1, _2, _3, key, x):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what are these 4 dummy arguments?

This rule only works with operators which are written in jax.
"""

def transition(self, _0, _1, _2, _3, key, x):
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def transition(self, _0, _1, _2, _3, key, x):
def transition(self, sampler, machine, parameters, state, key, x):

PhilipVinc added a commit that referenced this pull request Jul 2, 2023
This is also a remanding of the original PR by @wdphy16 , following the
new infrastructure for Jax-compatible operators.

resuscitates #968

@wdphy16 as you are the original author of the code, could you please go
back through it and comment it throughout? It's very complicated and
it's very hard to follow what the generation of masks do...

---------

authored-by: Clemens Giuliani <clemens@inailuig.it>
inailuig added a commit that referenced this pull request Jul 4, 2023
This is what is left of #968 rebased on master.
---------

Co-authored-by: Filippo Vicentini <filippovicentini@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants