<a href="https://colab.research.google.com/github/iskra3138/ImageSr/blob/master/einsum.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
#%tensorflow_version 2.x
import tensorflow as tf

TensorFlow 2.x selected.


In [None]:
import torch
import numpy as np

# EINSUM IS ALL YOU NEED - EINSTEIN SUMMATION IN DEEP LEARNING

– Tim Rocktäschel, 30/04/2018 – updated 02/05/2018

https://rockt.github.io/2018/04/30/einsum

When talking to colleagues I realized that not everyone knows about einsum, my favorite function for developing deep learning models. This post is trying to change that once and for all! :) Einstein summation (einsum) is implemented in numpy, as well as deep learning libraries such as TensorFlow and, thanks to Thomas Viehmann, recently also PyTorch. For background reading on einsum, I recommend the excellent blog posts by Olexa Bilaniuk and Alex Riley. While their posts discuss einsum in the context of numpy, I am going to illustrate how einsum is extremely useful for writing elegant PyTorch/TensorFlow models.1

## 1 EINSUM NOTATION

If you are anything like me, you find it difficult to remember the names and signatures of all the different functions in PyTorch/TensorFlow for calculating dot products, outer products, transposes and matrix-vector or matrix-matrix multiplications. Einsum notation is an elegant way to express all of these, as well as complex operations on tensors, using essentially a domain-specific language. This has benefits beyond not having to memorize or regularly looking up specific library functions. Once you understand and make use of einsum, you will be able to write more concise and efficient code more quickly. When not using einsum it is easy to introduce unnecessary reshaping and transposing of tensors, as well as intermediate tensors that could be omitted. Furthermore, domain-specific languages like einsum can sometimes be compiled to high-performing code, and an einsum-like domain-specific language is in fact the basis for the recently introduced Tensor Comprehensions3 in PyTorch which automatically generate GPU code and auto-tune that code for specific input sizes. In addition, projects like opt einsum and tf einsum opt can be used to optimize tensor contraction order of einsum expressions.4

Let's say we want to multiply two matrices ${\color{red}\mathbf{A}}\in\mathbb{R}^{I\,\times\,K}$ and ${\color{blue}\mathbf{B}}\in\mathbb{R}^{K\,\times\,J}$ followed by calculating the sum of each column resulting in a vector ${\color{green}\mathbf{c}}\in\mathbb{R}^{J}$ Using Einstein summation notation, we can write this as

${\color{green}c_j} = \sum_i\sum_k {\color{red}A_{ik}}{\color{blue}B_{kj}} = {\color{red}A_{ik}}{\color{blue}B_{kj}}$

which specifies how all individual elements ci in c are calculated from multiplying values in the column vectors Ai: and row vectors B:j and summing them up. Note that for Einstein notation, the summation Sigmas can be dropped as we implicitly sum over repeated indices (k in this example) and indices not mentioned in the output specification (i in this example). So far so good, but we can also express more basic operations using einsum. For instance, calculating the dot product of two vectors a,b∈RI can be written as

${\color{green}c} = \sum_i {\color{red}a_i}{\color{blue}b_i} = {\color{red}a_i}{\color{blue}b_i}.$

A problem that I encounter often in deep learning is applying a transformation to vectors in a higher-order tensor. For example, I might have a tensor that contains T-long sequences of K-dimensional word vectors for N training examples in a batch and I want to project the word vectors to a different dimension Q. Let T∈RN×T×K be an order-3 tensor where the first dimension corresponds to the batch, the second dimension to the sequence length, and the last dimension to the word vectors. In addition, let W∈RK×Q be a projection matrix. The desired computation can be expressed using einsum

${\color{green}C_{ntq}} = \sum_k {\color{red}T_{ntk}}{\color{blue}W_{kq}} = {\color{red}T_{ntk}}{\color{blue}W_{kq}}.$

