Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

defn roadmap #77

Closed
8 of 12 tasks
josevalim opened this issue Dec 1, 2020 · 4 comments
Closed
8 of 12 tasks

defn roadmap #77

josevalim opened this issue Dec 1, 2020 · 4 comments
Labels
kind:feature New feature or request

Comments

@josevalim
Copy link
Collaborator

josevalim commented Dec 1, 2020

Constructs:

  • Math operators
  • Bit operators
  • Logical operators
  • Slices (access + put_in/update_in)
  • Conditionals (if/cond)
  • Tuples (with pattern matching)
  • Loops (for? while?)
  • Support random functions
  • Default arguments

Passes:

  • Autograd
  • vmap
  • pmap
@seanmor5
Copy link
Collaborator

seanmor5 commented Dec 1, 2020

Does it make sense to have pmap as its own pass/function/whatever, or does it make more sense to provide simple constructs for people to accomplish the same thing?

@josevalim
Copy link
Collaborator Author

To be honest, I don’t know yet, probably the latter. I just put it here so we don’t forget to track it but this is most likely a device/exla concern. I will move it. :)

@josevalim
Copy link
Collaborator Author

@seanmor5 so vmap may require changes to the underlying code to be compiled as it aims to add a new dimension to computations by making them batchable. So it is definitely a defn pass. It is hard to assess right now how big those changes are, since all of our operations are element wise so far. But with Jax, this code:

from jax import make_jaxpr

def f(x, y):
  a = jnp.dot(x, y)
  b = jnp.tanh(a)
  return b

xs = jnp.ones((8, 2, 3))
ys = jnp.ones((8, 3, 4))

print("f jaxpr")
print(make_jaxpr(f)(xs[0], ys[0]))

print("vmap(f) jaxpr")
print(make_jaxpr(vmap(f))(xs, ys))

prints:

f jaxpr
{ lambda  ; a b.
  let c = dot_general[ dimension_numbers=(((1,), (0,)), ((), ()))
                       precision=None ] a b
      d = tanh c
  in (d,) }
vmap(f) jaxpr
{ lambda  ; a b.
  let c = dot_general[ dimension_numbers=(((2,), (1,)), ((0,), (0,)))
                       precision=None ] a b
      d = tanh c
  in (d,) }

pmap, on the other hand, is about the devices. We want to move the data to separate devices when sending in and read them back from all multiple devices into a single binary. Or keep them on multiple references. So it is definitely a device based operation. I will create a separate issue for tracking the device roadmap.

@josevalim josevalim added the kind:feature New feature or request label Jan 23, 2021
@josevalim
Copy link
Collaborator Author

I have broken all remaining tasks to separate issues. The pmap discussion is tied to #127.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
kind:feature New feature or request
Projects
None yet
Development

No branches or pull requests

2 participants