# The Annotated Discrete Diffusion Models

In this tutorial, we'll explore how discrete diffusion models can be applied to text generation by building a character-level text diffusion model from scratch.

---

![Intro GIF](https://raw.githubusercontent.com/ash80/diffusion-gpt/master/assets/intro_text.gif)

---

Most modern chatbots, including ChatGPT, generate text sequentially: one token at a time, left to right. On the other hand, diffusion models that are the main driver behind the recent successes of image and video generators take a very different approach. They start by corrupting data with noise and then learn to denoise it.

Extending diffusion models to text, however, is not straightforward. Unlike images, which exist in a continuous space where adding and removing noise is easier, text is discrete, making the addition and removal of "noise" trickier. Since text is made of discrete symbols, "adding noise" here means flipping characters or tokens till it becomes gibberish. Teaching a model to undo this noise is far less straightforward.

To tackle this challenge, we'll begin with Andrej Karpathy's character-level baby GPT, a minimal yet mighty model for sequence modeling, and transform it into a character-level discrete diffusion model. Our implementation will closely follow the ideas presented in the paper Discrete Diffusion Modeling by Estimating the Ratios of the Data Distribution ([arXiv:2310.16834](https://arxiv.org/abs/2310.16834)).

## Shakespeare Dataset

As a first step let's clone Andrej Karpathy's nanoGPT GitHub repository which contains the scripts to prepare the character level dataset from the Shakespeare' work.

In [1]:
!git clone https://github.com/karpathy/nanoGPT

Cloning into 'nanoGPT'...
remote: Enumerating objects: 686, done.[K
remote: Total 686 (delta 0), reused 0 (delta 0), pack-reused 686 (from 1)[K
Receiving objects: 100% (686/686), 974.05 KiB | 22.65 MiB/s, done.
Resolving deltas: 100% (380/380), done.


`data/shakespeare_char/` in this repo provides a `prepare.py` script to prepare the dataset. We'll copy this to our notebook's working directory and run the script.

In [2]:
import shutil as sh
# Copy to our project's working directory
if not sh.os.path.exists('shakespeare_char'):
    sh.copytree('nanoGPT/data/shakespeare_char', 'shakespeare_char')

In [3]:
# Prepare the character-level dataset
%run shakespeare_char/prepare.py

length of dataset in characters: 1,115,394
all the unique characters: 
 !$&',-.3:;?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz
vocab size: 65
train has 1,003,854 tokens
val has 111,540 tokens


Let's examine the contents of the `shakespeare_char` directory.

In [4]:
%ls shakespeare_char

input.txt  meta.pkl  prepare.py  readme.md  train.bin  val.bin


The directory includes a `meta.pkl` file, which includes all the metadata we'll need for character-level modeling. Inside, you'll find a dictionary with these three keys:

  - `vocab_size`: the total number of unique characters in the Shakespeare dataset.
  - `stoi`: a dictionary mapping each character to a unique index in the range `[0, vocab_size)`.
  - `itos`: maps indices back to their corresponding characters.

`stoi` and `itos` are like the bilingual dictionaries between characters and numbers, while `vocab_size` tells us how large the "alphabet" is.

Now let's extract the mappings (`stoi`, `itos`) along with the vocabulary size, and print them out for inspection.

In [5]:
import os
import pickle

# Path to Shakespeare metadata
data_dir = './shakespeare_char/'

# Load the metadata dictionary
meta_path = os.path.join(data_dir, 'meta.pkl')
vocab_size = None
with open(meta_path, 'rb') as f:
    meta = pickle.load(f)

# Character and index mappings
itos = meta['itos'] # index to string (character)
stoi = meta['stoi'] # string (character) to index

# Display the vocabulary
print(f"vocabulary: {repr(''.join(stoi.keys()))}")

# Total number of unique characters
vocab_size = meta['vocab_size']
print(f'vocabulary size: {vocab_size}')

vocabulary: "\n !$&',-.3:;?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"
vocabulary size: 65


### Define Dataset Class

With our vocabulary in place, we're ready to work with the actual text data. The `shakespeare_char` directory also contains two files: `train.bin` and `test.bin`, which store the Shakespeare corpus in a compact format with characters already mapped to their numerical indices.

Our next step is to wrap this data in a PyTorch `Dataset` module. This will:

  1. Load the encoded Shakespeare text from disk.
  2. Extract a batch of sub-sequences of a given context length.
  3. Return these sub-sequences as tensors, ready for training.

Each training example essentially server as a "window" into the Shakespeare text, represented as numbers.

In [6]:
import os
import numpy as np
import torch
import torch.utils.data as data

class ShakespeareDataset(data.Dataset):
    """
    Memory-mapped dataset for character-level sequences.

    Each item is a 1D tensor of indices (torch.long) of length `context_len`
    from a rolling window over the encoded Shakespeare corpus.

    Notes
    -----
    - Uses np.memmap to avoid loading the entire file into RAM.
    - Returns only `x` (the context window).
      This will serve as the clean target for denoising.
      Noising will be applied on-the-fly during the training.
    """
    def __init__(
        self,
        data_dir: str,
        split: str = "train",
        context_len: int = 256,
        dtype: np.dtype = np.uint16,
    ) -> None:
        if split not in {"train", "val"}:
            raise ValueError(f"split must be 'train' or 'val', got: {split!r}")
        if context_len <= 0:
            raise ValueError(f"context_len must be positive, got: {context_len}")

        self.split = split
        self.context_len = int(context_len)

        bin_path = os.path.join(data_dir, f"{split}.bin")
        if not os.path.isfile(bin_path):
            raise FileNotFoundError(f"Could not find {bin_path}")

        # Memory-map the encoded corpus. uint16 matches the preprocessing.
        self.data = np.memmap(bin_path, dtype=dtype, mode="r")

        # Number of valid starting positions for a full context window
        self._n = max(0, len(self.data) - self.context_len)

    def __len__(self) -> int:
        return self._n

    def __getitem__(self, index: int) -> torch.Tensor:
        if index < 0 or index >= self._n:
            raise IndexError(f"Index {index} out of range for dataset of length {self._n}.")
        # Slice a contiguous window and convert to torch.long (int64)
        x_np = self.data[index : index + self.context_len].astype(np.int64)
        x = torch.from_numpy(x_np)  # shape: [context_len], dtype: torch.long
        return x


### Define and Initialise Dataloaders

We'll train on a batch of sequences with a batch size of 64, each with a **context length of 256** characters, the maximum number of characters the model sees per training window for character-level denoising. Larger contexts capture longer-range structure but increase memory/compute. **256** is a practical middle ground for character-level diffusion.


In [7]:
from torch.utils import data

def get_data_loader(data_dir, split, batch_size, context_len=256):
    dataset = ShakespeareDataset(data_dir, split, context_len)
    return data.DataLoader(dataset, batch_size=batch_size, shuffle=True)

# Initialise
batch_size = 64
context_length = 256

train_dataloader = get_data_loader(data_dir, 'train', batch_size, context_length)
val_dataloader   = get_data_loader(data_dir, 'val', batch_size, context_length)

# Peek at one batch to confirm shapes/types
batch = next(iter(train_dataloader))
print(batch.shape)
print(batch[0]) # A tensor of indices of length `context_length`

torch.Size([64, 256])
tensor([42,  1, 61, 47, 50, 50,  0, 35, 46, 47, 41, 46,  1, 58, 47, 56, 43, 42,
         1, 51, 39, 48, 43, 57, 58, 63,  1, 42, 47, 42,  1, 51, 39, 49, 43,  1,
        58, 46, 43, 43,  1, 53, 44, 44, 43, 56,  6,  0, 32, 46, 43,  1, 56, 43,
        57, 47, 45, 52, 39, 58, 47, 53, 52,  1, 53, 44,  1, 58, 46, 63,  1, 57,
        58, 39, 58, 43,  1, 39, 52, 42,  1, 41, 56, 53, 61, 52,  0, 32, 53,  1,
        20, 43, 52, 56, 63,  1, 14, 53, 50, 47, 52, 45, 40, 56, 53, 49, 43,  8,
         0,  0, 23, 21, 26, 19,  1, 30, 21, 15, 20, 13, 30, 16,  1, 21, 21, 10,
         0, 19, 47, 60, 43,  1, 51, 43,  1, 58, 46, 43,  1, 41, 56, 53, 61, 52,
         8,  1, 20, 43, 56, 43,  6,  1, 41, 53, 59, 57, 47, 52,  6,  1, 57, 43,
        47, 64, 43,  1, 58, 46, 43,  1, 41, 56, 53, 61, 52, 11,  0, 20, 43, 56,
        43,  1, 41, 53, 59, 57, 47, 52, 10,  0, 27, 52,  1, 58, 46, 47, 57,  1,
        57, 47, 42, 43,  1, 51, 63,  1, 46, 39, 52, 42,  6,  1, 39, 52, 42,  1,
        53, 52,  1

### Decoding Indices to Text

To decode a tensor of indices back into text, we map each index to its character with the help of `itos` created above. Let's use it to define a `decode()` function.

In [8]:
def decode(indices_tensor: torch.Tensor):
    '''Decodes a 1D tensor of indices to text'''
    indices = indices_tensor.cpu().numpy()
    return ''.join([itos[i] for i in indices])

# Check what the model is "seeing"
print(decode(batch[0]))

d will
Which tired majesty did make thee offer,
The resignation of thy state and crown
To Henry Bolingbroke.

KING RICHARD II:
Give me the crown. Here, cousin, seize the crown;
Here cousin:
On this side my hand, and on that side yours.
Now is this golden c



## Diffusion in Discrete Space

Most large language models today, including ChatGPT, are trained using a method called **autoregressive modeling**, or *next-token prediction*. Let's unpack what that means. Suppose we have a dataset that follows some probability distribution $p_{\text{data}}$, and a sample from it is a sequence of tokens:
$$
x_1, x_2, \dots, x_d
$$
where $d$ is the sequence length (for example, the number of words or characters in a sentence).

In autoregressive modeling, we train a neural network to predict the next token based on all the tokens that came before it:
$$
p_{\theta}(x_i \mid x_1, x_2, \dots, x_{i-1}) \approx p_{\text{data}}(x_i \mid x_1, x_2, \dots, x_{i-1})
$$
Here, $\theta$ represents the model's parameters.

For text data, each token $x_i$ is **discrete**; it comes from a fixed vocabulary $\mathcal{X} = {1, 2, \dots, N}$, where $N$ is the number of unique tokens (like all possible words or characters).

### From Continuous to Discrete Diffusion

In image or video generation, the most powerful models today are **diffusion models**, like *Stable Diffusion* or *Sora*.
These models work in **continuous space**, meaning data (like pixel values) can smoothly vary. During training, they learn to *denoise*, that is, to reverse a process that gradually adds random noise to images.

However, for text, where each token is discrete, we can't simply "add a bit of noise." Since a character can't be nudged slightly, it must *jump* to another token in the vocabulary.

So the question becomes: How do we define "adding noise" when our data is made up of discrete symbols?

This leads us to **discrete diffusion models**, which describe how probability distributions over discrete tokens evolve over time.

### Defining a Discrete Diffusion Process

To build an intuition, let's focus on a single token, for example, one character.
At any moment in time $t$, we can describe our uncertainty about which token it is using a probability vector:
$$
p_t \in \mathbb{R}^N, \quad p_t^i \ge 0, \quad \sum_i p_t^i = 1
$$
Each element $p_t^i$ tells us how likely the token is to be the $i$-th vocabulary element.

We now define a **continuous-time Markov process** that describes how this distribution changes:
$$
\frac{d p_t}{d t} = Q_t p_t, \quad p_0 \approx p_{\text{data}} \tag{1}
$$
Here:

* $Q_t \in \mathbb{R}^{N \times N}$ is called the **rate matrix** (or **diffusion matrix**),
* the off-diagonal entries of $Q_t$ are nonnegative,
* and each column of $Q_t$ sums to zero (so the total probability remains 1).

Often, we make $Q_t$ simple by writing it as:
$$
Q_t = \sigma(t) Q^{\text{tok}}
$$

where $\sigma(t)$ controls how much noise we add over time, and $Q^{\text{tok}}$ defines the basic structure of the transitions.

### The Uniform Rate Matrix

We will use a **uniform rate matrix**, where any token is equally likely to change into any other token:
$$
Q^{\text{tok}} = \frac{1}{N}
\begin{pmatrix}
1 - N & 1 & \dots & 1 \\
1 & 1 - N & \dots & 1 \\
\vdots & \vdots & \ddots & \vdots \\
1 & 1 & \dots & 1 - N
\end{pmatrix}
$$

This can also be written compactly as:
$$
Q^{\text{tok}} = \frac{1}{N}J - I = P - I
$$

where:

* $J$ is an all-ones matrix,
* $I$ is the identity matrix,
* and $P = \frac{1}{N}J$ is a projection matrix that projects any probability vector onto the uniform distribution.

### Solving the Diffusion Equation

For this uniform rate matrix, we can solve the differential equation in Eq. (1):
$$
p_t = e^{\bar{\sigma}(t) Q^{\text{tok}}} p_0 = \left[P + e^{-\bar{\sigma}(t)}(I - P)\right] p_0 \tag{2}
$$

where: $\bar{\sigma}(t) = \int \sigma(\tau) d\tau$

Derivation: Use the exponential series with the fact that $P^2 = P$ to arrive at the right hand side of the above equation. This equation also has the following desirable properties:

* When $\bar{\sigma}(t) = 0$, $p_t = p_0 \approx p_{\text{data}}$: no noise has been added.
* As $\bar{\sigma}(t) \to \infty$, the distribution becomes **uniform**: $p_t \to p_{\text{base}} = P p_0 = \frac{1}{N}\mathbf{1}$ meaning all tokens are equally likely.

### What Does This Mean for a Character?

Suppose we start with a character $x_0 \in \mathcal{X}$. After diffusing for some time $t$, the probability that it remains the same or changes to another token is:
$$\Pr\{y_t \mid x_0\} =
\begin{cases}
    e^{-\bar \sigma (t)}+\dfrac{1-e^{- \bar \sigma (t)}}{N} & y_t=x_0, \\[6pt]
    \dfrac{1-e^{- \bar \sigma (t) }}{N} & y_t \neq x_0,
\end{cases} \tag{3}$$

So over time, the character "forgets" what it was, smoothly transitioning from its original identity toward a uniform distribution over all possible characters.

In practice, we can apply this diffusion process **independently to every character** in a text sequence to simulate adding noise to an entire sequence.


## Perturbing the batch with noise

**Goal.** We want to "noisify" a batch of tokenised text by *independently* disturbing each token. In the discrete diffusion view, each token either

1. **stays the same** (with probability given by the first line of Eq. (3)), or
2. **jumps to a different token** (uniformly among the other $N-1$ choices).

From Eq. (3), the total probability of *changing to a different token* is
$$
\underbrace{1 - \Big(e^{-\bar\sigma(t)} + \frac{1 - e^{-\bar\sigma(t)}}{N}\Big)}_{\text{not staying the same}}
= \big(1 - e^{-\bar\sigma(t)}\big)\big(1 - \tfrac{1}{N}\big).
$$
We'll call this the **move probability**.

**Implementation detail.** When a token "moves," it must land on a *different* index with **uniform** probability over the $N-1$ alternatives, and not return to the original token. The code below guarantees that.


In [9]:
def perturb_batch(batch: torch.Tensor, sigma_bar: torch.Tensor) -> torch.Tensor:
    """
    Diffuse each token independently according to Eq. (3).

      - With probability e^{-sigma_bar} + (1 - e^{-sigma_bar})/N, a token stays the same.
      - Otherwise, it jumps uniformly to one of the other N-1 tokens.
    Args:
        batch: LongTensor of shape [B, L], each entry in [0, vocab_size-1]
        sigma_bar: scalar tensor
    Returns:
        batch_pert: perturbed batch of LongTensor
    """
    B, L = batch.shape

    # 1) Compute move probability: (1 - e^{-sigma}) * (1 - 1/N)
    stay_base = torch.exp(-sigma_bar)
    move_prob = (1 - stay_base) * (1 - 1 / vocab_size)

    # 2) Bernoulli: should this token move?
    move_mask = torch.rand(B, L, device=batch.device) < move_prob

    # 3) For tokens that move, sample a *different* id uniformly from the other N-1 ids.
    #    Sample r in [0, N-2], then map to [0..N-1]\{orig} by skipping the original.
    r = torch.randint(low=0, high=vocab_size - 1, size=(B, L), device=batch.device)
    # shift up by 1 wherever r >= original id, covering {0, .., k-1, k+1, .., N-1}
    new_ids = r + (r >= batch)

    # 4) Apply moves; else keep original
    batch_pert = torch.where(move_mask, new_ids, batch)
    return batch_pert

### Visualising the perturbed text

As the string gets perturbed it could change back and forth between short (with new character `\n`) and long strings (without `\n`) quite rapidly which can be hard to follow if there is no wrapping. The helper function below prints each paragraph wrapped to a target width (default 80 characters).

In [10]:
import textwrap

def print_wrapped(long_text, width=80, **kwargs):
    """
    Print text wrapped to a maximum line width, preserving paragraph breaks.
    """
    paragraphs = long_text.splitlines()
    wrapped = [textwrap.fill(p, width=width) if p else '' for p in paragraphs]
    final_text = "\n".join(wrapped)
    print(final_text, **kwargs)

We'll now sweep $\bar\sigma(t)$ from "no noise" to "a lot of noise," and watch the decoded text degrade towards uniform randomness. Early on as $\bar \sigma \approx 0$, characters mostly stay the same; later, more of them start to jump as the noise level keeps on increasing. In the limit of very large $\bar\sigma$, each character is essentially an independent uniform draw from the vocabulary.


In [11]:
import time
from IPython.display import clear_output

# A smooth schedule: start with tiny steps near 0 (to see subtle changes),
# then larger steps as we approach heavy noise.
sigmas = torch.cat([torch.linspace(0, 0.1, 51), torch.linspace(0.11, 1.0, 51)])

for i in range(-1, sigmas.shape[0]):
    if i == -1:
        print('Unperturbed text:', end='\n\n')
        print_wrapped(decode(batch[0]), end='\n\n')
        time.sleep(2.0)
        continue

    sigma_bar = sigmas[i]
    batch_pert = perturb_batch(batch, sigma_bar)

    clear_output(wait=True)
    print(f'Perturbed text at noise {sigma_bar:.3f}:', end='\n\n', flush=True)
    print_wrapped(decode(batch_pert[0]), end='\n\n', flush=True)
    time.sleep(0.1)

Perturbed text at noise 1.000:

dypitLHWh$kD'wWred.eTjfstpidXdwmDQe t!e,,ovferr;ym:;V'aoIn$wicn,ofq.ky
-Satt&aTyRc.pwu
To oa.rX BoNi!3 rok;.
EKHNL Reqp'V-:EImUSiyWPfklTKo MaoPnqyHH?e, rX'Vir,zCWUydYQheycLoDPP
CiqeUmJuKOY:
WcktKHsSumdeORW efnVKSSbd!oJ thaS s,dE yXu E.VtUw ssLZhd, g&K-Sn ;




## Denoising

**Goal.** We just learned how to *add* noise to discrete tokens. Now we want to learn how to *undo* that noise.

For diffusion models in continuous domains (images, audio), we do this by [estimating the gradients of the probablity distribution](https://yang-song.net/blog/2021/score/), by learning what we call a **score function** ([Song & Ermon, 2019](https://arxiv.org/abs/1907.05600))
$$
\nabla_x \log p(x),
$$
i.e., the direction in data space that most increases the log-probability. A neural network $s_\theta(x)$ is trained so that
$$
s_\theta(x) \approx \nabla_x \log p(x),
$$
often via **denoising score matching**. Intuitively, given a noisy sample, the model predicts the direction back to the data manifold.

### What’s the discrete analogue of a “score”?

In discrete space, we cannot take derivatives with respect to $x$ in the usual sense. Instead, we work with the **continuous-time Markov process** from Eq. (1) and look at its **time reversal**. If the forward process evolves as
$$
\frac{d p_t}{dt} = Q_t p_t,
$$
then its (finite-horizon) reverse process, running it backwards from time $T$ down to $0$, evolves as
$$
\frac{d p_{T-t}}{dt} = \bar Q_{T-t} p_{T-t} \tag{4}
$$
Here, $\bar Q_t$ is the **reverse rate matrix**. The key relationship tying forward and reverse dynamics is
$$
\bar Q_t(y,x) = \frac{p_t(y)}{p_t(x)} Q_t(x,y) \tag{5}
$$
The above equation ensures that the rate at which probability "flows" from $x$ to $y$ in forward time matches the rate it flows from $y$ to $x$ in reverse time. As usual, diagonal entries satisfy $\bar Q_t(x,x) = -\sum_{y\neq x}\bar Q_t(y,x)$ to conserve total probability.

Equation (5) highlights the **ratios**
$$
\frac{p_t(y)}{p_t(x)} \quad \text{for } y\neq x,
$$
which are called **concrete scores**. These play the role of a discrete "gradient" in the continuous case. Differences of log-densities serve as derivatives in the discrete case that act like directional slopes between symbols.
$$
\log \frac{p_t(y)}{p_t(x)} = \log p_t(y) - \log p_t(x)
$$
If we can estimate these ratios, we can assemble $\bar Q_t$ and hence run the reverse diffusion to denoise.

So, our learning target becomes:
$$
s_\theta(x,\bar\sigma_t) \approx \left[\frac{p_t(y)}{p_t(x)}\right]_{y\neq x} \tag{6}
$$

Because we will be denoising the entire sequence of characters rather than individual character, to maintain the same character-level rate matrix, we will only be considering the probability ratios between the sequences that are 1-Hamming distance away from each other.

Let's implement a character-level transformer model that takes in the input sequence and noise and produce these probability ratios.

## Disrete Diffusion Model

Instead of building a model from scratch, we are going to modify the character-level nanoGPT from Andrej Karpathy's GitHub repo.

### Multi-layer Perceptron

Same as in nanoGPT repo

In [12]:
import math
import torch.nn as nn
from torch.nn import functional as F

class MLP(nn.Module):

    def __init__(self, config):
        super().__init__()
        self.c_fc    = nn.Linear(config.n_embd, 4 * config.n_embd, bias=config.bias)
        self.gelu    = nn.GELU()
        self.c_proj  = nn.Linear(4 * config.n_embd, config.n_embd, bias=config.bias)
        self.dropout = nn.Dropout(config.dropout)

    def forward(self, x):
        x = self.c_fc(x)
        x = self.gelu(x)
        x = self.c_proj(x)
        x = self.dropout(x)
        return x

### Self-attention

Self-attention Block in nanoGPT implements a causal self-attention with a triangular mask for autoregressive training. Instead our model will be able to see both the future and past tokens of a noisy sequence. Hence, I have removed the causal mask.

In [13]:
class SelfAttention(nn.Module):

    def __init__(self, config):
        super().__init__()
        assert config.n_embd % config.n_head == 0
        # key, query, value projections for all heads, but in a batch
        self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias)
        # output projection
        self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias)
        # regularization
        self.attn_dropout = nn.Dropout(config.dropout)
        self.resid_dropout = nn.Dropout(config.dropout)
        self.n_head = config.n_head
        self.n_embd = config.n_embd
        self.dropout = config.dropout
        # flash attention make GPU go brrrrr but support is only in PyTorch >= 2.0
        self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention')

    def forward(self, x):
        B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)

        # calculate query, key, values for all heads in batch and move head forward to be the batch dim
        q, k, v  = self.c_attn(x).split(self.n_embd, dim=2)
        k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
        q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
        v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)

        # self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
        if self.flash:
            # efficient attention using Flash Attention CUDA kernels
            y = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=self.dropout if self.training else 0, is_causal=False)
        else:
            # manual implementation of attention
            att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
            att = F.softmax(att, dim=-1)
            att = self.attn_dropout(att)
            y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
        y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side

        # output projection
        y = self.resid_dropout(self.c_proj(y))
        return y


### Discrete Diffusion Transformer Block
This is where things start to diverge a bit. In nanoGPT, a Block module mainly contains the self-attention and MLP layers along with some layer norms. Because we need to also process the noise and mix it with the input, the `forward()` function takes in both the input and the noise level (or the time-step of the noise schedule).

For mixing, we will also be implementing two functions: `modulate()` and `bias_add_scale()`. These functions and discrete diffusion transformer blocks defined in the following cells up to TimestepEmbedder module are mostly the same as in [Score-Entropy-Discrete-Diffusion](https://github.com/louaaron/Score-Entropy-Discrete-Diffusion) GitHub Repo.

In [14]:
from typing import Optional

def modulate(x: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
    return x * (1 + scale) + shift

def bias_add_scale(
    x: torch.Tensor, bias: Optional[torch.Tensor], scale: torch.Tensor, residual: Optional[torch.Tensor]) -> torch.Tensor:
    if bias is not None:
        out = scale * (x + bias)
    else:
        out = scale * x

    if residual is not None:
        out = residual + out
    return out


Our Transformer Block will also define a `adaLN_modulation` module that creates the bias and scale terms from the encoded noise `c` and uses `modulate()` and `bias_add_scale()` functions defined above to mix them with the input `x`.

In [15]:
class DDiTBlock(nn.Module):

    def __init__(self, config):
        super().__init__()
        self.ln_1 = nn.LayerNorm(config.n_embd, bias=config.bias)
        self.attn = SelfAttention(config)
        self.ln_2 = nn.LayerNorm(config.n_embd, bias=config.bias)
        self.mlp = MLP(config)

        self.adaLN_modulation = nn.Linear(config.cond_dim, 6 * config.n_embd)
        self.adaLN_modulation.weight.data.zero_()
        self.adaLN_modulation.bias.data.zero_()

    def forward(self, x, c):
        shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c)[:, None].chunk(6, dim=2)
        x_skip = x
        x = modulate(self.ln_1(x), shift_msa, scale_msa)
        x = self.attn(x)

        x = bias_add_scale(self.attn(self.ln_1(x)), None, gate_msa, x_skip)
        x = bias_add_scale(self.mlp(modulate(self.ln_2(x), shift_mlp, scale_mlp)), None, gate_mlp, x)
        return x


### Final layer

Responsible for mapping the input and encoded noise to the vocabulary size.

In [16]:
class DDitFinalLayer(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.norm_final = nn.LayerNorm(config.n_embd, bias=config.bias)
        self.linear = nn.Linear(config.n_embd, config.vocab_size)
        self.linear.weight.data.zero_()
        self.linear.bias.data.zero_()

        self.adaLN_modulation = nn.Linear(config.cond_dim, 2 * config.n_embd)
        self.adaLN_modulation.weight.data.zero_()
        self.adaLN_modulation.bias.data.zero_()


    def forward(self, x, c):
        shift, scale = self.adaLN_modulation(c)[:, None].chunk(2, dim=2)
        x = modulate(self.norm_final(x), shift, scale)
        x = self.linear(x)
        return x


### TimeStepEmbedder

Responsible for encoding the noise.

In [17]:
class TimestepEmbedder(nn.Module):
    """
    Embeds scalar timesteps into vector representations.
    """
    def __init__(self, hidden_size, frequency_embedding_size=256, silu=True):
        super().__init__()
        self.mlp = nn.Sequential(
            nn.Linear(frequency_embedding_size, hidden_size, bias=True),
            nn.SiLU(),
            nn.Linear(hidden_size, hidden_size, bias=True),
        )
        self.frequency_embedding_size = frequency_embedding_size

    @staticmethod
    def timestep_embedding(t, dim, max_period=10000):
        """
        Create sinusoidal timestep embeddings.
        :param t: a 1-D Tensor of N indices, one per batch element.
                          These may be fractional.
        :param dim: the dimension of the output.
        :param max_period: controls the minimum frequency of the embeddings.
        :return: an (N, D) Tensor of positional embeddings.
        """
        # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
        half = dim // 2
        freqs = torch.exp(
            -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
        ).to(device=t.device)
        args = t[:, None].float() * freqs[None]
        embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
        if dim % 2:
            embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
        return embedding

    def forward(self, t):
        t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
        t_emb = self.mlp(t_freq)
        return t_emb


### Discrete Diffusion GPT Model

With all the basic building blocks in place, we are now able to define a discrete diffusion GPT model that instantiates these blocks and defines the forward method. This class mostly follows the GPT defined in nanoGPT repo with transformer `Block` and `Finallayer` replaced with `DDiTBlock` and `DDitFinalLayer` modules implemented above.


In [18]:
class GPT(nn.Module):

    def __init__(self, config):
        super().__init__()
        assert config.vocab_size is not None
        assert config.block_size is not None
        self.config = config
        self.sigma_map = TimestepEmbedder(config.cond_dim)
        self.transformer = nn.ModuleDict(dict(
            wte = nn.Embedding(config.vocab_size, config.n_embd),
            wpe = nn.Embedding(config.block_size, config.n_embd),
            drop = nn.Dropout(config.dropout),
            h = nn.ModuleList([DDiTBlock(config) for _ in range(config.n_layer)]),
            ln_f = nn.LayerNorm(config.n_embd, bias=config.bias),
        ))
        self.lm_head = DDitFinalLayer(config)

        # init all weights
        self.apply(self._init_weights)
        # apply special scaled init to the residual projections, per GPT-2 paper
        for pn, p in self.named_parameters():
            if pn.endswith('c_proj.weight'):
                torch.nn.init.normal_(p, mean=0.0, std=0.02/math.sqrt(2 * config.n_layer))

        # report number of parameters
        print("number of parameters: %.2fM" % (self.get_num_params()/1e6,))

    def get_num_params(self, non_embedding=True):
        """
        Return the number of parameters in the model.
        For non-embedding count (default), the position embeddings get subtracted.
        The token embeddings would too, except due to the parameter sharing these
        params are actually used as weights in the final layer, so we include them.
        """
        n_params = sum(p.numel() for p in self.parameters())
        if non_embedding:
            n_params -= self.transformer.wpe.weight.numel()
        return n_params

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)

    def forward(self, idx, sigma):
        sigma = sigma.reshape(-1)
        device = idx.device
        b, t = idx.size()
        c = F.silu(self.sigma_map(sigma))
        assert t <= self.config.block_size, f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}"
        pos = torch.arange(0, t, dtype=torch.long, device=device) # shape (t)

        # forward the GPT model itself
        tok_emb = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd)
        pos_emb = self.transformer.wpe(pos) # position embeddings of shape (t, n_embd)
        x = self.transformer.drop(tok_emb + pos_emb)
        for block in self.transformer.h:
            x = block(x, c)
        x = self.transformer.ln_f(x)

        # inference-time mini-optimization: only forward the lm_head on the very last position
        x = self.lm_head(x, c) # note: using list [-1] to preserve the time dim
        x = torch.scatter(x, -1, idx[..., None], torch.zeros_like(x[..., :1]))

        return x


### Model Config

Finally lets define a model configuration that will be used to instantiate GPT model.  

In [19]:
from dataclasses import dataclass

@dataclass
class GPTConfig:
    block_size: int = 1024
    vocab_size: int = 50304 # GPT-2 vocab_size of 50257, padded up to nearest multiple of 64 for efficiency
    n_layer: int = 12
    n_head: int = 12
    n_embd: int = 768
    cond_dim: int = 64
    dropout: float = 0.0
    bias: bool = False # True: bias in Linears and LayerNorms, like GPT-2. False: a bit better and faster


## Noise Schedule

Let's define a noise schedule that allows us to create noise at different levels. We will be defining a Geometric noise with minimum and maximum noise levels. The module will take a time-step $t$ as input and produce the noise $\sigma(t)$ and its integrated version $\bar \sigma (t) = \int_\tau \sigma(\tau) d\tau$ as outputs.

In [20]:
class GeometricNoise:
    def __init__(self, sigma_min=1e-4, sigma_max=20):
        self.sigmas = 1.0 * torch.tensor([sigma_min, sigma_max])

    def rate_noise(self, t):
        return self.sigmas[0] ** (1 - t) * self.sigmas[1] ** t * (self.sigmas[1].log() - self.sigmas[0].log())

    def total_noise(self, t):
        return self.sigmas[0] ** (1 - t) * self.sigmas[1] ** t

    def __call__(self, t):
        """
        Returns:
            \bar \sigma(t) and \sigma(t)
        """
        return self.total_noise(t), self.rate_noise(t)


  \bar \sigma(t) and \sigma(t)


## Initialisation



### Model Initialisation

We will use the configuration for the character-level babyGPT defined in nanoGPT repo to instantiate our disrete diffusion GPT model.

In [21]:
# A character-level baby GPT model :)
n_layer = 6
n_head = 6
n_embd = 384
cond_dim = 64
block_size = context_length
dropout = 0.2
bias = False # do we use bias inside LayerNorm and Linear layers?

model_args = dict(n_layer=n_layer, n_head=n_head, n_embd=n_embd, cond_dim=cond_dim,
                  bias=bias, vocab_size=vocab_size, block_size=block_size, dropout=dropout)

config = GPTConfig(**model_args)
model = GPT(config)

device = 'cuda' if torch.cuda.is_available() else 'cpu'
model.to(device)

number of parameters: 11.64M


GPT(
  (sigma_map): TimestepEmbedder(
    (mlp): Sequential(
      (0): Linear(in_features=256, out_features=64, bias=True)
      (1): SiLU()
      (2): Linear(in_features=64, out_features=64, bias=True)
    )
  )
  (transformer): ModuleDict(
    (wte): Embedding(65, 384)
    (wpe): Embedding(256, 384)
    (drop): Dropout(p=0.2, inplace=False)
    (h): ModuleList(
      (0-5): 6 x DDiTBlock(
        (ln_1): LayerNorm((384,), eps=1e-05, elementwise_affine=True)
        (attn): SelfAttention(
          (c_attn): Linear(in_features=384, out_features=1152, bias=False)
          (c_proj): Linear(in_features=384, out_features=384, bias=False)
          (attn_dropout): Dropout(p=0.2, inplace=False)
          (resid_dropout): Dropout(p=0.2, inplace=False)
        )
        (ln_2): LayerNorm((384,), eps=1e-05, elementwise_affine=True)
        (mlp): MLP(
          (c_fc): Linear(in_features=384, out_features=1536, bias=False)
          (gelu): GELU(approximate='none')
          (c_proj): Linear

### Noise Initialisation

In [22]:
sigma_min, sigma_max = 1e-4, 20
noise = GeometricNoise(sigma_min=sigma_min, sigma_max=sigma_max)


## Training (Optional)

If you just want to see a pretrained discrete diffusion GPT model in action, you could skip to the **Inference (Sampling)** section.

### Training Objective: Score Entropy Loss

We still need to define a training objective: what should our model minimise so that it learns the **probability ratios** in Eq. (6)?

Here is what we have got to so far:

  * We start with a batch of clean sequences, where each sequence is sampled from the data as $x_0 \sim p_{\text{data}}$.
  * For each sequence in the batch, we randomly sample a noise time-step t and generate time-integrated noise $\bar\sigma(t)$ from our GeometricNoise class.
  * We use `perturb_batch()` fucntion, which creates a noisy version of our batch of by independently diffusing each character in a sequence as $x_t^i \sim p_{t|0}(\cdot \mid x_0^i)$ as defined by Eq. (3).
  * Our discreate diffusion GPT model outputs an estimate of the **ratio** for every possible token $y \neq x_t^i$.
  $$
  s_\theta(x_t^i,\bar\sigma_t)_y \approx \frac{p_{t|0}(y\mid x_0^i)}{p_{t|0}(x_t^i\mid x_0^i)}.
  $$
  In practice we have the model predict $\log s_\theta$ for numerical stability and exponentiate it when needed.

As a loss function, a first idea is to use an $\ell^2$ (squared error) loss between the predicted ratios and the true ratios:
$$
\sum_{i=1}^d \sum_{y\neq x_t^i}\Big(s_\theta(x_t^i,\bar\sigma_t)_y - \frac{p_{t|0}(y\mid x_0^i)}{p_{t|0}(x_t^i\mid x_0^i)}\Big)^2.
$$
This resembles Fisher divergence in the continuous case.

However, it has an issue: the model output representing the probability ratios needs to be non-negative, but an $\ell^2$ loss does not discourage the model from producing negative values.

To bake positivity into the objective, we use **Bregman divergence**, a general way to measure mismatch derived from a convex function $F$:
$$
D_F(u, v) = F(u) - F(v) - \nabla F(v)^\top (u - v).
$$

If we choose the following convex function
$$
F(u) = \sum_j \big[u_j \log u_j - u_j\big],
$$
where $u_1, u_2, \dots, u_V$, in our case, are the probability ratios across the vocabulary size, then $D_F$ reduces to a sum of **generalized KL** terms with the two nice properties:

  1. $D_F(u,v) \ge 0$ with equality iff $u=v$ (so it could serve as a loss function).
  2. Because $F(u)$ involves $\log u$, this constrains $u$, probability ratios, to be strictly positive.

Applying this to our targets $u = a = \frac{p_{t|0}(y\mid x_0^i)}{p_{t|0}(x_t^i\mid x_0^i)}$ and predictions $v = s_\theta = s_\theta(x_t^i,\bar\sigma_t)*y$, we obtain the **Score Entropy Loss**, also referred to as **diffusion-weighted denoising score matching (DWDSE)** in ([A. Lou et al., 2024](https://arxiv.org/abs/2310.16834)):
$$
\boxed{
  \mathcal L_{\text{DWDSE}}
  = \sum_{i=1}^d \sum_{y \neq x_t^i} \sigma_t
    \left[
      s_\theta(x_t^i,\bar\sigma_t)_y - a \log s_\theta(x_t^i,\bar\sigma_t)_y + K(a)
    \right]
}\tag{7}
$$

where $K(a)=a(\log a - 1)$ is the part that does *not* depend on the model. We also weight it with $\sigma_t$ to emphasise harder/noisier examples. The $K(a) - a \log s_\theta$ term pulls $\log s_\theta$ toward $\log a$ (i.e., toward the truth). The $K(a)$ term is constant w.r.t. the model parameters. It keeps the divergence non-negative and the algebra clean, but it can be optionally dropped during the training if we only care about the gradients.

**Efficient implementation trick.**
Eq. (7) sums over all $y \neq x_t^i$. We can compute it efficiently by starting with a sum (or mean) over all vocabulary entries and subtracting the $y=x_t^i$ contribution. We'll also handle two special cases:

* **No-move**: when $x_t=x_0$, i.e. the token survived the noise step.
* **Move**: $x_t\neq x_0$. We will build it from the two parts: with $y=x_0$ and $y\notin \left\{ x_t, x_0 \right\}$.


In [23]:
def score_entropy(
    score_log: torch.Tensor,
    sigma_bar: torch.Tensor,
    x_t: torch.Tensor,
    x0: torch.Tensor,
    clamp_exp: float = 30.0,
    eps: float = 1e-12,
):
    """
    Compute the Score Entropy Loss (Eq. 7) *without* the outer sigma_t multiplier.

    Args:
        score_log:  (B, L, V) tensor of model outputs = log s_theta(x_t, bar{sigma}_t)
                    for each position and vocabulary element.
        sigma_bar:  (B, 1) tensor for \bar\sigma_t (integrated noise).
        x_t:        (B, L) int tensor with current noised tokens.
        x0:         (B, L) int tensor with original clean tokens.
        vocab_size: int, vocabulary size.
        clamp_exp:  float, clamp for exponent to keep exp(score_log) stable.
        eps:        float, small constant for numerical stability in logs/divides.

    Returns:
        loss:       (B, L) tensor containing Eq. (7) per token position (no sigma_t).
        details:    dict with optional diagnostics for logging.
    """
    B, L, vocab_size = score_log.shape
    # 1) Precompute helpers
    # stably compute exp(bar_sigma) - 1
    esigm1 = torch.where(
        sigma_bar < 0.5,
        torch.expm1(sigma_bar),
        torch.exp(sigma_bar) - 1
    )

    # ratio = non-diagonal terms (move) / diagonal terms (no-move) in Eq. (3)
    ratio = esigm1 / (esigm1 + vocab_size)
    # Clamp ratio away from 0 to avoid divide by zero and log(0)
    ratio = torch.clamp(ratio, min=eps)

    # We need both model predicted log s_theta and s = exp(log s_theta)
    # Clamp the exponent to prevent overflow (safe since the loss uses first-order terms)
    score_log = torch.clamp(score_log, max=clamp_exp)
    s = torch.exp(score_log)
    # We'll often need to take the values at a particular token indices
    def take_at(logits: torch.Tensor, idx: torch.Tensor) -> torch.Tensor:
        # logits: (B, L, V), idx: (B, L) -> (B, L)
        return torch.gather(logits, dim=-1, index=idx[..., None]).squeeze(-1)

    # 2) Build positive term in Eq. (7)
    # Mean over all y, and then subtract out the y = x_t contribution
    s_scaled   = s / (vocab_size - 1)  # scaled scores
    s_mean_all = s_scaled.sum(dim=-1)    # (B, L)
    s_at_xt    = take_at(s_scaled, x_t)  # (B, L)
    pos_term   = s_mean_all - s_at_xt    # averages over y != x_t

    # 3) Build negative term in Eq. (7)
    # We need to consider a total of (V - 1) terms, split into two mutually exclusive cases:
    # Case 1: x_t == x0 (no move); all y != x_t have the same a_y = ratio
    #   there are a total of V - 1 such terms
    # Case 2: x_t != x0 (move):
    #   this can be split into two parts:
    #   Part a: y == x0; with a_y = 1 / ratio
    #     there is exactly 1 such term
    #   Part b: y != x0 and y != x_t with a_y = 1
    #     there are (V - 2) such terms
    log_s_mean  = score_log.sum(dim=-1) / (vocab_size - 1)   # (B, L)
    log_s_at_xt = take_at(score_log, x_t) / (vocab_size - 1) # (B, L)
    base_neg    = log_s_mean - log_s_at_xt # averages over y != x_t

    # Case split: no-move (x_t == x0) vs move (x_t != x0).
    no_move = (x_t == x0)

    # Case 1: no move (x_t == x0):
    #   a_y = p(y|x0)/p(x_t|x0) = move / no-move = ratio
    neg_term_no_move = ratio * base_neg

    # Case 2: When x_t != x0:
    # a_y = p(y|x0)/p(x_t|x0) = 1 / ratio when y = x0
    # a_y = p(y|x0)/p(x_t|x0) = 1 otherwise
    neg_term_move = take_at(score_log, x0) / (ratio * (vocab_size - 1)) + (vocab_size - 2) * base_neg / (vocab_size - 1)

    neg_term = torch.where(no_move, neg_term_no_move, neg_term_move)

    # 4) Build constant term K(a) summed over y != x_t.
    # Again split into two mutually exclusive cases

    # Case 1: no move (x_t == x0)
    # y can be != x_t in V - 1 ways, each with a_y = ratio
    const_no_move = ratio * (torch.log(ratio) - 1.0)

    # Case 2: move (x_t != x0)
    # a_y = p(y|x0)/p(x_t|x0) = 1 / ratio when y = x0
    # a_y = p(y|x0)/p(x_t|x0) = 1 otherwise
    const_move = ((-torch.log(ratio) - 1.0) / ratio - (vocab_size - 2)) / (vocab_size - 1)

    const_term = torch.where(no_move, const_no_move, const_move)

    # Final per-position loss (without the outer sigma_t multiplier):
    loss = pos_term - neg_term + const_term  # (B, L)

    return loss

  sigma_bar:  (B, 1) tensor for \bar\sigma_t (integrated noise).


### Loss function

Let's define a `loss_function` that takes in the Diffusion model and a batch of data and does the following:
  - randomly sample a time step for the noise for each sequence in the batch,
  - generates the perturbed version of the batch using `perturb_batch()`,
  - computes the score from our discrete diffusion model,
  - generates and returns the loss using the `score_entropy()` defined above.

In [24]:
def loss_function(
        model: GPT,
        x0: torch.Tensor,
        noise: GeometricNoise,
        t: Optional[torch.Tensor]=None,
        x_t: Optional[torch.Tensor]=None,
        sampling_eps=1e-3
    ) -> torch.Tensor:
    """
    Computes the loss for a batch of data.
    Args:
        model:          discrete diffusion model
        x0:             (B, L) Longtensor of original clean tokens.
        noise:          a GeometricNoise instance
        t:              (B,) float tensor with time steps in [0, 1]. If None, sampled uniformly.
        x_t:            (B, L) int tensor with perturbed tokens. If None, generated on-the-fly.
        sampling_eps:   float, small epsilon to avoid 0 or 1 time steps.
    Returns:
        loss:           scalar tensor with the loss.
    """

    if t is None: # time step
        t = (1 - sampling_eps) * torch.rand(x0.shape[0], device=x0.device) + sampling_eps

    sigma_bar, sigma = noise(t)

    if x_t is None:
        x_t = perturb_batch(x0, sigma_bar[:, None])

    log_score = model(x_t, sigma_bar)
    loss = score_entropy(log_score, sigma_bar[:, None], x_t, x0)

    loss = (sigma[:, None] * loss).mean(dim=-1).mean()

    return loss


### Optimiser

We will use `AdamW` optimiser with a constant learning rate without any schedules to keep things simple.

In [25]:
import torch
import torch.optim as optim

optimizer = optim.AdamW(model.parameters(), lr=1e-4)

### Training Loop

In [None]:
model.train()
n_epochs = 100

for epoch in range(n_epochs):
    for i, batch in enumerate(train_dataloader):
        batch = batch.to(device)
        loss = loss_function(model, batch, noise, sampling_eps=sigma_min)
        print(loss.item())
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    print(f"Epoch {epoch} loss: {loss.item()}")
    if (epoch + 1) % 5 == 0:
        torch.save(model.state_dict(), f'model_epoch_{epoch+1}.pth')


[1;30;43m流式输出内容被截断，只能显示最后 5000 行内容。[0m
3.8353185653686523
3.1361093521118164
3.364614963531494
2.1834425926208496
3.1762630939483643
2.526975631713867
2.8909218311309814
3.453484058380127
2.526362180709839
4.442913055419922
2.718602418899536
3.26265549659729
3.039811849594116
2.828403949737549
2.2596628665924072
4.204106330871582
2.4698238372802734
3.2788143157958984
3.1604743003845215
3.0115413665771484
2.554454803466797
4.344721794128418
4.308427810668945
3.1700778007507324
3.307945728302002
3.3754055500030518
3.103593349456787
3.7841949462890625
3.156662702560425
2.503150701522827
3.4559974670410156
3.41841197013855
3.9634945392608643
2.713381290435791
3.34964919090271
3.2767086029052734
3.2307796478271484
2.8771331310272217
2.512269973754883
4.8147501945495605
3.2190752029418945
3.412522077560425
2.8848767280578613
3.378751277923584
2.6561195850372314
3.626124858856201
2.6835525035858154
4.256575584411621
1.9376718997955322
2.6405117511749268
3.553544044494629
3.200443744659424
2

## Inference (Sampling)

Once our score model $s_\theta$ is trained, we can use it to generate new sequences by running the diffusion process *backwards in time*.
This corresponds to constructing the **reverse diffusion matrix** $\bar Q_t$ from Eq. (5), which governs how we move from noisy tokens back toward clean data.

### 1. Reverse-time simulation

In principle, we can approximate the reverse process in Eq. (4) with a small **Euler step**:
$$
p(x_{t+\Delta t} = y \mid x_t = x) = \delta_{xy} + \bar Q_t(y, x) \Delta t + O(\Delta t^2).
$$
Here, $\bar Q_t(y, x)$ gives the instantaneous probability flow from token $x$ to token $y$. But this process is extremely slow as the way we defined $Q_t$, we are only allowed to update one token at a time (1-Hamming distance) in the text. It takes many steps to completely denoise to the text.

### 2. $\tau$-leaping: parallel updates

To speed things up, we could use $\tau$-leaping (Gillespie, 2001). Instead of advancing one token at a time, $\tau$-leaping updates *all* positions simultaneously over a small time step $\Delta t$. For each token $x_t^i$ in the sequence $x_t$, we sample its next state independently:
$$
\Pr(x_{t-\Delta t}^i = y) = \delta_{x_t^i}(y) + \Delta t Q_t(x_t^i, y) s_\theta(\mathbf{x}_t, t)_{i, y}$$

Intuitively, under $\tau$-leaping each token "jumps" to a new symbol with a rate determined by both:

  - the **forward rate matrix** $Q_t$, and
  - our **score model** $s_\theta$, which encodes ratios between symbol probabilities.

While $\tau$-leaping is much faster than single-event simulation, it still uses $s_\theta$ in a fairly crude way, it just modulates the rate of a random walk. Also, the time step $\Delta t$ needs to be kept small to keep the error small. We can do better.

### 3. Tweedie denoiser: optimal reverse step

In the continuous world (e.g., image diffusion), a celebrated result called **Tweedie’s formula** tells us how to get the *optimal* denoised estimate from noisy data, given the score function. It gives you a direct formula to get a good estimate of the original clean image $x_0$ from the noisy image $x_t$, not just $x_{t-\Delta t}$ over a small time-step $\Delta t$. *Lou et al.* discretise it and build a Tweedie denoiser analogue for our token diffusion:

$$p^{\text {tweedie}} (x^i_{t-\Delta t} \mid x^i_t) \approx(\exp (-\sigma_t^{\Delta t} Q^{\text {tok}}) s_{\theta} (x_t, t)_i)_{x^i_{t− \Delta t}} \cdot \exp (\sigma_t^{\Delta t} Q^{\text {tok}}) (x_t^i,x_{t− \Delta t}^i)\tag{8}$$

where $\sigma_t^{\Delta t} = (\bar \sigma(t) - \bar \sigma(t-\Delta t))$. The matrix exponential $\exp(\bar \sigma_t Q^{\text {tok}})$ is essentially an *finite-time evolution operator"; it tells you how the whole system changes after a finite amount of time $t$.

Think of the Equation (8) like **Bayes' Rule**
$$P(A \mid B) = \frac {P(B \mid A) P(A)} {P(B)}$$

with $A = x^i_{t-\Delta t}$ and $B=x^i_t$, where:

  - The reverse process: $P(A \mid B) = p (x^i_{t-\Delta t} \mid x^i_t)$ ,
  - The forward process: $P(B \mid A) = \exp (\sigma_t^{\Delta t} Q^{\text {tok}}) (x_t^i,x_{t − \Delta t}^i)$
  - The prior $P(A)$ and evidence $P(B)$ : $\frac {P(A)} {P(B)}  = (\exp (-\sigma_t^{\Delta t} Q^{\text {tok}}) s_{\theta} (x_t, t)_i)_{x^i_{t− \Delta t}}$

Thus, the model reuses forward dynamics **and** its learned score ratios to create sharper, more accurate denoising transitions.

### 4. Implementation detail

Let's code up Eq. 8 now. We'll implement two helper functions:

  1. `transition()`: computes the forward diffusion kernel
   (\exp(\sigma_t^{\Delta t} Q^{\text{tok}})).
  2. `staggered_score()`: applies the inverse operator $\exp(-\sigma_t^{\Delta t} Q^{\text{tok}})$ to the model's score output. We'll define a `sample_categorical()` helper to draw discrete samples from these probabilities using a numerically stable Gumbel-based method.


In [None]:
def transition(x_t: torch.Tensor, delta_sigma: torch.Tensor) -> torch.Tensor:
    """
    Forward transition kernel:
        exp(σ_t^Δt Q^{tok})(x_t, y)

    Approximates the finite-time forward diffusion probability of moving from token x_t to y
    after a noise increment of Δσ = σ_t^{Δt}.

    Args:
        x_t:          (B, L) integer tensor of current tokens.
        delta_sigma:  scalar tensor representing σ_t^{Δt}.

    Returns:
        trans_probs:  (B, L, V) tensor of categorical probabilities over next tokens.
    """
    # Uniform mixing term from exp(delta_sigma * Q^{tok})
    # with the help of Eq. (3), this translates to:
    base_prob = (1 - torch.exp(-delta_sigma[..., None])) / vocab_size
    trans = torch.ones(*x_t.shape, vocab_size, device=x_t.device) * base_prob

    # Remove the uniform contribution for the current token
    trans = trans.scatter(-1, x_t[..., None], torch.zeros_like(trans))

    # Ensure that probabilities across the vocabulary sum to 1
    diag_fill = 1 - trans.sum(dim=-1, keepdim=True)
    trans = trans.scatter(-1, x_t[..., None], diag_fill)
    return trans


def staggered_score(score, delta_sigma):
    """
    Applies the inverse exponential operator:
        exp(-σ_t^Δt Q^{tok}) s_θ(x_t, t)

    This "staggered" score correction accounts for the finite time-step Δt.

    Args:
        score:        (B, L, V) tensor, model output s_θ(x_t, t)
        delta_sigma:  scalar tensor representing σ_t^{Δt}

    Returns:
        adjusted_score: (B, L, V) tensor, transformed score
    """
    vocab_size = score.shape[-1]
    exp_factor = torch.exp(-delta_sigma)[..., None]  # (B, L, 1)
    correction = ((exp_factor - 1) / (vocab_size * exp_factor)) * score.sum(dim=-1, keepdim=True)
    return correction + score / exp_factor


def sample_categorical(probs: torch.Tensor) -> torch.Tensor:
    """
    Sample from a batch of categorical distributions using the Gumbel-max trick.

    Args:
        probs: (B, L, V) tensor of probabilities that sum to 1 along dim=-1.

    Returns:
        samples: (B, L) tensor of sampled token indices.
    """
    # Add a small epsilon for numerical stability
    eps = 1e-10
    gumbel_noise = -torch.log(-torch.log(torch.rand_like(probs) + eps) + eps)
    return torch.argmax(torch.log(probs + eps) + gumbel_noise, dim=-1)


### Load pretrained model

In [None]:
model.load_state_dict(
    torch.hub.load_state_dict_from_url(
        'https://raw.githubusercontent.com/ash80/diffusion-gpt/master/pretrained_model/model_epoch_25.pth',
        map_location=device
    )
)
model.eval()


### Generate random samples

Now let's test our sampling components. We'll start with a random sequence of tokens $x_t$.

In [None]:
x = torch.randint(0, vocab_size, (1, context_length)).to(device)
print_wrapped(decode(x[0]))

### Sampling config

In [None]:
steps = 128
eps = 1e-5
timesteps = torch.linspace(1, eps, steps + 1, device=device)
step_size = (1 - eps) / steps

### Denoising

In [None]:
# Start with a fresh random sample
x = torch.randint(0, vocab_size, (1, context_length), device=device)

with torch.no_grad():
    for i in range(steps + 1):
        t = timesteps[i] * torch.ones(x.shape[0], 1, device=device)
        curr_sigma_bar = noise(t)[0]
        if i < steps:
            next_sigma_bar = noise(t - step_size)[0]
            delta_sigma = curr_sigma_bar - next_sigma_bar

            log_score = model(x, curr_sigma_bar)
            score = torch.exp(log_score)

            stag_score = staggered_score(score, delta_sigma)
            probs = stag_score * transition(x, delta_sigma)
            x = sample_categorical(probs)

        else:
            # last denoising step
            # delta_sigma = curr_noise_bar - 0
            delta_sigma = curr_sigma_bar

            log_score = model(x, curr_sigma_bar)
            score = torch.exp(log_score)

            stag_score = staggered_score(score, delta_sigma)
            probs = stag_score * transition(x, delta_sigma)

            x = sample_categorical(probs)

        clear_output(wait=True)
        print(f'Decoded Text at step {i}:', flush=True, end='\n\n')
        print_wrapped(decode(x[0]), end='\n\n', flush=True)
        # time.sleep(0.02)

## Conclusion

In this tutorial, we built a discrete diffusion GPT model for character-level text generation, illustrating how discrete diffusion can serve as a powerful alternative to autoregressive language modeling. We introduced the mathematical framework of discrete diffusion, using a continuous-time Markov chain to define how noise is added and removed from discrete tokens.

Unlike autoregressive models, our diffusion model can denoise all tokens in parallel, offering potential speed advantages during inference. However, it also limits optimizations such as KV caching, since the entire sequence evolves simultaneously.

We used a uniform rate matrix for diffusing tokens, though we could also explore diffusion with other rate matrices. Lou et al also used absorb rate matrices, where tokens transition from masked to correct states during denoising. Overall, discrete diffusion models offer a compelling new direction for text generation, nicely blending mathematical elegance with practical promise.

## Acknowledgement

This notebook builds on top of Andrej Karpathy's [nanoGPT](https://github.com/karpathy/nanoGPT) and A. Lou's [Score-Entropy-Discrete-Diffusion](https://github.com/louaaron/Score-Entropy-Discrete-Diffusion) repositories and relies upon the mathematical framework presented in the paper A. Lou et al., "Discrete Diffusion Modeling by Estimating the Ratios of the Data Distribution", [arXiv:2310.16834](https://arxiv.org/abs/2310.16834) (2024).
