In [1]:
# | default_exp attention

%load_ext autoreload
%autoreload 2

%env TOKENIZERS_PARALLELISM=false

env: TOKENIZERS_PARALLELISM=false


In [85]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from torch.utils.data import Dataset, DataLoader
import uuid
import pandas as pd
from tqdm.notebook import tqdm
from torch.utils.tensorboard import SummaryWriter
import datetime
from icecream import ic
import math

from transformers import AutoTokenizer

In [3]:
df = pd.read_csv("../dataset/bob_dylan_lyrics.csv")
lines = []
nb_rows = 100
row_id = 0
for r in df.iterrows():
    # todo: one line is one sentence.
    lines.append(r[1]["title"])
    # sentences.append(r[1]["title"] + "\n" + r[1]["lyrics"])
    lyrics = r[1]["lyrics"].split("\n")
    for line in lyrics:
        if len(line.strip()) > 0:
            lines.append(line.strip())
        row_id += 1
        if len(lines) >= nb_rows:
            break
    if len(lines) >= nb_rows:
        break


In [68]:
class LyricsDataset(Dataset):
    def __init__(self, texts, tokenizer, seq_len=128):
        self.tokenizer = tokenizer
        self.seq_len = seq_len
        self.examples = []
        for line in texts:
            tokens = tokenizer.convert_tokens_to_ids(tokenizer.tokenize(line))
            self.examples.append(tokens)

    def __len__(self):
        return len(self.examples)

    def __getitem__(self, idx):
        inp = self.examples[idx]
        pad_id = self.tokenizer.pad_token_id
        inp = inp + [pad_id] * (self.seq_len - len(inp))
        return torch.tensor(inp)


In [69]:
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
tokenizer.add_special_tokens({"pad_token": "[PAD]"})
dataset = LyricsDataset(lines, tokenizer, seq_len=32)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)
if torch.backends.mps.is_available():
    device = torch.device("mps")
elif torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")

In [70]:
next(iter(dataloader))
len(tokenizer)

30522

In [71]:
class DiscreteDiffusion:
    def __init__(self, num_tokens, timesteps, beta_start=0.0001, beta_end=0.02):
        self.num_tokens = num_tokens
        self.timesteps = timesteps
        self.betas = torch.linspace(beta_start, beta_end, timesteps)
        self.alphas = 1.0 - self.betas
        self.alpha_bars = torch.cumprod(self.alphas, dim=0)

    def q_sample(self, x0, t):
        B, L = x0.shape
        out = torch.zeros_like(x0)
        for i in range(B):
            a_bar = self.alpha_bars[t[i]].item()
            mask = torch.rand((L,), device=x0.device) >= a_bar
            out[i] = x0[i].clone()
            noise = torch.randint(0, self.num_tokens, (L,), device=x0.device)
            out[i][mask] = noise[mask]
        return out

In [72]:
diffusion = DiscreteDiffusion(num_tokens=len(tokenizer), timesteps=100)

In [73]:
inp = next(iter(dataloader)).to(device)
inp

tensor([[ 2000, 16696,  8232,  ...,     0,     0,     0],
        [ 2129,  2002,  2351,  ...,     0,     0,     0],
        [ 2092,  1010,  1996,  ...,     0,     0,     0],
        ...,
        [ 2027,  2179,  2032,  ...,     0,     0,     0],
        [ 2092,  1010,  2057,  ...,     0,     0,     0],
        [ 2416,  4595,  2111,  ...,     0,     0,     0]], device='mps:0')

In [79]:
def show_noise(line_nb, step):
    src_line = tokenizer.decode(inp[line_nb].cpu().numpy())
    noisy_inp = diffusion.q_sample(inp[line_nb : line_nb + 1], torch.tensor([step]).to(device))
    noisy_line = tokenizer.decode(noisy_inp[0].cpu().numpy())
    return src_line, noisy_line


sent_nb = 4
ic(show_noise(sent_nb, 0)[1])
ic(show_noise(sent_nb, 12)[1])
ic(show_noise(sent_nb, 50)[1])
ic(show_noise(sent_nb, 99)[1]);


ic| show_noise(sent_nb, 0)[1]: ('there on the sidewalk he did lay [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] '
                                '[PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] '
                                '[PAD] [PAD] [PAD] [PAD] [PAD] [PAD]')
ic| show_noise(sent_nb, 12)[1]: ('there on the sidewalk he did lay [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] '
                                 'mountains [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] '
                                 '[PAD] [PAD] [PAD] [PAD] [PAD] [PAD]')
