# Einstein summation notation

Rules:

- two parts, inputs and outputs separated by `->`
- letters represent axis - hence `ijk` is a 3-tensor
- sum over repeated letters in inputs
- sum over letters not in output

The result of an `einsum` can  generally be understood as a element-wise multiplication of two tensors (with broadcasting) followed by a sum over zero or more axes. It turns out that this is a surprisingly flexible way of executing many common array operations.

It is easier to understand with examples.

In [1]:
import numpy as np

## Set up a vector and a matrix

In [2]:
x = np.arange(1, 4)
x

array([1, 2, 3])

In [3]:
M = np.arange(1, 10).reshape(3,3)
M

array([[1, 2, 3],
       [4, 5, 6],
       [7, 8, 9]])

## Working with vectors

In [73]:
np.einsum('i->i', x)

array([1, 2, 3])

In [5]:
np.einsum('i->', x)

6

In [6]:
np.sum(x)

6

In [7]:
np.einsum('i,j->ij', x, x)

array([[1, 2, 3],
       [2, 4, 6],
       [3, 6, 9]])

In [8]:
np.outer(x, x)

array([[1, 2, 3],
       [2, 4, 6],
       [3, 6, 9]])

In [9]:
x[:,None] * x[None, :]

array([[1, 2, 3],
       [2, 4, 6],
       [3, 6, 9]])

In [87]:
import time
start = time.time()
np.einsum('i,j->ji', x, x)
time.time()-start

9.1552734375e-05

In [88]:
import time
start = time.time()
np.outer(x, x).T
time.time()-start

0.00010943412780761719

In [12]:
(x[:,None] * x[None, :]).T

array([[1, 2, 3],
       [2, 4, 6],
       [3, 6, 9]])

In [13]:
np.einsum('i,i->', x, x)

14

In [14]:
np.inner(x, x)

14

In [15]:
(x * x).sum()

14

In [16]:
np.einsum('i,i->i', x, x)

array([1, 4, 9])

In [17]:
x * x

array([1, 4, 9])

## Working with matrices

In [18]:
np.einsum('ii->i', M)

array([1, 5, 9])

In [19]:
np.diag(M)

array([1, 5, 9])

In [20]:
np.einsum('ij->ji', M)

array([[1, 4, 7],
       [2, 5, 8],
       [3, 6, 9]])

In [21]:
np.transpose(M)

array([[1, 4, 7],
       [2, 5, 8],
       [3, 6, 9]])

In [22]:
np.dot(M, M)

array([[ 30,  36,  42],
       [ 66,  81,  96],
       [102, 126, 150]])

In [23]:
np.einsum('ij,jk->ik', M, M)

array([[ 30,  36,  42],
       [ 66,  81,  96],
       [102, 126, 150]])

In [24]:
np.einsum('ij,ji->', M, M)

261

In [25]:
(M*M.T).sum()

261

In [26]:
np.einsum('ij,ij->', M, M)

285

In [27]:
(M*M).sum()

285

In [28]:
np.einsum('ij,jk->ijk', M, M)

array([[[ 1,  2,  3],
        [ 8, 10, 12],
        [21, 24, 27]],

       [[ 4,  8, 12],
        [20, 25, 30],
        [42, 48, 54]],

       [[ 7, 14, 21],
        [32, 40, 48],
        [63, 72, 81]]])

In [29]:
M[...,None] * M

array([[[ 1,  2,  3],
        [ 8, 10, 12],
        [21, 24, 27]],

       [[ 4,  8, 12],
        [20, 25, 30],
        [42, 48, 54]],

       [[ 7, 14, 21],
        [32, 40, 48],
        [63, 72, 81]]])

In [30]:
np.einsum('ij,kl->ijkl', M, M)

array([[[[ 1,  2,  3],
         [ 4,  5,  6],
         [ 7,  8,  9]],

        [[ 2,  4,  6],
         [ 8, 10, 12],
         [14, 16, 18]],

        [[ 3,  6,  9],
         [12, 15, 18],
         [21, 24, 27]]],


       [[[ 4,  8, 12],
         [16, 20, 24],
         [28, 32, 36]],

        [[ 5, 10, 15],
         [20, 25, 30],
         [35, 40, 45]],

        [[ 6, 12, 18],
         [24, 30, 36],
         [42, 48, 54]]],


       [[[ 7, 14, 21],
         [28, 35, 42],
         [49, 56, 63]],

        [[ 8, 16, 24],
         [32, 40, 48],
         [56, 64, 72]],

        [[ 9, 18, 27],
         [36, 45, 54],
         [63, 72, 81]]]])

