In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F

torch.manual_seed(0)

# ----------------------------
# Dummy denoising model
# ----------------------------
class DummyModel(nn.Module):
    """
    This model pretends to predict x0.
    It simply outputs random logits so we can
    test the D3PM mechanics.
    """
    def __init__(self, vocab_size):
        super().__init__()
        self.vocab_size = vocab_size

    def forward(self, x_t, t, domain, seq_boundaries=None, max_seqlen=None):
        B, L = x_t.shape
        logits = torch.randn(B, L, self.vocab_size)
        return {"logits": logits}


# ----------------------------
# D3PM (slightly instrumented)
# ----------------------------
class D3PM:
    def __init__(self, vocab_size, T=5, mask_id=2):
        self.K = vocab_size
        self.T = T
        self.mask_id = mask_id

        betas = torch.linspace(1e-4, 0.02, T)
        alphas = 1.0 - betas
        alphas_cumprod = torch.cumprod(alphas, dim=0)

        self.betas = betas
        self.alphas = alphas
        self.alphas_cumprod = alphas_cumprod

        print("Initialized D3PM")
        print("betas:", betas)
        print("alphas:", alphas)
        print("alphas_cumprod:", alphas_cumprod)
        print("-" * 50)

    def _get_Qt_bar(self, t, device):
        alpha_bar = self.alphas_cumprod.to(device)[t - 1]
        Qt_bar = torch.zeros(len(t), self.K, self.K, device=device)

        for i in range(len(t)):
            ab = alpha_bar[i]
            Qt_bar[i] = torch.eye(self.K, device=device) * ab
            Qt_bar[i, :, self.mask_id] += (1 - ab)

        print(f"Qt_bar (t={t.tolist()}):\n", Qt_bar)
        return Qt_bar

    def q_sample(self, x0, t):
        B, L = x0.shape
        device = x0.device

        print("\n--- FORWARD DIFFUSION q_sample ---")
        print("x0:", x0)
        print("t:", t)

        Qt_bar = self._get_Qt_bar(t, device)
        x0_onehot = F.one_hot(x0, self.K).float()

        qt = torch.einsum('blk,bkj->blj', x0_onehot, Qt_bar)
        print("q(x_t | x0):", qt)

        x_t = torch.multinomial(qt.view(-1, self.K), 1).view(B, L)
        print("Sampled x_t:", x_t)
        return x_t

    def _get_Qt(self, t, device):
        alpha = self.alphas.to(device)[t - 1]
        Qt = torch.eye(self.K, device=device) * alpha
        Qt[:, self.mask_id] += (1 - alpha)
        return Qt.unsqueeze(0)

    def p_sample(self, model, x_t, t, domain, seq_boundaries, max_seqlen):
        device = x_t.device
        B, L = x_t.shape

        print(f"\n--- REVERSE STEP p_sample (t={t}) ---")
        print("x_t:", x_t)

        t_batch = torch.full((B,), t, device=device, dtype=torch.long)

        with torch.no_grad():
            output = model(x_t, t_batch, domain)
            logits = output["logits"]
            p_x0 = F.softmax(logits, dim=-1)

        print("Predicted p(x0 | x_t):", p_x0)

        if t <= 1:
            x_tm1 = torch.multinomial(p_x0.view(-1, self.K), 1).view(B, L)
            print("Final x0 sample:", x_tm1)
            return x_tm1

        Qt = self._get_Qt(t, device)[0]
        Qt_1_bar = self._get_Qt_bar(torch.tensor([t - 1], device=device), device)[0]

        q_xt_given_xt_1 = Qt[:, x_t].permute(1, 2, 0)
        q_xt_1_given_x0 = torch.einsum('blk,kj->blj', p_x0, Qt_1_bar)

        posterior = q_xt_given_xt_1 * q_xt_1_given_x0
        posterior = posterior / (posterior.sum(dim=-1, keepdim=True) + 1e-8)

        print("Posterior p(x_{t-1} | x_t):", posterior)

        x_tm1 = torch.multinomial(posterior.view(-1, self.K), 1).view(B, L)
        print("Sampled x_{t-1}:", x_tm1)

        return x_tm1


# ----------------------------
# MAIN DEMO
# ----------------------------
def main():
    vocab_size = 5     # tokens: {0,1,2,3,4}
    mask_id = 2
    T = 95

    d3pm = D3PM(vocab_size=vocab_size, T=T, mask_id=mask_id)
    model = DummyModel(vocab_size)

    # Fake data
    B, L = 1, 4
    x0 = torch.tensor([[0, 1, 3, 4]])
    domain = torch.zeros_like(x0)

    # Sample random timestep
    t = torch.tensor([T])

    # Forward diffusion
    x_t = d3pm.q_sample(x0, t)

    # Reverse diffusion loop
    for step in reversed(range(1, T + 1)):
        x_t = d3pm.p_sample(
            model=model,
            x_t=x_t,
            t=step,
            domain=domain,
            seq_boundaries=None,
            max_seqlen=L
        )

    print("\n=== FINAL RESULT ===")
    print("Original x0:", x0)
    print("Reconstructed x0_hat:", x_t)


if __name__ == "__main__":
    main()


Initialized D3PM
betas: tensor([1.0000e-04, 3.1170e-04, 5.2340e-04, 7.3511e-04, 9.4681e-04, 1.1585e-03,
        1.3702e-03, 1.5819e-03, 1.7936e-03, 2.0053e-03, 2.2170e-03, 2.4287e-03,
        2.6404e-03, 2.8521e-03, 3.0638e-03, 3.2755e-03, 3.4872e-03, 3.6989e-03,
        3.9106e-03, 4.1223e-03, 4.3340e-03, 4.5457e-03, 4.7574e-03, 4.9691e-03,
        5.1809e-03, 5.3926e-03, 5.6043e-03, 5.8160e-03, 6.0277e-03, 6.2394e-03,
        6.4511e-03, 6.6628e-03, 6.8745e-03, 7.0862e-03, 7.2979e-03, 7.5096e-03,
        7.7213e-03, 7.9330e-03, 8.1447e-03, 8.3564e-03, 8.5681e-03, 8.7798e-03,
        8.9915e-03, 9.2032e-03, 9.4149e-03, 9.6266e-03, 9.8383e-03, 1.0050e-02,
        1.0262e-02, 1.0473e-02, 1.0685e-02, 1.0897e-02, 1.1109e-02, 1.1320e-02,
        1.1532e-02, 1.1744e-02, 1.1955e-02, 1.2167e-02, 1.2379e-02, 1.2590e-02,
        1.2802e-02, 1.3014e-02, 1.3226e-02, 1.3437e-02, 1.3649e-02, 1.3861e-02,
        1.4072e-02, 1.4284e-02, 1.4496e-02, 1.4707e-02, 1.4919e-02, 1.5131e-02,
        1.5343e-