ic| show_noise(sent_nb, 50)[1]: ('there on the sidewalk he did lay rhythms [PAD] [PAD] [PAD]ɣ [PAD] [PAD] '
                                 '[PAD] [PAD] [PAD] [PAD] makers [PAD] [PAD]feit [PAD] everyday [PAD] [PAD] ᴬ '
                                 '[PAD] [PAD] ridley pause [PAD]')
ic| show_noise(sent_nb, 99)[1]: ('allies eliminating the 184 facility did knox determines [PAD] announced '
                      

In [83]:
class TimeEmbedding(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.lin = nn.Linear(dim, dim)

    def forward(self, t):
        half = self.lin.in_features // 2
        freqs = torch.exp(-math.log(10000) * torch.arange(half, dtype=torch.float32) / half).to(t.device)
        args = t[:, None].float() * freqs[None]
        emb = torch.cat([torch.sin(args), torch.cos(args)], dim=-1)
        return self.lin(emb)


class DiffusionTransformer(nn.Module):
    def __init__(self, vocab_size, seq_len, dim=512, heads=8, layers=6):
        super().__init__()
        self.token_emb = nn.Embedding(vocab_size, dim)
        self.pos_emb = nn.Embedding(seq_len, dim)
        self.time_emb = TimeEmbedding(dim)
        enc_layer = nn.TransformerEncoderLayer(dim, heads, dim * 4)
        self.transformer = nn.TransformerEncoder(enc_layer, layers)
        self.to_logits = nn.Linear(dim, vocab_size)
        self.seq_len = seq_len

    def forward(self, x, t):
        B, L = x.shape
        tok = self.token_emb(x)
        pos = self.pos_emb(torch.arange(L, device=x.device))
        temb = self.time_emb(t).unsqueeze(1)
        h = tok + pos + temb
        h = self.transformer(h.transpose(0, 1)).transpose(0, 1)
        return self.to_logits(h)

In [92]:
seq_len = 32
batch_size = 32
dataset = LyricsDataset(lines, tokenizer, seq_len=seq_len)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

diffusion = DiscreteDiffusion(num_tokens=len(tokenizer), timesteps=100)
model = DiffusionTransformer(vocab_size=len(tokenizer), seq_len=seq_len)
model.to(device)
epochs = 100
lr = 1e-4

optimizer = torch.optim.Adam(model.parameters(), lr=lr)
for epoch in range(1, epochs + 1):
    model.train()
    total_loss = 0.0
    for inp in dataloader:
        inp = inp.to(device)
        B = inp.size(0)
        t = torch.randint(0, diffusion.timesteps, (B,), device=device)
        noised = diffusion.q_sample(inp, t)
        logits = model(noised, t)
        loss = F.cross_entropy(
            logits.view(-1, logits.size(-1)),
            inp.view(-1),
            ignore_index=model.token_emb.padding_idx
            if (hasattr(model.token_emb, "padding_idx") and model.token_emb.padding_idx is not None)
            else -100,
        )
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss += loss.item() * B
    print(f"Epoch {epoch}/{epochs} | Loss: {total_loss / len(dataloader.dataset):.4f}")



Epoch 1/100 | Loss: 7.3190
Epoch 2/100 | Loss: 3.8459
Epoch 2/100 | Loss: 3.8459
Epoch 3/100 | Loss: 3.6677
Epoch 3/100 | Loss: 3.6677
Epoch 4/100 | Loss: 3.5243
Epoch 4/100 | Loss: 3.5243
Epoch 5/100 | Loss: 3.3978
Epoch 5/100 | Loss: 3.3978
Epoch 6/100 | Loss: 3.2776
Epoch 6/100 | Loss: 3.2776
Epoch 7/100 | Loss: 3.1534
Epoch 7/100 | Loss: 3.1534
Epoch 8/100 | Loss: 2.9913
Epoch 8/100 | Loss: 2.9913
Epoch 9/100 | Loss: 2.8114
Epoch 9/100 | Loss: 2.8114
Epoch 10/100 | Loss: 2.6487
Epoch 10/100 | Loss: 2.6487
Epoch 11/100 | Loss: 2.5297
Epoch 11/100 | Loss: 2.5297
Epoch 12/100 | Loss: 2.3970
Epoch 12/100 | Loss: 2.3970
Epoch 13/100 | Loss: 2.3057
Epoch 13/100 | Loss: 2.3057
Epoch 14/100 | Loss: 2.1907
Epoch 14/100 | Loss: 2.1907
Epoch 15/100 | Loss: 2.1217
Epoch 15/100 | Loss: 2.1217
Epoch 16/100 | Loss: 2.0426
Epoch 16/100 | Loss: 2.0426
Epoch 17/100 | Loss: 1.9895
Epoch 17/100 | Loss: 1.9895
Epoch 18/100 | Loss: 1.9133
Epoch 18/100 | Loss: 1.9133
Epoch 19/100 | Loss: 1.8548
Epoch 19/

In [97]:
def generate(model, diffusion, tokenizer, prompt, length=50, device="cpu"):
    model.to(device).eval()
    # tokenize prompt
    p_tok = tokenizer.convert_tokens_to_ids(tokenizer.tokenize(prompt))
    seq_len = model.seq_len
    # ensure prompt + generation fits
    assert len(p_tok) + length <= seq_len, f"prompt+length ({len(p_tok) + length}) exceeds model.seq_len ({seq_len})"
    # initialize full sequence: prompt + random tokens
    x = torch.randint(0, diffusion.num_tokens, (1, seq_len), device=device)
    x[0, : len(p_tok)] = torch.tensor(p_tok, device=device)
    fixed = torch.zeros(seq_len, dtype=torch.bool, device=device)
    fixed[: len(p_tok)] = True

    with torch.no_grad():
        for t in reversed(range(diffusion.timesteps)):
            t_batch = torch.tensor([t], device=device)
            logits = model(x, t_batch)
            probs = F.softmax(logits, dim=-1)
            x = torch.multinomial(probs.view(-1, diffusion.num_tokens), num_samples=1).view(1, seq_len)
            # restore prompt positions
            x[0, fixed] = torch.tensor(p_tok, device=device)

    generated = x[0, len(p_tok) :].tolist()
    return prompt + tokenizer.decode(generated)


start_text = "You will be "
start_toks = tokenizer.tokenize(start_text)
start_len = len(start_toks)
ic(start_len)
print(generate(model, diffusion, tokenizer, start_text, length=seq_len - start_len, device=device))

ic| start_len: 3
| start_len: 3


You will be down [PAD] [PAD] [PAD] the from [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD]


In [91]:
len(tokenizer)

30522

In [32]:
# Toy dataset: simple sentences
# sentences = ["the cat sits", "a dog runs", "birds fly high", "fish swim fast", "the sun shines"]

# Build a simple vocabulary
words = set(" ".join(lines).split())
word_to_idx = {word: idx + 1 for idx, word in enumerate(words)}  # +1 for padding
word_to_idx["<pad>"] = 0
idx_to_word = {idx: word for word, idx in word_to_idx.items()}
vocab_size = len(word_to_idx)
max_len = max(len(s.split()) for s in lines)


# Convert sentences to token sequences
def tokenize_sentence(sentence):
    tokens = [word_to_idx[word] for word in sentence.split()]
    return tokens + [0] * (max_len - len(tokens))  # Pad to max_len


tokenized_data = [tokenize_sentence(s) for s in lines]


In [33]:
np.array(tokenized_data).shape


(100, 13)

In [34]:
# Custom Dataset
class TextDataset(Dataset):
    def __init__(self, data):
        self.data = torch.tensor(data, dtype=torch.long)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]


# lets overfit on the first 100 sentences
dataset = TextDataset(tokenized_data)
dataloader = DataLoader(dataset, batch_size=64, shuffle=True)


In [35]:
next(iter(dataloader)).shape


torch.Size([64, 13])

In [49]:
# D3PM (Discrete Denoising Diffusion Probabilistic Model)
class D3PM:
    def __init__(self, num_steps=50, vocab_size=vocab_size):
        self.num_steps = num_steps
        self.vocab_size = vocab_size

        self.betas = torch.linspace(0.0001, 0.02, num_steps)
        self.alphas = 1.0 - self.betas
        self.alpha_bars = torch.cumprod(self.alphas, dim=0)
        self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)

    def q_sample(self, x0, t):
        # Get batch size and shape
        device = x0.device
        b = x0.shape[0]
        x0_flat = x0.view(b, -1)
        # At t=0, no noise: x_t = x0
        xt = x0_flat.clone()
        for i in range(b):
            step = t[i].item()
            if step == 0:
                continue
            # For step > 0, corrupt x0
            mask = torch.rand_like(x0_flat[i].float()) < (1 - self.alphas_cumprod[step])
            noise = torch.randint(0, self.vocab_size, x0_flat[i].shape, device=device)
            xt[i][mask] = noise[mask]
        return xt.view_as(x0)

    def sample(self, model, batch_size, device, context=None, guidance_scale=1.0):
        x_t = torch.randint(0, self.vocab_size, (batch_size, max_len), device=device)
        if context is not None:
            # Repeat context to match batch size
            context = context.repeat(batch_size, 1)  # Shape: [batch_size, context_len]
            x_t[:, : context.shape[1]] = context  # Fix context tokens
        for t in reversed(range(self.num_steps)):
            t_tensor = torch.full((batch_size,), t, device=device, dtype=torch.long)
            logits_uncond = model(x_t, t_tensor, context=None)
            logits_cond = model(x_t, t_tensor, context=context)
            logits = logits_uncond + guidance_scale * (logits_cond - logits_uncond)
            probs = F.softmax(logits, dim=-1)
            x_t = torch.multinomial(probs.view(-1, self.vocab_size), 1).view(batch_size, max_len)
            if context is not None:
                x_t[:, : context.shape[1]] = context  # Preserve context
        return x_t


