In [8]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import random

# ======================
# 1. Basic Setup
# ======================

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

# A tiny toy corpus: each sentence has fixed length 4
sentences = [
    ["I",   "love", "AI", "."],
    ["You", "love", "AI", "."],
    ["We",  "love", "AI", "."],
    ["I",   "love", "NLP", "."],
    ["You", "love", "NLP", "."],
    ["We",  "love", "NLP", "."],
]

# Build vocabulary
vocab = ["<pad>", "I", "You", "We", "love", "AI", "NLP", "."]
token2id = {tok: i for i, tok in enumerate(vocab)}
id2token = {i: tok for tok, i in token2id.items()}

vocab_size = len(vocab)
seq_len = 4

def encode_sentence(tokens):
    assert len(tokens) == seq_len
    return torch.tensor([token2id[t] for t in tokens], dtype=torch.long)

data_x0 = torch.stack([encode_sentence(s) for s in sentences])  # (N, L)
data_x0 = data_x0.to(device)
num_samples = data_x0.size(0)

# =======================
# 2. Discrete Forward Diffusion Process q(x_t | x_0)
#    Simplified version: categorical noise (keep or random token)
# =======================

T = 5  # Number of diffusion steps (small for demonstration)
betas = torch.linspace(0.1, 0.3, T).to(device)  # Noise strength per step
alphas = 1.0 - betas
alpha_bars = torch.cumprod(alphas, dim=0)  # Cumulative product of alphas

def q_sample(x0, t):
    """
    x0: (B, L) integer tokens
    t:  (B,) integer timesteps in [1..T]
    Returns x_t: (B, L)

    Simplified discrete forward process:
      q(x_t | x_0) = ᾱ_t * one_hot(x0) + (1 - ᾱ_t) * Uniform
    """
    B, L = x0.shape
    K = vocab_size

    # Look up alpha_bar_t (t ranges 1..T, so subtract 1 for indexing)
    alpha_bar = alpha_bars[t - 1]              # (B,)
    alpha_bar = alpha_bar.view(B, 1, 1)        # (B,1,1)

    # Start from uniform distribution over vocabulary
    p = torch.full((B, L, K),
                   fill_value=(1.0 / K),
                   device=device)              # (B,L,K)

    # Compute: p = (1 - alpha_bar) * Uniform + alpha_bar * OneHot(x0)
    p = p * (1.0 - alpha_bar)                  # scale uniform part

    idx = x0.unsqueeze(-1)                     # (B,L,1)
    base = p.gather(-1, idx)                   # get current prob at x0 index
    p.scatter_(-1, idx, base + alpha_bar)      # add alpha_bar into x0 index

    # Sample x_t from categorical distribution
    p_reshaped = p.view(-1, K)                  # (B*L, K)
    xt_flat = torch.multinomial(p_reshaped, num_samples=1)  # (B*L,1)
    xt = xt_flat.view(B, L)                     # (B,L)

    return xt


# =======================
# 3. Simple SEDD-style Model
#    Input:  x_t, t
#    Output: Predicted distribution over x_0 at each position
# =======================

class SimpleSEDDModel(nn.Module):
    def __init__(self, vocab_size, d_model, T, seq_len):
        super().__init__()
        self.token_emb = nn.Embedding(vocab_size, d_model)
        self.pos_emb   = nn.Embedding(seq_len, d_model)
        self.time_emb  = nn.Embedding(T + 1, d_model)  # t in [1..T], index 0 unused

        self.mlp = nn.Sequential(
            nn.Linear(d_model, d_model),
            nn.GELU(),
            nn.Linear(d_model, vocab_size),
        )

        # Precompute positional indices 0..L-1
        self.register_buffer("positions", torch.arange(seq_len).long())

    def forward(self, x_t, t):
        """
        x_t: (B, L)
        t:   (B,)
        Returns logits: (B, L, V)
        """
        B, L = x_t.shape
        tok_emb = self.token_emb(x_t)                        # (B,L,D)
        pos_emb = self.pos_emb(self.positions)[None, :, :]   # (1,L,D)
        time_emb = self.time_emb(t).unsqueeze(1)             # (B,1,D)

        h = tok_emb + pos_emb + time_emb                     # (B,L,D)
        logits = self.mlp(h)                                 # (B,L,V)
        return logits

