In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import numpy as np
import random
import os
from tqdm import tqdm
from pathlib import Path
from my_transformer.utils import save_model

In [None]:
# --- Configuration ---
BATCH_SIZE = 64
BLOCK_SIZE = 128  # Max sequence length for training
EMBEDDING_DIM = 64
N_HEADS = 4
N_LAYERS = 2
LEARNING_RATE = 1e-4
NUM_DIFFUSION_STEPS = 100
TRAIN_STEPS = 50000  # More steps needed for better results
GRADIENT_ACCUMULATION_STEPS = 1  # For larger effective batch size

# Device setup
device = "mps" if torch.mps.is_available() else "cpu"
print(f"Using device: {device}")

# --- Data Loading and Preparation (from NanoGPT) ---
# !wget https://raw.githubusercontent.com/karpathy/makemore/master/tinyshakespeare.txt
# Check if tinyshakespeare.txt exists, if not, download it
# data_file = "tinyshakespeare.txt"
data_file = Path("../dataset") / "shakespeare.txt"
if not os.path.exists(data_file):
    print(f"Downloading {data_file}...")
    import requests

    url = "https://raw.githubusercontent.com/karpathy/makemore/master/tinyshakespeare.txt"
    r = requests.get(url)
    with open(data_file, "w") as f:
        f.write(r.text)
    print(f"Downloaded {data_file}")
else:
    print(f"{data_file} already exists.")

with open(data_file, "r", encoding="utf-8") as f:
    text = f.read()

# Create character vocabulary
chars = sorted(list(set(text)))
vocab_size = len(chars)
stoi = {ch: i for i, ch in enumerate(chars)}  # String to int
itos = {i: ch for i, ch in enumerate(chars)}  # Int to string
encode = lambda s: [stoi[c] for c in s]
decode = lambda l: "".join([itos[i] for i in l])

print(f"Vocabulary size: {vocab_size}")
print(f"Characters: {''.join(chars)}")

# Train/validation split
data = torch.tensor(encode(text), dtype=torch.long)
n = int(0.9 * len(data))
train_data = data[:n]
val_data = data[n:]


# Data batching function
def get_batch(split):
    data = train_data if split == "train" else val_data
    ix = torch.randint(len(data) - BLOCK_SIZE, (BATCH_SIZE,))
    x = torch.stack([data[i : i + BLOCK_SIZE] for i in ix])
    return x.to(device)


In [None]:
stoi
vocab_size

In [None]:
class LayerNorm(nn.Module):
    """LayerNorm but with optional bias. PyTorch's LayerNorm doesn't support bias=False"""

    def __init__(self, ndim, bias):
        super().__init__()
        self.weight = nn.Parameter(torch.ones(ndim))
        self.bias = nn.Parameter(torch.zeros(ndim)) if bias else None

    def forward(self, input):
        return F.layer_norm(input, self.weight.shape, self.weight, self.bias, 1e-5)


class CausalSelfAttention(nn.Module):
    def __init__(self, config):
        super().__init__()
        assert config.n_embd % config.n_head == 0
        self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias)
        self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias)
        self.n_head = config.n_head
        self.n_embd = config.n_embd
        self.dropout = nn.Dropout(config.dropout)
        self.flash = hasattr(torch.nn.functional, "scaled_dot_product_attention") and config.flash

    def forward(self, x):
        B, T, C = x.size()  # batch size, sequence length, embedding dimensionality (n_embd)
        qkv = self.c_attn(x).split(self.n_embd, dim=2)
        q, k, v = qkv[0], qkv[1], qkv[2]
        k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)  # (B, nh, T, hs)
        q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)  # (B, nh, T, hs)
        v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)  # (B, nh, T, hs)

        if self.flash:
            y = torch.nn.functional.scaled_dot_product_attention(
                q, k, v, attn_mask=None, dropout_p=self.dropout.p, is_causal=False
            )
        else:
            att = (q @ k.transpose(-2, -1)) * (1.0 / (k.size(-1) ** 0.5))
            att = F.softmax(att, dim=-1)  # No causal mask for diffusion
            att = self.dropout(att)
            y = att @ v  # (B, nh, T, T) @ (B, nh, T, hs) -> (B, nh, T, hs)
        y = y.transpose(1, 2).contiguous().view(B, T, C)
        y = self.dropout(self.c_proj(y))
        return y