diffusion = D3PM(num_steps=50, vocab_size=vocab_size)
ic(tokenized_data[:2])
diffusion.q_sample(torch.tensor(tokenized_data[:2], dtype=torch.long), torch.tensor([0, 49], dtype=torch.long))


ic| tokenized_data[:2]: [[216, 106, 50, 263, 306, 66, 0, 0, 0, 0, 0, 0, 0],
                         [100, 349, 236, 292, 349, 258, 108, 92, 310, 53, 0, 0, 0]]


tensor([[216, 106,  50, 263, 306,  66,   0,   0,   0,   0,   0,   0,   0],
        [ 67, 311, 236,  24, 349, 258, 273,  92, 310,  53,   0, 340,   0]])

torch.Size([50, 356, 356])

In [50]:
# NanoGPT-like Model with Context
class NanoGPT(nn.Module):
    def __init__(self, vocab_size, n_embd=64, n_head=4, n_layer=2, max_len=max_len, num_steps=50):
        super().__init__()
        self.token_embedding = nn.Embedding(vocab_size, n_embd)
        self.position_embedding = nn.Embedding(max_len, n_embd)
        self.context_embedding = nn.Embedding(vocab_size, n_embd)
        self.blocks = nn.ModuleList(
            [
                nn.TransformerEncoderLayer(d_model=n_embd, nhead=n_head, dim_feedforward=n_embd * 4)
                for _ in range(n_layer)
            ]
        )
        self.ln_f = nn.LayerNorm(n_embd)
        self.head = nn.Linear(n_embd, vocab_size)
        self.max_len = max_len
        self.time_embedding = nn.Embedding(num_steps, n_embd)

    def forward(self, x, t, context=None):
        B, T = x.shape
        device = x.device
        tok_emb = self.token_embedding(x)
        pos_emb = self.position_embedding(torch.arange(T, device=device))
        t_emb = self.time_embedding(t).unsqueeze(1)
        x = tok_emb + pos_emb + t_emb
        if context is not None:
            ctx_emb = self.context_embedding(context)
            ctx_pos_emb = self.position_embedding(torch.arange(context.shape[1], device=device))
            ctx = ctx_emb + ctx_pos_emb
            x = torch.cat([ctx, x], dim=1)
            x = x[:, :T] + pos_emb
        for block in self.blocks:
            x = block(x)
        x = self.ln_f(x)
        logits = self.head(x)
        return logits


