In [1]:
import jax
import jax.numpy as jnp
import flax.linen as nn
import matplotlib.pyplot as plt
import flax
from typing import Any
import numpy as np
import functools
from einops import rearrange
import tqdm
import time

# The Transformer

<!-- - History
- Transformer Structure
- Input Heads
- Output Heads -->

A ubiquitous neural network architecture today is the **transformer**. Originating from origins in language modelling, the Transformer has proven to have empirically powerful scaling properties, and makes almost no domain-specific assumptions. Today, transformers are used across the board, even in image or robotic control domains.

## Transformer Architectural Diagram

The transformer is a residual network, with residual blocks comprising of dense layers and self-attention layers. The transformer is a **set-operator**, meaning the activations at stage are a *set* of feature vectors. Positional information within this set is represented via positional encodings applied to each vector. We will cover the various input/output heads of the transformer later.

**Residual blocks**. The specific details of residual blocks vary between each model. We will describe the **GPT-2** architecture here. In GPT-2, each residual block consists of:
- Layernorm on the residual stream vectors.
- Multi-headed self attention.
- A residual connection, plus a second Layernorm.
- Two dense layers, with a GeLU activation between.

Each attention/dense layer is applied in *parallel* among the entire set of feature vectors. This is why transformers can make very efficient use of GPU time -- even if the batch size is small, the true batch size of each operation is z