In [1]:
# install einsum semirings
!pip install torch-semiring-einsum



In [4]:
import torch
import torch_semiring_einsum

# Log-space einsum, where addition replaced with LSE and
# multiplication replaced with addition

# Pre-compile an einsum equation.
EQUATION = torch_semiring_einsum.compile_equation('bik,bkj->bij')
# Create some parameters to multiply.
A = torch.log(torch.rand(10, 3, 5, requires_grad=True))
B = torch.log(torch.rand(10, 5, 7, requires_grad=True))
# Run einsum.
C = torch_semiring_einsum.log_einsum(EQUATION, A, B)
# Now C is differentiable.
C.sum().backward()

In [9]:
M, N, P = 10, 3, 8
EQUATION_STR = 'mn,np->mp'
EQUATION = torch_semiring_einsum.compile_equation(EQUATION_STR)
logA, logB = torch.randn(M, N), torch.randn(N, P)
res_1 = torch.einsum('mn,np->mp', A.exp(), B.exp()).log()
res_2 = torch_semiring_einsum.log_einsum(EQUATION, A, B)


In [13]:
torch.allclose(res_1, res_2)

True

In [None]:
R = 8
A1 = torch.randn(R, R)
A2 = torch.randn(R, R)
A3 = torch.randn(R, R)

# Make expanded tensors so Ak is 9-mode with same slice
A1_lg = A1.reshape(R, R, *((1,1)*2))
A2_lg = A2.reshape(*((1,1)*1), R, R, *((1,1)*1))
A3_lg = A3.reshape(*((1,1)*2), R, R)

# Before expand: just print actual storage mem
for t in [A1_lg, A2_lg, A3_lg]:
    print(f"{t.shape} {t.untyped_storage().nbytes()}")

# Now expand 
A1_lg = A1_lg.expand(-1, -1, *((R, R)*2))
A2_lg = A2_lg.expand(*((R, R)*1), -1, -1, *((R, R)*1))
A3_lg = A3_lg.expand(*((R, R)*2), -1, -1)

# After expand: just print actual storage mem
for t in [A1_lg, A2_lg, A3_lg]:
    print(f"{t.shape} {t.untyped_storage().nbytes()}")

tens = torch.stack([A1_lg, A2_lg, A3_lg], dim=0)
print(f"tens.shape {tens.shape} {tens.untyped_storage().nbytes()}") # <- blows up

torch.Size([8, 8, 1, 1, 1, 1]) 256
torch.Size([1, 1, 8, 8, 1, 1]) 256
torch.Size([1, 1, 1, 1, 8, 8]) 256
torch.Size([8, 8, 8, 8, 8, 8]) 256
torch.Size([8, 8, 8, 8, 8, 8]) 256
torch.Size([8, 8, 8, 8, 8, 8]) 256
tens.shape torch.Size([3, 8, 8, 8, 8, 8, 8]) 3145728


In [None]:
# Try lse on the flattened tensors

lse = torch.logsumexp()