model = SimpleSEDDModel(vocab_size=vocab_size, d_model=64, T=T, seq_len=seq_len).to(device)

# =======================
# 4. Score Entropy Loss
#    Here: cross-entropy over x_0
#    This corresponds to Score Entropy under a delta posterior
# =======================

def score_entropy_loss(logits, x0):
    """
    logits: (B,L,V)
    x0:     (B,L)

    Target is one-hot δ(x0).
    Loss = - E_{t, x0, x_t} [ log p_theta(x0 | x_t, t) ]
    """
    B, L, V = logits.shape
    loss = F.cross_entropy(
        logits.view(B * L, V),
        x0.view(B * L),
    )
    return loss


# =======================
# 5. Training Loop
# =======================

optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
num_steps = 2000
batch_size = 8

def sample_batch():
    idx = torch.randint(0, num_samples, (batch_size,), device=device)
    x0 = data_x0[idx]                            # (B,L)

    # Sample timestep t ~ Uniform{1..T}
    t = torch.randint(1, T + 1, (batch_size,), device=device)

    # Forward diffusion
    x_t = q_sample(x0, t)                        # (B,L)
    return x0, x_t, t

for step in range(1, num_steps + 1):
    model.train()
    x0, x_t, t = sample_batch()
    logits = model(x_t, t)
    loss = score_entropy_loss(logits, x0)

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    if step % 200 == 0:
        print(f"[Step {step:4d}] loss = {loss.item():.4f}")

        # Visualize denoising from strong noise (t = T)
        model.eval()
        with torch.no_grad():
            demo_idx = random.randrange(num_samples)
            clean = data_x0[demo_idx:demo_idx+1]           # (1,L)
            t_demo = torch.tensor([T], device=device)      # use maximum noise level
            noised = q_sample(clean, t_demo)               # (1,L)
            logits_demo = model(noised, t_demo)            # (1,L,V)
            pred = logits_demo.argmax(dim=-1)              # (1,L)

            def decode(x):
                return " ".join(id2token[int(i)] for i in x)

            print("  clean: ", decode(clean[0]))
            print("  noised:", decode(noised[0]))
            print("  pred:  ", decode(pred[0]))
            print("-" * 40)

# =======================
# 6. Generation Demo (one-step denoising from pure noise)
#    Sample uniform x_T and directly predict x_0
# =======================

def generate_from_noise(n=5):
    model.eval()
    with torch.no_grad():
        # Pure noise: each position sampled uniformly
        x_T = torch.randint(0, vocab_size, (n, seq_len), device=device)
        t = torch.full((n,), T, device=device, dtype=torch.long)

        logits = model(x_T, t)
        x0_pred = logits.argmax(dim=-1)

        for i in range(n):
            print(f"Sample {i}:")
            print("  x_T   :", " ".join(id2token[int(x)] for x in x_T[i]))
            print("  x0hat :", " ".join(id2token[int(x)] for x in x0_pred[i]))
            print("-" * 40)

print("\n===== Generation from pure noise (one-step denoise) =====")
generate_from_noise(n=5)


Using device: cpu
[Step  200] loss = 0.3905
  clean:  We love NLP .
  noised: We AI love You
  pred:   We love NLP .
----------------------------------------
[Step  400] loss = 0.4113
  clean:  We love NLP .
  noised: We love You .
  pred:   We love NLP .
----------------------------------------
[Step  600] loss = 0.3782
  clean:  You love NLP .
  noised: You . NLP .
  pred:   You love NLP .
----------------------------------------
[Step  800] loss = 0.1961
  clean:  I love AI .
  noised: I love AI You
  pred:   I love AI .
----------------------------------------
[Step 1000] loss = 0.2742
  clean:  You love NLP .
  noised: <pad> love AI love
  pred:   I love AI .
----------------------------------------
[Step 1200] loss = 0.1769
  clean:  I love NLP .
  noised: We <pad> I .
  pred:   We love NLP .
----------------------------------------
[Step 1400] loss = 0.3846
  clean:  I love AI .
  noised: We love <pad> We
  pred:   We love AI .
----------------------------------------
[Step 1600