-
Notifications
You must be signed in to change notification settings - Fork 109
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Register TensorNetwork and subclasses as JAX-compatible types #148
Comments
If you change unpack code to In general happy to add the pytree register, the only downside is that one probably has to try importing |
I've been researching a lil and apparently JAX only allows mapping of JAX arrays, so we cannot But
This could be solved using import hooks injected into |
I do have a
I see, I'm not familiar with this function/submodule! |
So I made Also, I need to fix some lines in Here is an example: import jax
import quimb.tensor as qtn
L = 10
batch_size = 32
psi = qtn.MPS_rand_state(10, 4, normalize=True)
phis = [qtn.MPS_rand_state(10, 4, normalize=True) for _ in range(batch_size)]
phis_arrays = [jax.numpy.asarray([state.tensors[i].data for state in states]) for i in range(L)]
def overlap(*arrays):
phi = qtn.MatrixProductState(arrays)
return phi.H @ psi
jax.vmap(vect_overlap)(*phis_arrays)
jax.vmap(jax.grad(overlap, argnums=list(range(L))))(*phis_arrays) In this sense, the solution would be similar to the I have tried the same but with vectorized TNs but I'm running into problems. Specifically, the vectorizer is returning me the same vector for different TNs. Here is the code I've tried. Any idea what I'm doing wrong? batch_size = 32
psi = qtn.MPS_rand_state(10, 4, normalize=True)
phis = [qtn.MPS_rand_state(10, 4, normalize=True) for _ in range(batch_size)]
vectorizer = qtn.optimize.Vectorizer(phis[0].arrays)
vect_phis = jax.numpy.asarray([vectorizer.pack(phi.arrays) for phi in phis])
def vect_overlap(vect_arrays):
arrays = vectorizer.unpack(vect_arrays)
phi = qtn.MatrixProductState(arrays)
return phi.H @ psi
jax.vmap(vect_overlap)(vect_phis)
Yeah, it's one of this hacky, hidden Python modules. Here is an example of how to use import hooks: https://stackoverflow.com/a/54456931 |
Nice, yes exposing a 'raw array' function interface might be generally useful - that's basically what the tn compiler decorator does too. Regarding A class that simply goes from arrays to flattened vector form should actually simpler than this, and you'd ignore all the dtype stuff and simply use |
Is your feature request related to a problem?
I've been running into problems when trying to make JAX and Quimb cooperate. JAX only supports its own native formats (e.g. jax.numpy.array). Quimb surpasses this obstacle by manually:
jax.numpy.array
sjax
these steps are performed by the
JaxHandler
class which is used by theTNOptimizer
class, and it seems to work well for the purposes of theTNOptimizer
but I'm running into problems when I want to do some more stuff.mainly, if you directly pass a TN to a function transformed by JAX (i.e.
jax.grad
,jax.jit
,jax.vmap
) it crashes completely because JAX does not recognize any TN class as compatible.the solution is to register
TensorNetwork
and subclasses as JAX-compatible. fortunately, we can do it by using thejax.tree_util.register_pytree_node
method.Describe the solution you'd like
This code is working for me when calling
jax.jit
andjax.grad
, but not when callingjax.vmap
.I'm thinking on how to generalize this for
TensorNetwork
and subclasses, becauseregister_pytree_node
does not ascend through the class hierarchy. one solution is to callregister_pytree_node
insideTensorNetwork.__init_subclass__
such that it is called every time we inherit from it.Describe alternatives you've considered
No response
Additional context
These are the examples I have tested against:
The text was updated successfully, but these errors were encountered: