In [1]:
import torch
from einops import rearrange, reduce, repeat
from torch import einsum, nn

# einops

## Re-arrange

This is same as `permute`. I can shift an axis.

In [2]:
x = torch.rand(4, 3, 1080, 1920)
print(x.shape)
x = rearrange(x, 'b c h w -> b h w c')
print(x.shape)

torch.Size([4, 3, 1080, 1920])
torch.Size([4, 1080, 1920, 3])


This is same as view. I can collapse some axis.

In [3]:
x = torch.rand(4, 3, 1080, 1920)
print(x.shape)
x = rearrange(x, 'b c h w -> b c (h w)')
print(x.shape)

torch.Size([4, 3, 1080, 1920])
torch.Size([4, 3, 2073600])


In [4]:
x = torch.rand(4, 3, 1080, 1920)
print(x.shape)
x = rearrange(x, 'b c (h1 h2) w -> b c h1 h2 w', h2=8)
print(x.shape)

torch.Size([4, 3, 1080, 1920])
torch.Size([4, 3, 135, 8, 1920])


In [5]:
x = torch.rand(16, 16)
print(x.shape)
x = rearrange(x, '(h1 w1) (h2 w2) -> h1 w1 h2 w2', h1=4, h2=4)
print(x.shape)

torch.Size([16, 16])
torch.Size([4, 4, 4, 4])


## Repeat

Repeat is like `torch.expand`.

In [6]:
x = torch.rand(4, 3, 1080, 1920)
print(x.shape)
x = repeat(x, 'b c h w -> b repeat h w c', repeat=2)
print(x.shape)

torch.Size([4, 3, 1080, 1920])
torch.Size([4, 2, 1080, 1920, 3])


In [7]:
x = torch.rand(4, 3, 1080, 1920)
print(x.shape)
x = repeat(x, 'b c h w -> b (2 h) (2 w) c')
print(x.shape)

torch.Size([4, 3, 1080, 1920])
torch.Size([4, 2160, 3840, 3])


## Reduce

Reduce is the opposite of repeat.

In [8]:
x = torch.rand(4, 3, 1080, 1920)
print(x.shape)
x = reduce(x, 'b c h w -> b c', 'mean')
print(x.shape)
x

torch.Size([4, 3, 1080, 1920])
torch.Size([4, 3])


tensor([[0.5002, 0.5004, 0.4998],
        [0.5001, 0.5000, 0.5000],
        [0.5000, 0.4997, 0.5000],
        [0.4998, 0.4997, 0.5000]])

In [9]:
x = torch.rand(4, 3, 1080, 1920)
print(x.shape)
x = reduce(x, 'b c h w -> b c', 'max')
print(x.shape)
x

torch.Size([4, 3, 1080, 1920])
torch.Size([4, 3])


tensor([[1.0000, 1.0000, 1.0000],
        [1.0000, 1.0000, 1.0000],
        [1.0000, 1.0000, 1.0000],
        [1.0000, 1.0000, 1.0000]])

In [10]:
x = torch.rand(4, 3, 1080, 1920)
print(x.shape)
x = reduce(x, 'b c h w -> b c', 'min')
print(x.shape)
x

torch.Size([4, 3, 1080, 1920])
torch.Size([4, 3])


tensor([[1.2517e-06, 1.9670e-06, 1.1921e-07],
        [1.1921e-07, 2.3842e-07, 2.3842e-07],
        [0.0000e+00, 1.7881e-07, 3.5763e-07],
        [1.0133e-06, 2.3842e-07, 0.0000e+00]])

In [11]:
x = torch.rand(4, 3, 1080, 1920)
print(x.shape)
x = reduce(x, 'b c h w -> b h w', 'mean')
print(x.shape)

torch.Size([4, 3, 1080, 1920])
torch.Size([4, 1080, 1920])


## Concrete Example with Attention Head

Suppose I have 8 by 8 patches, each patch has an embedding dimension, I will group them into multiple windows.

In [12]:
embed_dim = 32
head_dim = 64
num_heads = 8
window_size = 4

x = torch.rand(4, 8, 8, 32) # Patch tensor with shape (batch, height, width, embedding dim)
to_qkv = nn.Linear(embed_dim, 3 * num_heads * head_dim) # Each Q, K, V has multiple heads.

B, H, W, _ = x.shape

qkv = to_qkv(x).chunk(3, dim=-1)
print([t.shape for t in qkv])

window_height = H // window_size
window_width = W // window_size

for t in qkv:
    t_prime = rearrange(t, 'b (num_win_h win_h) (num_win_w win_w) (h d) -> b h (num_win_h num_win_w) (win_h win_w) d',
                        h=num_heads, d=head_dim, win_h=window_size, win_w=window_size)
    print('(batch_size, num_heads, num_windows, num_patches, head_dim)', t_prime.shape)

[torch.Size([4, 8, 8, 512]), torch.Size([4, 8, 8, 512]), torch.Size([4, 8, 8, 512])]
(batch_size, num_heads, num_windows, num_patches, head_dim) torch.Size([4, 8, 4, 16, 64])
(batch_size, num_heads, num_windows, num_patches, head_dim) torch.Size([4, 8, 4, 16, 64])
(batch_size, num_heads, num_windows, num_patches, head_dim) torch.Size([4, 8, 4, 16, 64])


# einsum

Re-use the attention head example from above. I will a dot product to compute the attention score.

In [13]:
q = torch.randn(4, 8, 4, 16, 64)
k = torch.randn(4, 8, 4, 16, 64)
v = torch.randn(4, 8, 4, 16, 64)

# Dot product along (16, 16) or (i, j).
dot_product = einsum('b h w i d, b h w j d -> b h w i j', q, k)
print(dot_product.shape)

torch.Size([4, 8, 4, 16, 16])


More dot product example (outer dot product)

In [14]:
A = torch.ones(16, 64)
B = torch.ones(16, 64)
einsum('n i, n j -> n', A, B)

tensor([4096., 4096., 4096., 4096., 4096., 4096., 4096., 4096., 4096., 4096.,
        4096., 4096., 4096., 4096., 4096., 4096.])

In [15]:
A = torch.ones(4, 16, 64)
B = torch.ones(4, 16, 64)
einsum('n i d, n j d -> n i j', A, B).shape

torch.Size([4, 16, 16])

Matrix multiplication

In [16]:
A = torch.rand(4, 2, 5)
B = torch.rand(4, 5, 4)
einsum('b i j, b j k -> b i k', A, B).shape

torch.Size([4, 2, 4])