In [None]:
# Toy dataset: simple sentences
# sentences = ["the cat sits", "a dog runs", "birds fly high", "fish swim fast", "the sun shines"]

# Build a simple vocabulary
words = set(" ".join(lines).split())
word_to_idx = {word: idx + 1 for idx, word in enumerate(words)}  # +1 for padding
word_to_idx["<pad>"] = 0
idx_to_word = {idx: word for word, idx in word_to_idx.items()}
vocab_size = len(word_to_idx)
max_len = max(len(s.split()) for s in lines)


# Convert sentences to token sequences
def tokenize_sentence(sentence):
    tokens = [word_to_idx[word] for word in sentence.split()]
    return tokens + [0] * (max_len - len(tokens))  # Pad to max_len


tokenized_data = [tokenize_sentence(s) for s in lines]


In [None]:
np.array(tokenized_data).shape


(100, 13)

In [None]:
# Custom Dataset
class TextDataset(Dataset):
    def __init__(self, data):
        self.data = torch.tensor(data, dtype=torch.long)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]


# lets overfit on the first 100 sentences
dataset = TextDataset(tokenized_data)
dataloader = DataLoader(dataset, batch_size=64, shuffle=True)


In [None]:
# D3PM (Discrete Denoising Diffusion Probabilistic Model)
class D3PM:
    def __init__(self, num_steps=50, vocab_size=vocab_size):
        self.num_steps = num_steps
        self.vocab_size = vocab_size

        self.betas = torch.linspace(0.0001, 0.02, num_steps)
        self.alphas = 1.0 - self.betas
        self.alpha_bars = torch.cumprod(self.alphas, dim=0)
        self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)

    def q_sample(self, x0, t):
        # Get batch size and shape
        device = x0.device
        b = x0.shape[0]
        x0_flat = x0.view(b, -1)
        # At t=0, no noise: x_t = x0
        xt = x0_flat.clone()
        for i in range(b):
            step = t[i].item()
            if step == 0:
                continue
            # For step > 0, corrupt x0
            mask = torch.rand_like(x0_flat[i].float()) < (1 - self.alphas_cumprod[step])
            noise = torch.randint(0, self.vocab_size, x0_flat[i].shape, device=device)
            xt[i][mask] = noise[mask]
        return xt.view_as(x0)

    def sample(self, model, batch_size, device, context=None, guidance_scale=1.0):
        x_t = torch.randint(0, self.vocab_size, (batch_size, max_len), device=device)
        if context is not None:
            # Repeat context to match batch size
            context = context.repeat(batch_size, 1)  # Shape: [batch_size, context_len]
            x_t[:, : context.shape[1]] = context  # Fix context tokens
        for t in reversed(range(self.num_steps)):
            t_tensor = torch.full((batch_size,), t, device=device, dtype=torch.long)
            logits_uncond = model(x_t, t_tensor, context=None)
            logits_cond = model(x_t, t_tensor, context=context)
            logits = logits_uncond + guidance_scale * (logits_cond - logits_uncond)
            probs = F.softmax(logits, dim=-1)
            x_t = torch.multinomial(probs.view(-1, self.vocab_size), 1).view(batch_size, max_len)
            if context is not None:
                x_t[:, : context.shape[1]] = context  # Preserve context
        return x_t


diffusion = D3PM(num_steps=50, vocab_size=vocab_size)
ic(tokenized_data[:2])
diffusion.q_sample(torch.tensor(tokenized_data[:2], dtype=torch.long), torch.tensor([0, 49], dtype=torch.long))


ic| tokenized_data[:2]: [[216, 106, 50, 263, 306, 66, 0, 0, 0, 0, 0, 0, 0],
                         [100, 349, 236, 292, 349, 258, 108, 92, 310, 53, 0, 0, 0]]


tensor([[216, 106,  50, 263, 306,  66,   0,   0,   0,   0,   0,   0,   0],
        [ 67, 311, 236,  24, 349, 258, 273,  92, 310,  53,   0, 340,   0]])

In [None]:
# NanoGPT-like Model with Context
class NanoGPT(nn.Module):
    def __init__(self, vocab_size, n_embd=64, n_head=4, n_layer=2, max_len=max_len, num_steps=50):
        super().__init__()
        self.token_embedding = nn.Embedding(vocab_size, n_embd)
        self.position_embedding = nn.Embedding(max_len, n_embd)
        self.context_embedding = nn.Embedding(vocab_size, n_embd)
        self.blocks = nn.ModuleList(
            [
                nn.TransformerEncoderLayer(d_model=n_embd, nhead=n_head, dim_feedforward=n_embd * 4)
                for _ in range(n_layer)
            ]
        )
        self.ln_f = nn.LayerNorm(n_embd)
        self.head = nn.Linear(n_embd, vocab_size)
        self.max_len = max_len
        self.time_embedding = nn.Embedding(num_steps, n_embd)

    def forward(self, x, t, context=None):
        B, T = x.shape
        device = x.device
        tok_emb = self.token_embedding(x)
        pos_emb = self.position_embedding(torch.arange(T, device=device))
        t_emb = self.time_embedding(t).unsqueeze(1)
        x = tok_emb + pos_emb + t_emb
        if context is not None:
            ctx_emb = self.context_embedding(context)
            ctx_pos_emb = self.position_embedding(torch.arange(context.shape[1], device=device))
            ctx = ctx_emb + ctx_pos_emb
            x = torch.cat([ctx, x], dim=1)
            x = x[:, :T] + pos_emb
        for block in self.blocks:
            x = block(x)
        x = self.ln_f(x)
        logits = self.head(x)
        return logits


# Check for MPS availability
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
print(f"Using device: {device}")

Using device: mps


