-
Notifications
You must be signed in to change notification settings - Fork 2.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implement np.mgrid, np.ogrid, & np.r_, np.c_ #5850
Comments
np.mgrid
missing in jax
I'd love to solve this. Anyone with a good direction suggestion for me to point to, to solve this issue? @jakevdp @Conchylicultor |
To add to the list, |
The way to implement these would be to create a private class, and override the import jax.numpy as jnp
class _Mgrid:
def __getitem__(self, key):
if isinstance(key, slice):
return jnp.arange(key.start or 0, key.stop, key.step or 1)
else:
raise NotImplementedError()
mgrid = _Mgrid()
print(mgrid[:5])
# [0 1 2 3 4] The challenge will be to make certain it handles all the input cases that |
You could also look at the numpy implementation to see how they are implementing it |
But note that if you copy from numpy's implementation rather than writing your own version, the code should go in https://github.com/google/jax/tree/master/jax/_src/third_party for licensing reasons. |
And in this case, where there's probably no need to look at NumPy's implementation to implement the API faithfully and well, it's much better to implement from scratch so we don't have to split up the code based on license. In general:
|
Is there an any update for this issue? If not, I would like to progress it :) |
I didn't find time for this @minoring |
Thanks for your response! @VikasHanumegowda |
Related to google#5850
Related to google#5850
Related to google#5850
Related to google#5850
Related to google#5850
Related to google#5850
Related to google#5850
Hi |
mgrid has been implemented in #6248, but I don't know of any other work on this |
Hi! |
Related to google#5850
Related to google#5850
Related to google#5850
Related to google#5850
Related to google#5850
Can I take up |
I don't know of anyone working on |
Looks like this issue can be closed since all function have been implemented 🚀 |
Thanks! |
Numpy has a very convenient
np.mgrid
which allow to easily create coordinate grids: https://numpy.org/doc/stable/reference/generated/numpy.mgrid.html.I was expecting jax to have the same, but it seems to be missing:
Tracking implementation of these:
jnp.mgrid
Implement jnp.mgrid #6248jnp.ogrid
Implement jnp.ogrid #6342jnp.r_
Implement jnp.r_ and jnp.c_ #6593jnp.c_
Implement jnp.r_ and jnp.c_ #6593jnp.ix_
jnp.s_
The text was updated successfully, but these errors were encountered: