In [2]:
import sys, subprocess, textwrap, math, time

def pip_install(pkg: str):
    subprocess.check_call([sys.executable, "-m", "pip", "install", "-q", pkg])

pip_install("einops")
pip_install("torch")

import torch
import torch.nn as nn
import torch.nn.functional as F

from einops import rearrange, reduce, repeat, einsum, pack, unpack
from einops.layers.torch import Rearrange, Reduce

torch.manual_seed(0)
device = "cuda" if torch.cuda.is_available() else "cpu"
print("Device:", device)

def section(title: str):
    print("\n" + "=" * 90)
    print(title)
    print("=" * 90)

def show_shape(name, x):
    print(f"{name:>18} shape = {tuple(x.shape)}  dtype={x.dtype}  device={x.device}")

Device: cpu


In [3]:
section("1) rearrange")
x = torch.randn(2, 3, 4, 5, device=device)
show_shape("x", x)

x_bhwc = rearrange(x, "b c h w -> b h w c")
show_shape("x_bhwc", x_bhwc)

x_split = rearrange(x, "b (g cg) h w -> b g cg h w", g=3)
show_shape("x_split", x_split)

x_tokens = rearrange(x, "b c h w -> b (h w) c")
show_shape("x_tokens", x_tokens)

y = torch.randn(2, 7, 11, 13, 17, device=device)
y2 = rearrange(y, "b ... c -> b c ...")
show_shape("y", y)
show_shape("y2", y2)

try:
    _ = rearrange(torch.randn(2, 10, device=device), "b (h w) -> b h w", h=3)
except Exception as e:
    print("Expected error (shape mismatch):", type(e).__name__, "-", str(e)[:140])


1) rearrange
                 x shape = (2, 3, 4, 5)  dtype=torch.float32  device=cpu
            x_bhwc shape = (2, 4, 5, 3)  dtype=torch.float32  device=cpu
           x_split shape = (2, 3, 1, 4, 5)  dtype=torch.float32  device=cpu
          x_tokens shape = (2, 20, 3)  dtype=torch.float32  device=cpu
                 y shape = (2, 7, 11, 13, 17)  dtype=torch.float32  device=cpu
                y2 shape = (2, 17, 7, 11, 13)  dtype=torch.float32  device=cpu
Expected error (shape mismatch): EinopsError -  Error while processing rearrange-reduction pattern "b (h w) -> b h w".
 Input tensor shape: torch.Size([2, 10]). Additional info: {'h': 3}.


In [4]:
section("2) reduce")
imgs = torch.randn(8, 3, 64, 64, device=device)
show_shape("imgs", imgs)

gap = reduce(imgs, "b c h w -> b c", "mean")
show_shape("gap", gap)

pooled = reduce(imgs, "b c (h ph) (w pw) -> b c h w", "mean", ph=2, pw=2)
show_shape("pooled", pooled)

chmax = reduce(imgs, "b c h w -> b c", "max")
show_shape("chmax", chmax)

section("3) repeat")
vec = torch.randn(5, device=device)
show_shape("vec", vec)

vec_batched = repeat(vec, "d -> b d", b=4)
show_shape("vec_batched", vec_batched)

q = torch.randn(2, 32, device=device)
q_heads = repeat(q, "b d -> b heads d", heads=8)
show_shape("q_heads", q_heads)


2) reduce
              imgs shape = (8, 3, 64, 64)  dtype=torch.float32  device=cpu
               gap shape = (8, 3)  dtype=torch.float32  device=cpu
            pooled shape = (8, 3, 32, 32)  dtype=torch.float32  device=cpu
             chmax shape = (8, 3)  dtype=torch.float32  device=cpu

3) repeat
               vec shape = (5,)  dtype=torch.float32  device=cpu
       vec_batched shape = (4, 5)  dtype=torch.float32  device=cpu
           q_heads shape = (2, 8, 32)  dtype=torch.float32  device=cpu