class MLP(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=config.bias)
        self.gelu = nn.GELU()
        self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=config.bias)
        self.dropout = nn.Dropout(config.dropout)

    def forward(self, x):
        x = self.c_fc(x)
        x = self.gelu(x)
        x = self.c_proj(x)
        x = self.dropout(x)
        return x


class Block(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.ln_1 = LayerNorm(config.n_embd, bias=config.bias)
        self.attn = CausalSelfAttention(config)  # Will modify to remove causal mask
        self.ln_2 = LayerNorm(config.n_embd, bias=config.bias)
        self.mlp = MLP(config)

    def forward(self, x):
        x = x + self.attn(self.ln_1(x))
        x = x + self.mlp(self.ln_2(x))
        return x


class DiffusionConfig:
    def __init__(self, vocab_size, block_size, n_layer, n_head, n_embd, dropout, bias, flash=True):
        self.vocab_size = vocab_size
        self.block_size = block_size
        self.n_layer = n_layer
        self.n_head = n_head
        self.n_embd = n_embd
        self.dropout = dropout
        self.bias = bias
        self.flash = flash


class DenoisingUnet(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config

        self.wte = nn.Embedding(config.vocab_size, config.n_embd)
        self.wpe = nn.Embedding(config.block_size, config.n_embd)
        self.timestep_mlp = nn.Sequential(
            nn.Linear(config.n_embd, config.n_embd * 4), nn.SiLU(), nn.Linear(config.n_embd * 4, config.n_embd)
        )
        self.drop = nn.Dropout(config.dropout)
        self.h = nn.ModuleList([Block(config) for _ in range(config.n_layer)])
        self.ln_f = LayerNorm(config.n_embd, bias=config.bias)
        self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)  # Predict logit for noise (or original)

        # Weight tying (typical for transformer LMs)
        self.lm_head.weight = self.wte.weight  # Share the same weight for embedding and unembedding

        self.apply(self._init_weights)

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)

    def forward(self, x_noisy_embed, timesteps):
        B, T, C = x_noisy_embed.size()

        # Positional embeddings
        pos = torch.arange(0, T, dtype=torch.long, device=timesteps.device)
        pos_emb = self.wpe(pos)

        # Timestep embeddings
        # We need to create sinusoidal embeddings for timesteps first, then pass through MLP
        # This is a common practice in diffusion models
        half_dim = self.config.n_embd // 2
        emb = np.log(10000) / (half_dim - 1)
        emb = torch.exp(torch.arange(half_dim, device=timesteps.device) * -emb)
        emb = timesteps.float()[:, None] * emb[None, :]  # (B, dim)
        emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)  # (B, dim * 2)
        timestep_emb = self.timestep_mlp(emb).unsqueeze(1)  # (B, 1, C)

        # Add timestep embedding to sequence embeddings
        # This is often added at various points in a real U-Net. For simplicity, we add it once here.
        x = x_noisy_embed + pos_emb + timestep_emb

        x = self.drop(x)
        for block in self.h:
            x = block(x)
        x = self.ln_f(x)

        # Project back to vocabulary logits (or noise prediction directly)
        # For simplicity, we directly predict the logit for the original token
        # A more direct diffusion approach would predict the noise or the 'x_0' directly.
        # We'll use this to predict the target for the loss function.
        logits = self.lm_head(x)  # (B, T, vocab_size)
        return logits

In [None]:
class LayerNorm(nn.Module):
    """LayerNorm but with optional bias. PyTorch's LayerNorm doesn't support bias=False"""

    def __init__(self, ndim, bias):
        super().__init__()
        self.weight = nn.Parameter(torch.ones(ndim))
        self.bias = nn.Parameter(torch.zeros(ndim)) if bias else None

    def forward(self, input):
        return F.layer_norm(input, self.weight.shape, self.weight, self.bias, 1e-5)


