## Contents

* [Definition](#definition)
* [1D: Vectors](#1d)
    * [Identity: *(i)*](#1d_identity)
    * [Binary operations](#1d_binary)
        * [Element-wise product, *(i,i -> i)*](#vec_hadamard)
        * [Inner product, *(i,i->i)*](#vec_inner)
        * [Outer product, *(i,j->ij)*](#vec_outer)
* [2D: Matrices](#2d)
    * [Unary operations](#2d_unary)
        * [Identity: *(ij)*](#2d_identity)
        * [Total sum *('ij -> ')*](#2d_sum)
        * [Sum along axis. *('ij -> i')*](#2d_axis_sum)
        * [Transposition. *('ij -> ji')*](#2d_transpose)
        * [Diagonal. *('ii -> i')*](#2d_diagonal)
        * [Trace. *('ii')*](#2d_trace)
    * [Binary operations](#2d_unary)
        * [Element-wise product. *('ij, ij -> ij')*](#2d_hadamard)
        * [Matrix product, *(ij,jk)*](#2d_inner)
        * [Expansion. *('ij, kl -> ijk')*](#2d_expansion)
        * [Other composite operations. ('ij, jk -> ij')](#2d_other)
        
* [3D](#3d)
    * [Tensor Contraction](#3d_contraction)
    * [Batch Matrix multiplication](#3d_batch_mat_prod)

#### <div id='refs'>References:</div>
* [Torch einsum](https://pytorch.org/docs/stable/torch.html#torch.einsum)
* [Numpy einsum](https://docs.scipy.org/doc/numpy/reference/generated/numpy.einsum.html)
* [basic guide to einsum](https://ajcr.net/Basic-guide-to-einsum/)
* [Einstein Summation in Numpy](https://obilaniu6266h16.wordpress.com/page/3/)
* [Einstein Summation, WolframMathWorld](https://mathworld.wolfram.com/EinsteinSummation.html)

### <div id='intro'>Brief intro</div>

*Einstein notation*, AKA *Enstein summation* or *Einsum* is a concise way to express matrix operations. It is particularly suited for tensors operations and, thus, mainly used in Physics. In particular Einstein invented such notation in order to easily and cleary indicate the indexes of the tensors over which to perform matrix products.  
[Numpy einsum](https://numpy.org/doc/stable/reference/generated/numpy.einsum.html) was the first python implementation of such notation, as much as I know. You can know more from [here](https://ajcr.net/Basic-guide-to-einsum/#historical-notes-and-links).  
It is important to notice that **Numpy einsum is not a perfect copy of the Einstein notation**, it has its pros and cons.
[Torch einsum](https://pytorch.org/docs/stable/generated/torch.einsum.html#torch-einsum) is the Pytorch version of the Numpy implementation. It came with perks like parallelization on GPU, so it was faster than the Numpy one at the beginning. As much as I know, Numpy recently optimized einsum. Pytorch einsum is not the perfect copy of Numpy einsum. For example I noticed it lacks the broadcasting of a scalar over a vector. 

For more info you can have a look at the [References](#refs) or Google will help you.


### <div id='definition'>Definition</div>

```torch.einsum(equation, *operands) → Tensor```

>This function provides a way of computing multilinear expressions (i.e. sums of products) using the Einstein summation convention.

>The equation is given in terms of lower case letters (indices) to be associated with each dimension of the operands and result. The left hand side lists the operands dimensions, separated by commas.
* There should be one index letter per tensor dimension. The right hand side follows after -> and gives the indices for the output. If the -> and right hand side are omitted, it implicitly defined as the alphabetically sorted list of all indices appearing exactly once in the left hand side. * The indices not apprearing in the output are summed over after multiplying the operands entries. 
* If an index appears several times for the same operand, a diagonal is taken.
* Ellipses … represent a fixed number of dimensions. If the right hand side is inferred, the ellipsis dimensions are at the beginning of the output.

In [150]:
import torch

Quick example

In [154]:
a = torch.Tensor([0, 1, 2])

B = torch.Tensor([[ 0,  1,  2,  3],
              [ 4,  5,  6,  7],
              [ 8,  9, 10, 11]])

print(f"a:\n{a}\n a.shape:{a.shape}")
print(f"B:\n{B}\n B.shape:{B.shape}")

a:
tensor([0., 1., 2.])
 a.shape:torch.Size([3])
B:
tensor([[ 0.,  1.,  2.,  3.],
        [ 4.,  5.,  6.,  7.],
        [ 8.,  9., 10., 11.]])
 B.shape:torch.Size([3, 4])


In [155]:
C = torch.einsum('i,ij->i', a, B)
print(f"\nC:\n{C}\n C.shape: {C.shape}")


C:
tensor([ 0., 22., 76.])
 C.shape: torch.Size([3])


In [156]:
C = (a[:, None] * B).sum(axis=1) # product + sum
print(f"\nC:\n{C}\n C.shape:\n{C.shape}")


C:
tensor([ 0., 22., 76.])
 C.shape:
torch.Size([3])


Einsum is basically a product + sum

_____

### <div id='1d'>1D Tensor: Vector</div>

In [5]:
u = torch.Tensor([0, 1, 2, 3])

##### <div id='1d_identity'>Identity: *(i)*</div>

In [6]:
torch.einsum('i', u)

tensor([0., 1., 2., 3.])

#### <div id='1d_binary'>Binary operations</div>

In [7]:
v = torch.Tensor([4, 5, 6, 7])

##### <div id='vec_inner'>Inner product</div> *(i,i->i)*, assumes two vectors of same length

In [8]:
torch.einsum('i,i', u, v)

tensor(38.)

In [9]:
(u*v).sum() # it comes from this

tensor(38.)

In [10]:
u@v

tensor(38.)

##### <div id='vec_hadamard'>dot Element-wise product</div> *(i,i -> i)*, assumes two vectors of same length

In [11]:
torch.einsum('i,i->i', u, v)

tensor([ 0.,  5., 12., 21.])

In [12]:
u*v

tensor([ 0.,  5., 12., 21.])

##### <div id='vec_outer'>Outer product</div> *(i,j)*, assumes two vectors not necessarily of same length

In [13]:
torch.einsum('i,j -> ij', u, v)

tensor([[ 0.,  0.,  0.,  0.],
        [ 4.,  5.,  6.,  7.],
        [ 8., 10., 12., 14.],
        [12., 15., 18., 21.]])

In [14]:
u[:, None]@v[None,:] # = u(i,1) x v(1,j)

tensor([[ 0.,  0.,  0.,  0.],
        [ 4.,  5.,  6.,  7.],
        [ 8., 10., 12., 14.],
        [12., 15., 18., 21.]])

### <div id='2d'>2D Tensor: Matrix</div>

In [15]:
A = torch.Tensor([
        [ 0,  1,  2,  3],
        [ 4,  5,  6,  7],
        [ 8,  9, 10, 11]
])
print(f"A.shape:{A.shape}")

A.shape:torch.Size([3, 4])


We can say *i* and *j* being rows and columns axes

#### <div id='2d_unary'>Unary Operations</div>

##### <div id='2d_identity'>Identity. *('ij')* </div>

In [16]:
torch.einsum('ij', A)

tensor([[ 0.,  1.,  2.,  3.],
        [ 4.,  5.,  6.,  7.],
        [ 8.,  9., 10., 11.]])

In [17]:
torch.einsum('ij -> ij', A)

tensor([[ 0.,  1.,  2.,  3.],
        [ 4.,  5.,  6.,  7.],
        [ 8.,  9., 10., 11.]])

##### <div id='2d_sum'>Total sum *('ij -> ')*</div>

In [18]:
torch.einsum('ij-> ', A) # a scalar output is returned

tensor(66.)

In [19]:
A.sum()

tensor(66.)

##### <div id='2d_axis_sum'>Sum along axis. *('ij -> i')*</div>

In [20]:
torch.einsum('ij->j', A) # contraction along i: j (columns) survives

tensor([12., 15., 18., 21.])

In [21]:
A.sum(axis=0) # sum along columns; a row will result: <<4 columns>>

tensor([12., 15., 18., 21.])

In [22]:
A.sum(axis=0).shape

torch.Size([4])

Now the other axis

In [23]:
torch.einsum('ij->i', A)

tensor([ 6., 22., 38.])

In [24]:
A.sum(axis=1)

tensor([ 6., 22., 38.])

##### <div id='2d_transpose'>Transposition. *('ij -> ji')*</div>

In [25]:
torch.einsum('ji', A)

tensor([[ 0.,  4.,  8.],
        [ 1.,  5.,  9.],
        [ 2.,  6., 10.],
        [ 3.,  7., 11.]])

In [26]:
torch.einsum('ij->ji', A)

tensor([[ 0.,  4.,  8.],
        [ 1.,  5.,  9.],
        [ 2.,  6., 10.],
        [ 3.,  7., 11.]])

##### <div id='2d_diagonal'>Diagonal. *('ii -> i')*</div>

Needs a trunkated tensor, namely a square one

In [27]:
try: torch.einsum('ii->i', A)
except Exception as e: print(e)

size of dimension does not match previous size, operand 0, dim 1


In [28]:
torch.einsum('ii->i', A[:, :3])

tensor([ 0.,  5., 10.])

In [29]:
torch.diag(A)

tensor([ 0.,  5., 10.])

In [30]:
torch.einsum('jj->j', A[:, :3])

tensor([ 0.,  5., 10.])

##### <div id='2d_trace'>Trace. *('ii')*</div>

In [31]:
torch.einsum('ii', A[:, :3]) # sum of the diagonal

tensor(15.)

In [32]:
torch.trace(A)

tensor(15.)

Note the same axis in the declaration, ```(ii)```, different from ```(ij)```, being the identity for 2 rank tensors.

#### <div id='2d_binary'>Binary operations</div>

In [33]:
B = torch.Tensor([
        [ 12,  13,  14,  15],
        [ 16,  17, 18, 19],
        [ 20, 21, 22, 23]
])
print(f"B.shape:{B.shape}")

B.shape:torch.Size([3, 4])


##### <div id='2d_hadamard'>Element-wise product. *('ij, ij -> ij')*</div>

Element-wise multiplication of A and B. row vs row

In [34]:
A.shape, B.shape

(torch.Size([3, 4]), torch.Size([3, 4]))

In [35]:
torch.einsum('ij, ij -> ij', A, B)

tensor([[  0.,  13.,  28.,  45.],
        [ 64.,  85., 108., 133.],
        [160., 189., 220., 253.]])

Explaination.  
Let's take the same operation between their first rows and compare with the previous output. The result will be equal to the first row.

In [36]:
torch.einsum('ij, ij -> ij', A[0,None], B[0,None])

tensor([[ 0., 13., 28., 45.]])

In [37]:
A[0] * B[0]

tensor([ 0., 13., 28., 45.])

If we switch output axes in the einsum we can have the transpose result.

In [38]:
torch.einsum('ij, ij -> ji', A, B) # just transpose

tensor([[  0.,  64., 160.],
        [ 13.,  85., 189.],
        [ 28., 108., 220.],
        [ 45., 133., 253.]])

##### <div id=''>Matrix product, dot, ('ij,jk -> ik')</div>

In [39]:
print(f"A.shape:{tuple(A.shape)}, B.shape:{tuple(B.shape)}")

A.shape:(3, 4), B.shape:(3, 4)


In [40]:
torch.einsum('ij, jk -> ik', A, B.T)

tensor([[ 86., 110., 134.],
        [302., 390., 478.],
        [518., 670., 822.]])

In [41]:
torch.matmul(A,B.T) # same as above

tensor([[ 86., 110., 134.],
        [302., 390., 478.],
        [518., 670., 822.]])

Generic product *(ij, kl)*

In [42]:
torch.einsum('ij, kl -> ik', A, B).shape

torch.Size([3, 3])

In [43]:
torch.einsum('ij, kl -> il', A, B).shape

torch.Size([3, 4])

In [44]:
torch.einsum('ij, kl -> ij', A, B).shape

torch.Size([3, 4])

The above results can be seen as different sums on the dims [expansion](#2d_expansion) (broadcasting) of the two matrices. Below an example

In [45]:
a = torch.einsum('ij, kl -> ijkl', A, B) # expansion
torch.einsum('ijkl -> ij', a) # sum along k and l

tensor([[   0.,  210.,  420.,  630.],
        [ 840., 1050., 1260., 1470.],
        [1680., 1890., 2100., 2310.]])

In [46]:
torch.einsum('ij, kl -> ij', A, B)

tensor([[   0.,  210.,  420.,  630.],
        [ 840., 1050., 1260., 1470.],
        [1680., 1890., 2100., 2310.]])

##### <div id='2d_expansion'>Expansion. *('ij, kl -> ijk')*</div>

It distributes (broadcast) the product between A and the contraction of B along an axis.

In [47]:
torch.einsum('ij, kl -> ijk', A, B)

tensor([[[  0.,   0.,   0.],
         [ 54.,  70.,  86.],
         [108., 140., 172.],
         [162., 210., 258.]],

        [[216., 280., 344.],
         [270., 350., 430.],
         [324., 420., 516.],
         [378., 490., 602.]],

        [[432., 560., 688.],
         [486., 630., 774.],
         [540., 700., 860.],
         [594., 770., 946.]]])

In [48]:
print(torch.einsum('ij, kl -> ijk', A, B).shape)

torch.Size([3, 4, 3])


Let's see how it is generated

In [49]:
# 1. contraction of B along l
B_k = torch.einsum('kl -> k', B)
B_k

tensor([54., 70., 86.])

In [50]:
# 2. element-wise product between A and B_k
torch.einsum('ij, k -> ijk', A, B_k)

tensor([[[  0.,   0.,   0.],
         [ 54.,  70.,  86.],
         [108., 140., 172.],
         [162., 210., 258.]],

        [[216., 280., 344.],
         [270., 350., 430.],
         [324., 420., 516.],
         [378., 490., 602.]],

        [[432., 560., 688.],
         [486., 630., 774.],
         [540., 700., 860.],
         [594., 770., 946.]]])

##### <div id='2d_other'>Other composite operations. *('ij, jk -> ij')*</div>

Please notice the surviving axes: the result will have same dimensions as A.

In [57]:
torch.einsum('ij, jk -> ij', A, B.T)

tensor([[  0.,  51., 108., 171.],
        [192., 255., 324., 399.],
        [384., 459., 540., 627.]])

This happens because *i* by *j* matrix A is multiplied against a vector resulting by the contraction of B.T on its first index.

In [58]:
torch.einsum('jk', B.T) # original matrix B.T

tensor([[12., 16., 20.],
        [13., 17., 21.],
        [14., 18., 22.],
        [15., 19., 23.]])

In [59]:
torch.einsum('jk -> j', B.T) # contraction on j

tensor([48., 51., 54., 57.])

In [60]:
B_T_j = torch.einsum('jk -> j', B.T)
torch.einsum('ij, j -> ij', A, B_T_j) # simple product

tensor([[  0.,  51., 108., 171.],
        [192., 255., 324., 399.],
        [384., 459., 540., 627.]])

### <div id='3d'>3D Tensor</div>

The previous assumptions can be extended to higher rank tensors.

In [61]:
A = torch.Tensor([
    [
        [ 0,  1,  2,  3],
        [ 4,  5,  6,  7],
        [ 8,  9, 10, 11]
    ],
    [
        [ 12,  13,  14,  15],
        [ 16,  17, 18, 19],
        [ 20, 21, 22, 23]
    ]

])
print(f"A.shape:{A.shape}")

A.shape:torch.Size([2, 3, 4])


Remind that alphabetical ordering leads the priority of the indices. "ijk" is different from "jki".

##### <div id='3d_identity'>Identity</div>

In [62]:
torch.einsum('ijk', A)

tensor([[[ 0.,  1.,  2.,  3.],
         [ 4.,  5.,  6.,  7.],
         [ 8.,  9., 10., 11.]],

        [[12., 13., 14., 15.],
         [16., 17., 18., 19.],
         [20., 21., 22., 23.]]])

In [63]:
torch.einsum('ijk -> ijk', A)

tensor([[[ 0.,  1.,  2.,  3.],
         [ 4.,  5.,  6.,  7.],
         [ 8.,  9., 10., 11.]],

        [[12., 13., 14., 15.],
         [16., 17., 18., 19.],
         [20., 21., 22., 23.]]])

##### Transpositions

Swap axes *j* and *k*: axes transposition

In [64]:
torch.einsum('i kj', A)

tensor([[[ 0.,  4.,  8.],
         [ 1.,  5.,  9.],
         [ 2.,  6., 10.],
         [ 3.,  7., 11.]],

        [[12., 16., 20.],
         [13., 17., 21.],
         [14., 18., 22.],
         [15., 19., 23.]]])

##### Diagonals

In [65]:
torch.einsum('ijk->i', A)

tensor([ 66., 210.])

In [66]:
torch.einsum('ijk ->i', A)

tensor([ 66., 210.])

##### <div id='3d_contraction'>Tensor contraction</div>

Just another product against some axes

In [67]:
B = torch.Tensor([
    [
        [ 10,  11,  12,  13],
        [ 14,  15,  16,  17],
        [ 18,  19, 20, 21]
    ],
    [
        [ 32,  33,  34,  35],
        [ 36,  37, 38, 39],
        [ 40, 41, 42, 43]
    ]

])
print(f"B.shape:{B.shape}")

B.shape:torch.Size([2, 3, 4])


In [68]:
print(f"A.shape:{list(A.shape)}, B.shape:{list(B.shape)}")

A.shape:[2, 3, 4], B.shape:[2, 3, 4]


In [69]:
torch.einsum('ijk, ijl->kl', A, B)

tensor([[1960., 2020., 2080., 2140.],
        [2110., 2176., 2242., 2308.],
        [2260., 2332., 2404., 2476.],
        [2410., 2488., 2566., 2644.]])

##### <div id='3d_batch_mat_prod'>Batch matrix multiplication</div>

In [70]:
A = torch.arange(8*5).reshape(8,5)
print(f"A.shape:{list(A.shape)}")
A

A.shape:[8, 5]


tensor([[ 0,  1,  2,  3,  4],
        [ 5,  6,  7,  8,  9],
        [10, 11, 12, 13, 14],
        [15, 16, 17, 18, 19],
        [20, 21, 22, 23, 24],
        [25, 26, 27, 28, 29],
        [30, 31, 32, 33, 34],
        [35, 36, 37, 38, 39]])

In [71]:
B = torch.arange(5*3).reshape(5,3)
print(f"B.shape:{list(B.shape)}")
B

B.shape:[5, 3]


tensor([[ 0,  1,  2],
        [ 3,  4,  5],
        [ 6,  7,  8],
        [ 9, 10, 11],
        [12, 13, 14]])

Let's see a normal matrix multiplication between A, (8,5) and B, (5,3).

In [72]:
A@B

tensor([[  90,  100,  110],
        [ 240,  275,  310],
        [ 390,  450,  510],
        [ 540,  625,  710],
        [ 690,  800,  910],
        [ 840,  975, 1110],
        [ 990, 1150, 1310],
        [1140, 1325, 1510]])

Now split A in 4 chunks and compute the same multiplication as above, but batched.

In [80]:
A_batched = A.reshape(4,2,5) # cut in four
A_batched

tensor([[[ 0,  1,  2,  3,  4],
         [ 5,  6,  7,  8,  9]],

        [[10, 11, 12, 13, 14],
         [15, 16, 17, 18, 19]],

        [[20, 21, 22, 23, 24],
         [25, 26, 27, 28, 29]],

        [[30, 31, 32, 33, 34],
         [35, 36, 37, 38, 39]]])

In [74]:
print(f"A_batched.shape:{list(A_batched.shape)}, B.shape:{list(B.shape)}")

A_batched.shape:[4, 2, 5], B.shape:[5, 3]


In [81]:
torch.einsum('ijk, ku->iju', A_batched,B)

tensor([[[  90,  100,  110],
         [ 240,  275,  310]],

        [[ 390,  450,  510],
         [ 540,  625,  710]],

        [[ 690,  800,  910],
         [ 840,  975, 1110]],

        [[ 990, 1150, 1310],
         [1140, 1325, 1510]]])