As a final example, say you are given an order-4 tensor T∈RN×T×K×M and you are supposed to project vectors in the 3rd dimension to Q using the projection matrix from before. However, let's say I also ask you to sum over the 2nd dimension and transpose the first and last dimension in the result, yielding a tensor C∈RM×,Q×N.5 Einsum to the rescue!

${\color{green}C_{mqn}} = \sum_t\sum_k {\color{red}T_{ntkm}}{\color{blue}W_{kq}} = {\color{red}T_{ntkm}}{\color{blue}W_{kq}}.$

Note that transposing the result of the tensor contraction is achieved by swapping n with m (Cmqn instead of Cnqm).

## 2 ALL YOU NEED: EINSUM IN NUMPY, PYTORCH, AND TENSORFLOW

Einsum is implemented in numpy via np.einsum, in PyTorch via torch.einsum, and in TensorFlow via tf.einsum.6 All three einsum functions share the same signature einsum(equation,operands) where equation is a string representing the Einstein summation and operands is a sequence of tensors.7 The examples above can all be written using an equation string. For instance, our first example cj=∑i∑kAikBkj can be written as the equation string "ik,kj -> j". Note that the naming of the indices (i, j, k) is arbitrary but it needs to be used consistently.

What's great about having einsum not only in numpy but also in PyTorch and TensorFlow is that it can be used in arbitrary computation graphs for neural network architectures and that we can backpropagate through it. A typical call to einsum has the following form

${\color{green}\textbf{result}} = \text{einsum}("{\color{red}\square\square},{\color{purple}\square\square\square},{\color{blue}\square\square}\,\text{->}\,{\color{green}\square\square}", {\color{red}\text{arg1}}, {\color{purple}\text{arg2}}, {\color{blue}\text{arg3}})$

result=einsum("□□,□□□,□□->□□",arg1,arg2,arg3)


where □ is a placeholder for a character identifying a tensor dimension. From this equation string we can infer that arg1 and arg3 are matrices, arg2 is an order-3 tensor, and that the result of this einsum operation is a matrix. Note that einsum works with a variable number of inputs. In the example above, einsum specifies an operation on three arguments, but it can also be used for operations involving one, two or more than three arguments. Einsum is best learned by studying examples, so let's go through some examples for einsum in PyTorch that correspond to library functions which are used in many deep learning models.

### 2.1 MATRIX TRANSPOSE
${\color{green}B_{ji}} = {\color{red}A_{ij}}$

In [None]:
## Numpy 
a = np.arange(6).reshape(2, 3)
b = np.einsum('ij->ji', a)
print ("original", a)
print ("transpose", b)

original [[0 1 2]
 [3 4 5]]
transpose [[0 3]
 [1 4]
 [2 5]]


In [None]:
## PyTorch 
a = torch.arange(6).reshape(2, 3)
b =torch.einsum('ij->ji', [a])
print ("original", a)
print ("transpose", b)

original tensor([[0, 1, 2],
        [3, 4, 5]])
transpose tensor([[0, 3],
        [1, 4],
        [2, 5]])


In [None]:
## Tensorflow 
a = tf.reshape(tf.range(6),(2, 3))
b = tf.einsum('ij->ji', a)
print ("original", a)
print ("transpose", b)

original tf.Tensor(
[[0 1 2]
 [3 4 5]], shape=(2, 3), dtype=int32)
transpose tf.Tensor(
[[0 3]
 [1 4]
 [2 5]], shape=(3, 2), dtype=int32)


### 2.2 SUM
${\color{green}b} = \sum_i\sum_j {\color{red}A_{ij}} = {\color{red}A_{ij}}$

In [None]:
## Numpy
a = np.arange(6).reshape(2, 3)
b = np.einsum('ij->', a)

print ("original", a)
print ("sum", b)

original [[0 1 2]
 [3 4 5]]
sum 15


In [None]:
## PyTorch
a = torch.arange(6).reshape(2, 3)
b = torch.einsum('ij->', [a])

print ("original", a)
print ("sum", b)

original tensor([[0, 1, 2],
        [3, 4, 5]])
