$\color{brown}{Preamble}$: At the time of this writing, I'm using **PyTorch** **`v1.7.1`** binded with **`cuda11.0`** and **`cudnn8.0`**.

In [None]:
import numpy as np
import torch

In [None]:
print("version: ", torch.__version__)
mydevice = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("device : ", mydevice)

## Einstein Summation (einsum)

Einsum is a powerful concept for processing tensors while at the same time writing very succinct code. The reasons to adopt **_einsum_** are:

  - the code is usually one-liner
  - it's memory efficient
  - less error-prone

Let us now see some sample problems to fully grasp the power of einsum.

#### **inputs**

In [None]:
# some input tensors to work with

vec = torch.tensor([0, 1, 2, 3])

aten = torch.tensor([[11, 12, 13, 14],
                     [21, 22, 23, 24],
                     [31, 32, 33, 34],
                     [41, 42, 43, 44]])

bten = torch.tensor([[1, 1, 1, 1],
                     [2, 2, 2, 2],
                     [3, 3, 3, 3],
                     [4, 4, 4, 4]])

-------------
### matrix multiplication

  $\textbf{C}_{ik}$  =  $\sum_{j}$ $\textbf{A}_{i\color{green}{j}}$ * $\textbf{B}_{\color{green}{j}k}$

For a matrix multiplication to work, the number of columns in the first matrix (e.g., $A$) should match the number of rows in the second matrix (e.g., $B$).

In [None]:
c = torch.einsum('ij, jk -> ik', aten, bten)
print("einsum matmul: \n", c)

# sanity check
c = torch.matmul(aten, bten)   # or: aten.mm(bten)
print("torch matmul: \n", c)

------------------
### hadamard product (*i.e.,* element-wise product of tensors)
$\textbf{C}_{ij}$  =  $\textbf{A1}_{ij}$ * $\textbf{A2}_{ij}$ * ... *  $\textbf{AN}_{ij}$

In [None]:
hp = torch.einsum('ij, ij, i -> ij', aten, bten, vec)     # note: `vec` is treated as a column vector
print("einsum hadamard product: \n", hp)

# sanity check
ep = aten * bten * vec[:, None]
print("element-wise product: \n", ep)

$\color{brown}{Note}$: we can raise the elements of a tensor to power `n` by repeating the tensor `n` times. For instance, a tensor can be *cubed* by repeating it 3 times.

In [None]:
hp = torch.einsum('ij, ij, ij -> ij', bten, bten, bten)
print("einsum hadamard product: \n", hp)

# sanity check
ep = bten * bten * bten
print("element-wise product: \n", ep)