class CausalSelfAttention(nn.Module):
    def __init__(self, config):
        super().__init__()
        assert config.n_embd % config.n_head == 0
        self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias)
        self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias)
        self.n_head = config.n_head
        self.n_embd = config.n_embd
        self.dropout = nn.Dropout(config.dropout)
        self.flash = hasattr(torch.nn.functional, "scaled_dot_product_attention") and config.flash

    def forward(self, x):
        B, T, C = x.size()  # batch size, sequence length, embedding dimensionality (n_embd)
        qkv = self.c_attn(x).split(self.n_embd, dim=2)
        q, k, v = qkv[0], qkv[1], qkv[2]
        k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)  # (B, nh, T, hs)
        q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)  # (B, nh, T, hs)
        v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)  # (B, nh, T, hs)

        if self.flash:
            y = torch.nn.functional.scaled_dot_product_attention(
                q, k, v, attn_mask=None, dropout_p=self.dropout.p, is_causal=False
            )
        else:
            att = (q @ k.transpose(-2, -1)) * (1.0 / (k.size(-1) ** 0.5))
            att = F.softmax(att, dim=-1)  # No causal mask for diffusion
            att = self.dropout(att)
            y = att @ v  # (B, nh, T, T) @ (B, nh, T, hs) -> (B, nh, T, hs)
        y = y.transpose(1, 2).contiguous().view(B, T, C)
        y = self.dropout(self.c_proj(y))
        return y


class MLP(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=config.bias)
        self.gelu = nn.GELU()
        self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=config.bias)
        self.dropout = nn.Dropout(config.dropout)

    def forward(self, x):
        x = self.c_fc(x)
        x = self.gelu(x)
        x = self.c_proj(x)
        x = self.dropout(x)
        return x


class Block(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.ln_1 = LayerNorm(config.n_embd, bias=config.bias)
        self.attn = CausalSelfAttention(config)  # Will modify to remove causal mask
        self.ln_2 = LayerNorm(config.n_embd, bias=config.bias)
        self.mlp = MLP(config)

    def forward(self, x):
        x = x + self.attn(self.ln_1(x))
        x = x + self.mlp(self.ln_2(x))
        return x


class DiffusionConfig:
    def __init__(self, vocab_size, block_size, n_layer, n_head, n_embd, dropout, bias, flash=True):
        self.vocab_size = vocab_size
        self.block_size = block_size
        self.n_layer = n_layer
        self.n_head = n_head
        self.n_embd = n_embd
        self.dropout = dropout
        self.bias = bias
        self.flash = flash


class DenoisingUnet(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config

        self.wte = nn.Embedding(config.vocab_size, config.n_embd)
        self.wpe = nn.Embedding(config.block_size, config.n_embd)
        self.timestep_mlp = nn.Sequential(
            nn.Linear(config.n_embd, config.n_embd * 4), nn.SiLU(), nn.Linear(config.n_embd * 4, config.n_embd)
        )
        self.drop = nn.Dropout(config.dropout)
        self.h = nn.ModuleList([Block(config) for _ in range(config.n_layer)])
        self.ln_f = LayerNorm(config.n_embd, bias=config.bias)
        self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)  # Predict logit for noise (or original)

        # Weight tying (typical for transformer LMs)
        self.lm_head.weight = self.wte.weight  # Share the same weight for embedding and unembedding

        self.apply(self._init_weights)

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)

    def forward(self, x_noisy_embed, timesteps):
        B, T, C = x_noisy_embed.size()

        # Positional embeddings
        pos = torch.arange(0, T, dtype=torch.long, device=timesteps.device)
        pos_emb = self.wpe(pos)

        # Timestep embeddings
        # We need to create sinusoidal embeddings for timesteps first, then pass through MLP
        # This is a common practice in diffusion models
        half_dim = self.config.n_embd // 2
        emb = np.log(10000) / (half_dim - 1)
        emb = torch.exp(torch.arange(half_dim, device=timesteps.device) * -emb)
        emb = timesteps.float()[:, None] * emb[None, :]  # (B, dim)
        emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)  # (B, dim * 2)
        timestep_emb = self.timestep_mlp(emb).unsqueeze(1)  # (B, 1, C)

        # Add timestep embedding to sequence embeddings
        # This is often added at various points in a real U-Net. For simplicity, we add it once here.
        x = x_noisy_embed + pos_emb + timestep_emb

        x = self.drop(x)
        for block in self.h:
            x = block(x)
        x = self.ln_f(x)

        # Project back to vocabulary logits (or noise prediction directly)
        # For simplicity, we directly predict the logit for the original token
        # A more direct diffusion approach would predict the noise or the 'x_0' directly.
        # We'll use this to predict the target for the loss function.
        logits = self.lm_head(x)  # (B, T, vocab_size)
        return logits

