In [None]:
# | default_exp attention

%load_ext autoreload
%autoreload 2

%env TOKENIZERS_PARALLELISM=false

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from torch.utils.data import Dataset, DataLoader
import uuid
import pandas as pd


In [None]:
df = pd.read_csv("../dataset/bob_dylan_lyrics.csv")
sentences = []
for r in df.iterrows():
    # todo: one line is one sentence.
    sentences.append(r[1]["title"])
    # sentences.append(r[1]["title"] + "\n" + r[1]["lyrics"])
    lyrics = r[1]["lyrics"].split("\n")
    for line in lyrics:
        if len(line.strip()) > 0:
            sentences.append(line.strip())

words = set(" ".join(sentences).split())


In [None]:
len(sentences)


In [None]:
# Toy dataset: simple sentences
# sentences = ["the cat sits", "a dog runs", "birds fly high", "fish swim fast", "the sun shines"]

# Build a simple vocabulary
words = set(" ".join(sentences).split())
word_to_idx = {word: idx + 1 for idx, word in enumerate(words)}  # +1 for padding
word_to_idx["<pad>"] = 0
idx_to_word = {idx: word for word, idx in word_to_idx.items()}
vocab_size = len(word_to_idx)
max_len = max(len(s.split()) for s in sentences)


# Convert sentences to token sequences
def tokenize_sentence(sentence):
    tokens = [word_to_idx[word] for word in sentence.split()]
    return tokens + [0] * (max_len - len(tokens))  # Pad to max_len


tokenized_data = [tokenize_sentence(s) for s in sentences]


In [None]:
np.array(tokenized_data).shape


In [None]:
# Custom Dataset
class TextDataset(Dataset):
    def __init__(self, data):
        self.data = torch.tensor(data, dtype=torch.long)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]


dataset = TextDataset(tokenized_data)
dataloader = DataLoader(dataset, batch_size=64, shuffle=True)


In [None]:
next(iter(dataloader)).shape

In [None]:
# D3PM (Discrete Denoising Diffusion Probabilistic Model)
class D3PM:
    def __init__(self, num_steps=50, vocab_size=vocab_size, use_absorbing=False):
        self.num_steps = num_steps
        self.vocab_size = vocab_size
        self.use_absorbing = use_absorbing
        self.betas = torch.linspace(0.0001, 0.02, num_steps)
        self.alphas = 1.0 - self.betas
        self.alpha_bars = torch.cumprod(self.alphas, dim=0)
        self.Q_t = []
        for beta in self.betas:
            Q = (1 - beta) * torch.eye(vocab_size) + (beta / vocab_size) * torch.ones(vocab_size, vocab_size)
            self.Q_t.append(Q)
        self.Q_t = torch.stack(self.Q_t)

    def q_sample(self, x_0, t):
        batch_size = x_0.shape[0]
        device = x_0.device
        t = t.to(device)
        x_t = torch.zeros_like(x_0, dtype=torch.long)
        for i in range(batch_size):
            for j in range(x_0.shape[1]):
                token = x_0[i, j].item()
                probs = self.Q_t[t[i], token].to(device)
                x_t[i, j] = torch.multinomial(probs, 1).item()
        return x_t

    def sample(self, model, batch_size, device, context=None, guidance_scale=1.0):
        x_t = torch.randint(0, self.vocab_size, (batch_size, max_len), device=device)
        if context is not None:
            # Repeat context to match batch size
            context = context.repeat(batch_size, 1)  # Shape: [batch_size, context_len]
            x_t[:, : context.shape[1]] = context  # Fix context tokens
        for t in reversed(range(self.num_steps)):
            t_tensor = torch.full((batch_size,), t, device=device, dtype=torch.long)
            logits_uncond = model(x_t, t_tensor, context=None)
            logits_cond = model(x_t, t_tensor, context=context)
            logits = logits_uncond + guidance_scale * (logits_cond - logits_uncond)
            probs = F.softmax(logits, dim=-1)
            x_t = torch.multinomial(probs.view(-1, self.vocab_size), 1).view(batch_size, max_len)
            if context is not None:
                x_t[:, : context.shape[1]] = context  # Preserve context
        return x_t