# Check for MPS availability
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
print(f"Using device: {device}")

Using device: mps


In [51]:
num_steps = 50
model = NanoGPT(vocab_size=vocab_size, num_steps=num_steps).to(device)
diffusion = D3PM(num_steps=num_steps, vocab_size=vocab_size)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

In [53]:
# Training

num_epochs = 10
dropout_prob = 0.1  # Probability of dropping context for classifier-free guidance
global_step = 0  # Track global step for TensorBoard

# Set up TensorBoard logging
timestamp = datetime.datetime.now().strftime("%d-%m-%Y_%H:%M:%S")
log_dir = f"../runs/d3pm/{timestamp}"
writer = SummaryWriter(log_dir=log_dir)
print(f"TensorBoard logs will be saved to: {log_dir}")

epoch_bar = tqdm(range(num_epochs), desc="🚀 Training", position=0, leave=True)
for epoch in epoch_bar:
    total_loss = 0
    running_avg_loss = 0
    batch_count = 0

    batch_nb = len(dataloader)
    inner_pbar = tqdm(range(batch_nb), desc=f"  ⚙️ Inner Task {epoch + 1}", position=1, leave=False, colour="green")

    for batch_idx in inner_pbar:
        batch = next(iter(dataloader))
        batch = batch.to(device)
        optimizer.zero_grad()
        t = torch.randint(0, num_steps, (batch.shape[0],), device=device, dtype=torch.long)
        x_t = diffusion.q_sample(batch, t)
        context = batch[:, :1] if np.random.rand() > dropout_prob else None
        logits = model(x_t, t, context=context)
        loss = F.cross_entropy(logits.view(-1, vocab_size), batch.view(-1))
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        batch_count += 1
        running_avg_loss = total_loss / batch_count
        global_step += 1

        # Log batch loss to TensorBoard
        writer.add_scalar("Loss/Batch", loss.item(), global_step)
        writer.add_scalar("Loss/Running_Average", running_avg_loss, global_step)

        # Log learning rate
        writer.add_scalar("Learning_Rate", optimizer.param_groups[0]["lr"], global_step)

        inner_pbar.set_postfix({"Step": f"{batch_idx + 1}/{batch_nb}", "Status": "Processing..."})

    avg_loss = total_loss / len(dataloader)
    epoch_bar.set_postfix({"Epoch": f"{epoch + 1}/{num_epochs}", "Loss": f"{avg_loss:.4f}"})

    # Log epoch loss to TensorBoard
    writer.add_scalar("Loss/Epoch", avg_loss, epoch + 1)

    if (epoch + 1) % 50 == 0:
        checkpoint = {
            "epoch": epoch + 1,
            "model_state_dict": model.state_dict(),
            "optimizer_state_dict": optimizer.state_dict(),
            "loss": avg_loss,
        }
        torch.save(checkpoint, f"../models/d3pm_epoch_{epoch + 1}.pth")
        print(f"Saved checkpoint: d3pm_epoch_{epoch + 1}.pth")

