# Einsum. Einstein summation in deep learning

[Original video](https://youtu.be/pkVwUVEHmfI)

  * [Einsum is all you need - Einstein summation in Deep Learning](https://rockt.github.io/2018/04/30/einsum)
  * [Einstein Summation in Numpy](https://obilaniu6266h16.wordpress.com/2016/02/04/einstein-summation-in-numpy/)
  * [A basic introduction to NumPy's einsum](https://ajcr.net/Basic-guide-to-einsum/)

Einsum:
  * torch.einsum
  * numpy.einsum
  * tensorflow.einsum

In [92]:
import torch

print(torch.__version__)

1.8.0+cu101


In [93]:
# matrix-matrix multiplication
a = torch.rand(3, 5)
b = torch.rand(5, 2)
m1 = torch.empty((3, 2))

for i in range(a.shape[0]):
    for j in range(b.shape[1]):
        total = 0
        for k in range(a.shape[1]):
            total += a[i, k] * b[k, j]
        m1[i, j] = total


m2 = torch.matmul(a, b)
m3 = torch.einsum('ik,kj->ij', a, b)

print(torch.allclose(m1, m2))
print(torch.allclose(m2, m3))

True
True


In [94]:
# matrix-vector multiplication
a = torch.rand(2, 3)
b = torch.rand(3)

m1 = a.mv(b)
m2 = torch.einsum('ij,j->i', a, b)
print(torch.allclose(m1, m2))
print(m1)
print(m2)

True
tensor([1.0594, 0.5829])
tensor([1.0594, 0.5829])


In [95]:
# outer product
a = torch.arange(5)
b = torch.arange(3)
m1 = torch.empty((5, 3)).float()

for i in range(a.shape[0]):
    for j in range(b.shape[0]):
        total = 0
        # no sum loop index
        total += a[i] * b[j]
        m1[i, j] = total


m2 = torch.outer(a, b).float()
m3 = torch.einsum('i,j->ij', a, b).float()

print(torch.allclose(m1, m2))
print(torch.allclose(m2, m3))
print(m3)

True
True
tensor([[0., 0., 0.],
        [0., 1., 2.],
        [0., 2., 4.],
        [0., 3., 6.],
        [0., 4., 8.]])


In [96]:
# sum tensor
a = torch.arange(1, 10).reshape(3, 3)
m1 = torch.empty(1).float()

sum = 0
for i in range(a.shape[0]):
    for j in range(a.shape[1]):
        m1[0] += a[i, j]


m2 = torch.sum(a).float()
m3 = torch.einsum('ij->', a).float()

print(torch.allclose(m1, m2))
print(torch.allclose(m2, m3))
print(m3)

False
True
tensor(45.)


In [97]:
# column sum
a = torch.arange(1, 10).reshape(3, 3)

m1 = torch.sum(a, axis=0)
m2 = torch.einsum('ij->j', a)
print(torch.allclose(m1, m2))

# row sum
m1 = torch.sum(a, axis=1)
m2 = torch.einsum('ij->i', a)
print(torch.allclose(m1, m2))

True
True


In [98]:
# permute axis
a = torch.arange(30).reshape(5, 3, 2)

m1 = a.permute(2, 1, 0)
m2 = torch.einsum('ijk->kji', a)

print(torch.allclose(m1, m2))

True


In [99]:
# transpose tensor
a = torch.arange(6).reshape(3, 2)

m1 = a.transpose(0, 1)
m2 = torch.einsum('ij->ji', a)

print(torch.allclose(m1, m2))

True


In [100]:
# dot product 2nd row with 2nd row of matrix
a = torch.arange(1, 7).reshape(3, 2)

m1 = a[1].dot(a[1])
m2 = torch.einsum('i,i->', a[1], a[1])

print(torch.allclose(m1, m2))
print(m2)
print(a)

True
tensor(25)
tensor([[1, 2],
        [3, 4],
        [5, 6]])


In [101]:
# dot product with matrix
a = torch.arange(1, 7).reshape(3, 2)

m1 = torch.tensordot(a, a)
m2 = torch.einsum('ij,ij->', a, a)

print(m1)
print(m2)

tensor(91)
tensor(91)


In [102]:
# element-wise multiplication
a = torch.arange(1, 7).reshape(3, 2)

m1 = a * a
m2 = torch.einsum('ij,ij->ij', a, a)

print(m1)
print(m2)

tensor([[ 1,  4],
        [ 9, 16],
        [25, 36]])
tensor([[ 1,  4],
        [ 9, 16],
        [25, 36]])


In [103]:
# batch matrix multiplication, matmul 2 last axis
a = torch.rand((3, 2, 5))
b = torch.rand((5, 3))

m1 = torch.matmul(a, b)
m2 = torch.einsum('ijk,kl->ijl', a, b)

print(torch.allclose(m1, m2))

True


In [104]:
# matrix diagonal
a = torch.arange(1, 10).reshape(3, 3)

m0 = a.diagonal()
m1 = a.diag()
m2 = torch.einsum('ii->i', a)

print(m0)
print(m1)
print(m2)
print(a)

tensor([1, 5, 9])
tensor([1, 5, 9])
tensor([1, 5, 9])
tensor([[1, 2, 3],
        [4, 5, 6],
        [7, 8, 9]])


In [105]:
# matrix trace
a = torch.arange(1, 10).reshape(3, 3)

m1 = a.diagonal().sum()
m2 = torch.einsum('ii->', a)

print(m1)
print(m2)

tensor(15)
tensor(15)
