In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from typing import Optional, Union, Tuple, Dict
from enum import Enum
import math


class TransitionType(Enum):
    UNIFORM = "uniform"
    ABSORBING = "absorbing"
    GAUSSIAN = "gaussian"
    EMBEDDING = "embedding"


class NoiseSchedule(Enum):
    LINEAR = "linear"
    COSINE = "cosine"
    MUTUAL_INFO = "mutual_info"


class D3PM(nn.Module):
    """
    Discrete Denoising Diffusion Probabilistic Model (D3PM)

    Based on "Structured Denoising Diffusion Models in Discrete State-Spaces"
    by Austin et al. (2021)
    """

    def __init__(
        self,
        num_classes: int,
        timesteps: int = 1000,
        transition_type: TransitionType = TransitionType.UNIFORM,
        noise_schedule: NoiseSchedule = NoiseSchedule.COSINE,
        hybrid_loss_coeff: float = 0.001,
        mask_token_id: Optional[int] = None,
        embedding_dim: Optional[int] = None,
        device: str = "cuda" if torch.cuda.is_available() else "cpu",
    ):
        super().__init__()

        self.num_classes = num_classes
        self.timesteps = timesteps
        self.transition_type = transition_type
        self.noise_schedule = noise_schedule
        self.hybrid_loss_coeff = hybrid_loss_coeff
        self.device = device

        # Set mask token (for absorbing diffusion)
        if mask_token_id is None:
            self.mask_token_id = num_classes - 1  # Default to last token
        else:
            self.mask_token_id = mask_token_id

        # Initialize transition matrices and noise schedule
        self._setup_noise_schedule()
        self._setup_transition_matrices()

        # For embedding-based transitions
        if transition_type == TransitionType.EMBEDDING:
            if embedding_dim is None:
                raise ValueError("embedding_dim must be specified for embedding transition type")
            self.token_embeddings = nn.Embedding(num_classes, embedding_dim)

    def _setup_noise_schedule(self):
        """Setup the noise schedule βt for each timestep"""
        if self.noise_schedule == NoiseSchedule.LINEAR:
            self.betas = torch.linspace(1e-4, 0.02, self.timesteps, device=self.device)
        elif self.noise_schedule == NoiseSchedule.COSINE:
            self.betas = self._cosine_schedule()
        elif self.noise_schedule == NoiseSchedule.MUTUAL_INFO:
            # Simplified mutual information schedule
            self.betas = 1.0 / torch.arange(self.timesteps, 0, -1, device=self.device, dtype=torch.float32)
        else:
            raise ValueError(f"Unknown noise schedule: {self.noise_schedule}")

    def _cosine_schedule(self, s: float = 0.008) -> torch.Tensor:
        """Cosine noise schedule from Improved DDPM"""
        steps = self.timesteps + 1
        x = torch.linspace(0, self.timesteps, steps, device=self.device)
        alphas_cumprod = torch.cos(((x / self.timesteps) + s) / (1 + s) * math.pi * 0.5) ** 2
        alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
        betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
        return torch.clip(betas, 0, 0.999)

    def _setup_transition_matrices(self):
        """Setup transition matrices Qt for each timestep"""
        self.Qt = []
        self.Qt_bar = []  # Cumulative products

        for t in range(self.timesteps):
            if self.transition_type == TransitionType.UNIFORM:
                Qt = self._uniform_transition_matrix(self.betas[t])
            elif self.transition_type == TransitionType.ABSORBING:
                Qt = self._absorbing_transition_matrix(self.betas[t])
            elif self.transition_type == TransitionType.GAUSSIAN:
                Qt = self._gaussian_transition_matrix(self.betas[t])
            else:
                # For embedding-based, we'll compute dynamically
                Qt = torch.eye(self.num_classes, device=self.device)

            self.Qt.append(Qt)

            # Compute cumulative product
            if t == 0:
                Qt_bar = Qt
            else:
                Qt_bar = torch.matmul(self.Qt_bar[-1], Qt)
            self.Qt_bar.append(Qt_bar)

    def _uniform_transition_matrix(self, beta_t: float) -> torch.Tensor:
        """Uniform transition matrix: Qt = (1-βt)I + βt/K * 11^T"""
        K = self.num_classes
        Qt = (1 - beta_t) * torch.eye(K, device=self.device)
        Qt += beta_t / K * torch.ones(K, K, device=self.device)
        return Qt

    def _absorbing_transition_matrix(self, beta_t: float) -> torch.Tensor:
        """Absorbing state transition matrix"""
        K = self.num_classes
        Qt = (1 - beta_t) * torch.eye(K, device=self.device)

        # Transitions to mask token
        Qt[:, self.mask_token_id] += beta_t
        Qt[self.mask_token_id, self.mask_token_id] = 1.0  # Absorbing state

        return Qt

    def _gaussian_transition_matrix(self, beta_t: float) -> torch.Tensor:
        """Discretized Gaussian transition matrix"""
        K = self.num_classes
        Qt = torch.zeros(K, K, device=self.device)

        for i in range(K):
            for j in range(K):
                if i != j:
                    # Gaussian-like transitions based on distance
                    distance = abs(i - j)
                    prob = torch.exp(-4 * distance**2 / ((K - 1) ** 2 * beta_t))
                    Qt[i, j] = prob

        # Normalize rows and set diagonal
        row_sums = Qt.sum(dim=1, keepdim=True)
        Qt = Qt / (row_sums + 1e-8)  # Avoid division by zero

        # Set diagonal to ensure row sums = 1
        diag_vals = 1.0 - Qt.sum(dim=1) + torch.diag(Qt)
        Qt.fill_diagonal_(0)
        Qt += torch.diag(diag_vals)

        return Qt

    def _embedding_transition_matrix(self, beta_t: float, embeddings: torch.Tensor) -> torch.Tensor:
        """Transition matrix based on embedding similarity"""
        # Compute pairwise similarities
        similarities = torch.matmul(embeddings, embeddings.T)
        similarities = F.softmax(similarities / 0.1, dim=-1)  # Temperature scaling

        # Create transition matrix
        Qt = (1 - beta_t) * torch.eye(self.num_classes, device=self.device)
        Qt += beta_t * similarities

        return Qt

    def q_sample(self, x_start: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
        """
        Sample from q(xt | x0) using the cumulative transition matrix

        Args:
            x_start: Initial data [batch_size, ...]
            t: Timestep [batch_size]

        Returns:
            xt: Noisy data at timestep t
        """
        batch_size = x_start.shape[0]

        # Get cumulative transition matrices for each batch element
        Qt_bar_batch = torch.stack([self.Qt_bar[t_i] for t_i in t])

        # Convert to one-hot if needed
        if x_start.dtype == torch.long:
            x_start_onehot = F.one_hot(x_start, self.num_classes).float()
        else:
            x_start_onehot = x_start

        # Apply transition: p = x_start @ Qt_bar
        original_shape = x_start_onehot.shape
        x_flat = x_start_onehot.view(batch_size, -1, self.num_classes)

        # Batch matrix multiplication
        probs = torch.bmm(
            x_flat,
            Qt_bar_batch.unsqueeze(1)
            .expand(-1, x_flat.shape[1], -1, -1)
            .view(batch_size * x_flat.shape[1], self.num_classes, self.num_classes),
        )
        probs = probs.view(batch_size, x_flat.shape[1], self.num_classes)

        # Sample from categorical distribution
        xt = torch.multinomial(probs.view(-1, self.num_classes), 1).view(batch_size, -1)

        return xt.view(x_start.shape)

    def q_posterior(self, x_start: torch.Tensor, xt: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
        """
        Compute q(xt-1 | xt, x0)

        Returns:
            Log probabilities of q(xt-1 | xt, x0)
        """
        batch_size = x_start.shape[0]

        # Get transition matrices
        if t[0] == 0:
            return F.one_hot(x_start, self.num_classes).float().log()

        Qt = torch.stack([self.Qt[t_i] for t_i in t])
        Qt_bar_prev = torch.stack([self.Qt_bar[t_i - 1] for t_i in t])

        # Convert to one-hot
        x_start_onehot = F.one_hot(x_start, self.num_classes).float()
        xt_onehot = F.one_hot(xt, self.num_classes).float()

        # Compute posterior: q(xt-1|xt,x0) ∝ q(xt|xt-1) * q(xt-1|x0)
        # This is equation (3) from the paper

        # For numerical stability, work in log space
        log_Qt = torch.log(Qt + 1e-8)
        log_Qt_bar_prev = torch.log(Qt_bar_prev + 1e-8)

        # Compute unnormalized log probabilities
        log_probs = []
        for k in range(self.num_classes):
            k_onehot = torch.zeros_like(x_start_onehot)
            k_onehot[..., k] = 1.0

            # q(xt|xt-1=k) * q(xt-1=k|x0)
            term1 = torch.sum(xt_onehot * torch.matmul(k_onehot, Qt), dim=-1)
            term2 = torch.sum(k_onehot * torch.matmul(x_start_onehot, Qt_bar_prev), dim=-1)

            log_prob_k = torch.log(term1 + 1e-8) + torch.log(term2 + 1e-8)
            log_probs.append(log_prob_k.unsqueeze(-1))

        log_probs = torch.cat(log_probs, dim=-1)

        # Normalize
        log_probs = log_probs - torch.logsumexp(log_probs, dim=-1, keepdim=True)

        return log_probs

    def compute_loss(
        self, model: nn.Module, x_start: torch.Tensor, t: Optional[torch.Tensor] = None
    ) -> Dict[str, torch.Tensor]:
        """
        Compute the D3PM loss

        Args:
            model: Neural network that predicts x0 given xt and t
            x_start: Clean data [batch_size, ...]
            t: Timesteps [batch_size]. If None, sample randomly

        Returns:
            Dictionary with loss components
        """
        batch_size = x_start.shape[0]
        device = x_start.device

        if t is None:
            t = torch.randint(0, self.timesteps, (batch_size,), device=device)

        # Sample xt ~ q(xt | x0)
        xt = self.q_sample(x_start, t)

        # Predict x0
        x0_pred_logits = model(xt, t)

        # Convert predictions to probabilities
        x0_pred_probs = F.softmax(x0_pred_logits, dim=-1)

        # Compute reverse process probabilities using x0 parameterization
        # p_theta(xt-1 | xt) = sum_x0 q(xt-1 | xt, x0) * p_theta(x0 | xt)

        # VLB loss components
        losses = {}

        # L0: Direct reconstruction loss at t=1
        mask_t1 = (t == 1).float()
        if mask_t1.sum() > 0:
            x_start_onehot = F.one_hot(x_start, self.num_classes).float()
            l0 = -torch.sum(x_start_onehot * torch.log(x0_pred_probs + 1e-8), dim=-1)
            losses["l0"] = (mask_t1 * l0.mean(dim=tuple(range(1, len(l0.shape))))).sum() / (mask_t1.sum() + 1e-8)
        else:
            losses["l0"] = torch.tensor(0.0, device=device)

        # Lt-1: KL divergence terms for t > 1
        mask_t_gt1 = (t > 1).float()
        if mask_t_gt1.sum() > 0:
            # Compute q(xt-1 | xt, x0)
            q_posterior_probs = torch.exp(self.q_posterior(x_start, xt, t))

            # Compute p_theta(xt-1 | xt) using x0 parameterization
            p_theta_probs = self._compute_reverse_probs(xt, x0_pred_probs, t)

            # KL divergence
            kl = torch.sum(q_posterior_probs * torch.log(q_posterior_probs / (p_theta_probs + 1e-8) + 1e-8), dim=-1)
            losses["lt"] = (mask_t_gt1 * kl.mean(dim=tuple(range(1, len(kl.shape))))).sum() / (mask_t_gt1.sum() + 1e-8)
        else:
            losses["lt"] = torch.tensor(0.0, device=device)

        # LT: Prior matching (should be small if T is large enough)
        mask_tT = (t == self.timesteps - 1).float()
        if mask_tT.sum() > 0:
            # Assuming uniform prior
            uniform_prior = torch.ones_like(x0_pred_probs) / self.num_classes
            xt_onehot = F.one_hot(xt, self.num_classes).float()
            prior_kl = torch.sum(xt_onehot * torch.log(xt_onehot / uniform_prior + 1e-8), dim=-1)
            losses["lT"] = (mask_tT * prior_kl.mean(dim=tuple(range(1, len(prior_kl.shape))))).sum() / (
                mask_tT.sum() + 1e-8
            )
        else:
            losses["lT"] = torch.tensor(0.0, device=device)

        # Auxiliary loss: direct x0 prediction
        x_start_onehot = F.one_hot(x_start, self.num_classes).float()
        aux_loss = -torch.sum(x_start_onehot * torch.log(x0_pred_probs + 1e-8), dim=-1)
        losses["aux"] = aux_loss.mean()

        # Total VLB loss
        losses["vlb"] = losses["l0"] + losses["lt"] + losses["lT"]

        # Hybrid loss (L_lambda from paper)
        losses["total"] = losses["vlb"] + self.hybrid_loss_coeff * losses["aux"]

        return losses

    def _compute_reverse_probs(self, xt: torch.Tensor, x0_pred_probs: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
        """
        Compute p_theta(xt-1 | xt) using x0 parameterization
        """
        batch_size = xt.shape[0]

        # Sum over all possible x0 values
        reverse_probs = torch.zeros(xt.shape + (self.num_classes,), device=xt.device)

        for x0_val in range(self.num_classes):
            # Create x0 with value x0_val everywhere
            x0_candidate = torch.full_like(xt, x0_val)

            # Get q(xt-1 | xt, x0_candidate)
            if t[0] > 0:
                q_prob = torch.exp(self.q_posterior(x0_candidate, xt, t))

                # Weight by p_theta(x0_candidate | xt)
                x0_prob = x0_pred_probs[..., x0_val : x0_val + 1]

                reverse_probs += q_prob * x0_prob

        return reverse_probs

    @torch.no_grad()
    def p_sample(self, model: nn.Module, xt: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
        """
        Sample from p_theta(xt-1 | xt)
        """
        # Predict x0
        x0_pred_logits = model(xt, t)
        x0_pred_probs = F.softmax(x0_pred_logits, dim=-1)

        if t[0] == 0:
            # At t=0, return the predicted x0
            return torch.multinomial(x0_pred_probs.view(-1, self.num_classes), 1).view(xt.shape)

        # Compute reverse probabilities
        reverse_probs = self._compute_reverse_probs(xt, x0_pred_probs, t)

        # Sample from the reverse distribution
        xt_prev = torch.multinomial(reverse_probs.view(-1, self.num_classes), 1).view(xt.shape)

        return xt_prev

    @torch.no_grad()
    def p_sample_loop(self, model: nn.Module, shape: Tuple[int, ...], device: str = None) -> torch.Tensor:
        """
        Generate samples by running the reverse diffusion process
        """
        if device is None:
            device = self.device

        batch_size = shape[0]

        # Start from noise (uniform random or mask tokens)
        if self.transition_type == TransitionType.ABSORBING:
            xt = torch.full(shape, self.mask_token_id, device=device, dtype=torch.long)
        else:
            xt = torch.randint(0, self.num_classes, shape, device=device, dtype=torch.long)

        # Reverse diffusion
        for i in reversed(range(self.timesteps)):
            t = torch.full((batch_size,), i, device=device, dtype=torch.long)
            xt = self.p_sample(model, xt, t)

        return xt

    @torch.no_grad()
    def ddim_sample(
        self, model: nn.Module, shape: Tuple[int, ...], eta: float = 0.0, ddim_steps: int = 50, device: str = None
    ) -> torch.Tensor:
        """
        DDIM-style deterministic sampling (adapted for discrete case)
        """
        if device is None:
            device = self.device

        batch_size = shape[0]

        # Create subsequence of timesteps
        step_size = self.timesteps // ddim_steps
        timesteps = list(range(0, self.timesteps, step_size))

        # Start from noise
        if self.transition_type == TransitionType.ABSORBING:
            xt = torch.full(shape, self.mask_token_id, device=device, dtype=torch.long)
        else:
            xt = torch.randint(0, self.num_classes, shape, device=device, dtype=torch.long)

        # Reverse diffusion with larger steps
        for i, t_cur in enumerate(reversed(timesteps)):
            t = torch.full((batch_size,), t_cur, device=device, dtype=torch.long)

            # Predict x0
            x0_pred_logits = model(xt, t)
            x0_pred_probs = F.softmax(x0_pred_logits, dim=-1)

            if i == len(timesteps) - 1:
                # Last step: return predicted x0
                xt = torch.multinomial(x0_pred_probs.view(-1, self.num_classes), 1).view(xt.shape)
            else:
                # Intermediate step: use DDIM-like update
                t_prev = timesteps[len(timesteps) - i - 2] if i < len(timesteps) - 1 else 0

                # For simplicity, use probabilistic sampling
                # A more sophisticated DDIM adaptation could be implemented
                reverse_probs = self._compute_reverse_probs(xt, x0_pred_probs, t)
                xt = torch.multinomial(reverse_probs.view(-1, self.num_classes), 1).view(xt.shape)

        return xt


# Example usage and testing
if __name__ == "__main__":
    # Create a simple model for testing
    class SimpleModel(nn.Module):
        def __init__(self, num_classes: int, hidden_dim: int = 256):
            super().__init__()
            self.time_embed = nn.Embedding(1000, hidden_dim)
            self.input_embed = nn.Embedding(num_classes, hidden_dim)
            self.net = nn.Sequential(
                nn.Linear(hidden_dim * 2, hidden_dim),
                nn.ReLU(),
                nn.Linear(hidden_dim, hidden_dim),
                nn.ReLU(),
                nn.Linear(hidden_dim, num_classes),
            )

        def forward(self, x, t):
            # x: [batch_size, seq_len], t: [batch_size]
            x_embed = self.input_embed(x)  # [batch_size, seq_len, hidden_dim]
            t_embed = self.time_embed(t).unsqueeze(1)  # [batch_size, 1, hidden_dim]

            # Broadcast time embedding
            t_embed = t_embed.expand(-1, x.shape[1], -1)

            # Concatenate and predict
            combined = torch.cat([x_embed, t_embed], dim=-1)
            logits = self.net(combined)

            return logits

    # Test the D3PM implementation
    device = "cuda" if torch.cuda.is_available() else "cpu"

    # Initialize D3PM with absorbing diffusion
    d3pm = D3PM(
        num_classes=1000,
        timesteps=100,
        transition_type=TransitionType.ABSORBING,
        noise_schedule=NoiseSchedule.COSINE,
        device=device,
    ).to(device)

    # Create a simple model
    model = SimpleModel(num_classes=1000).to(device)

    # Test data
    batch_size, seq_len = 4, 32
    x_start = torch.randint(0, 1000, (batch_size, seq_len), device=device)

    # Test forward process
    t = torch.randint(0, 100, (batch_size,), device=device)
    xt = d3pm.q_sample(x_start, t)
    print(f"Original shape: {x_start.shape}, Noisy shape: {xt.shape}")

    # Test loss computation
    losses = d3pm.compute_loss(model, x_start, t)
    print(f"Losses: {[(k, v.item()) for k, v in losses.items()]}")

    # Test sampling
    samples = d3pm.p_sample_loop(model, (2, seq_len), device=device)
    print(f"Generated samples shape: {samples.shape}")

    # Test DDIM sampling
    ddim_samples = d3pm.ddim_sample(model, (2, seq_len), ddim_steps=10, device=device)
    print(f"DDIM samples shape: {ddim_samples.shape}")

    print("D3PM implementation test completed successfully!")

RuntimeError: view size is not compatible with input tensor's size and stride (at least one dimension spans across two contiguous subspaces). Use .reshape(...) instead.