# Toy example

In [10]:
import torch
from einops import rearrange, einsum
import einx
import numpy as np

In [None]:
batch = 4
sequence =  100
d_in = 64
d_out = 64

In [7]:
D = torch.randn(batch, sequence, d_in)
A = torch.randn(d_out, d_in)

## Basic implementation
Y1 = D @ A.T
# Hard to tell the input and output shapes and what they mean.
# What shapes can D and A have, and do any of these have unexpected behavior?

## Einsum is self-documenting and robust
# D A -> Y
Y2 = einsum(D, A, "batch sequence d_in, d_out d_in -> batch sequence d_out")
## Or, a batched version where D can have any leading dimensions but A is constrained.
Y3 = einsum(D, A, "... d_in, d_out d_in -> ... d_out")

assert torch.allclose(Y1, Y2) and torch.allclose(Y1, Y3)

In [9]:
images = torch.randn(64, 128, 128, 3) # (batch, height, width, channel)
dim_by = torch.linspace(start=0.0, end=1.0, steps=10)
## Reshape and multiply
dim_value = rearrange(dim_by, "dim_value -> 1 dim_value 1 1 1")
images_rearr = rearrange(images, "b height width channel -> b 1 height width channel")
dimmed_images1 = images_rearr * dim_value
## Or in one go:
dimmed_images2 = einsum(
images, dim_by,
"batch height width channel, dim_value -> batch dim_value height width channel"
)
assert torch.allclose(dimmed_images1, dimmed_images2)

In [None]:
# Suppose we have a batch of images represented as a tensor of shape (batch, height, width,
# channel), and we want to perform a linear transformation across all pixels of the image, but this
# transformation should happen independently for each channel. Our linear transformation is
# represented as a matrix B of shape (height × width, height × width).
channels_last = torch.randn(64, 32, 32, 3) # (batch, height, width, channel)
B = torch.randn(32*32, 32*32)
## Rearrange an image tensor for mixing across all pixels
channels_last_flat = channels_last.view(
-1, channels_last.size(1) * channels_last.size(2), channels_last.size(3)
)
channels_first_flat = channels_last_flat.transpose(1, 2)
channels_first_flat_transformed = channels_first_flat @ B.T
channels_last_flat_transformed = channels_first_flat_transformed.transpose(1, 2)
channels_last_transformed = channels_last_flat_transformed.view(*channels_last.shape)
# Instead, using einops:
height = width = 32
## Rearrange replaces clunky torch view + transpose
channels_first = rearrange(
channels_last,
"batch height width channel -> batch channel (height width)"
)
channels_first_transformed = einsum(
channels_first, B,
"batch channel pixel_in, pixel_out pixel_in -> batch channel pixel_out"
)
channels_last_transformed = rearrange(
channels_first_transformed,
"batch channel (height width) -> batch height width channel",
height=height, width=width
)
# Or, if you’re feeling crazy: all in one go using einx.dot (einx equivalent of einops.einsum)
height = width = 32
channels_last_transformed = einx.dot(
"batch row_in col_in channel, (row_out col_out) (row_in col_in)"
"-> batch row_out col_out channel",
channels_last, B,
col_in=width, col_out=width
)