Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Op] Enhanced Einsum #283

Closed
wants to merge 20 commits into from
Closed

[Op] Enhanced Einsum #283

wants to merge 20 commits into from

Conversation

LeshengJin
Copy link
Contributor

@LeshengJin LeshengJin commented Aug 13, 2023

Typically, Einsum only performs element-wise multiplication and summation across indices. This pr expands Einsum's capabilities:

  1. Customize element-wise computation and index combination.
  2. Einsum can now produce Tuple(tensor) outputs.

This enhanced Einsum could represents more complex computations within a few lines of code. For example,

  • sum(x), sum(x^2)
einsum("ij -> i, i", x, fcompute=lambda x_ij: (x_ij, x_ij * x_ij))
  • sum(x), prod(x)
einsum(
    "ij -> i, i",
    x,
    fcombine=lambda x, y: (x[0] + y[0], x[1] * y[1]),
    fidentity=lambda dtype1, dtype2: (tvm.tir.const(0, dtype1), tvm.tir.const(1, dtype2)),
)
  • Online Softmax
def fcombine(tensor1, tensor2):
    mi = tensor1[0]
    di = tensor1[1]
    mj = tensor2[0]
    dj = tensor2[1]
    r0 = tvm.tir.max(mi, mj)
    r1 = di * tvm.tir.exp(mi - r0) + dj * tvm.tir.exp(mj - r0)
    return r0, r1

def fidentity(dtype1, dtype2):
    return tvm.te.min_value(dtype1), tvm.tir.const(0, dtype2)

mv, dv = einsum(
    "ij -> i, i",
    x,
    fcompute=lambda x_ij: (x_ij, 1.0),
    fcombine=fcombine,
    fidentity=fidentity,
)

softmax_x = einsum(
    "ij, i, i -> ij",
    (x, mv, dv),
    fcompute=lambda x_ij, mv_i, dv_i: (tvm.tir.exp(x_ij) - mv_i) / dv_i,
)

@LeshengJin LeshengJin changed the title Einsum [Op] Einsum with customized compute function and combine function Aug 13, 2023
@LeshengJin LeshengJin changed the title [Op] Einsum with customized compute function and combine function [Op] Enhanced Einsum Aug 13, 2023
@junrushao
Copy link
Member

You may take a look at this: https://einops.rocks/

@tqchen tqchen closed this May 24, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

5 participants