In [31]:
M[...,None,None] * M

array([[[[ 1,  2,  3],
         [ 4,  5,  6],
         [ 7,  8,  9]],

        [[ 2,  4,  6],
         [ 8, 10, 12],
         [14, 16, 18]],

        [[ 3,  6,  9],
         [12, 15, 18],
         [21, 24, 27]]],


       [[[ 4,  8, 12],
         [16, 20, 24],
         [28, 32, 36]],

        [[ 5, 10, 15],
         [20, 25, 30],
         [35, 40, 45]],

        [[ 6, 12, 18],
         [24, 30, 36],
         [42, 48, 54]]],


       [[[ 7, 14, 21],
         [28, 35, 42],
         [49, 56, 63]],

        [[ 8, 16, 24],
         [32, 40, 48],
         [56, 64, 72]],

        [[ 9, 18, 27],
         [36, 45, 54],
         [63, 72, 81]]]])

In [32]:
np.einsum('ij,kl->jkl', M, M)

array([[[ 12,  24,  36],
        [ 48,  60,  72],
        [ 84,  96, 108]],

       [[ 15,  30,  45],
        [ 60,  75,  90],
        [105, 120, 135]],

       [[ 18,  36,  54],
        [ 72,  90, 108],
        [126, 144, 162]]])

In [33]:
(M[:,:,None,None] * M).sum(0)

array([[[ 12,  24,  36],
        [ 48,  60,  72],
        [ 84,  96, 108]],

       [[ 15,  30,  45],
        [ 60,  75,  90],
        [105, 120, 135]],

       [[ 18,  36,  54],
        [ 72,  90, 108],
        [126, 144, 162]]])

In [34]:
np.einsum('ij,jk->ijk', M, M)

array([[[ 1,  2,  3],
        [ 8, 10, 12],
        [21, 24, 27]],

       [[ 4,  8, 12],
        [20, 25, 30],
        [42, 48, 54]],

       [[ 7, 14, 21],
        [32, 40, 48],
        [63, 72, 81]]])

In [35]:
M[...,None] * M

array([[[ 1,  2,  3],
        [ 8, 10, 12],
        [21, 24, 27]],

       [[ 4,  8, 12],
        [20, 25, 30],
        [42, 48, 54]],

       [[ 7, 14, 21],
        [32, 40, 48],
        [63, 72, 81]]])

In [36]:
np.einsum('ij,jk->ik', M, M)

array([[ 30,  36,  42],
       [ 66,  81,  96],
       [102, 126, 150]])

In [37]:
M@M

array([[ 30,  36,  42],
       [ 66,  81,  96],
       [102, 126, 150]])

In [38]:
(M[...,None] * M).sum(1)

array([[ 30,  36,  42],
       [ 66,  81,  96],
       [102, 126, 150]])

### Matrices and vectors

In [39]:
np.einsum('ij,j->i', M, x)

array([14, 32, 50])

In [40]:
np.dot(M, x)

array([14, 32, 50])

In [41]:
(M * x[None,:]).sum(1)

array([14, 32, 50])

In [42]:
np.einsum('ij,i->ij', M, x)

array([[ 1,  2,  3],
       [ 8, 10, 12],
       [21, 24, 27]])

In [43]:
M*x[:, None]

array([[ 1,  2,  3],
       [ 8, 10, 12],
       [21, 24, 27]])

In [44]:
np.einsum('ij,i->i', M, x)

array([ 6, 30, 72])

In [45]:
(M*x[:, None]).sum(1)

array([ 6, 30, 72])

In [46]:
np.einsum('ij,i->j', M, x)

array([30, 36, 42])

In [47]:
(M*x[:, None]).sum(0)

array([30, 36, 42])

## Example: Pairwise distances

In [48]:
A = np.arange(12).reshape(3,4)
B = np.arange(16).reshape(4,4)

In [49]:
A

array([[ 0,  1,  2,  3],
       [ 4,  5,  6,  7],
       [ 8,  9, 10, 11]])

In [50]:
B

array([[ 0,  1,  2,  3],
       [ 4,  5,  6,  7],
       [ 8,  9, 10, 11],
       [12, 13, 14, 15]])

`scipy.pdist` can calculate pairwise distance for a set of row vectors, but not between 2 different sets

In [51]:
from scipy.spatial.distance import pdist, squareform

In [52]:
squareform(pdist(A))

array([[ 0.,  8., 16.],
       [ 8.,  0.,  8.],
       [16.,  8.,  0.]])

In plain Python

