In [2]:
import numpy as np
import torch

`torch.einsum(equation, *operands) → Tensor`

Sums the product of the elements of the input operands along dimensions specified using a notation based on the Einstein summation convention.

Einsum allows computing many common multi-dimensional linear algebraic array operations by representing them in a short-hand format based on the Einstein summation convention, given by equation. The details of this format are described below, but the general idea is to label every dimension of the input operands with some subscript and define which subscripts are part of the output. The output is then computed by summing the product of the elements of the operands along the dimensions whose subscripts are not part of the output. For example, matrix multiplication can be computed using einsum as torch.einsum(“ij,jk->ik”, A, B). Here, j is the summation subscript and i and k the output subscripts 

In [37]:
t1 = torch.tensor([[0,1,2],[1,1,1]])
t2 = torch.tensor([[1,1,1],[1,2,3]])

In [38]:
x1 = torch.einsum("ab,de->ae", [t1,t2])

In [39]:
x1

tensor([[ 6,  9, 12],
        [ 6,  9, 12]])

In [40]:
x2 = torch.einsum("ab,de->bd", [t1,t2])

In [41]:
x2

tensor([[ 3,  6],
        [ 6, 12],
        [ 9, 18]])

In [42]:
x3 = torch.einsum("ab,de->ad", [t1,t2])

In [43]:
x3

tensor([[ 9, 18],
        [ 9, 18]])

In [45]:
x1T = torch.einsum("ab,de->ea", [t1,t2])

In [46]:
x1T

tensor([[ 6,  6],
        [ 9,  9],
        [12, 12]])

In [47]:
diag = torch.einsum("ab,de->a", [t1,t2])

In [48]:
diag

tensor([27, 27])

In [49]:
diagT = torch.einsum("ab,de->e", [t1,t2])

In [51]:
diagT

tensor([12, 18, 24])

In [54]:
#can also just reshape
o1 = torch.tensor(np.ones((5,4,3)))
oT = torch.einsum("abc->cba", [o1])
oT.shape #5,4,3 --> 3,4,5

torch.Size([3, 4, 5])

In [63]:
# row summation
a1 = torch.tensor([[1,2,3],[1,2,3]]) #2x3
row_sum = torch.einsum("ij->i", [a1])
row_sum

tensor([6, 6])

In [65]:
# col summation
a1 = torch.tensor([[1,2,3],[1,2,3]]) #2x3
col_sum = torch.einsum("ij->j", [a1])
col_sum

tensor([2, 4, 6])

In [66]:
# full summation
a1 = torch.tensor([[1,2,3],[1,2,3]]) #2x3
a_sum = torch.einsum("ij->", [a1])
a_sum

tensor(12)

In [80]:
#matrix vector multiplication
v1 = torch.tensor([[1,2,3]]) #1x3 vector
m1 = torch.tensor([[1,2,3],[1,2,3]]) #2x3 matrix
mv = torch.einsum("ij,kj->ik", [m1, v1]) # 2x3,1x3 -> 2x1
mv

tensor([[14],
        [14]])

In [81]:
#matrix matrixT (self)
m1 = torch.tensor([[1,2,3],[1,2,3]]) #2x3 matrix
#note, einsum handles the transpose flip for us
#we just set it in the inputs
mm = torch.einsum("ij,kj->ik", [m1, m1]) # 2x3,3x2 -> 2x2
mm

tensor([[14, 14],
        [14, 14]])

In [82]:
#dot product of first row
m1 = torch.tensor([[1,2,3],[1,2,3]]) #2x3 matrix
d1 = torch.einsum("i,i->", [m1[0], m1[0]])
print(d1)

tensor(14)


In [83]:
#dot product of a matrix
m1 = torch.tensor([[1,2,3],[1,2,3]]) #2x3 matrix
d2 = torch.einsum("ij,ij->", [m1,m1])
print(d2)

tensor(28)


In [85]:
#element wise product (haadamard product)
m1 = torch.tensor([[1,2,3],[1,2,3]]) #2x3 matrix
h1 = torch.einsum("ij,ij->ij", m1,m1)
print(h1)


tensor([[1, 4, 9],
        [1, 4, 9]])


In [88]:
#outer product
va = torch.tensor([1,2,3]) #3 element vector
vb = torch.tensor([1,2,3,4,5]) #5 element vector
mv = torch.einsum("i,j->ij", va, vb) # 3,5 --> 3x5 outer product matrix
print(mv)

tensor([[ 1,  2,  3,  4,  5],
        [ 2,  4,  6,  8, 10],
        [ 3,  6,  9, 12, 15]])


In [91]:
#batch matrix multiplication
a = torch.rand((3,2,5))
b = torch.rand((3,5,3))
c = torch.einsum("ijk,ikl->ijl", a,b)
print(c.shape)

torch.Size([3, 2, 3])


In [93]:
#matrix diagonal
x = torch.tensor([[1,0,0],[0,1,0],[0,0,1]])
xdiag = torch.einsum("ii->i",x)
print(xdiag)

tensor([1, 1, 1])


In [95]:
#matrix trace
x = torch.tensor([[1,0,0],[0,1,0],[0,0,1]])
xtrace = torch.einsum("ii->",x) #just sum all values
print(xtrace)

tensor(3)
