-
Notifications
You must be signed in to change notification settings - Fork 188
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
defn roadmap #77
Comments
Does it make sense to have |
To be honest, I don’t know yet, probably the latter. I just put it here so we don’t forget to track it but this is most likely a device/exla concern. I will move it. :) |
@seanmor5 so vmap may require changes to the underlying code to be compiled as it aims to add a new dimension to computations by making them batchable. So it is definitely a from jax import make_jaxpr
def f(x, y):
a = jnp.dot(x, y)
b = jnp.tanh(a)
return b
xs = jnp.ones((8, 2, 3))
ys = jnp.ones((8, 3, 4))
print("f jaxpr")
print(make_jaxpr(f)(xs[0], ys[0]))
print("vmap(f) jaxpr")
print(make_jaxpr(vmap(f))(xs, ys)) prints:
pmap, on the other hand, is about the devices. We want to move the data to separate devices when sending in and read them back from all multiple devices into a single binary. Or keep them on multiple references. So it is definitely a device based operation. I will create a separate issue for tracking the device roadmap. |
I have broken all remaining tasks to separate issues. The pmap discussion is tied to #127. |
Constructs:
Passes:
The text was updated successfully, but these errors were encountered: