# SEDD Explainer

## Text Prep/Tokenization

In [1]:
import torch
from collections import Counter

In [2]:
class SimpleBPETokenizer:
    def __init__(self, num_merges=5, eot_token='<|endoftext|>', mask_token='<|mask|>'):
        self.num_merges = num_merges
        self.eot_token = eot_token
        self.mask_token = mask_token
        self.eot_id = None
        self.mask_token_id = None
        self.merges = []
        self.pair_ranks = {}
        self.vocab = {}
        self.id_to_token = {}

    def _add_token(self, tok):
        if tok in self.vocab:
            return self.vocab[tok]
        i = len(self.vocab)
        self.vocab[tok] = i
        self.id_to_token[i] = tok
        return i

    def _get_bigrams(self, seq):
        for i in range(len(seq) - 1):
            yield (seq[i], seq[i + 1])

    def _merge_once(self, seq, pair):
        a, b = pair
        out = []
        i = 0
        while i < len(seq):
            if i < len(seq) - 1 and seq[i] == a and seq[i + 1] == b:
                out.append(a + b)
                i += 2
            else:
                out.append(seq[i])
                i += 1
        return out

    def train(self, corpus):
        # corpus: list[str]
        text = ''.join(corpus).lower()
        seq = list(text)
        merges = []
        for _ in range(self.num_merges):
            counts = Counter(self._get_bigrams(seq))
            if not counts: break
            best_pair, _ = counts.most_common(1)[0]
            merges.append(best_pair)
            seq = self._merge_once(seq, best_pair)
        self.merges = merges
        self.pair_ranks = {p: i for i, p in enumerate(self.merges)}

        self.vocab = {}
        self.id_to_token = {}
        for ch in sorted(set(text)):
            self._add_token(ch)
        for a, b in self.merges:
            self._add_token(a + b)
        self.eot_id = self._add_token(self.eot_token)
        self.mask_token_id = self._add_token(self.mask_token)

    def encode(self, text, force_last_eot=True):
        # treat literal eot marker as special; remove it from content
        if self.eot_token in text:
            text = text.replace(self.eot_token, '')
        seq = list(text)

        # make sure all seen base chars exist
        for ch in set(seq):
            if ch not in self.vocab:
                self._add_token(ch)

        # greedy BPE using learned pair ranks
        if self.merges:
            while True:
                best_pair, best_rank = None, None
                for p in self._get_bigrams(seq):
                    r = self.pair_ranks.get(p)
                    if r is not None and (best_rank is None or r < best_rank):
                        best_pair, best_rank = p, r
                if best_pair is None:
                    break
                seq = self._merge_once(seq, best_pair)

        # ensure all tokens in seq exist in vocab (e.g., if new chars appeared)
        for tok in seq:
            if tok not in self.vocab:
                self._add_token(tok)

        ids = [self.vocab[tok] for tok in seq]

        # FORCE: append EOT id if not already last
        if force_last_eot:
            if not ids or ids[-1] != self.eot_id:
                ids.append(self.eot_id)

        return ids

    def decode(self, ids):
        # drop trailing EOT if present
        if ids and self.eot_id is not None and ids[-1] == self.eot_id:
            ids = ids[:-1]
        toks = [self.id_to_token[i] for i in ids]
        return ''.join(toks)


In [3]:
raw_example_1 = r'''Linear algebra is central to almost all areas of mathematics. For instance, linear algebra is fundamental in modern presentations of geometry, including for defining basic objects such as lines, planes and rotations. Also, functional analysis, a branch of mathematical analysis, may be viewed as the application of linear algebra to function spaces.'''
raw_example_2 = r'''In numerical analysis and linear algebra, lower–upper (LU) decomposition or factorization factors a matrix as the product of a lower triangular matrix and an upper triangular matrix (see matrix multiplication and matrix decomposition).'''


In [4]:
tok = SimpleBPETokenizer(num_merges=5)
tok.train([raw_example_1,raw_example_2])
tok.merges

[(' ', 'a'), ('a', 't'), ('i', 'n'), (' ', 'm'), ('i', 'o')]

In [5]:
tok.vocab

{' ': 0,
 '(': 1,
 ')': 2,
 ',': 3,
 '.': 4,
 'a': 5,
 'b': 6,
 'c': 7,
 'd': 8,
 'e': 9,
 'f': 10,
 'g': 11,
 'h': 12,
 'i': 13,
 'j': 14,
 'l': 15,
 'm': 16,
 'n': 17,
 'o': 18,
 'p': 19,
 'r': 20,
 's': 21,
 't': 22,
 'u': 23,
 'v': 24,
 'w': 25,
 'x': 26,
 'y': 27,
 'z': 28,
 '–': 29,
 ' a': 30,
 'at': 31,
 'in': 32,
 ' m': 33,
 'io': 34,
 '<|endoftext|>': 35,
 '<|mask|>': 36}