In [None]:
num_steps = 50
model = NanoGPT(vocab_size=vocab_size, num_steps=num_steps).to(device)
diffusion = D3PM(num_steps=num_steps, vocab_size=vocab_size)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

In [None]:
# Training

num_epochs = 10
dropout_prob = 0.1  # Probability of dropping context for classifier-free guidance
global_step = 0  # Track global step for TensorBoard

# Set up TensorBoard logging
timestamp = datetime.datetime.now().strftime("%d-%m-%Y_%H:%M:%S")
log_dir = f"../runs/d3pm/{timestamp}"
writer = SummaryWriter(log_dir=log_dir)
print(f"TensorBoard logs will be saved to: {log_dir}")

epoch_bar = tqdm(range(num_epochs), desc="🚀 Training", position=0, leave=True)
for epoch in epoch_bar:
    total_loss = 0
    running_avg_loss = 0
    batch_count = 0

    batch_nb = len(dataloader)
    inner_pbar = tqdm(range(batch_nb), desc=f"  ⚙️ Inner Task {epoch + 1}", position=1, leave=False, colour="green")

    for batch_idx in inner_pbar:
        batch = next(iter(dataloader))
        batch = batch.to(device)
        optimizer.zero_grad()
        t = torch.randint(0, num_steps, (batch.shape[0],), device=device, dtype=torch.long)
        x_t = diffusion.q_sample(batch, t)
        context = batch[:, :1] if np.random.rand() > dropout_prob else None
        logits = model(x_t, t, context=context)
        loss = F.cross_entropy(logits.view(-1, vocab_size), batch.view(-1))
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        batch_count += 1
        running_avg_loss = total_loss / batch_count
        global_step += 1

        # Log batch loss to TensorBoard
        writer.add_scalar("Loss/Batch", loss.item(), global_step)
        writer.add_scalar("Loss/Running_Average", running_avg_loss, global_step)

        # Log learning rate
        writer.add_scalar("Learning_Rate", optimizer.param_groups[0]["lr"], global_step)

        inner_pbar.set_postfix({"Step": f"{batch_idx + 1}/{batch_nb}", "Status": "Processing..."})

    avg_loss = total_loss / len(dataloader)
    epoch_bar.set_postfix({"Epoch": f"{epoch + 1}/{num_epochs}", "Loss": f"{avg_loss:.4f}"})

    # Log epoch loss to TensorBoard
    writer.add_scalar("Loss/Epoch", avg_loss, epoch + 1)

    if (epoch + 1) % 50 == 0:
        checkpoint = {
            "epoch": epoch + 1,
            "model_state_dict": model.state_dict(),
            "optimizer_state_dict": optimizer.state_dict(),
            "loss": avg_loss,
        }
        torch.save(checkpoint, f"../models/d3pm_epoch_{epoch + 1}.pth")
        print(f"Saved checkpoint: d3pm_epoch_{epoch + 1}.pth")

# Close the TensorBoard writer
writer.close()
print("✅ Training completed!")
print(f"View logs with: tensorboard --logdir={log_dir}")

TensorBoard logs will be saved to: ../runs/d3pm/29-05-2025_17:10:50


🚀 Training:   0%|          | 0/10 [00:00<?, ?it/s]

  ⚙️ Inner Task 1:   0%|          | 0/2 [00:00<?, ?it/s]

  ⚙️ Inner Task 2:   0%|          | 0/2 [00:00<?, ?it/s]

  ⚙️ Inner Task 3:   0%|          | 0/2 [00:00<?, ?it/s]

  ⚙️ Inner Task 4:   0%|          | 0/2 [00:00<?, ?it/s]

  ⚙️ Inner Task 5:   0%|          | 0/2 [00:00<?, ?it/s]

  ⚙️ Inner Task 6:   0%|          | 0/2 [00:00<?, ?it/s]

  ⚙️ Inner Task 7:   0%|          | 0/2 [00:00<?, ?it/s]

  ⚙️ Inner Task 8:   0%|          | 0/2 [00:00<?, ?it/s]

  ⚙️ Inner Task 9:   0%|          | 0/2 [00:00<?, ?it/s]

  ⚙️ Inner Task 10:   0%|          | 0/2 [00:00<?, ?it/s]

✅ Training completed!
View logs with: tensorboard --logdir=../runs/d3pm/29-05-2025_17:10:50


In [None]:
# Optional: Add gradient norm logging and model parameter histograms
# Uncomment the lines below if you want more detailed logging

# Log model parameters and gradients every 10 batches
# if global_step % 10 == 0:
#     for name, param in model.named_parameters():
#         if param.grad is not None:
#             writer.add_histogram(f'Gradients/{name}', param.grad, global_step)
#         writer.add_histogram(f'Parameters/{name}', param, global_step)

#     # Log gradient norm
#     total_norm = 0
#     for p in model.parameters():
#         if p.grad is not None:
#             param_norm = p.grad.data.norm(2)
#             total_norm += param_norm.item() ** 2
#     total_norm = total_norm ** (1. / 2)
#     writer.add_scalar('Gradient_Norm', total_norm, global_step)

In [None]:
# Generate samples with context
start_sentence = "Come the days of"
model.eval()
with torch.no_grad():
    start_words = start_sentence.split()
    context = torch.tensor([word_to_idx[word] for word in start_words if word in word_to_idx], device=device)  # Sh
    samples = diffusion.sample(model, batch_size=3, device=device, context=context, guidance_scale=2.0)
    for sample in samples:
        text = [idx_to_word[idx.item()] for idx in sample if idx.item() in idx_to_word]
        print("Generated:", " ".join(text).replace("<pad>", ""))

Generated: Come the of him a-meowin’ city And it’s to   place man,
Generated: Come the of merrier” big I’m Gate tryin’   was  on
Generated: Come the of That very Cops   How    