In [5]:
section("4) patchify")
B, C, H, W = 4, 3, 32, 32
P = 8
img = torch.randn(B, C, H, W, device=device)
show_shape("img", img)

patches = rearrange(img, "b c (h p1) (w p2) -> b (h w) (p1 p2 c)", p1=P, p2=P)
show_shape("patches", patches)

img_rec = rearrange(
    patches,
    "b (h w) (p1 p2 c) -> b c (h p1) (w p2)",
    h=H // P,
    w=W // P,
    p1=P,
    p2=P,
    c=C,
)
show_shape("img_rec", img_rec)

max_err = (img - img_rec).abs().max().item()
print("Reconstruction max abs error:", max_err)
assert max_err < 1e-6

section("5) attention")
B, T, D = 2, 64, 256
Hh = 8
Dh = D // Hh
x = torch.randn(B, T, D, device=device)
show_shape("x", x)

proj = nn.Linear(D, 3 * D, bias=False).to(device)
qkv = proj(x)
show_shape("qkv", qkv)

q, k, v = rearrange(qkv, "b t (three heads dh) -> three b heads t dh", three=3, heads=Hh, dh=Dh)
show_shape("q", q)
show_shape("k", k)
show_shape("v", v)

scale = Dh ** -0.5
attn_logits = einsum(q, k, "b h t dh, b h s dh -> b h t s") * scale
show_shape("attn_logits", attn_logits)

attn = attn_logits.softmax(dim=-1)
show_shape("attn", attn)

out = einsum(attn, v, "b h t s, b h s dh -> b h t dh")
show_shape("out (per-head)", out)

out_merged = rearrange(out, "b h t dh -> b t (h dh)")
show_shape("out_merged", out_merged)


4) patchify
               img shape = (4, 3, 32, 32)  dtype=torch.float32  device=cpu
           patches shape = (4, 16, 192)  dtype=torch.float32  device=cpu
           img_rec shape = (4, 3, 32, 32)  dtype=torch.float32  device=cpu
Reconstruction max abs error: 0.0

5) attention
                 x shape = (2, 64, 256)  dtype=torch.float32  device=cpu
               qkv shape = (2, 64, 768)  dtype=torch.float32  device=cpu
                 q shape = (2, 8, 64, 32)  dtype=torch.float32  device=cpu
                 k shape = (2, 8, 64, 32)  dtype=torch.float32  device=cpu
                 v shape = (2, 8, 64, 32)  dtype=torch.float32  device=cpu
       attn_logits shape = (2, 8, 64, 64)  dtype=torch.float32  device=cpu
              attn shape = (2, 8, 64, 64)  dtype=torch.float32  device=cpu
    out (per-head) shape = (2, 8, 64, 32)  dtype=torch.float32  device=cpu
        out_merged shape = (2, 64, 256)  dtype=torch.float32  device=cpu


In [6]:
section("6) pack unpack")
B, Cemb = 2, 128

class_token = torch.randn(B, 1, Cemb, device=device)
image_tokens = torch.randn(B, 196, Cemb, device=device)
text_tokens = torch.randn(B, 32, Cemb, device=device)
show_shape("class_token", class_token)
show_shape("image_tokens", image_tokens)
show_shape("text_tokens", text_tokens)

packed, ps = pack([class_token, image_tokens, text_tokens], "b * c")
show_shape("packed", packed)
print("packed_shapes (ps):", ps)

mixer = nn.Sequential(
    nn.LayerNorm(Cemb),
    nn.Linear(Cemb, 4 * Cemb),
    nn.GELU(),
    nn.Linear(4 * Cemb, Cemb),
).to(device)

mixed = mixer(packed)
show_shape("mixed", mixed)

class_out, image_out, text_out = unpack(mixed, ps, "b * c")
show_shape("class_out", class_out)
show_shape("image_out", image_out)
show_shape("text_out", text_out)
assert class_out.shape == class_token.shape
assert image_out.shape == image_tokens.shape
assert text_out.shape == text_tokens.shape

