I am currently using einx in a graph neural network, where the number of edges are dynamic between graphs.
import time, torch, einx
def run_einx():
for E in range(1000, 1050): # 50 novel shapes
d, c = torch.randn(E, 8), torch.rand(E)
einx.divide("e d, e -> e d", d, c)
def run_torch():
for E in range(1000, 1050):
d, c = torch.randn(E, 8), torch.rand(E)
d / c.unsqueeze(-1) # shape-agnostic: no per-shape cost
t0 = time.perf_counter(); run_einx(); print(f"einx, novel shapes: {time.perf_counter()-t0:.3f}s")
t0 = time.perf_counter(); run_einx(); print(f"einx, cached: {time.perf_counter()-t0:.3f}s")
t0 = time.perf_counter(); run_torch(); print(f"raw torch: {time.perf_counter()-t0:.3f}s")
This script produces this output:
einx, novel shapes: 0.481s
einx, cached: 0.004s
raw torch: 0.002s
I think the reason why einx is slower is because einx needs to generate a new expression every time (because the shapes of the tensors d and c are not consistent each time).
To see the code that's generated by einx, I ran
import torch, einx
dist = torch.rand(4096); centers = torch.linspace(0, 6, 32)
print(einx.subtract("e, r -> e r", dist, centers, graph=True))
Which generates this code:
import torch
def op(a, b):
a = torch.reshape(a, (4096, 1))
b = torch.reshape(b, (1, 32))
c = torch.subtract(a, b)
return c
I believe that these reshapes are generated for each unique call (because we need to hardcode the dimensions such as 4096). This means that graphs with different number of edges will not use the same einx code and will therefore be much slower as custom code needs to be generated each time to have the proper shapes.
I think we need some way to specify that some parameters are dynamic to trigger smarter compilation?
I am currently using einx in a graph neural network, where the number of edges are dynamic between graphs.
This script produces this output:
I think the reason why einx is slower is because einx needs to generate a new expression every time (because the shapes of the tensors d and c are not consistent each time).
To see the code that's generated by einx, I ran
Which generates this code:
I believe that these reshapes are generated for each unique call (because we need to hardcode the dimensions such as 4096). This means that graphs with different number of edges will not use the same einx code and will therefore be much slower as custom code needs to be generated each time to have the proper shapes.
I think we need some way to specify that some parameters are dynamic to trigger smarter compilation?