In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import random

# =========================
# 0. Basic Setup & Data
# =========================

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

torch.manual_seed(42)
random.seed(42)

# Toy sentence dataset: each sequence has fixed length 4
sentences = [
    ["I",   "love", "AI",  "."],
    ["You", "love", "AI",  "."],
    ["We",  "love", "AI",  "."],
    ["I",   "love", "NLP", "."],
    ["You", "love", "NLP", "."],
    ["We",  "love", "NLP", "."],
]

seq_len = 4

# Explicitly include [MASK] token in the vocabulary
vocab = ["<pad>", "I", "You", "We", "love", "AI", "NLP", ".", "[MASK]"]
token2id = {tok: i for i, tok in enumerate(vocab)}
id2token = {i: tok for tok, i in token2id.items()}
MASK_ID = token2id["[MASK]"]
vocab_size = len(vocab)

print("vocab_size =", vocab_size)
print("MASK_ID    =", MASK_ID)

def encode_sentence(tokens):
    assert len(tokens) == seq_len
    return torch.tensor([token2id[t] for t in tokens], dtype=torch.long)

data_x0 = torch.stack([encode_sentence(s) for s in sentences]).to(device)  # (N, L)
num_samples = data_x0.size(0)

# =========================
# 1. Discrete Forward Process q(x_t | x_0)
#    SEDD-style mask diffusion
# =========================

T = 5  # Total diffusion steps

# Define per-step betas and compute alpha_bar_t
betas = torch.linspace(0.1, 0.4, T).to(device)
alphas = 1.0 - betas
alpha_bars = torch.cumprod(alphas, dim=0)   # (T,)

def q_sample_mask(x0, t):
    """
    Simplified SEDD-style forward process:
    For each position:
      - With probability alpha_bar_t, keep x0
      - With probability (1 - alpha_bar_t), replace with MASK

    x0: (B, L)
    t : (B,) in [1..T]
    Returns x_t: (B, L)
    """
    B, L = x0.shape

    # Look up corresponding alpha_bar_t
    alpha_bar_t = alpha_bars[(t - 1).long()]   # (B,)
    alpha_bar_t = alpha_bar_t.view(B, 1)       # (B,1)

    # Sample whether each token is kept
    u = torch.rand(B, L, device=device)        # (B,L)
    keep = (u < alpha_bar_t)                   # True -> keep x0

    xt = torch.where(
        keep,
        x0,
        torch.full_like(x0, MASK_ID)
    )
    return xt

# =========================
# 2. Simple SEDD Model
#    Input:  x_t, t
#    Output: logits over x_0
# =========================

class SimpleSEDDModel(nn.Module):
    def __init__(self, vocab_size, d_model, T, seq_len):
        super().__init__()
        self.token_emb = nn.Embedding(vocab_size, d_model)
        self.pos_emb   = nn.Embedding(seq_len, d_model)
        self.time_emb  = nn.Embedding(T + 1, d_model)  # t ∈ [1..T], index 0 unused

        self.mlp = nn.Sequential(
            nn.Linear(d_model, d_model),
            nn.GELU(),
            nn.Linear(d_model, vocab_size),
        )

        self.register_buffer("positions", torch.arange(seq_len).long())

    def forward(self, x_t, t):
        """
        x_t: (B, L)
        t  : (B,)
        Returns logits: (B, L, V)
        """
        B, L = x_t.shape

        tok_emb = self.token_emb(x_t)                       # (B,L,D)
        pos_emb = self.pos_emb(self.positions)[None, :, :]  # (1,L,D)
        time_emb = self.time_emb(t).unsqueeze(1)            # (B,1,D)

        h = tok_emb + pos_emb + time_emb                    # (B,L,D)
        logits = self.mlp(h)                                # (B,L,V)
        return logits

model = SimpleSEDDModel(vocab_size=vocab_size, d_model=64, T=T, seq_len=seq_len).to(device)

# =========================
# 3. Score Entropy Loss
#    Cross-entropy under a delta posterior
# =========================

def score_entropy_loss(logits, x0):
    """
    logits: (B,L,V)
    x0:     (B,L)

    Loss = -E[ log p_theta(x0 | x_t, t) ]
    """
    B, L, V = logits.shape
    return F.cross_entropy(
        logits.view(B * L, V),
        x0.view(B * L),
    )