In [53]:
res = np.zeros([A.shape[0], A.shape[0]])
for i in range(A.shape[0]):
    for j in range(A.shape[0]):
        res[i,j] = ((A[i] - A[j]) @ (A[i] - A[j]))**0.5
res

array([[ 0.,  8., 16.],
       [ 8.,  0.,  8.],
       [16.,  8.,  0.]])

We can calculate the Euclidean distance between our row vectors by broadcasting. Since we want all pairwise results, this is basically a generalized outer product where the operator is subtraction rather than multiplication. We recall that the regular outer product for a matrix `M` is given using broadcasting by `M[:,None] * M[None,:]`

In [54]:
M = np.arange(3)
M

array([0, 1, 2])

In [55]:
M[:,None] * M[None, :]

array([[0, 0, 0],
       [0, 1, 2],
       [0, 2, 4]])

The second dimension expansion is actually unnecssary since broadcasting rules take care of it.

In [56]:
M[:, None] * M

array([[0, 0, 0],
       [0, 1, 2],
       [0, 2, 4]])

We basically do the same thing for A

In [57]:
A

array([[ 0,  1,  2,  3],
       [ 4,  5,  6,  7],
       [ 8,  9, 10, 11]])

In [58]:
A[:,None,:]

array([[[ 0,  1,  2,  3]],

       [[ 4,  5,  6,  7]],

       [[ 8,  9, 10, 11]]])

In [59]:
A.shape

(3, 4)

In [60]:
A[:,None,:].shape

(3, 1, 4)

In [61]:
A_ = A[:,None,:] - A
A_

array([[[ 0,  0,  0,  0],
        [-4, -4, -4, -4],
        [-8, -8, -8, -8]],

       [[ 4,  4,  4,  4],
        [ 0,  0,  0,  0],
        [-4, -4, -4, -4]],

       [[ 8,  8,  8,  8],
        [ 4,  4,  4,  4],
        [ 0,  0,  0,  0]]])

With redundant dimension expansion, to make everything explicit. This makes it clear that the `k` components of each vector are left as they are, which makes sense since we treat each row vector as a unit.

In [62]:
A[:, None, :] - A[None, :, :]

array([[[ 0,  0,  0,  0],
        [-4, -4, -4, -4],
        [-8, -8, -8, -8]],

       [[ 4,  4,  4,  4],
        [ 0,  0,  0,  0],
        [-4, -4, -4, -4]],

       [[ 8,  8,  8,  8],
        [ 4,  4,  4,  4],
        [ 0,  0,  0,  0]]])

In [63]:
A_ * A_

array([[[ 0,  0,  0,  0],
        [16, 16, 16, 16],
        [64, 64, 64, 64]],

       [[16, 16, 16, 16],
        [ 0,  0,  0,  0],
        [16, 16, 16, 16]],

       [[64, 64, 64, 64],
        [16, 16, 16, 16],
        [ 0,  0,  0,  0]]])

In [64]:
(A_ * A_).sum(axis=2)**0.5

array([[ 0.,  8., 16.],
       [ 8.,  0.,  8.],
       [16.,  8.,  0.]])

By now, we know that the pattern (X * X).sum(axis=k) can be written more efficiently as an einsum.

In [65]:
np.einsum('ijk,ijk->ij', A_, A_)**0.5

array([[ 0.,  8., 16.],
       [ 8.,  0.,  8.],
       [16.,  8.,  0.]])

Now we can do the same for two different sets of same length vectors

In [66]:
C_ = A[:,None,:] - B
C_

array([[[  0,   0,   0,   0],
        [ -4,  -4,  -4,  -4],
        [ -8,  -8,  -8,  -8],
        [-12, -12, -12, -12]],

       [[  4,   4,   4,   4],
        [  0,   0,   0,   0],
        [ -4,  -4,  -4,  -4],
        [ -8,  -8,  -8,  -8]],

       [[  8,   8,   8,   8],
        [  4,   4,   4,   4],
        [  0,   0,   0,   0],
        [ -4,  -4,  -4,  -4]]])

In [67]:
np.einsum('ijk,ijk->ij', C_, C_)**0.5

array([[ 0.,  8., 16., 24.],
       [ 8.,  0.,  8., 16.],
       [16.,  8.,  0.,  8.]])

In [68]:
res = np.zeros((A.shape[0], B.shape[0]))
for i in range(A.shape[0]):
    for j in range(B.shape[0]):
        res[i,j] = ((A[i] - B[j]) @ (A[i] - B[j]))**0.5
res

array([[ 0.,  8., 16., 24.],
       [ 8.,  0.,  8., 16.],
       [16.,  8.,  0.,  8.]])