### Expanding into a full loop

In [None]:
# What does einsum do?
# We can think of it as big loop, over all indices.
# Suppose we have:
# np.einsum("nd,dk->nk", A, B)
# The code for this is:


import numpy as np


N = 10
D = 20
K = 30

A = np.random.randn(N,D)
B = np.random.randn(D,K)

out = np.zeros((N, K))
for n in range(N):
    for d in range(D):
        for k in range(K):
            out[n,k] += A[n,d] * B[d,k]

print(np.max(np.abs(out - np.einsum("nd,dk->nk", A, B))))

0.0


### Dot Products

In [2]:
# Einsum to compute dot product:

A = np.arange(4)
B = np.arange(4)

# Pairwise multiply, then sum:
res1 = np.sum(np.einsum("n,n->n", A, B), axis=0)
print(res1)

# Summing and multiplying in one einsum:
res2 = np.einsum("n,n->", A, B)
res1 - res2

14


np.int64(0)

In [3]:
# Einsum can generalize the dot product:

A=np.random.randn(N,K)
B=np.random.randn(N,K)

print(np.sum(A*B))
print(np.einsum("nk,nk->",A,B))

-22.57579805941183
-22.57579805941183


### Einsum Interpretaion as Cartesian Product

In [4]:
# Einsum is like doing a cartesian product across axes that are not repeated,
# and a pairwise product for axes that are.
# The end result is summed across all axes that do not appear in the output indices specified.
# Einsum is like broadcasting out, multiplying, and then summing across the axes that appear twice
x = np.random.randn(N, D)
w = np.random.randn(D, K)

out = np.einsum("nd,dk->nk", x, w)

x = x.reshape(N, D, 1)
w = w.reshape(1, D, K)
out2 = np.sum(x * w, axis=1)

print(np.max(np.abs(out - out2)))

0.0


In [5]:
# When using einsum,
# we sum across axes that do not appear in the output, and are repeated in the input.
A = np.random.randn(4, 3, 7)
B = np.random.randn(4, 7, 2)
res1 = np.sum(np.einsum("nij,njk->nik", A, B), axis=0)
res2 = np.einsum("nij,njk->ik", A, B)
res1 - res2

array([[-8.8817842e-16, -8.8817842e-16],
       [-8.8817842e-16, -8.8817842e-16],
       [ 0.0000000e+00,  0.0000000e+00]])

In [6]:
# We sum across axes that do not appear in the output, but appear only once in the input.
A = np.random.randn(4, 3, 7)
B = np.random.randn(7, 2)
res1 = np.sum(np.einsum("nij,jk->nik", A, B), axis=0)
res2 = np.einsum("nij,jk->ik", A, B)
res1 - res2

array([[ 1.11022302e-16,  3.55271368e-15],
       [-5.55111512e-17,  4.44089210e-16],
       [ 2.22044605e-16,  8.88178420e-16]])