In [None]:
# NanoGPT-like Model with Context
class NanoGPT(nn.Module):
    def __init__(self, vocab_size, n_embd=64, n_head=4, n_layer=2, max_len=max_len, num_steps=50):
        super().__init__()
        self.token_embedding = nn.Embedding(vocab_size, n_embd)
        self.position_embedding = nn.Embedding(max_len, n_embd)
        self.context_embedding = nn.Embedding(vocab_size, n_embd)
        self.blocks = nn.ModuleList(
            [
                nn.TransformerEncoderLayer(d_model=n_embd, nhead=n_head, dim_feedforward=n_embd * 4)
                for _ in range(n_layer)
            ]
        )
        self.ln_f = nn.LayerNorm(n_embd)
        self.head = nn.Linear(n_embd, vocab_size)
        self.max_len = max_len
        self.time_embedding = nn.Embedding(num_steps, n_embd)

    def forward(self, x, t, context=None):
        B, T = x.shape
        device = x.device
        tok_emb = self.token_embedding(x)
        pos_emb = self.position_embedding(torch.arange(T, device=device))
        t_emb = self.time_embedding(t).unsqueeze(1)
        x = tok_emb + pos_emb + t_emb
        if context is not None:
            ctx_emb = self.context_embedding(context)
            ctx_pos_emb = self.position_embedding(torch.arange(context.shape[1], device=device))
            ctx = ctx_emb + ctx_pos_emb
            x = torch.cat([ctx, x], dim=1)
            x = x[:, :T] + pos_emb
        for block in self.blocks:
            x = block(x)
        x = self.ln_f(x)
        logits = self.head(x)
        return logits


# Check for MPS availability
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
print(f"Using device: {device}")

In [None]:
# Training
num_steps = 50
model = NanoGPT(vocab_size=vocab_size, num_steps=num_steps).to(device)
diffusion = D3PM(num_steps=num_steps, vocab_size=vocab_size)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
num_epochs = 100
dropout_prob = 0.1  # Probability of dropping context for classifier-free guidance

for epoch in range(num_epochs):
    total_loss = 0
    for batch in dataloader:
        batch = batch.to(device)
        optimizer.zero_grad()
        t = torch.randint(0, num_steps, (batch.shape[0],), device=device, dtype=torch.long)
        x_t = diffusion.q_sample(batch, t)
        context = batch[:, :1] if np.random.rand() > dropout_prob else None
        logits = model(x_t, t, context=context)
        loss = F.cross_entropy(logits.view(-1, vocab_size), batch.view(-1))
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    if (epoch + 1) % 10 == 0:
        print(f"Epoch {epoch + 1}, Loss: {total_loss / len(dataloader):.4f}")

    if (epoch + 1) % 50 == 0:
        checkpoint = {
            "epoch": epoch + 1,
            "model_state_dict": model.state_dict(),
            "optimizer_state_dict": optimizer.state_dict(),
            "loss": total_loss / len(dataloader),
        }
        torch.save(checkpoint, f"../models/d3pm_epoch_{epoch + 1}.pth")
        print(f"Saved checkpoint: d3pm_epoch_{epoch + 1}.pth")

In [None]:
# Generate samples with context
model.eval()
with torch.no_grad():
    context = torch.tensor([[word_to_idx["the"], word_to_idx["dog"]]], device=device)
    samples = diffusion.sample(model, batch_size=3, device=device, context=context, guidance_scale=2.0)
    for sample in samples:
        text = [idx_to_word[idx.item()] for idx in sample if idx.item() in idx_to_word]
        print("Generated:", " ".join(text).replace("<pad>", ""))