Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Named tensors #5048

Open
juliuskunze opened this issue Nov 30, 2020 · 8 comments
Open

Named tensors #5048

juliuskunze opened this issue Nov 30, 2020 · 8 comments
Assignees
Labels
enhancement New feature or request

Comments

@juliuskunze
Copy link
Contributor

juliuskunze commented Nov 30, 2020

PyTorch has experimental support for named tensors achieving some compelling design goals while keeping existing code compatible. For example, binop broadcasting is still based on dimension order (unlike in xarray), consistent with standard NumPy/JAX/... semantics, but checks that aligned dimension names match.

It would be great to have named tensors that work both in op-by-op and under function transformations in JAX.

@shoyer In #1565 you mentioned that this could be done by wrapping JAX based on #611. According to my current understanding, this means:

  • Add name rules for lax primitives, returning the output dimension names for given input dimension names.
  • Add a corresponding eval_names transform.
  • Add a NamedDeviceArray subtype of DeviceArray that adds a names property.
  • We want names to be propagated in op-by-op mode on NamedDeviceArrays. For that,
  • Make jitted functions propagate names when applied to NamedDeviceArrays.

Is this plan sound? @shoyer @mattjj Would you update (and merge, if successful) #611 just for this application? In that case, I'd be interested in prototyping a named tensor library for JAX, with a good amount of passion, in accordance with #1565. (:

@hawkinsp hawkinsp added the enhancement New feature or request label Nov 30, 2020
@Jeevesh8
Copy link

Have you started working on this @juliuskunze ?

@apaszke
Copy link
Member

apaszke commented Dec 21, 2020

We are actually working on something that will pretty much realize the plan that @juliuskunze has outlined here, with some additional benefits too (e.g. making it very easy to shard those programs with named axes over multiple accelerators).

@juliuskunze
Copy link
Contributor Author

@Jeevesh8 No, and now I won't anymore. (: @apaszke That's great to hear! Will this go into the JAX repo?

@degregat
Copy link

Do I assume correctly that this evolved into named axes, or is there another module I did not find?

@froystig
Copy link
Member

That's correct.

@juliuskunze
Copy link
Contributor Author

juliuskunze commented Jun 28, 2021

@apaszke @froystig That looks awesome! Rad choice not taking into account order of named axes and broadcasting by name! That's semantically cleaner and probably more future-proof than I expected. (: The thing that I thought would make this impractical is that it's hard to optimize misaligned axes for dot products and similar ops where implicit transposes are needed on device. I guess the performance hit is not so bad or axis order optimization could/should be automated in the future anyway? Curious about your thoughts on this.

+1 for allowing arrays and operations with named axes outside of xmap, i. e. make named axis arrays first-class in jax as suggested above.

@Bit0r
Copy link

Bit0r commented Aug 6, 2023

A more powerful implementation is to use first-class dimensions, and torchdim uses objects as dimension "variables"

@Bit0r
Copy link

Bit0r commented Aug 8, 2023

@apaszke Perhaps it could be further independent of axis position? By utilizing the named tensor feature, operations that do not depend on axis position can be achieved.

  1. For example, the for loop operation can directly specify which axis to loop on, and the framework automatically advances the axis to the first dimension. The entire operation is transparent, and users do not need to write any additional code.
  2. For example, we can directly use axis names and indices on that axis to access tensors.
# tensor.named_shape={'batch':32, 'time':100, 'hidden':200}

t[{'time':0, 'hidden':0}] = 1000 # Select tensor with axis time 0 and axis hidden 0, and set tensor to 1000 with broadcast.

for t in tensor['time']:
# Jax automatically performs dimension permutation for operations: tensor: batch, time, hidden -> time, batch, hidden
# t.named_shape = {'batch':32, 'hidden':200}
    ...

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

8 participants