sum tensor(15)


In [None]:
## Tensorflow 
a = tf.range(6)
a= tf.reshape(a, (2, 3))
b = tf.einsum('ij->', a)

print ("original", a)
print ("sum", b)

original tf.Tensor(
[[0 1 2]
 [3 4 5]], shape=(2, 3), dtype=int32)
sum tf.Tensor(15, shape=(), dtype=int32)


### 2.3 COLUMN SUM
${\color{green}b_j} = \sum_i {\color{red}A_{ij}} = {\color{red}A_{ij}}$

In [None]:
## Numpy
a = np.arange(6).reshape(2, 3)
b = np.einsum('ij->j', a)

print ("original", a)
print ("column sum", b)

original [[0 1 2]
 [3 4 5]]
column sum [3 5 7]


In [None]:
## PyTorch
a = torch.arange(6).reshape(2, 3)
b = torch.einsum('ij->j', [a])

print ("original", a)
print ("column sum", b)

original tensor([[0, 1, 2],
        [3, 4, 5]])
column sum tensor([3, 5, 7])


In [None]:
## Tensorflow 
a = tf.range(6)
a = tf.reshape(a,(2, 3))
b = tf.einsum('ij->j', a)

print ("original", a)
print ("column sum", b)

original tf.Tensor(
[[0 1 2]
 [3 4 5]], shape=(2, 3), dtype=int32)
column sum tf.Tensor([3 5 7], shape=(3,), dtype=int32)


### 2.4 ROW SUM
${\color{green}b_i} = \sum_j {\color{red}A_{ij}} = {\color{red}A_{ij}}$

In [None]:
## Numpy
a = np.arange(6).reshape(2, 3)
b = np.einsum('ij->i', a)

print ("original", a)
print ("row sum", b)

original [[0 1 2]
 [3 4 5]]
row sum [ 3 12]


In [None]:
## PyTorch
a = torch.arange(6).reshape(2, 3)
b = torch.einsum('ij->i', [a])

print ("original", a)
print ("row sum", b)

original tensor([[0, 1, 2],
        [3, 4, 5]])
row sum tensor([ 3, 12])


In [None]:
## Tensorflow 
a = tf.range(6)
a = tf.reshape(a, (2, 3))
b = tf.einsum('ij->i', a)

print ("original", a)
print ("row sum", b)

original tf.Tensor(
[[0 1 2]
 [3 4 5]], shape=(2, 3), dtype=int32)
row sum tf.Tensor([ 3 12], shape=(2,), dtype=int32)


### 2.5 MATRIX-VECTOR MULTIPLICATION
${\color{green}c_i} = \sum_k {\color{red}A_{ik}}{\color{blue}b_k} = {\color{red}A_{ik}}{\color{blue}b_k}$

In [None]:
## Numpy
a = np.arange(6).reshape(2, 3)
b = np.arange(3)
c = np.einsum('ik,k->i', a, b)
# tensor([  5.,  14.])

print (a)
print (b)
print (c)

[[0 1 2]
 [3 4 5]]
[0 1 2]
[ 5 14]


In [None]:
## PyTorch
a = torch.arange(6).reshape(2, 3)
b = torch.arange(3)
c = torch.einsum('ik,k->i', [a, b])
# tensor([  5.,  14.])

print (a)
print (b)
print (c)

tensor([[0, 1, 2],
        [3, 4, 5]])
tensor([0, 1, 2])
tensor([ 5, 14])


In [None]:
## Tensorflow 
a = tf.range(6)
a = tf.reshape(a, (2, 3))
b = tf.range(3)
c = tf.einsum('ik,k->i', a, b)
# tensor([  5.,  14.])

print (a)
print (b)
print (c)

tf.Tensor(
[[0 1 2]
 [3 4 5]], shape=(2, 3), dtype=int32)
tf.Tensor([0 1 2], shape=(3,), dtype=int32)
tf.Tensor([ 5 14], shape=(2,), dtype=int32)


