In [1]:
import re
from typing import *

import datasets
import einops
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
from tqdm import tqdm

import torch
from torch import nn, Tensor, tensor
import torch.nn.functional as F

import unit_scaling as uu
import unit_scaling.functional as U

# Config & helpers
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
vocab_size = 256
depth = 4
head_size = 64
mlp_expansion = 2

# Training
n_steps = int(5e3)
warmup_steps = int(1e3)
batch_size = 16
sequence_length = 256
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
compile = True

In [4]:
dataset = datasets.load_dataset("Salesforce/wikitext", "wikitext-103-raw-v1", split="train")
data = torch.frombuffer(bytearray("".join(dataset["text"]), encoding="utf8"), dtype=torch.uint8)
def batches() -> Iterable[Tensor]:
    for _ in range(n_steps):
        yield torch.stack([
            data[i:i + sequence_length].to(device=device, dtype=torch.long)
            for i in torch.randint(0, len(data) - sequence_length, size=(batch_size,))
        ])

Downloading readme: 100%|████████████████████████████████████████████████████████████████████| 10.5k/10.5k [00:00<00:00, 45.0MB/s]
Downloading data: 100%|████████████████████████████████████████████████████████████████████████| 733k/733k [00:00<00:00, 21.2MB/s]
Downloading data: 100%|█████████████████████████████████████████████████████████████████████████| 157M/157M [00:00<00:00, 376MB/s]
Downloading data: 100%|█████████████████████████████████████████████████████████████████████████| 157M/157M [00:00<00:00, 320MB/s]
Downloading data: 100%|████████████████████████████████████████████████████████████████████████| 657k/657k [00:00<00:00, 26.8MB/s]
Generating test split: 100%|████████████████████████████████████████████████████████| 4358/4358 [00:00<00:00, 64378.44 examples/s]
Generating train split: 100%|████████████████████████████████████████████████| 1801350/1801350 [00:01<00:00, 975090.72 examples/s]
Generating validation split: 100%|█████████████████████████████████████████████████

In [5]:
class SpTransformerLayer(nn.Module):
    def __init__(self, width: int) -> None:
        super().__init__()
        self.attn_norm = nn.LayerNorm(width, elementwise_affine=False)
        self.attn_qkv = nn.Linear(width, 3 * width, bias=False)
        self.attn_out = nn.Linear(width, width, bias=False)

        self.mlp_norm = nn.LayerNorm(width, elementwise_affine=False)
        self.mlp_up = nn.Linear(width, mlp_expansion * width, bias=False)
        self.mlp_gate = nn.Linear(width, mlp_expansion * width, bias=False)
        self.mlp_down = nn.Linear(mlp_expansion * width, width, bias=False)

    def forward(self, input: Tensor) -> Tensor:
        residual = self.attn_norm(input)
        q, k, v = einops.rearrange(self.attn_qkv(residual), "b s (z h d) -> z b h s d", d=head_size, z=3)
        qkv = F.scaled_dot_product_attention(q, k, v, is_causal=True)
        residual = self.attn_out(einops.rearrange(qkv, "b h s d -> b s (h d)"))
        input = input + residual

        residual = self.mlp_norm(input)
        residual = self.mlp_down(self.mlp_up(residual) * F.silu(self.mlp_gate(residual)))
        return input + residual


class SpTransformerDecoder(nn.Sequential):
    def __init__(self, width: int) -> None:
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, width)
        self.layers = nn.Sequential(*(SpTransformerLayer(width) for _ in range(depth)))
        self.final_norm = nn.LayerNorm(width, elementwise_affine=False)
        self.projection = nn.Linear(width, vocab_size, bias=False)

    def loss(self, input_ids: Tensor) -> Tensor:
        logits = self(input_ids).float()
        return F.cross_entropy(
            logits[..., :-1, :].flatten(end_dim=-2), input_ids[..., 1:].flatten()
        )

