# Einstein summation convention

+ TORCH.EINSUM  
https://pytorch.org/docs/stable/generated/torch.einsum.html

+ Einsum에 대해 간략한 정리  
https://ita9naiwa.github.io/numeric%20calculation/2018/11/10/Einsum.html  
  
+ EINSUM IS ALL YOU NEED - EINSTEIN SUMMATION IN DEEP LEARNING  
https://rockt.github.io/2018/04/30/einsum  

+ Einstein Summation in Numpy  
https://obilaniu6266h16.wordpress.com/2016/02/04/einstein-summation-in-numpy/

```
Vector inner product: "a,a->" (Assumes two vectors of same length)
Vector element-wise product: "a,a->a" (Assumes two vectors of same length)
Vector outer product: "a,b->ab" (Vectors not necessarily same length.)
Matrix transposition: "ab->ba"
Matrix diagonal: "ii->i"
Matrix trace: "ii->"
1-D Sum: "a->"
2-D Sum: "ab->"
3-D Sum: "abc->"
Matrix inner product "ab,ab->" (If you pass twice the same argument, it becomes a matrix L2 norm)
Left-multiplication Matrix-Vector: "ab,b->a"
Right-multiplication Vector-Matrix: "a,ab->b"
Matrix Multiply: "ab,bc->ac"
Batch Matrix Multiply: "Yab,Ybc->Yac"
Quadratic form / Mahalanobis Distance: "a,ab,b->"
```

In [1]:
import torch

  from .autonotebook import tqdm as notebook_tqdm


+ Transpose

In [2]:
x = torch.rand(2,2)
print(x)
y =  torch.einsum("ij->ji",x)
print(y)

tensor([[0.9050, 0.0506],
        [0.5148, 0.7061]])
tensor([[0.9050, 0.5148],
        [0.0506, 0.7061]])


+ diagonal

In [3]:
x = torch.rand(2,2)
print(x)
y =  torch.einsum("ii->i",x)
print(y)

tensor([[0.3799, 0.6717],
        [0.7511, 0.1676]])
tensor([0.3799, 0.1676])


+ trace

In [4]:
x = torch.rand(2,2)
print(x)
y =  torch.einsum("ii->",x)
print(y)

tensor([[0.6582, 0.7292],
        [0.4265, 0.0500]])
tensor(0.7082)


+ matrix sum to scalar

In [6]:
x = torch.rand(2,2)
print(x)
y =  torch.einsum("ab->",x)
print(y)

tensor([[0.4424, 0.4060],
        [0.5191, 0.4222]])
tensor(1.7895)


+ matrix column or row sum (to vector)

In [14]:
x = torch.randint(10, (2, 2))
print(x)
y =  torch.einsum("ab->a",x)
print(y)
y =  torch.einsum("ab->b",x)
print(y)

tensor([[2, 3],
        [3, 4]])
tensor([5, 7])
tensor([5, 7])


+ Dot Product, Outer product of two vectors

In [21]:
a = torch.randint(10, (2,))
b = torch.randint(10, (2,))
print(a)
print(b)
y =  torch.einsum("a,b->",a,b)
print(y)
y =  torch.einsum("a,b->ab",a,b)
print(y)

tensor([2, 7])
tensor([6, 6])
tensor(108)
tensor([[12, 12],
        [42, 42]])


+ Hadamard(element-wise) product of vector or matrix

In [23]:
a = torch.randint(10, (2,))
b = torch.randint(10, (2,))
print(a)
print(b)
y =  torch.einsum("i,i->i",a,b)
print(y)
print("-----------------")
a = torch.randint(10, (2,2))
b = torch.randint(10, (2,2))
print(a)
print(b)
y =  torch.einsum("ij,ij->",a,b)
print(y)

tensor([8, 6])
tensor([0, 4])
tensor([ 0, 24])
-----------------
tensor([[3, 3],
        [1, 4]])
tensor([[2, 2],
        [6, 1]])
tensor(22)


+ Matrix-Vector multiplication

In [25]:
a = torch.randint(10, (2,2))
b = torch.randint(10, (2,))
print(a)
print(b)
y =  torch.einsum("ij,j->i",a,b)
print(y)