### 2.6 MATRIX-MATRIX MULTIPLICATION
${\color{green}C_{ij}} = \sum_k {\color{red}A_{ik}}{\color{blue}B_{kj}} =  {\color{red}A_{ik}}{\color{blue}B_{kj}}$

In [None]:
## Numpy
a = np.arange(6).reshape(2, 3)
b = np.arange(15).reshape(3, 5)
c = np.einsum('ik,kj->ij', a, b)
'''
tensor([[  25.,   28.,   31.,   34.,   37.],
        [  70.,   82.,   94.,  106.,  118.]])
'''
print (a)
print (b)
print (c)

[[0 1 2]
 [3 4 5]]
[[ 0  1  2  3  4]
 [ 5  6  7  8  9]
 [10 11 12 13 14]]
[[ 25  28  31  34  37]
 [ 70  82  94 106 118]]


In [None]:
## PyTorch
a = torch.arange(6).reshape(2, 3)
b = torch.arange(15).reshape(3, 5)
c = torch.einsum('ik,kj->ij', [a, b])
'''
tensor([[  25.,   28.,   31.,   34.,   37.],
        [  70.,   82.,   94.,  106.,  118.]])
'''
print (a)
print (b)
print (c)

tensor([[0, 1, 2],
        [3, 4, 5]])
tensor([[ 0,  1,  2,  3,  4],
        [ 5,  6,  7,  8,  9],
        [10, 11, 12, 13, 14]])
tensor([[ 25,  28,  31,  34,  37],
        [ 70,  82,  94, 106, 118]])


In [None]:
## Tensorflow 
a = tf.range(6)
a = tf.reshape(a, (2, 3))
b = tf.range(15)
b = tf.reshape(b, (3, 5))
c = tf.einsum('ik,kj->ij', a, b)
'''
tensor([[  25.,   28.,   31.,   34.,   37.],
        [  70.,   82.,   94.,  106.,  118.]])
'''
print (a)
print (b)
print (c)

tf.Tensor(
[[0 1 2]
 [3 4 5]], shape=(2, 3), dtype=int32)
tf.Tensor(
[[ 0  1  2  3  4]
 [ 5  6  7  8  9]
 [10 11 12 13 14]], shape=(3, 5), dtype=int32)
tf.Tensor(
[[ 25  28  31  34  37]
 [ 70  82  94 106 118]], shape=(2, 5), dtype=int32)


### 2.7 DOT PRODUCT

Vector:

${\color{green}c} = \sum_i {\color{red}a_i\color{blue}b_i} = {\color{red}a_i\color{blue}b_i}$

In [None]:
## Numpy
a = np.arange(3)
b = np.arange(3,6)  # -- a vector of length 3 containing [3, 4, 5]
c = np.einsum('i,i->', a, b)

print (a)
print (b)
print (c)

'''
tensor(14.)
'''

[0 1 2]
[3 4 5]
14


'\ntensor(14.)\n'

In [None]:
## PyTorch
a = torch.arange(3)
b = torch.arange(3,6)  # -- a vector of length 3 containing [3, 4, 5]
c = torch.einsum('i,i->', [a, b])

print (a)
print (b)
print (c)

'''
tensor(14.)
'''

tensor([0, 1, 2])
tensor([3, 4, 5])
tensor(14)


In [None]:
## Tensorflow 
a = tf.range(3)
b = tf.range(3,6)  # -- a vector of length 3 containing [3, 4, 5]
c = tf.einsum('i,i->', a, b)

print (a)
print (b)
print (c)

'''
tensor(14.)
'''

tf.Tensor([0 1 2], shape=(3,), dtype=int32)
tf.Tensor([3 4 5], shape=(3,), dtype=int32)
tf.Tensor(14, shape=(), dtype=int32)


'\ntensor(14.)\n'

Matrix:

${\color{green}c} = \sum_i\sum_j {\color{red}A_{ij}\color{blue}B_{ij}} = {\color{red}A_{ij}\color{blue}B_{ij}}$