In [None]:
class GaussianDiffusion:
    def __init__(self, num_timesteps, embed_dim, device):
        self.num_timesteps = num_timesteps
        self.embed_dim = embed_dim
        self.device = device

        # Linear noise schedule
        self.betas = torch.linspace(0.0001, 0.02, num_timesteps, device=device)
        self.alphas = 1.0 - self.betas
        self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
        self.alphas_cumprod_prev = F.pad(self.alphas_cumprod[:-1], (1, 0), value=1.0)
        self.sqrt_alphas_cumprod = torch.sqrt(self.alphas_cumprod)
        self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - self.alphas_cumprod)

        # Calculations for posterior q(x_{t-1} | x_t, x_0)
        self.posterior_variance = self.betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)

    def q_sample(self, x_start_embed, t, noise=None):
        """
        Forward diffusion (adding noise)
        x_start_embed: (B, T, C) - ground truth embeddings
        t: (B,) - timesteps
        """
        if noise is None:
            noise = torch.randn_like(x_start_embed)

        sqrt_alphas_cumprod_t = self.sqrt_alphas_cumprod[t].view(-1, 1, 1)
        sqrt_one_minus_alphas_cumprod_t = self.sqrt_one_minus_alphas_cumprod[t].view(-1, 1, 1)

        x_noisy = sqrt_alphas_cumprod_t * x_start_embed + sqrt_one_minus_alphas_cumprod_t * noise
        return x_noisy, noise

    def p_mean_variance(self, model, x_t_embed, t, x_start_embed_uncond=None):
        """
        Predict parameters of the reverse diffusion process.
        Here, we simplify by directly predicting the noise.
        A more direct approach might predict x_0, and then derive mean/variance.
        """
        # Model predicts noise (or x_0)
        # For our simplified model, the DenoisingUnet directly predicts logits of x_0
        # So we adapt this to be a noise prediction via the common reparameterization.
        predicted_x_start_logits = model(x_t_embed, t)  # (B, T, vocab_size)
        # Convert predicted logits to embeddings (average or gumbel-softmax during training for discrete)
        # For simplicity, we'll assume the model is learning to predict an "x_0_pred" that's compatible
        # with our embedding space for loss calculation.
        # A true Denoising diffusion would predict epsilon (noise) or x_0.
        # Let's make it predict noise (epsilon) for DDPM formulation clarity.

        # Adapt DenoisingUnet to predict noise (epsilon) instead of x_0 logits.
        # This means the final lm_head should predict noise in embedding space.
        # Let's adjust DenoisingUnet's lm_head for this:
        # Instead of lm_head, we'll have a final linear layer to predict noise in embedding dim.

        # For this example, let's assume the model outputs `predicted_noise`.
        # This requires changing the last layer of DenoisingUnet:
        # self.lm_head = nn.Linear(config.n_embd, config.n_embd, bias=False) # Predict noise in embedding space

        # If DenoisingUnet predicts noise:
        # predicted_noise = model(x_t_embed, t) # (B, T, C)
        # Then we would derive x_0_pred from predicted_noise
        # pred_x_start = (x_t_embed - self.sqrt_one_minus_alphas_cumprod[t].view(-1,1,1) * predicted_noise) / self.sqrt_alphas_cumprod[t].view(-1,1,1)

        # For our current DenoisingUnet, it predicts logits of x_0.
        # Let's make the loss directly compare the noisy embedding to the target embedding.
        # We will train the model to predict the *original* x_start_embed from the noisy x_t_embed.
        # This simplifies the DDPM equations for training.

        # For sampling, if the model predicts x_start_embed_pred:
        # Mean of q(x_{t-1} | x_t, x_0_pred)
        # mean = (x_t_embed - self.betas[t].view(-1,1,1) * predicted_noise / self.sqrt_one_minus_alphas_cumprod[t].view(-1,1,1)) / self.sqrt_alphas_cumprod[t].view(-1,1,1)

        # For simplicity, let's assume `model` outputs the *predicted noise* `eps_pred`
        # and we need to modify DenoisingUnet's last layer to predict `config.n_embd`
        # instead of `vocab_size`.
        # Let's adjust DenoisingUnet to predict noise directly for clarity.
        pass  # This function is more complex and depends on the exact formulation of the reverse process.
        # We'll use a simplified training loss below.

    @torch.no_grad()
    def p_sample(self, model, x_t_embed, t):
        """
        Reverse diffusion (denoising step for sampling)
        """
        # This implementation requires the model to predict `predicted_noise` (epsilon)
        # which is the common DDPM setup.
        # Let's assume `model` directly predicts `predicted_noise`.
        # So DenoisingUnet's `lm_head` will output `config.n_embd`
        # and we'll use MSE loss on noise.

        predicted_noise = model(x_t_embed, t)

        # Derived x_0_pred from predicted_noise
        pred_x_start = (
            x_t_embed - self.sqrt_one_minus_alphas_cumprod[t].view(-1, 1, 1) * predicted_noise
        ) / self.sqrt_alphas_cumprod[t].view(-1, 1, 1)

        # Clamp x_0_pred to be within reasonable bounds (e.g., embedding space range)
        # Not strictly necessary if embeddings are normalized
        # pred_x_start = torch.clamp(pred_x_start, -1., 1.) # if embeddings are normalized

        model_mean = pred_x_start * self.sqrt_alphas_cumprod_prev[t].view(
            -1, 1, 1
        ) + self.sqrt_one_minus_alphas_cumprod_t[t].view(-1, 1, 1) * (
            1.0 - self.alphas_cumprod_prev[t].view(-1, 1, 1)
        ) * predicted_noise / (1.0 - self.alphas_cumprod[t].view(-1, 1, 1))

        # This formula is incorrect and simplified. Correct mean formula from DDPM:
        # mean = (x_t - self.betas[t] * predicted_noise / self.sqrt_one_minus_alphas_cumprod[t]) / self.sqrt_alphas[t]

        # Let's use simpler DDPM sampling for this example
        beta_t = self.betas[t].view(-1, 1, 1)
        sqrt_one_minus_alphas_cumprod_t = self.sqrt_one_minus_alphas_cumprod[t].view(-1, 1, 1)
        sqrt_alpha_t = torch.sqrt(self.alphas[t]).view(-1, 1, 1)

        # Mean of q(x_{t-1} | x_t, x_0)
        model_mean = (x_t_embed - beta_t * predicted_noise / sqrt_one_minus_alphas_cumprod_t) / sqrt_alpha_t

        if t.min() == 0:
            return model_mean
        else:
            posterior_variance_t = self.posterior_variance[t].view(-1, 1, 1)
            noise = torch.randn_like(x_t_embed)
            return model_mean + torch.sqrt(posterior_variance_t) * noise

    @torch.no_grad()
    def p_sample_loop(self, model, shape):
        """
        Complete sampling loop.
        shape: (B, T, C) - desired shape of the output embeddings
        """
        B, T, C = shape
        img = torch.randn(shape, device=self.device)  # Start with pure noise

        for i in tqdm(reversed(range(0, self.num_timesteps)), desc="sampling loop time step", total=self.num_timesteps):
            t = torch.full((B,), i, device=self.device, dtype=torch.long)
            img = self.p_sample(model, img, t)
        return img

    @torch.no_grad()
    def sample(self, model, x_start_embed, num_samples=1):
        """
        Simplified sampling for text by starting from noise and denoising.
        This won't use x_start_embed directly, but needs its shape.
        """
        return self.p_sample_loop(model, x_start_embed.shape)

