Skip to content

Dynamic Shape Support #36

Description

@curtischong

I am currently using einx in a graph neural network, where the number of edges are dynamic between graphs.

import time, torch, einx

def run_einx():
    for E in range(1000, 1050):  # 50 novel shapes
        d, c = torch.randn(E, 8), torch.rand(E)
        einx.divide("e d, e -> e d", d, c)

def run_torch():
    for E in range(1000, 1050):
        d, c = torch.randn(E, 8), torch.rand(E)
        d / c.unsqueeze(-1)  # shape-agnostic: no per-shape cost

t0 = time.perf_counter(); run_einx(); print(f"einx, novel shapes: {time.perf_counter()-t0:.3f}s")
t0 = time.perf_counter(); run_einx(); print(f"einx, cached:       {time.perf_counter()-t0:.3f}s")
t0 = time.perf_counter(); run_torch(); print(f"raw torch:          {time.perf_counter()-t0:.3f}s")

This script produces this output:

einx, novel shapes: 0.481s
einx, cached:       0.004s
raw torch:          0.002s

I think the reason why einx is slower is because einx needs to generate a new expression every time (because the shapes of the tensors d and c are not consistent each time).

To see the code that's generated by einx, I ran

import torch, einx
dist = torch.rand(4096); centers = torch.linspace(0, 6, 32)
print(einx.subtract("e, r -> e r", dist, centers, graph=True))

Which generates this code:

import torch
def op(a, b):
    a = torch.reshape(a, (4096, 1))
    b = torch.reshape(b, (1, 32))
    c = torch.subtract(a, b)
    return c

I believe that these reshapes are generated for each unique call (because we need to hardcode the dimensions such as 4096). This means that graphs with different number of edges will not use the same einx code and will therefore be much slower as custom code needs to be generated each time to have the proper shapes.

I think we need some way to specify that some parameters are dynamic to trigger smarter compilation?

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions