## References

- https://einops.rocks
- [einsum is all you need: aladdin persson video](https://www.youtube.com/watch?v=pkVwUVEHmfI)
- [einsum is all you need: tim rocktaschel blog](https://rockt.ai/2018/04/30/einsum)

## Einsum notation

Einsum notation for each element in the output matrix, $C$ of shape [i, k], found from matrix multiplication of two matrices, $A$ of shape [i, j] and $B$ of shape [j, k], is defined as:

$$
C_{ik} = \sum_{j} A_{i,j} B_{j,k}
$$

In code (einops), the entire matrix multiplication yielding the full matrix $C$ can be written as:

```python
c = einops.einsum(a, b, "i j, j k -> i k")
```

where the matrices to be operated on are the first arguments (`a, b`), 

and the einsum string (`"i j, j k -> i k"`) is the last argument. 

The einsum string contains space-separated values (`"i j"` for `a`),

where each value corresponds to a dimension (`"i"` for `a`'s rows and `"j"` for `a`'s columns) in the input matrices. 

The matrix arguments are comma-separated, and an arrow, `->` yields the space-separated dimensions of the output matrix (`"i k"`). 

Key points to remember about the einsum string:

- Shared dimensions in the input matrices (in the example above, `"j"`) means we match up and multiply the corresponding elements along these dimensions in the input matrices (obviously, each pair of these dimensions must be the same size).

- Any dimensions not specified in the output matrix (in the example above, `"j"`) are summed over.

- The output axes can be returned in any order (e.g. in the example above, we could have done `"k i"` instead of `"i k"`).

In [1]:
"""Imports."""

import torch
from einops import asnumpy, einsum, rearrange, reduce, repeat, pack, parse_shape, unpack
from einops.layers.torch import Rearrange, Reduce

In [2]:
"""Recreate above example in code."""

# Following the example above, let i=2, j=3, k=2
a = torch.tensor([[1, 2, 3], [4, 5, 6]])
b = torch.tensor([[1, 2], [3, 4], [5, 6]])
c = einsum(a, b, "i j, j k -> i k")
print(c)

tensor([[22, 28],
        [49, 64]])


## Matrix operations using `einsum`, `rearrange`, `reduce`, and `repeat`

### Permutations and rearrangements

In [3]:
"""Create a dataset representing images."""

x = torch.randn(32, 3, 224, 224)  # assume this is a batch of 32 RGB images of size 224x224
x -= x.min()  # ensure all pixels are positive
print(x.shape)

torch.Size([32, 3, 224, 224])


In [4]:
"""Simple permutations."""

# Move the channel axis to the end
print(rearrange(x, "b c x y -> b x y c").shape)

# We can also do this  with einsum
print(einsum(x, "b c x y -> b x y c").shape)


torch.Size([32, 224, 224, 3])
torch.Size([32, 224, 224, 3])


In [5]:
"""Do more complex rearrangements (e.g. flattens or splits)."""

# Flatten the image dimensions
print(rearrange(x, "b c x y -> b (c x y)").shape)  

# Split each image into 4 quadrants and reconstruct the batch (should be 4x larger)
print(rearrange(x, "b c (x2 x) (y2 y) -> (b x2 y2) c x y", x2=2, y2=2).shape)

torch.Size([32, 150528])
torch.Size([128, 3, 112, 112])


### Matrix multiplications

Einsum takes care of dimension matching, so long as we specify the dimensions correctly

In [6]:
"""Standard matrix multiplications."""

a = torch.tensor([[0, 1, 2], [3, 4, 5]])
b = torch.tensor([[1, 2], [3, 4], [5, 6]])
print(a)
print(b)
print()

# Rows of `a` by columns of `b` (same as initial example)
# (since 'j' is shared, and represents dim2 of `a` and dim1 of `b`, we take the values along
# dim2 of `a` (i.e. its rows) and the values along dim1 of `b` (i.e. its cols) and multiply them
# together, and since 'j' is omitted from the output, we sum over it)
print(einsum(a, b, "i j, j k -> i k"), "\n")

# Similarly, we can multiply the columns of `a` by the rows of `b` (without explictly transposing)
print(einsum(a, b, "i j, k i -> j k"), "\n")

tensor([[0, 1, 2],
        [3, 4, 5]])
tensor([[1, 2],
        [3, 4],
        [5, 6]])

tensor([[13, 16],
        [40, 52]]) 

tensor([[ 6, 12, 18],
        [ 9, 19, 29],
        [12, 26, 40]]) 



In [7]:
"""Element-wise multiplications."""

# We can do row-major hadamard and return the output in either the shape of `a` or `b`
print(einsum(a, b, "i j, j i -> i j"), "\n")  # shape of `a`
print(einsum(a, b, "i j, j i -> j i"), "\n")  # shape of `b`

# We can do a flat hadamard
print(einsum(a.flatten(), b.flatten(), "n, n -> n"), "\n")

# We can decide to multiply each element in `a` by each element in `b` (by writing all input
# dimensions independently), and choose which dimensions to sum over.

# Sum over all dimensions
print(einsum(a, b, "i j, k l -> "), "\n")  

# Sum over no dimensions (each element-wise multiplication is a separate element in the output)
print(einsum(a, b, "i j, k l -> i j k l").shape, "\n")
# print(einsum(a, b, "i j, k l -> i j k l"), "\n")

# Sum over all but the first dimension of the second matrix
print(einsum(a, b, "i j, k l -> k"), "\n")

# Sum over the two dimensions of the first matrix
print(einsum(a, b, "i j, k l -> k l"), "\n")

# Sum over the two dimensions of the second matrix
print(einsum(a, b, "i j, k l -> i j"), "\n")

tensor([[ 0,  3, 10],
        [ 6, 16, 30]]) 

tensor([[ 0,  6],
        [ 3, 16],
        [10, 30]]) 

tensor([ 0,  2,  6, 12, 20, 30]) 

tensor(315) 

torch.Size([2, 3, 3, 2]) 

tensor([ 45, 105, 165]) 

tensor([[15, 30],
        [45, 60],
        [75, 90]]) 

tensor([[  0,  21,  42],
        [ 63,  84, 105]]) 



In [8]:
"""Multiplications of matrices with 3+ dimensions and 2+ shared dimensions."""

# Imagine we have a toy nn model object that holds 2 model instances and that we believe can
# represent 4 features in just 3 neurons; we feed into it a batch of feature values.

batch_sz, n_model_instances, n_feat, n_hidden = 2, 2, 4, 3

model_weights = torch.tensor(
    [
        [
            [0, 1, 2],
            [1, 2, 3],
            [2, 3, 4],
            [3, 4, 5]
        ],
        [
            [3, 4, 5],
            [2, 3, 4],
            [1, 2, 3],
            [0, 1, 2]
        ]
    ]
)

feat_vals = torch.tensor(
    [
        [
            [0, 1, 2, 3],
            [1, 2, 3, 4]
        ],
        [
            [1, 2, 3, 4],
            [0, 1, 2, 3]
        ]
    ]
)

print(f"model_weights: {parse_shape(model_weights, 'n_model_instances n_feat n_hidden')}")
print(f"feat_vals: {parse_shape(feat_vals, 'batch_sz n_model_instances n_feat')}")

# We want to get the activations for each of the 3 neurons, for each of the 2 model instances,
# for each of the 2 examples in the batch: we multiply the feature values by the model weights,
# matching on 'n_model_instances' and 'n_feat', and summing over 'n_feat' (each neuron combines 
# info from all features).
acts = einsum(
    feat_vals, 
    model_weights, 
    "batch model_i feat, model_i feat hidden -> batch model_i hidden"
)
print(f"acts: {parse_shape(acts, 'batch_sz n_model_instances n_hidden')} \n")
print(acts)

model_weights: {'n_model_instances': 2, 'n_feat': 4, 'n_hidden': 3}
feat_vals: {'batch_sz': 2, 'n_model_instances': 2, 'n_feat': 4}
acts: {'batch_sz': 2, 'n_model_instances': 2, 'n_hidden': 3} 

tensor([[[14, 20, 26],
         [10, 20, 30]],

        [[20, 30, 40],
         [ 4, 10, 16]]])


### Operations (reductions) over dimensions

In [9]:
"""Create a dataset representing images."""

x = torch.randn(32, 3, 224, 224)  # assume this is a batch of 32 RGB images of size 224x224
x -= x.min()  # ensure all pixels are positive
print(x.shape)

torch.Size([32, 3, 224, 224])


In [10]:
"""Sum over dimensions."""

# Sum over all dimensions
print(reduce(x, "b c x y ->", "sum"))
# Sum for each channel
print(reduce(x, "b c x y -> c", "sum"), "\n")

# We can also do these with einsum
print(einsum(x, "b c x y ->"))
print(einsum(x, "b c x y -> c"), "\n")


tensor(25579796.)
tensor([8527487., 8525645., 8526663.]) 

tensor(25579796.)
tensor([8527487., 8525645., 8526663.]) 



In [11]:
"""Perform more complex operations (e.g. mean, min, max, prod) over dimensions."""

# Mean for each channel
print(reduce(x, "b c x y -> c", "mean"))

# Var for each channel
print(reduce(x, "b c x y -> c", torch.var))

# Max for each 224x224 pixel
print(reduce(x, "b c x y -> x y", "max").shape)

tensor([5.3110, 5.3098, 5.3105])
tensor([0.9999, 1.0011, 0.9993])
torch.Size([224, 224])


### Other common matrix operations

In [17]:
"""Working with matrix diagonals."""

m = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])

# Get the diagonal.
print(einsum(m, "i i -> i"))

# Get the trace.
print(einsum(m, "i i -> "))

tensor([1, 5, 9])
tensor(15)
tensor([ 6, 15, 24])


### Repeat a tensor along a new axis or existing axis

In [None]:
"""Along a new axis: add different noise to copies of an image"""

x = torch.randn(3, 224, 224)  # assume this is an RGB image of size 224x224
x -= x.min()  # ensure all pixels are positive

# Add different noise to each copy of the image
x_b = repeat(x, "c x y -> b c x y", b=32)
noise = torch.randn_like(x_b) * 0.1
x_b = x_b + noise
print(x_b.shape)

torch.Size([32, 3, 224, 224])


In [30]:
"""Along an existing axis: elongate width or height"""

x_wide = repeat(x, "c x y -> c (w x) y", w=2)
print(x_wide.shape)

x_tall = repeat(x, "c x y -> c x (2 y)")  # we can also feed a numeric directly into einsum string
print(x_tall.shape)

torch.Size([3, 448, 224])
torch.Size([3, 224, 448])