In [None]:
## Numpy
a = np.arange(6).reshape(2, 3)
b = np.arange(6,12).reshape(2, 3)
c = np.einsum('ij,ij->', a, b)
'''
tensor(145.)
'''
print (a)
print (b)
print (c)

[[0 1 2]
 [3 4 5]]
[[ 6  7  8]
 [ 9 10 11]]
145


In [None]:
## PyTorch
a = torch.arange(6).reshape(2, 3)
b = torch.arange(6,12).reshape(2, 3)
c = torch.einsum('ij,ij->', [a, b])
'''
tensor(145.)
'''
print (a)
print (b)
print (c)

tensor([[0, 1, 2],
        [3, 4, 5]])
tensor([[ 6,  7,  8],
        [ 9, 10, 11]])
tensor(145)


In [None]:
## Tensorflow 
a = tf.range(6)
a = tf.reshape(a, (2, 3))
b = tf.range(6,12)
b = tf.reshape(b, (2, 3))
c = tf.einsum('ij,ij->', a, b)
'''
tensor(145.)
'''
print (a)
print (b)
print (c)

tf.Tensor(
[[0 1 2]
 [3 4 5]], shape=(2, 3), dtype=int32)
tf.Tensor(
[[ 6  7  8]
 [ 9 10 11]], shape=(2, 3), dtype=int32)
tf.Tensor(145, shape=(), dtype=int32)


### 2.8 HADAMARD PRODUCT
${\color{green}C_{ij}} = {\color{red}A_{ij}\color{blue}B_{ij}}$

In [None]:
## Numpy
a = np.arange(6).reshape(2, 3)
b = np.arange(6,12).reshape(2, 3)
c = np.einsum('ij,ij->ij', a, b)
'''
tensor([[  0.,   7.,  16.],
        [ 27.,  40.,  55.]])
'''
print (a)
print (b)
print (c)

[[0 1 2]
 [3 4 5]]
[[ 6  7  8]
 [ 9 10 11]]
[[ 0  7 16]
 [27 40 55]]


In [None]:
## PyTorch
a = torch.arange(6).reshape(2, 3)
b = torch.arange(6,12).reshape(2, 3)
c = torch.einsum('ij,ij->ij', [a, b])
'''
tensor([[  0.,   7.,  16.],
        [ 27.,  40.,  55.]])
'''
print (a)
print (b)
print (c)

tensor([[0, 1, 2],
        [3, 4, 5]])
tensor([[ 6,  7,  8],
        [ 9, 10, 11]])
tensor([[ 0,  7, 16],
        [27, 40, 55]])


In [None]:
## Tensorflow 
a = tf.range(6)
a = tf.reshape(a, (2, 3))
b = tf.range(6,12)
b = tf.reshape(b, (2, 3))
c = tf.einsum('ij,ij->ij', a, b)
'''
tensor([[  0.,   7.,  16.],
        [ 27.,  40.,  55.]])
'''
print (a)
print (b)
print (c)

tf.Tensor(
[[0 1 2]
 [3 4 5]], shape=(2, 3), dtype=int32)
tf.Tensor(
[[ 6  7  8]
 [ 9 10 11]], shape=(2, 3), dtype=int32)
tf.Tensor(
[[ 0  7 16]
 [27 40 55]], shape=(2, 3), dtype=int32)


### 2.9 OUTER PRODUCT

${\color{green}C_{ij}} = {\color{red}a_i\color{blue}b_j}$

In [None]:
## Numpy
a = np.arange(3)
b = np.arange(3,7)  # -- a vector of length 4 containing [3, 4, 5, 6]
c = np.einsum('i,j->ij', a, b)

'''
tensor([[  0.,   0.,   0.,   0.],
        [  3.,   4.,   5.,   6.],
        [  6.,   8.,  10.,  12.]])
'''
print (a)
print (b)
print (c)

