# Einstein summation notation

Rules:

- two parts, inputs and outputs separated by `->`
- letters represent axis - hence `ijk` is a 3-tensor
- sum over repeated letters in inputs
- sum over letters not in output

The result of an `einsum` can  generally be understood as a element-wise multiplication of two tensors (with broadcasting) followed by a sum over zero or more axes. It turns out that this is a surprisingly flexible way of executing many common array operations.

It is easier to understand with examples.

In [None]:
import numpy as np

## Set up a vector and a matrix

In [None]:
x = np.arange(1, 4)
x

In [None]:
M = np.arange(1, 10).reshape(3,3)
M

## Working with vectors

In [None]:
np.einsum('i->i', x)

In [None]:
np.einsum('i->', x)

In [None]:
np.sum(x)

In [None]:
np.einsum('i,j->ij', x, x)

In [None]:
np.outer(x, x)

In [None]:
x[:,None] * x[None, :]

In [None]:
np.einsum('i,j->ji', x, x)

In [None]:
np.outer(x, x).T

In [None]:
(x[:,None] * x[None, :]).T

In [None]:
np.einsum('i,i->', x, x)

In [None]:
np.inner(x, x)

In [None]:
(x * x).sum()

In [None]:
np.einsum('i,i->i', x, x)

In [None]:
x * x

## Working with matrices

In [None]:
np.einsum('ii->i', M)

In [None]:
np.diag(M)

In [None]:
np.einsum('ij->ji', M)

In [None]:
np.transpose(M)

In [None]:
np.dot(M, M)

In [None]:
np.einsum('ij,jk->ik', M, M)

In [None]:
np.einsum('ij,ji->', M, M)

In [None]:
(M*M.T).sum()

In [None]:
np.einsum('ij,ij->', M, M)

In [None]:
(M*M).sum()

In [None]:
np.einsum('ij,jk->ijk', M, M)

In [None]:
M[...,None] * M

In [None]:
np.einsum('ij,kl->ijkl', M, M)

In [None]:
M[...,None,None] * M

In [None]:
np.einsum('ij,kl->jkl', M, M)

In [None]:
(M[:,:,None,None] * M).sum(0)

In [None]:
np.einsum('ij,jk->ijk', M, M)

In [None]:
M[...,None] * M

In [None]:
np.einsum('ij,jk->ik', M, M)

In [None]:
M@M

In [None]:
(M[...,None] * M).sum(1)

### Matrices and vectors

In [None]:
np.einsum('ij,j->i', M, x)

In [None]:
np.dot(M, x)

In [None]:
(M * x[None,:]).sum(1)

In [None]:
np.einsum('ij,i->ij', M, x)

In [None]:
M*x[:, None]

In [None]:
np.einsum('ij,i->i', M, x)

In [None]:
(M*x[:, None]).sum(1)

In [None]:
np.einsum('ij,i->j', M, x)

In [None]:
(M*x[:, None]).sum(0)

## Example: Pairwise distances

In [None]:
A = np.arange(12).reshape(3,4)
B = np.arange(16).reshape(4,4)

In [None]:
A

In [None]:
B

`scipy.pdist` can calculate pairwise distance for a set of row vectors, but not between 2 different sets

In [None]:
from scipy.spatial.distance import pdist, squareform

In [None]:
squareform(pdist(A))

In plain Python

In [None]:
res = np.zeros([A.shape[0], A.shape[0]])
for i in range(A.shape[0]):
    for j in range(A.shape[0]):
        res[i,j] = ((A[i] - A[j]) @ (A[i] - A[j]))**0.5
res

We can calculate the Euclidean distance between our row vectors by broadcasting. Since we want all pairwise results, this is basically a generalized outer product where the operator is subtraction rather than multiplication. We recall that the regular outer product for a matrix `M` is given using broadcasting by `M[:,None] * M[None,:]`

In [None]:
M = np.arange(3)
M

In [None]:
M[:,None] * M[None, :]

The second dimension expansion is actually unnecssary since broadcasting rules take care of it.

In [None]:
M[:, None] * M

We basically do the same thing for A

In [None]:
A

In [None]:
A[:,None,:]

In [None]:
A.shape

In [None]:
A[:,None,:].shape

In [None]:
A_ = A[:,None,:] - A
A_

With redundant dimension expansion, to make everything explicit. This makes it clear that the `k` components of each vector are left as they are, which makes sense since we treat each row vector as a unit.

In [None]:
A[:, None, :] - A[None, :, :]

In [None]:
A_ * A_

In [None]:
(A_ * A_).sum(axis=2)**0.5

By now, we know that the pattern (X * X).sum(axis=k) can be written more efficiently as an einsum.

In [None]:
np.einsum('ijk,ijk->ij', A_, A_)**0.5

Now we can do the same for two different sets of same length vectors

In [None]:
C_ = A[:,None,:] - B
C_

In [None]:
np.einsum('ijk,ijk->ij', C_, C_)**0.5

In [None]:
res = np.zeros((A.shape[0], B.shape[0]))
for i in range(A.shape[0]):
    for j in range(B.shape[0]):
        res[i,j] = ((A[i] - B[j]) @ (A[i] - B[j]))**0.5
res