In [111]:
import torch
import torch.nn as nn
import torch.nn.functional as F

import math

# Einsum Rules

1. repeated indices are summed over e.g., "ik,kj->ij" means you sum over K
2. implicit multiplication is done b/t the elements from the first tensor and element from the second tensor i.e. A_ik and B_kj
3. Indices that appear only once define the shape of the output e.g., "ik,kj->ij" means i and j determine the output shape i.e. num rows in A, num cols in B
4. omitting the indices after -> sums over all dims
5. Ellipsis (...) can represent multiple dims e.g., "...ik,...kj->...ij" allow for batch matrix multiplication with any number of batch dims

# MHSA

In [131]:
class Attn(nn.Module):


    def forward(self, Q, K, V):
        N, T, H, E = K.shape
        # reshape Q, K, V to (N, H, T, E)
        Q = Q.transpose(1, 2)
        K = K.transpose(1, 2)
        V = V.transpose(1, 2)

        scale = 1 / math.sqrt(E)
        # (N, H, T, E) x (N, H, E, T) = (N, H, T, T)
        logits = torch.matmul(Q, K.transpose(3, 2)) * scale
        scores = F.softmax(logits, dim=-1)
        # (N, H, T, T) x (N, H, T, E) = (N, H, T, E)
        out = torch.matmul(scores, V)
        # (N, H, T, E) => (N, T, H, E)
        out = out.transpose(1, 2)
        return out



class EinsumAttn(nn.Module):

    def forward(self, Q, K, V):
        """
        Q: (N, T, H, E)
        K: (N, T, H, E)
        V: (N, T, H, D)

        Computes the multi-head scaled dot product attention operation on the input

        softmax((Q @ K.T) / sqrt(E)) @ V
        """
        N, T, H, E = K.shape

        # arrange shapes to be (N, H, T, _)
        Q = torch.einsum("nthe->nhte", [Q])
        K = torch.einsum("nthe->nhte", [K])
        V = torch.einsum("nthe->nhte", [V])

        # compute attention
        scale = 1.0 / math.sqrt(E)
        # compute Q @ K.T 
        logits = torch.einsum("nhij,nhkj->nhik", [Q, K]) * scale
        scores = F.softmax(logits, dim=-1)
        # (N, H, T, T) x (N, H, T, D) => (N, H, T, D)
        out = torch.einsum("nhik,nhkj->nhij", [scores, V])
        # arrange out shape to be (N, T, H, D)
        out = torch.einsum("nhij->nihj", [out])
        return out

In [133]:
N = 2
T = 3
H = 4
E = 5
D = 6

Q = torch.rand(N, T, H, E)
K = torch.rand(N, T, H, E)
V = torch.rand(N, T, H, D)

In [134]:
torch_out = F.scaled_dot_product_attention(Q.transpose(1, 2), K.transpose(1, 2), V.transpose(1, 2))
torch_out = torch_out.transpose(1, 2)

In [135]:
ein_att = EinsumAttn()
ein_out = ein_att(Q, K, V)

In [136]:
attn = Attn()
attn_out = attn(Q, K, V)

In [137]:
torch.allclose(torch_out, ein_out)

True

In [138]:
torch.allclose(attn_out, torch_out)

True

# Practice

In [72]:
A.T

tensor([[0.7139, 0.4191, 0.1453],
        [0.2913, 0.5299, 0.3287],
        [0.9360, 0.7446, 0.2974],
        [0.6496, 0.7781, 0.2169]])

In [74]:
torch.matmul(A, A.T)

tensor([[1.8925, 1.6559, 0.6187],
        [1.6559, 1.6163, 0.6253],
        [0.6187, 0.6253, 0.2647]])

In [78]:
torch.einsum("ij,kj->ik", [A, A])

tensor([[1.8925, 1.6559, 0.6187],
        [1.6559, 1.6163, 0.6253],
        [0.6187, 0.6253, 0.2647]])

### Matrix mat mul

In [9]:
A = torch.rand(3, 4)

In [11]:
B = torch.rand(4, 3)

In [12]:
torch.matmul(A, B)

tensor([[0.7445, 0.6853, 0.3380],
        [0.7750, 0.9510, 0.8061],
        [1.4233, 0.9981, 0.8949]])

In [13]:
torch.einsum("ik,kj->ij", [A, B])

tensor([[0.7445, 0.6853, 0.3380],
        [0.7750, 0.9510, 0.8061],
        [1.4233, 0.9981, 0.8949]])

### Batch Matrix Mat mul

In [15]:
A = torch.rand(10, 3, 4)

In [16]:
B = torch.rand(10, 4, 5)

In [18]:
Cmm = torch.matmul(A, B) 

In [19]:
Cein = torch.einsum("bik,bkj->bij", [A, B])

In [23]:
torch.equal(Cein, Cmm)

True

### Dot product

In [24]:
a = torch.rand(10)

In [25]:
b = torch.rand(10)

In [26]:
torch.dot(a, b)

tensor(2.5617)

In [27]:
torch.einsum("i,i->", [a, b])

tensor(2.5617)

### Outer Product

In [31]:
outer = torch.outer(a, b)

In [32]:
outer_ein = torch.einsum("i,j->ij", [a, b])

In [33]:
torch.equal(outer, outer_ein)

True

### Transpose

In [34]:
A = torch.rand(3, 4)

In [36]:
A_transpose = A.T

In [41]:
A_transpose_ein = torch.einsum("ij->ji", [A])

In [42]:
torch.equal(A_transpose, A_transpose_ein)

True