In [None]:
# --- Modify DenoisingUnet to predict noise (epsilon) ---
class DenoisingUnet(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config

        self.wte = nn.Embedding(config.vocab_size, config.n_embd)
        self.wpe = nn.Embedding(config.block_size, config.n_embd)
        self.timestep_mlp = nn.Sequential(
            nn.Linear(config.n_embd, config.n_embd * 4), nn.SiLU(), nn.Linear(config.n_embd * 4, config.n_embd)
        )
        self.drop = nn.Dropout(config.dropout)
        self.h = nn.ModuleList([Block(config) for _ in range(config.n_layer)])
        self.ln_f = LayerNorm(config.n_embd, bias=config.bias)

        # CHANGE: Predict noise in embedding space (output dim is embed_dim)
        self.noise_head = nn.Linear(config.n_embd, config.n_embd, bias=False)

        self.apply(self._init_weights)

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)

    def forward(self, x_noisy_embed, timesteps):
        B, T, C = x_noisy_embed.size()

        pos = torch.arange(0, T, dtype=torch.long, device=timesteps.device)
        pos_emb = self.wpe(pos)

        half_dim = self.config.n_embd // 2
        emb = np.log(10000) / (half_dim - 1)
        emb = torch.exp(torch.arange(half_dim, device=timesteps.device) * -emb)
        emb = timesteps.float()[:, None] * emb[None, :]
        emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)
        timestep_emb = self.timestep_mlp(emb).unsqueeze(1)

        x = x_noisy_embed + pos_emb + timestep_emb

        x = self.drop(x)
        for block in self.h:
            x = block(x)
        x = self.ln_f(x)

        # Predict noise
        predicted_noise = self.noise_head(x)  # (B, T, C)
        return predicted_noise


