# Toy example

In [1]:
import torch
from torch import nn
from einops import rearrange, einsum
import einx
import numpy as np

In [None]:
batch = 4
sequence =  100
d_in = 64
d_out = 64

In [7]:
D = torch.randn(batch, sequence, d_in)
A = torch.randn(d_out, d_in)

## Basic implementation
Y1 = D @ A.T
# Hard to tell the input and output shapes and what they mean.
# What shapes can D and A have, and do any of these have unexpected behavior?

## Einsum is self-documenting and robust
# D A -> Y
Y2 = einsum(D, A, "batch sequence d_in, d_out d_in -> batch sequence d_out")
## Or, a batched version where D can have any leading dimensions but A is constrained.
Y3 = einsum(D, A, "... d_in, d_out d_in -> ... d_out")

assert torch.allclose(Y1, Y2) and torch.allclose(Y1, Y3)

In [9]:
images = torch.randn(64, 128, 128, 3) # (batch, height, width, channel)
dim_by = torch.linspace(start=0.0, end=1.0, steps=10)
## Reshape and multiply
dim_value = rearrange(dim_by, "dim_value -> 1 dim_value 1 1 1")
images_rearr = rearrange(images, "b height width channel -> b 1 height width channel")
dimmed_images1 = images_rearr * dim_value
## Or in one go:
dimmed_images2 = einsum(
images, dim_by,
"batch height width channel, dim_value -> batch dim_value height width channel"
)
assert torch.allclose(dimmed_images1, dimmed_images2)

In [None]:
# Suppose we have a batch of images represented as a tensor of shape (batch, height, width,
# channel), and we want to perform a linear transformation across all pixels of the image, but this
# transformation should happen independently for each channel. Our linear transformation is
# represented as a matrix B of shape (height × width, height × width).
channels_last = torch.randn(64, 32, 32, 3) # (batch, height, width, channel)
B = torch.randn(32*32, 32*32)
## Rearrange an image tensor for mixing across all pixels
channels_last_flat = channels_last.view(
-1, channels_last.size(1) * channels_last.size(2), channels_last.size(3)
)
channels_first_flat = channels_last_flat.transpose(1, 2)
channels_first_flat_transformed = channels_first_flat @ B.T
channels_last_flat_transformed = channels_first_flat_transformed.transpose(1, 2)
channels_last_transformed = channels_last_flat_transformed.view(*channels_last.shape)
# Instead, using einops:
height = width = 32
## Rearrange replaces clunky torch view + transpose
channels_first = rearrange(
channels_last,
"batch height width channel -> batch channel (height width)"
)
channels_first_transformed = einsum(
channels_first, B,
"batch channel pixel_in, pixel_out pixel_in -> batch channel pixel_out"
)
channels_last_transformed = rearrange(
channels_first_transformed,
"batch channel (height width) -> batch height width channel",
height=height, width=width
)
# Or, if you’re feeling crazy: all in one go using einx.dot (einx equivalent of einops.einsum)
height = width = 32
channels_last_transformed = einx.dot(
"batch row_in col_in channel, (row_out col_out) (row_in col_in)"
"-> batch row_out col_out channel",
channels_last, B,
col_in=width, col_out=width
)

In [None]:
class Linear(nn.Module):
    def __init__(self, in_features, out_features, device=None, dtype=None):
        """ Construct a linear transformation module. This function should accept the following parameters:
        in_features: int final dimension of the input
        out_features: int final dimension of the output
        device: torch.device | None = None Device to store the parameters on
        dtype: torch.dtype | None = None Data type of the parameters
        """
        W = torch.empty(out_features, in_features, device=device, dtype=dtype)
        self.weight = nn.Parameter(nn.init.trunc_normal_(W, a=-3, b=3))
        super().__init__()


    def forward(self, x: torch.Tensor) -> torch.Tensor: 
        # Apply the linear transformation to the input.
        res = einsum(self.weight, x, "d_out d_in, ... d_in -> ... d_out")
        return res


In [None]:
class Embedding(nn.Module):
    def __init__(self, num_embeddings, embedding_dim, device=None, dtype=None):
        """ Construct an embedding module. This function should accept the following parameters:
        num_embeddings: int Size of the vocabulary 
        embedding_dim: int Dimension of the embedding vectors, i.e., dmodel
        device: torch.device | None = None Device to store the parameters on
        dtype: torch.dtype | None = None Data type of the parameters
        """ 
        E = torch.empty(num_embeddings, embedding_dim, device=device, dtype=dtype)
        self.embeddings = nn.Parameter(nn.init.trunc_normal_(E, a=-3, b=3))
        super().__init__()

    def forward(self, token_ids: torch.Tensor) -> torch.Tensor:
        #  Lookup the embedding vectors
        return self.embeddings[token_ids, :]

In [None]:
class RMSNorm(nn.Module):
    def __init__(self, d_model: int, eps: float = 1e-5, device=None, dtype=None):
        """ Construct the RMSNorm module. This function should accept the following parameters:
        d_model: int Hidden dimension of the model
        eps: float = 1e-5 Epsilon value for numerical stability
        device: torch.device | None = None Device to store the parameters on
        dtype: torch.dtype | None = None Data type of the parameters
        """
        self.d_model = d_model 
        self.eps = eps
        self.gain = nn.Parameter(torch.ones(d_model, device=device, dtype=dtype))

    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # Process an input tensor of shape
        # (batch_size, sequence_length, d_model) and return a tensor of the same shape.
        in_dtype = x.dtype
        x = x.to(torch.float32)
        x2_over_g = einsum(x.pow(2), 1/self.gain, "batch_size sequence_length d_model, d_model -> batch_size sequence_length d_model")
        rms = (self.eps + x2_over_g).sqrt()
        res = (x * self.gain) / rms 
        return res.to(in_dtype)

In [None]:
class SwiGLU:
    