tensor([[5, 9],
        [8, 0]])
tensor([2, 7])
tensor([73, 16])


+ Matrix-Matrix Multiplication and Batched Matrix multiplication

In [29]:
a = torch.randint(10, (3,2))
b = torch.randint(10, (2,3))
print(a)
print(b)
y =  torch.einsum("ik,kj->ij",a,b)
print(y)

print("-------------------")

a = torch.randint(10, (2,3,2))
b = torch.randint(10, (2,2,3))

print(a)
print(b)
y =  torch.einsum("bik,bkj->bij",a,b)
print(y)

tensor([[1, 6],
        [0, 3],
        [7, 4]])
tensor([[3, 5, 6],
        [3, 6, 5]])
tensor([[21, 41, 36],
        [ 9, 18, 15],
        [33, 59, 62]])
-------------------
tensor([[[4, 6],
         [5, 8],
         [5, 2]],

        [[9, 1],
         [2, 2],
         [7, 8]]])
tensor([[[2, 3, 7],
         [2, 9, 7]],

        [[9, 4, 1],
         [2, 8, 0]]])
tensor([[[20, 66, 70],
         [26, 87, 91],
         [14, 33, 49]],

        [[83, 44,  9],
         [22, 24,  2],
         [79, 92,  7]]])


+ Quadritc Form, or Matrix norm, or Distance with respect to Matrix(Mahalanobis distance)

In [35]:
a = torch.randint(10, (2,))
b = torch.randint(10, (2,))
X = torch.randint(10, (2,2))
print(a)
print(b)
print(X)
y =  torch.einsum("i,ij,j->",a,X,b)
print(y)
y =  torch.einsum("i,ij,j->i",a,X,b)
print(y)
y =  torch.einsum("i,ij,j->j",a,X,b)
print(y)

tensor([7, 4])
tensor([8, 4])
tensor([[8, 3],
        [4, 5]])
tensor(740)
tensor([532, 208])
tensor([576, 164])


In [40]:
a = torch.randint(10, (4,3,3))
b = -torch.randint(10, (2,4,3))
print(a)
print(b)
y = torch.einsum('fty,cfy->cft', [a, b])
print(y)

tensor([[[9, 8, 2],
         [9, 4, 3],
         [7, 8, 9]],

        [[7, 1, 2],
         [5, 6, 2],
         [4, 8, 6]],

        [[9, 4, 0],
         [9, 3, 5],
         [0, 5, 0]],

        [[1, 8, 2],
         [3, 6, 0],
         [6, 5, 0]]])
tensor([[[ 0, -4,  0],
         [-6, -3, -5],
         [-4, -1, -8],
         [-5, -1, -2]],

        [[-4, -6, -7],
         [-6, -7, -6],
         [-1, -4, -1],
         [-6, -7, -1]]])
tensor([[[ -32,  -16,  -32],
         [ -55,  -58,  -78],
         [ -40,  -79,   -5],
         [ -17,  -21,  -35]],

        [[ -98,  -81, -139],
         [ -61,  -84, -116],
         [ -25,  -26,  -20],
         [ -64,  -60,  -71]]])


## Batched filtering

In [2]:
import torch

  from .autonotebook import tqdm as notebook_tqdm


In [14]:
B = 2
C = 4
F = 5
T = 6

X = torch.rand(B,C,F,T)
h = torch.rand(B,C,F,1)

Y = torch.einsum('bcft,bcfl->bft', [X, h])
print(X.shape)
print(h.shape)
print(Y.shape)

print(X[0,:,0,1])
print(h[0,:,0,0])
print(Y[0,0,1])

print(torch.matmul(X[0,:,0,1],h[0,:,0,0]))

torch.Size([2, 4, 5, 6])
torch.Size([2, 4, 5, 1])
torch.Size([2, 5, 6])
tensor([0.7124, 0.1091, 0.6574, 0.8030])
tensor([0.0092, 0.4089, 0.9227, 0.3968])
tensor(0.9763)
tensor(0.9763)