In [None]:
# --- Model Instantiation ---
config = DiffusionConfig(
    vocab_size=vocab_size,
    block_size=BLOCK_SIZE,
    n_layer=N_LAYERS,
    n_head=N_HEADS,
    n_embd=EMBEDDING_DIM,
    dropout=0.1,
    bias=False,  # NanoGPT prefers no bias
)

model = DenoisingUnet(config).to(device)
optimizer = optim.AdamW(model.parameters(), lr=LEARNING_RATE)
diffusion = GaussianDiffusion(NUM_DIFFUSION_STEPS, EMBEDDING_DIM, device)

# --- Training Loop ---
print("Starting training...")
model.train()
optimizer.zero_grad()  # Initialize gradients

TRAIN_STEPS = 1200
for i in tqdm(range(1, TRAIN_STEPS + 1), desc="Training"):
    # Get a batch of original discrete tokens
    batch_tokens = get_batch("train")  # (B, T)

    # Convert discrete tokens to embeddings (x_0)
    x_start_embed = model.wte(batch_tokens)  # (B, T, C)

    # Sample a random timestep t
    t = torch.randint(0, NUM_DIFFUSION_STEPS, (BATCH_SIZE,), device=device).long()

    # Add noise to x_start_embed
    x_noisy_embed, noise = diffusion.q_sample(x_start_embed, t)

    # Predict noise using the model
    predicted_noise = model(x_noisy_embed, t)

    # Calculate MSE loss between predicted noise and actual noise
    loss = F.mse_loss(predicted_noise, noise) / GRADIENT_ACCUMULATION_STEPS

    loss.backward()

    if i % GRADIENT_ACCUMULATION_STEPS == 0:
        optimizer.step()
        optimizer.zero_grad()

    if i % 1000 == 0:
        print(f"Step {i}: Loss = {loss.item() * GRADIENT_ACCUMULATION_STEPS:.4f}")
        # Optionally, save model or run a quick validation/sample
        save_model(model, "txt_diffusion", "1_0", i)


In [None]:
# --- Sampling ---
print("\nStarting sampling...")
model.eval()
with torch.no_grad():
    # Generate an initial random noise vector (shape matching expected output)
    # This is where we control the length, e.g., BLOCK_SIZE
    dummy_input_for_shape = torch.zeros(1, BLOCK_SIZE, EMBEDDING_DIM).to(device)
    sampled_embeddings = diffusion.sample(model, dummy_input_for_shape)  # (B, T, C)

    # To convert embeddings back to characters, we need a final projection
    # For a real text diffusion model, you might use a sophisticated decoding step
    # like a VAE decoder or a nearest-neighbor lookup in the embedding space.
    # For simplicity, we'll find the nearest embedding in our vocabulary.

    # Get all vocabulary embeddings
    all_vocab_embeddings = model.wte.weight.data  # (vocab_size, EMBEDDING_DIM)

    # Reshape sampled embeddings for batch processing
    sampled_embeddings_flat = sampled_embeddings.view(-1, EMBEDDING_DIM)  # (B*T, C)

    # Calculate cosine similarity or Euclidean distance to find nearest vocab embedding
    # Cosine similarity is good for normalized embeddings
    # We'll use Euclidean distance for simplicity here
    distances = torch.cdist(sampled_embeddings_flat, all_vocab_embeddings)  # (B*T, vocab_size)

    # Get the index of the closest vocabulary embedding
    predicted_indices = torch.argmin(distances, dim=1)  # (B*T,)

    # Reshape back to original sequence shape (B, T)
    predicted_indices = predicted_indices.view(sampled_embeddings.size(0), sampled_embeddings.size(1))

    # Decode and print
    print("\n--- Generated Sample (first batch item) ---")
    generated_text = decode(predicted_indices[0].tolist())
    print(generated_text)