section("7) layers")
class PatchEmbed(nn.Module):
    def __init__(self, in_channels=3, emb_dim=192, patch=8):
        super().__init__()
        self.patch = patch
        self.to_patches = Rearrange("b c (h p1) (w p2) -> b (h w) (p1 p2 c)", p1=patch, p2=patch)
        self.proj = nn.Linear(in_channels * patch * patch, emb_dim)

    def forward(self, x):
        x = self.to_patches(x)
        return self.proj(x)

class SimpleVisionHead(nn.Module):
    def __init__(self, emb_dim=192, num_classes=10):
        super().__init__()
        self.pool = Reduce("b t c -> b c", reduction="mean")
        self.classifier = nn.Linear(emb_dim, num_classes)

    def forward(self, tokens):
        x = self.pool(tokens)
        return self.classifier(x)

patch_embed = PatchEmbed(in_channels=3, emb_dim=192, patch=8).to(device)
head = SimpleVisionHead(emb_dim=192, num_classes=10).to(device)

imgs = torch.randn(4, 3, 32, 32, device=device)
tokens = patch_embed(imgs)
logits = head(tokens)
show_shape("tokens", tokens)
show_shape("logits", logits)

section("8) practical")
x = torch.randn(2, 32, 16, 16, device=device)
g = 8
xg = rearrange(x, "b (g cg) h w -> (b g) cg h w", g=g)
show_shape("x", x)
show_shape("xg", xg)

mean = reduce(xg, "bg cg h w -> bg 1 1 1", "mean")
var = reduce((xg - mean) ** 2, "bg cg h w -> bg 1 1 1", "mean")
xg_norm = (xg - mean) / torch.sqrt(var + 1e-5)
x_norm = rearrange(xg_norm, "(b g) cg h w -> b (g cg) h w", b=2, g=g)
show_shape("x_norm", x_norm)

z = torch.randn(3, 64, 20, 30, device=device)
z_flat = rearrange(z, "b c h w -> b c (h w)")
z_unflat = rearrange(z_flat, "b c (h w) -> b c h w", h=20, w=30)
assert (z - z_unflat).abs().max().item() < 1e-6
show_shape("z_flat", z_flat)

section("9) views")
a = torch.randn(2, 3, 4, 5, device=device)
b = rearrange(a, "b c h w -> b h w c")
print("a.is_contiguous():", a.is_contiguous())
print("b.is_contiguous():", b.is_contiguous())
print("b._base is a:", getattr(b, "_base", None) is a)

section("Done ✅ You now have reusable einops patterns for vision, attention, and multimodal token packing")


6) pack unpack
       class_token shape = (2, 1, 128)  dtype=torch.float32  device=cpu
      image_tokens shape = (2, 196, 128)  dtype=torch.float32  device=cpu
       text_tokens shape = (2, 32, 128)  dtype=torch.float32  device=cpu
            packed shape = (2, 229, 128)  dtype=torch.float32  device=cpu
packed_shapes (ps): [torch.Size([1]), torch.Size([196]), torch.Size([32])]
             mixed shape = (2, 229, 128)  dtype=torch.float32  device=cpu
         class_out shape = (2, 1, 128)  dtype=torch.float32  device=cpu
         image_out shape = (2, 196, 128)  dtype=torch.float32  device=cpu
          text_out shape = (2, 32, 128)  dtype=torch.float32  device=cpu

7) layers
            tokens shape = (4, 16, 192)  dtype=torch.float32  device=cpu
            logits shape = (4, 10)  dtype=torch.float32  device=cpu

8) practical
                 x shape = (2, 32, 16, 16)  dtype=torch.float32  device=cpu
                xg shape = (16, 4, 16, 16)  dtype=torch.float32  device=cpu
     