# =========================
# 4. Training Loop
#    Explicitly includes "full MASK" samples
# =========================

optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
num_steps = 3000
batch_size = 8

# Each sample has 30% chance of being replaced with full MASK at t = T
full_mask_prob = 0.3

def sample_batch():
    idx = torch.randint(0, num_samples, (batch_size,), device=device)
    x0 = data_x0[idx]                           # (B,L)

    # Sample t ~ Uniform{1..T}
    t = torch.randint(1, T + 1, (batch_size,), device=device)

    # Forward diffusion to get masked inputs
    x_t = q_sample_mask(x0, t)                  # (B,L)

    # Explicitly include some "full MASK at t=T" samples
    hard_mask = torch.rand(batch_size, device=device) < full_mask_prob
    if hard_mask.any():
        x_t[hard_mask] = MASK_ID
        t[hard_mask]   = T

    return x0, x_t, t

def decode(x):
    return " ".join(id2token[int(i)] for i in x)

for step in range(1, num_steps + 1):
    model.train()
    x0, x_t, t = sample_batch()
    logits = model(x_t, t)
    loss = score_entropy_loss(logits, x0)

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    if step % 300 == 0:
        print(f"[Step {step:4d}] loss = {loss.item():.4f}")

        # Check whether model can recover from strong noise (t = T)
        model.eval()
        with torch.no_grad():
            demo_idx = random.randrange(num_samples)
            clean = data_x0[demo_idx:demo_idx+1]       # (1,L)
            t_demo = torch.tensor([T], device=device)
            noised = q_sample_mask(clean, t_demo)      # (1,L)
            logits_demo = model(noised, t_demo)
            pred = logits_demo.argmax(dim=-1)

            print("  [t=T noise] clean : ", decode(clean[0]))
            print("             noised: ", decode(noised[0]))
            print("             pred  : ", decode(pred[0]))
            print("  ----------------------------")

            # Check "full MASK → reconstruction"
            full_mask = torch.full_like(clean, MASK_ID)
            logits_mask = model(full_mask, t_demo)
            pred_mask = logits_mask.argmax(dim=-1)
            print("  [FULL MASK]   x_t: ", decode(full_mask[0]))
            print("             pred:   ", decode(pred_mask[0]))
            print("====================================")

# =========================
# 5. Generation from Full MASK (Testing Function)
# =========================

def generate_from_full_mask(n=5):
    model.eval()
    with torch.no_grad():
        x_mask = torch.full((n, seq_len), MASK_ID, device=device, dtype=torch.long)
        t = torch.full((n,), T, device=device, dtype=torch.long)

        logits = model(x_mask, t)
        pred = logits.argmax(dim=-1)

        for i in range(n):
            print(f"Sample {i}:")
            print("  x_T([MASK]):", decode(x_mask[i]))
            print("  predicted x0:", decode(pred[i]))
            print("-" * 40)

print("\n===== Generation from FULL [MASK] tokens =====")
generate_from_full_mask(n=5)


Using device: cpu
vocab_size = 9
MASK_ID    = 8
[Step  300] loss = 0.4220
  [t=T noise] clean :  We love NLP .
             noised:  [MASK] [MASK] [MASK] [MASK]
             pred  :  I love NLP .
  ----------------------------
  [FULL MASK]   x_t:  [MASK] [MASK] [MASK] [MASK]
             pred:    I love NLP .
[Step  600] loss = 0.3681
  [t=T noise] clean :  I love AI .
             noised:  [MASK] [MASK] [MASK] [MASK]
             pred  :  We love NLP .
  ----------------------------
  [FULL MASK]   x_t:  [MASK] [MASK] [MASK] [MASK]
             pred:    We love NLP .
[Step  900] loss = 0.3224
  [t=T noise] clean :  I love AI .
             noised:  [MASK] [MASK] [MASK] [MASK]
             pred  :  You love NLP .
  ----------------------------
  [FULL MASK]   x_t:  [MASK] [MASK] [MASK] [MASK]
             pred:    You love NLP .
[Step 1200] loss = 0.2537
  [t=T noise] clean :  We love NLP .
             noised:  We love [MASK] [MASK]
             pred  :  We love NLP .
  -------------