# Close the TensorBoard writer
writer.close()
print("✅ Training completed!")
print(f"View logs with: tensorboard --logdir={log_dir}")

TensorBoard logs will be saved to: ../runs/d3pm/29-05-2025_17:10:50


🚀 Training:   0%|          | 0/10 [00:00<?, ?it/s]

  ⚙️ Inner Task 1:   0%|          | 0/2 [00:00<?, ?it/s]

  ⚙️ Inner Task 2:   0%|          | 0/2 [00:00<?, ?it/s]

  ⚙️ Inner Task 3:   0%|          | 0/2 [00:00<?, ?it/s]

  ⚙️ Inner Task 4:   0%|          | 0/2 [00:00<?, ?it/s]

  ⚙️ Inner Task 5:   0%|          | 0/2 [00:00<?, ?it/s]

  ⚙️ Inner Task 6:   0%|          | 0/2 [00:00<?, ?it/s]

  ⚙️ Inner Task 7:   0%|          | 0/2 [00:00<?, ?it/s]

  ⚙️ Inner Task 8:   0%|          | 0/2 [00:00<?, ?it/s]

  ⚙️ Inner Task 9:   0%|          | 0/2 [00:00<?, ?it/s]

  ⚙️ Inner Task 10:   0%|          | 0/2 [00:00<?, ?it/s]

✅ Training completed!
View logs with: tensorboard --logdir=../runs/d3pm/29-05-2025_17:10:50


In [None]:
# Optional: Add gradient norm logging and model parameter histograms
# Uncomment the lines below if you want more detailed logging

# Log model parameters and gradients every 10 batches
# if global_step % 10 == 0:
#     for name, param in model.named_parameters():
#         if param.grad is not None:
#             writer.add_histogram(f'Gradients/{name}', param.grad, global_step)
#         writer.add_histogram(f'Parameters/{name}', param, global_step)

#     # Log gradient norm
#     total_norm = 0
#     for p in model.parameters():
#         if p.grad is not None:
#             param_norm = p.grad.data.norm(2)
#             total_norm += param_norm.item() ** 2
#     total_norm = total_norm ** (1. / 2)
#     writer.add_scalar('Gradient_Norm', total_norm, global_step)

In [54]:
# Generate samples with context
start_sentence = "Come the days of"
model.eval()
with torch.no_grad():
    start_words = start_sentence.split()
    context = torch.tensor([word_to_idx[word] for word in start_words if word in word_to_idx], device=device)  # Sh
    samples = diffusion.sample(model, batch_size=3, device=device, context=context, guidance_scale=2.0)
    for sample in samples:
        text = [idx_to_word[idx.item()] for idx in sample if idx.item() in idx_to_word]
        print("Generated:", " ".join(text).replace("<pad>", ""))

Generated: Come the of him a-meowin’ city And it’s to   place man,
Generated: Come the of merrier” big I’m Gate tryin’   was  on
Generated: Come the of That very Cops   How    