In [6]:
def show_layer_stats(layer: nn.Module, input_shape: Tuple[int, ...]) -> None:
    input = torch.randn(*input_shape, requires_grad=True)
    output = layer(input)
    output.backward(torch.randn_like(output))
    print(f"# {type(layer).__name__}:")
    for k, v in {
        "output": output.std(),
        "input.grad": input.grad.std(),
        **{f"{name}": param.std() for name, param in layer.named_parameters()},
        **{f"{name}.grad": param.grad.std() for name, param in layer.named_parameters()},
    }.items():
        print(f"{k:>20}.std = {v.item():.2f}")

show_layer_stats(SpTransformerLayer(128), (batch_size, sequence_length, 128))

# SpTransformerLayer:
              output.std = 1.01
          input.grad.std = 1.01
     attn_qkv.weight.std = 0.05
     attn_out.weight.std = 0.05
       mlp_up.weight.std = 0.05
     mlp_gate.weight.std = 0.05
     mlp_down.weight.std = 0.04
attn_qkv.weight.grad.std = 3.88
attn_out.weight.grad.std = 6.06
  mlp_up.weight.grad.std = 8.27
mlp_gate.weight.grad.std = 8.51
mlp_down.weight.grad.std = 11.79


In [8]:
class UmupTransformerLayer(nn.Module):
    def __init__(self, width: int, layer_idx: int) -> None:
        super().__init__()
        self.attn_norm = uu.LayerNorm(width)
        self.attn_qkv = uu.Linear(width, 3 * width)
        self.attn_out = uu.Linear(width, width)

        self.mlp_norm = uu.LayerNorm(width)
        self.mlp_up = uu.Linear(width, mlp_expansion * width)
        self.mlp_gate = uu.Linear(width, mlp_expansion * width)
        self.mlp_down = uu.Linear(mlp_expansion * width, width)

        tau_rule = uu.transformer_residual_scaling_rule()
        self.attn_tau = tau_rule(2 * layer_idx, 2 * depth)
        self.mlp_tau = tau_rule(2 * layer_idx + 1, 2 * depth)

    def forward(self, input: Tensor) -> Tensor:
        residual, skip = U.residual_split(input, self.attn_tau)
        residual = self.attn_norm(residual)
        q, k, v = einops.rearrange(self.attn_qkv(residual), "b s (z h d) -> z b h s d", d=head_size, z=3)
        qkv = U.scaled_dot_product_attention(q, k, v, is_causal=True)
        residual = self.attn_out(einops.rearrange(qkv, "b h s d -> b s (h d)"))
        input = U.residual_add(residual, skip, self.attn_tau)

        residual, skip = U.residual_split(input, self.mlp_tau)
        residual = self.mlp_norm(residual)
        residual = self.mlp_down(U.silu_glu(self.mlp_up(residual), self.mlp_gate(residual)))
        return U.residual_add(residual, skip, self.mlp_tau)


In [9]:
class UmupTransformerDecoder(nn.Sequential):
    def __init__(self, width: int) -> None:
        super().__init__()
        self.embedding = uu.Embedding(vocab_size, width)
        self.layers = uu.DepthSequential(*(UmupTransformerLayer(width, i) for i in range(depth)))
        self.final_norm = uu.LayerNorm(width)
        self.projection = uu.LinearReadout(width, vocab_size)

    def loss(self, input_ids: Tensor) -> Tensor:
        logits = self(input_ids).float()
        return U.cross_entropy(
            logits[..., :-1, :].flatten(end_dim=-2), input_ids[..., 1:].flatten()
        )

In [14]:
show_layer_stats(UmupTransformerLayer(128, 0), (batch_size, sequence_length, 128))


# UmupTransformerLayer:
              output.std = 1.01
          input.grad.std = 1.10
     attn_qkv.weight.std = 1.00
     attn_out.weight.std = 1.00
       mlp_up.weight.std = 1.00
     mlp_gate.weight.std = 1.00
     mlp_down.weight.std = 1.00
attn_qkv.weight.grad.std = 0.63
attn_out.weight.grad.std = 1.07
  mlp_up.weight.grad.std = 0.70
mlp_gate.weight.grad.std = 0.73
mlp_down.weight.grad.std = 1.00