[0 1 2]
[3 4 5 6]
[[ 0  0  0  0]
 [ 3  4  5  6]
 [ 6  8 10 12]]


In [None]:
## PyTorch
a = torch.arange(3)
b = torch.arange(3,7)  # -- a vector of length 4 containing [3, 4, 5, 6]
c = torch.einsum('i,j->ij', [a, b])

'''
tensor([[  0.,   0.,   0.,   0.],
        [  3.,   4.,   5.,   6.],
        [  6.,   8.,  10.,  12.]])
'''
print (a)
print (b)
print (c)

tensor([0, 1, 2])
tensor([3, 4, 5, 6])
tensor([[ 0,  0,  0,  0],
        [ 3,  4,  5,  6],
        [ 6,  8, 10, 12]])


In [None]:
## Tensorflow 
a = tf.range(3)
b = tf.range(3,7)  # -- a vector of length 4 containing [3, 4, 5, 6]
c = tf.einsum('i,j->ij', a, b)

'''
tensor([[  0.,   0.,   0.,   0.],
        [  3.,   4.,   5.,   6.],
        [  6.,   8.,  10.,  12.]])
'''
print (a)
print (b)
print (c)

tf.Tensor([0 1 2], shape=(3,), dtype=int32)
tf.Tensor([3 4 5 6], shape=(4,), dtype=int32)
tf.Tensor(
[[ 0  0  0  0]
 [ 3  4  5  6]
 [ 6  8 10 12]], shape=(3, 4), dtype=int32)


### 2.10 BATCH MATRIX MULTIPLICATION

${\color{green}C_{ijl}} = \sum_k{\color{red}A_{ijk}\color{blue}B_{ikl}} = {\color{red}A_{ijk}\color{blue}B_{ikl}}$

In [None]:
## Numpy
a = np.random.randn(3,2,5)
b = np.random.randn(3,5,3)
c = np.einsum('ijk,ikl->ijl', a, b)

print (a.shape)
print (b.shape)
print (c)

(3, 2, 5)
(3, 5, 3)
[[[-0.95550252  2.93340872 -2.66304665]
  [-1.24365154 -2.71546213 -2.1199949 ]]

 [[-0.55838726 -1.43049071 -0.86375113]
  [ 0.68469286 -2.76500144  1.6191076 ]]

 [[-4.58631627 -0.1202802   2.91707242]
  [ 0.46530767 -3.80597532 -1.13911806]]]


In [None]:
## PyTorch
a = torch.randn(3,2,5)
b = torch.randn(3,5,3)
c = torch.einsum('ijk,ikl->ijl', [a, b])

print (a.shape)
print (b.shape)
print (c)

torch.Size([3, 2, 5])
torch.Size([3, 5, 3])
tensor([[[ 1.0160,  0.5631,  0.1172],
         [-0.2825, -2.3514, -0.1016]],

        [[ 3.1685, -2.4267,  1.8415],
         [ 0.8695,  1.9025, -0.6790]],

        [[ 0.1533, -8.8327, -2.9347],
         [-0.3107,  1.7491,  1.3890]]])


In [None]:
## Tensorflow 
a = tf.random.normal((3,2,5))
b = tf.random.normal((3,5,3))
c = tf.einsum('ijk,ikl->ijl', a, b)

print (a.shape)
print (b.shape)
print (c)

(3, 2, 5)
(3, 5, 3)
tf.Tensor(
[[[-2.4720335   6.1345625   1.0750296 ]
  [ 0.9897116   0.3400598   1.6981529 ]]

 [[ 1.9287645  -1.9831697   0.4655583 ]
  [-2.8058667   0.9074414  -1.1197346 ]]

 [[-1.1438993  -0.971941   -0.22273944]
  [-0.29590404 -0.22368309 -0.9517663 ]]], shape=(3, 2, 3), dtype=float32)


### 2.11 TENSOR CONTRACTION