In [6]:
vocab_size = len(tok.vocab)
vocab_size

37

In [7]:
eot = tok.eot_id
tokens = []
for example in [raw_example_1, raw_example_2]:
    tokens.extend([eot])
    tokens.extend(tok.encode(example.lower()))
all_tokens = torch.tensor(tokens, dtype=torch.long)
all_tokens

tensor([35, 15, 32,  9,  5, 20, 30, 15, 11,  9,  6, 20,  5,  0, 13, 21,  0,  7,
         9, 17, 22, 20,  5, 15,  0, 22, 18, 30, 15, 16, 18, 21, 22, 30, 15, 15,
        30, 20,  9,  5, 21,  0, 18, 10, 33, 31, 12,  9, 16, 31, 13,  7, 21,  4,
         0, 10, 18, 20,  0, 32, 21, 22,  5, 17,  7,  9,  3,  0, 15, 32,  9,  5,
        20, 30, 15, 11,  9,  6, 20,  5,  0, 13, 21,  0, 10, 23, 17,  8,  5, 16,
         9, 17, 22,  5, 15,  0, 32, 33, 18,  8,  9, 20, 17,  0, 19, 20,  9, 21,
         9, 17, 22, 31, 34, 17, 21,  0, 18, 10,  0, 11,  9, 18, 16,  9, 22, 20,
        27,  3,  0, 32,  7, 15, 23,  8, 32, 11,  0, 10, 18, 20,  0,  8,  9, 10,
        32, 32, 11,  0,  6,  5, 21, 13,  7,  0, 18,  6, 14,  9,  7, 22, 21,  0,
        21, 23,  7, 12, 30, 21,  0, 15, 32,  9, 21,  3,  0, 19, 15,  5, 17,  9,
        21, 30, 17,  8,  0, 20, 18, 22, 31, 34, 17, 21,  4, 30, 15, 21, 18,  3,
         0, 10, 23, 17,  7, 22, 34, 17,  5, 15, 30, 17,  5, 15, 27, 21, 13, 21,
         3, 30,  0,  6, 20,  5, 17,  7, 

# Modeling

## Data Loading

In [8]:
B_batch = 2 # Batch
T_context = 8 # context length

In [9]:
current_position = 0
tok_for_training = all_tokens[current_position:current_position + B_batch*T_context]
tok_for_training

tensor([35, 15, 32,  9,  5, 20, 30, 15, 11,  9,  6, 20,  5,  0, 13, 21])

In [10]:
x=tok_for_training[:].view(B_batch, T_context)
x.shape, x

(torch.Size([2, 8]),
 tensor([[35, 15, 32,  9,  5, 20, 30, 15],
         [11,  9,  6, 20,  5,  0, 13, 21]]))

In [11]:
# 1. Sample a random time 't' for each sequence (0 = Clean, 1 = All Noise)
# We use 't' as the probability a token is masked.
t = torch.rand(B_batch) 
t

tensor([0.6567, 0.1740])

In [12]:
# 2. Create the Noisy Input (The Forward Diffusion Process)
# Generate a mask where probability of masking is 't'
# We expand t to match sequence shape (Batch, Seq_Len)
mask_probs = t.unsqueeze(-1).expand(-1, T_context) 
mask_probs

tensor([[0.6567, 0.6567, 0.6567, 0.6567, 0.6567, 0.6567, 0.6567, 0.6567],
        [0.1740, 0.1740, 0.1740, 0.1740, 0.1740, 0.1740, 0.1740, 0.1740]])

In [13]:
mask_decision = torch.bernoulli(mask_probs).bool() # True = replace with [MASK]
mask_decision

tensor([[False,  True,  True, False,  True, False,  True,  True],
        [False, False, False, False, False, False, False, False]])

In [14]:
noisy_x = x.clone()
noisy_x

tensor([[35, 15, 32,  9,  5, 20, 30, 15],
        [11,  9,  6, 20,  5,  0, 13, 21]])

In [15]:
noisy_x[mask_decision] = tok.mask_token_id
noisy_x

tensor([[35, 36, 36,  9, 36, 20, 36, 36],
        [11,  9,  6, 20,  5,  0, 13, 21]])

## Forward pass

In [16]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [17]:
n_embd = 4

In [18]:
class SimpleDiffusionTransformer(nn.Module):
    def __init__(self):
        super().__init__()
        # Embeddings + Positional encoding (learned)
        self.embed = nn.Embedding(vocab_size, n_embd)
        self.pos_embed = nn.Embedding(T_context, n_embd)
        
        # A small Transformer Encoder
        layer = nn.TransformerEncoderLayer(d_model=n_embd, nhead=2, batch_first=True)
        self.transformer = nn.TransformerEncoder(layer, num_layers=2)
        
        # Final projection to vocab logits
        self.head = nn.Linear(n_embd, vocab_size)

    def forward(self, x):
        # x shape: (Batch, Seq_Len)
        b, t = x.shape
        positions = torch.arange(t).unsqueeze(0)
        
        # Combine token + pos embeddings
        x = self.embed(x) + self.pos_embed(positions)
        
        # Pass through transformer
        x = self.transformer(x)
        
        # Project to logits (Batch, Seq_Len, Vocab_Size)
        return self.head(x)

model = SimpleDiffusionTransformer()

In [19]:
# 3. Model Prediction
logits = model(noisy_x)
logits

tensor([[[-2.0422e-01,  5.4760e-01,  9.4382e-01,  1.3031e+00,  5.6503e-01,
           6.2514e-01,  1.3584e-01, -3.7753e-01,  1.1947e+00, -1.0565e+00,
          -4.7768e-01,  1.1230e-01, -3.5358e-01,  1.5026e-01, -3.3397e-01,
           2.9179e-01, -4.7736e-01, -5.7663e-01, -6.0222e-01, -3.0065e-01,
           2.7537e-01,  1.1681e+00, -4.7315e-01, -7.4860e-01,  3.1484e-03,
          -5.3555e-01,  7.3039e-01, -1.5316e+00,  1.0235e+00,  1.4146e+00,
           7.8472e-02,  4.6130e-01, -6.5214e-01, -2.6468e-02, -6.9362e-01,
          -3.2895e-01, -1.4786e-01],
         [-1.3454e-01,  5.4893e-01,  7.4787e-01,  5.5085e-01,  7.5383e-01,
           2.7464e-01,  3.3744e-01, -4.1641e-01,  1.1454e+00, -1.2173e+00,
           2.2461e-01,  6.6758e-01, -8.7597e-02, -1.9475e-01, -3.0284e-02,
          -7.6360e-04, -3.6457e-03, -2.5837e-01, -9.7978e-01, -3.4568e-01,
           2.4589e-01,  1.4267e+00, -1.8985e-01, -7.5750e-01,  4.0571e-02,
           9.9616e-02,  7.5312e-01, -6.5597e-01,  6.9997e-01,  

In [20]:
# 4. Calculate Loss
# We only care about the loss on tokens that were masked!
# Permute logits for CrossEntropyLoss: (Batch, Vocab, Seq_Len)
loss = F.cross_entropy(logits.permute(0, 2, 1), x, reduction='none')
loss

tensor([[4.2143, 3.8568, 3.5980, 4.7500, 2.7382, 3.7335, 3.4719, 3.0357],
        [3.0575, 5.0523, 4.0670, 4.1661, 3.3373, 4.0875, 3.9311, 2.9142]],
       grad_fn=<ViewBackward0>)

In [21]:
# Apply the mask so we only learn to reconstruct the missing parts
# (Technically SEDD re-weights this loss, but this is the core logic)
masked_loss = (loss * mask_decision.float()).sum() / (mask_decision.sum() + 1e-6)
masked_loss

tensor(3.3401, grad_fn=<DivBackward0>)

now let's run it 10 more times

In [22]:
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)

In [23]:
def train_step(x):

    t = torch.rand(B_batch)  
    mask_probs = t.unsqueeze(-1).expand(-1, T_context) 
    mask_decision = torch.bernoulli(mask_probs).bool() 
    
    noisy_x = x.clone()
    noisy_x[mask_decision] = tok.mask_token_id
    
    # 3. Model Prediction
    logits = model(noisy_x)
    
    # 4. Calculate Loss
    loss = F.cross_entropy(logits.permute(0, 2, 1), x, reduction='none')
    masked_loss = (loss * mask_decision.float()).sum() / (mask_decision.sum() + 1e-6)
    
    # Optimization
    optimizer.zero_grad()
    masked_loss.backward()
    optimizer.step()
    
    return masked_loss.item()

In [24]:
for i in range(10000):
    x=tok_for_training[:].view(B_batch, T_context)
    loss = train_step(x)
    if i%100 == 0: print(f'Step {i}, Loss: {loss:.4f}')
    

Step 0, Loss: 3.6228
Step 100, Loss: 2.8183
Step 200, Loss: 2.5582
Step 300, Loss: 1.2451
Step 400, Loss: 1.9005
Step 500, Loss: 1.5958
Step 600, Loss: 1.4626
Step 700, Loss: 1.3678
Step 800, Loss: 1.1796
Step 900, Loss: 1.1642
Step 1000, Loss: 0.9767
Step 1100, Loss: 0.7753
Step 1200, Loss: 0.8605
Step 1300, Loss: 1.0260
Step 1400, Loss: 0.9460
Step 1500, Loss: 0.7378
Step 1600, Loss: 0.9238
Step 1700, Loss: 1.1322
Step 1800, Loss: 0.7960
Step 1900, Loss: 0.6674
Step 2000, Loss: 0.8853
Step 2100, Loss: 1.7954
Step 2200, Loss: 0.7822
Step 2300, Loss: 0.6726
Step 2400, Loss: 0.7045
Step 2500, Loss: 0.6718
Step 2600, Loss: 0.5824
Step 2700, Loss: 0.8240
Step 2800, Loss: 0.6450
Step 2900, Loss: 0.7452
Step 3000, Loss: 0.9402
Step 3100, Loss: 2.7907
Step 3200, Loss: 0.2200
Step 3300, Loss: 1.0769
Step 3400, Loss: 0.7048
Step 3500, Loss: 0.4319
Step 3600, Loss: 0.9372
Step 3700, Loss: 0.8418
Step 3800, Loss: 0.9120
Step 3900, Loss: 0.5576
Step 4000, Loss: 0.4323
Step 4100, Loss: 1.0334
Step

## Generate

In [25]:
@torch.no_grad()
def sample_sedd(steps=3):
    model.eval()
    print("\n--- Starting Generation (Reverse Diffusion) ---")
    
    # 1. Start with pure noise (All [MASK] tokens)
    #x = torch.full((1, T_context), tok.mask_token_id)
    x = torch.cat((tok_for_training[:4].unsqueeze(0), torch.full((1, 4), tok.mask_token_id)), dim=1)
    
    # We will unmask roughly k tokens per step
    tokens_per_step = T_context // steps
    if tokens_per_step < 1: tokens_per_step = 1

    current_tokens = x[0].tolist()
    pretty_tokens = [tok.decode([t]) if t != tok.mask_token_id else "MM" for t in current_tokens]
    print(f"Step {0}/{steps}: {pretty_tokens}")
    
    # 2. Iterative Denoising Loop
    for step in range(steps):
        # A. Predict likely tokens for the CURRENT masks
        logits = model(x)
        probs = F.softmax(logits, dim=-1)
        
        # Get the most likely token and its confidence score
        max_probs, predicted_ids = torch.max(probs, dim=-1)
        
        # B. Decide which tokens to "commit" to (Unmask)
        # We look at currently masked positions
        is_currently_masked = (x == tok.mask_token_id)
        
        # If nothing is masked, we are done
        if not is_currently_masked.any():
            break
            
        # Add some noise to confidence to prevent getting stuck (Gumbel noise)
        # This represents the "stochastic" part of diffusion sampling
        gumbel_noise = -torch.log(-torch.log(torch.rand_like(max_probs)))
        noisy_confidence = max_probs + gumbel_noise
        
        # Set confidence of already unmasked tokens to -infinity so we don't pick them again
        noisy_confidence[~is_currently_masked] = -float('inf')
        
        # Select the top-k most confident tokens to unmask this step
        # (In real SEDD, this is handled by solving an SDE, but this is the discrete equivalent)
        num_to_unmask = min(tokens_per_step, is_currently_masked.sum().item())
        _, indices_to_unmask = torch.topk(noisy_confidence, k=num_to_unmask)
        
        # C. Update the sequence
        # We replace the [MASK] at the chosen indices with the model's prediction
        # Note: We are doing this in parallel for 'num_to_unmask' tokens!
        x[0, indices_to_unmask] = predicted_ids[0, indices_to_unmask]
        
        # Visualization
        current_tokens = x[0].tolist()
        pretty_tokens = [tok.decode([t]) if t != tok.mask_token_id else "MM" for t in current_tokens]
        print(f"Step {step+1}/{steps}: {pretty_tokens}")

    return x

# Run generation
final_seq = sample_sedd(steps=4)




--- Starting Generation (Reverse Diffusion) ---
Step 0/4: ['', 'l', 'in', 'e', 'MM', 'MM', 'MM', 'MM']
Step 1/4: ['', 'l', 'in', 'e', 'a', 'MM', 'MM', 'l']
Step 2/4: ['', 'l', 'in', 'e', 'a', 'r', ' a', 'l']


In [26]:
tok.decode([i.item() for i in tok_for_training[:8]])

'<|endoftext|>linear al'