Batch matrix multiplication is a special case of a tensor contraction. Let's say we have two tensors, an order-n tensor A∈RI1×⋯×In and an order-m tensor B∈RJ1×⋯×Im. As an example, take n=4, m=5 and assume that I2=J3 and I3=J5. We can multiply the two tensors in these two dimensions (2 and 3 for A and 3 and 5 for B) resulting in a new tensor C∈RI1×I4×J1×J2×J4 as follows

${\color{green}C_{pstuv}} = \sum_q\sum_r{\color{red}A_{pqrs}\color{blue}B_{tuqvr}} = {\color{red}A_{pqrs}\color{blue}B_{tuqvr}}$

In [None]:
## Numpy
a = np.random.randn(2,3,5,7)
b = np.random.randn(11,13,3,17,5)
c = np.einsum('pqrs,tuqvr->pstuv', a, b).shape

print (a.shape)
print (b.shape)
print (c)

(2, 3, 5, 7)
(11, 13, 3, 17, 5)
(2, 7, 11, 13, 17)


In [None]:
## PyTorch
a = torch.randn(2,3,5,7)
b = torch.randn(11,13,3,17,5)
c = torch.einsum('pqrs,tuqvr->pstuv', [a, b]).shape

print (a.shape)
print (b.shape)
print (c)

torch.Size([2, 3, 5, 7])
torch.Size([11, 13, 3, 17, 5])
torch.Size([2, 7, 11, 13, 17])


In [None]:
## Tensorflow 
a = tf.random.normal((2,3,5,7))
b = tf.random.normal((11,13,3,17,5))
c = tf.einsum('pqrs,tuqvr->pstuv', a, b).shape

print (a.shape)
print (b.shape)
print (c)

(2, 3, 5, 7)
(11, 13, 3, 17, 5)
(2, 7, 11, 13, 17)


### 2.12 BILINEAR TRANSFORMATION

As mentioned earlier, einsum can operate on more than two tensors. One example where this is used is bilinear transformation.

${\color{green}D_{ij}} = \sum_k\sum_l{\color{red}A_{ik}}{\color{purple}B_{jkl}}{\color{blue}C_{il}} = {\color{red}A_{ik}}{\color{purple}B_{jkl}}{\color{blue}C_{il}}$


In [None]:
## Numpy
a = np.random.randn(2,3)
b = np.random.randn(5,3,7)
c = np.random.randn(2,7)
d = np.einsum('ik,jkl,il->ij', a, b, c)

print (a.shape)
print (b.shape)
print (c.shape)
print (d)

(2, 3)
(5, 3, 7)
(2, 7)
[[10.95754188 17.36015475 -2.14860697 13.87924475 -2.19096659]
 [-1.10782111  3.95644448 -0.66453269  0.40203265 -2.73663174]]


In [None]:
## PyTorch
a = torch.randn(2,3)
b = torch.randn(5,3,7)
c = torch.randn(2,7)
d = torch.einsum('ik,jkl,il->ij', [a, b, c])

print (a.shape)
print (b.shape)
print (c.shape)
print (d)

torch.Size([2, 3])
torch.Size([5, 3, 7])
torch.Size([2, 7])
tensor([[-6.7528e+00, -8.6984e-01, -4.3319e-03, -5.8195e-02,  6.3494e+00],
        [-6.5004e-01,  1.3124e+00,  1.4814e+00,  1.7223e+00,  3.2651e+00]])


In [None]:
## Tensorflow 
a = tf.random.normal((2,3))
b = tf.random.normal((5,3,7))
c = tf.random.normal((2,7))
d = tf.einsum('ik,jkl,il->ij', a, b, c)

print (a.shape)
print (b.shape)
print (c.shape)
print (d)

(2, 3)
(5, 3, 7)
(2, 7)
tf.Tensor(
[[-0.08772419  2.6937442   3.023434    0.54483443 -3.9494355 ]
 [ 1.1517725   0.30645236 -1.321343    2.4934564  -0.03937294]], shape=(2, 5), dtype=float32)
