In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader
from torchvision.utils import make_grid


# Simple CNN denoiser
class SimpleDenoiser(nn.Module):
    def __init__(self):
        super().__init__()
        self.in_layer = nn.Sequential(
            nn.Conv2d(2, 32, 3, stride=1, padding=1),
            nn.ReLU(),
            nn.Conv2d(32, 64, 3, stride=1, padding=1),
            nn.ReLU(),
        )
        self.down1 = nn.Sequential(
            nn.Conv2d(64, 128, 3, stride=1, padding=1),
            nn.ReLU(),
            nn.Conv2d(128, 128, 4, stride=2, padding=1),
            nn.ReLU(),
        )
        self.down2 = nn.Sequential(
            nn.Conv2d(128, 256, 3, stride=1, padding=1),
            nn.ReLU(),
            nn.Conv2d(256, 256, 4, stride=2, padding=1),
            nn.ReLU(),
        )
        self.bottleneck = nn.Sequential(
            nn.Linear(256 * 7 * 7, 1024),
            nn.ReLU(),
            nn.Linear(1024, 256 * 7 * 7),
            nn.ReLU(),
        )
        self.up1 = nn.Sequential(
            nn.ConvTranspose2d(256, 128, 3, stride=1, padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(128, 128, 4, stride=2, padding=1),
            nn.ReLU(),
        )
        self.up2 = nn.Sequential(
            nn.ConvTranspose2d(128, 64, 3, stride=1, padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(64, 64, 4, stride=2, padding=1),
            nn.ReLU(),
        )
        self.out_layer = nn.Sequential(
            nn.Conv2d(64, 1, 3, stride=1, padding=1),
        )

    def forward(self, x: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
        t = t.view(-1, 1, 1, 1)
        t = t.repeat(1, 1, x.shape[2], x.shape[3])
        x = torch.cat([x, t], dim=1)
        d0 = self.in_layer(x)
        d1 = self.down1(d0)
        d2 = self.down2(d1)
        x = self.bottleneck(d2.view(d2.shape[0], -1))
        x = x.view(x.shape[0], 256, 7, 7)
        u1 = self.up1(x)
        u2 = self.up2(u1 + d1)
        return self.out_layer(u2 + d0)


In [2]:
from abc import ABC, abstractmethod
from typing import Generic, TypeVar, Generator
import torch
import torch.nn as nn


def is_notimplemented(func):
    return hasattr(func, "__notimplemented__")


def notimplemented(func):
    func.__notimplemented__ = True
    return func


BackwardParams = TypeVar("BackwardParams", bound=dict)


class DiffusionModel(nn.Module, ABC, Generic[BackwardParams]):
    """
    An abstract class for diffusion models.
    Defines the core interface that all diffusion models should implement.

    For any function that have argument `t`, t ranges from 1 to num_steps. (It's a convention for diffusion models)
    """

    def __init__(self, num_steps: int):
        super().__init__()
        self.num_steps = num_steps
        self.device = torch.device("cpu")

    def to(self, device: torch.device):
        self.device = device
        return super().to(device)

    def sample_t(self, batch_size: int) -> torch.Tensor:
        """
        Sample timesteps for training. Uniformly sample from [0, num_steps) by default.
        Override this method to sample from a different distribution.

        Args:
            batch_size: Number of timesteps to sample
            device: Device to put the sampled timesteps on

        Returns:
            Tensor of timesteps
        """
        return torch.randint(1, self.num_steps + 1, (batch_size,), device=self.device, dtype=torch.long)

    def get_batch_t(self, t: int, x: torch.Tensor) -> torch.Tensor:
        """
        Get the batch of timesteps with shape (x.shape[0], 1, 1, ..., 1) so x and the returned t can be broadcasted together.
        """
        return torch.full((x.shape[0],) + (1,) * (len(x.shape) - 1), t, device=self.device, dtype=torch.long)

    @abstractmethod
    def sample_x_T(self, shape: tuple[int, ...]) -> torch.Tensor:
        """
        Sample from p(x_T)
        """
        pass

    @abstractmethod
    def forward_one_step(self, x_t: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
        """
        Forward one step of the diffusion process.
        """
        pass

    @abstractmethod
    def forward_from_x0(self, x_0: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
        """
        Sample from q(x_t | x_0) - the forward diffusion process.

        Args:
            x_0: Initial data
            t: Timesteps

        Returns:
            The noised data x_t
        """
        pass

    def forward_process(self, x_0: torch.Tensor):
        """
        A generator that yields (t, x_t) with t from 0 to num_steps.
        """
        x_t = x_0
        for t in range(self.num_steps):
            yield t, x_t
            x_t = self.forward_from_x0(x_t, self.get_batch_t(t, x_t))

        yield self.num_steps, x_t

    @abstractmethod
    def backward_one_step(self, x_t: torch.Tensor, t: int, **kwargs: BackwardParams) -> tuple[torch.Tensor, dict]:
        """
        Sample from p(x_{t-1} | x_t).
        """
        pass

    @notimplemented
    def backward(self, x_t: torch.Tensor, t2: int, t1: int, **kwargs: BackwardParams) -> tuple[torch.Tensor, dict]:
        """
        Sample from  p(x_{t1} | x_{t2}).

        Optionally implement this method for a more efficient backward pass across multiple timesteps.
        """
        raise NotImplementedError("Backward is not implemented")

    def _backward(self, x_t: torch.Tensor, t2: int, t1: int, **kwargs: BackwardParams) -> tuple[torch.Tensor, dict]:
        """
        Sample from p(x_{t1} | x_{t2}). If backward is not implemented, this method falls back
        to multiple calls to backward_one_step.
        """
        if is_notimplemented(self.backward):
            for t in range(t2, t1, -1):
                x_t, info = self.backward_one_step(x_t, t, **kwargs)
            return x_t, info
        else:
            return self.backward(x_t, t2, t1, **kwargs)

    def backward_process(
        self, x_t: torch.Tensor, **kwargs: BackwardParams
    ) -> Generator[tuple[int, torch.Tensor, dict], None, None]:
        """
        A generator that yields (t, x_{t-1}, info) with t from num_steps to 0.
        """
        info = {}
        for t in range(self.num_steps, 0, -1):
            yield t, x_t, info
            x_t, info = self.backward_one_step(x_t, t, **kwargs)
        yield 0, x_t, info

    def sample(
        self, shape: tuple[int, ...], steps_to_return: int | list[int] = 0, **kwargs: BackwardParams
    ) -> torch.Tensor:
        """
        - Sample from p(x_0) if steps_to_return is 0 (default). Return shape is the same as the shape argument.
        - Sample from p(x_t) if steps_to_return is specified to t. Return shape is the same as the shape argument.
        - Return intermediate steps if steps_to_return is a list of timesteps. Return shape is (len(steps_to_return), *shape).

        Args:
            shape: Shape of samples to generate
            steps_to_return: Specify to return intermediate steps
            **kwargs: Additional sampling arguments

        Returns:
            Generated samples.
        """

        return_one_step = isinstance(steps_to_return, int)

        if isinstance(steps_to_return, int):
            steps_to_return = [steps_to_return]

        x_t = self.sample_x_T(shape)

        last_t = self.num_steps

        result = []
        for t in steps_to_return:
            if t != last_t:
                x_t, info = self._backward(x_t, last_t, t, **kwargs)
                last_t = t
            result.append(x_t)

        if return_one_step:
            return result[0]
        else:
            return torch.stack(result, dim=0)

    @abstractmethod
    def forward(self, x0: torch.Tensor, **kwargs) -> dict:
        """
        Training forward pass. Return a dictionary containing loss and other metrics.

        Args:
            x0: Input data
            **kwargs: Additional forward pass arguments

        Returns:
            Dictionary with required 'loss' field and optional additional metrics
        """
        pass


In [42]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Generator, Tuple, Dict, TypeVar


class GaussianDDPM(DiffusionModel[Dict]):
    def __init__(self, num_steps: int, denoiser: nn.Module, beta_start: float = 1e-4, beta_end: float = 0.02):
        """
        Gaussian Diffusion Model without an explicit scheduler.
        Stores alpha, alpha_bar, and beta tensors in the model.

        Args:
            num_steps: Number of diffusion steps.
            denoiser: The neural network model used for denoising.
            beta_start: Start value for beta schedule.
            beta_end: End value for beta schedule.
        """
        super().__init__(num_steps)
        self.denoiser = denoiser.to(self.device)

        # Define beta schedule (linear schedule)
        beta = torch.linspace(beta_start, beta_end, num_steps)

        # Compute alpha and alpha_bar
        alpha = 1.0 - beta
        alpha_bar = torch.cumprod(alpha, dim=0)

        # please type checker
        self.beta: torch.Tensor
        self.alpha: torch.Tensor
        self.alpha_bar: torch.Tensor

        # pad left so the index is 1-based
        beta = torch.cat([torch.zeros(1), beta], dim=0)
        alpha = torch.cat([torch.ones(1), alpha], dim=0)
        alpha_bar = torch.cat([torch.ones(1), alpha_bar], dim=0)

        # Store tensors as buffers (moved with model but not updated)
        self.register_buffer("beta", beta)
        self.register_buffer("alpha", alpha)
        self.register_buffer("alpha_bar", alpha_bar)

    def extract(self, source: torch.Tensor, t: torch.Tensor, x: torch.Tensor):
        """
        Extract the values from the source tensor at the given timestep.
        """
        return source[t].view([t.shape[0]] + [1] * (len(x.shape) - 1)).to(x.device)

    @torch.inference_mode()
    def sample_x_T(self, shape: Tuple[int, ...]) -> torch.Tensor:
        """Sample from p(x_T), which is standard Gaussian noise."""
        return torch.randn(shape, device=self.device)

    def forward_from_x0(self, x_0: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
        """
        Forward diffusion process q(x_t | x_0).
        This adds noise to x_0 according to the diffusion schedule.

        Args:
            x_0: Initial data (clean input).
            t: Timesteps.

        Returns:
            x_t: Noised data at time t.
        """
        assert (t >= 1).all() and (t <= self.num_steps).all(), f"t must be between 1 and {self.num_steps}, got {t}"
        alpha_bar_t = self.alpha_bar[t].view(-1, 1, 1, 1)  # Reshape for broadcasting
        noise = torch.randn_like(x_0)
        x_t = torch.sqrt(alpha_bar_t) * x_0 + torch.sqrt(1 - alpha_bar_t) * noise
        return x_t, noise

    def forward_one_step(self, x_t: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
        """
        Forward one step in the diffusion process (q(x_{t+1} | x_t)).
        This is usually not explicitly needed in standard DDPM formulations.
        """
        raise NotImplementedError("Not needed for standard DDPM")

    @torch.inference_mode()
    def backward_one_step(self, x_t: torch.Tensor, t: int, **kwargs) -> Tuple[torch.Tensor, Dict]:
        """
        Reverse diffusion step p(x_{t-1} | x_t) using learned denoiser.

        Args:
            x_t: The current noised sample at timestep t.
            t: The current timestep.

        Returns:
            x_{t-1}: The predicted less noisy sample.
            info: Dictionary containing intermediate values.
        """

        assert t >= 1 and t <= self.num_steps, f"t must be between 1 and {self.num_steps}, got {t}"
        t_tensor = torch.full((x_t.shape[0],), t, device=self.device, dtype=torch.long)

        # Predict x_0 from x_t using the denoiser
        eps_pred = self.denoiser(x_t, t_tensor.float() / self.num_steps)

        # Retrieve precomputed values
        alpha_bar = self.extract(self.alpha_bar, t_tensor, x_t)
        beta = self.extract(self.beta, t_tensor, x_t)

        alpha = self.extract(self.alpha, t_tensor, x_t)
        alpha_bar_last = self.extract(self.alpha_bar, t_tensor-1, x_t)

        # Compute the mean of the reverse step
        mu_t = (1 / torch.sqrt(alpha)) * (x_t - ((1-alpha) / torch.sqrt(1 - alpha_bar)) * eps_pred)

        # Add noise for stochastic sampling (except at t=0)
        if t > 0:
            noise = torch.randn_like(x_t) * torch.sqrt(beta)
            x_t_prev = mu_t + noise
        else:
            x_t_prev = mu_t  # No noise at final step

        return x_t_prev, {}

    def get_eps(self, x0: torch.Tensor, xt: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
        alpha_bar = self.extract(self.alpha_bar, t, xt)
        return (xt - torch.sqrt(alpha_bar) * x0) / torch.sqrt(1 - alpha_bar)

    def forward(self, x0: torch.Tensor, **kwargs) -> Dict:
        """
        Training forward pass.

        1. Sample a timestep `t`.
        2. Generate `x_t` using the forward diffusion process.
        3. Predict `x_0` from `x_t` using the denoiser.
        4. Compute loss between `x_0` and the predicted `x_0`.

        Args:
            x0: Input data (clean images or signals).

        Returns:
            Dictionary containing loss and optional metrics.
        """
        batch_size = x0.shape[0]
        t = self.sample_t(batch_size)

        xt, noise = self.forward_from_x0(x0, t)

        eps = self.get_eps(x0, xt, t)

        assert (noise-eps).abs().max() < 1e-5, "noise and eps are not the same"

        eps_pred = self.denoiser(xt, t.float() / self.num_steps)

        loss = F.mse_loss(eps_pred, eps)

        return {"loss": loss, 't': t, 'eps': eps, 'eps_pred': eps_pred}

In [None]:
# Load MNIST dataset
transform = transforms.Compose([transforms.ToTensor()])
train_dataset = torchvision.datasets.MNIST(root="./data", train=True, transform=transform, download=True)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)

# Initialize model
denoiser = SimpleDenoiser()
diffusion = GaussianDDPM(num_steps=500, denoiser=denoiser)

# Training loop
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)
diffusion.to(device)

optimizer = optim.Adam(denoiser.parameters(), lr=5e-4)
num_epochs = 20

for epoch in range(num_epochs):
    total_loss = 0
    for i, batch in enumerate(train_loader):
        x0, _ = batch
        x0 = x0.to(device)

        x0 = x0 * 2 - 1

        optimizer.zero_grad()
        loss_dict = diffusion.forward(x0)
        loss = loss_dict["loss"]
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        if i % 100 == 99:
            print(f"Epoch {epoch + 1}/{num_epochs}, Loss: {total_loss / 100}")
            total_loss = 0

print("Training completed!")

In [44]:
old_diffusion = diffusion
diffusion = GaussianDDPM(num_steps=500, denoiser=denoiser)
diffusion.load_state_dict(old_diffusion.state_dict())
del old_diffusion

In [None]:
def show_images(images, title="Generated Images"):
    grid = make_grid(images, nrow=8, normalize=True, value_range=(-1, 1))
    plt.figure(figsize=(8, 8))
    plt.imshow(grid.permute(1, 2, 0).cpu().numpy(), vmin=-1, vmax=1)
    plt.title(title)
    plt.axis("off")
    plt.show()


# Generate images by denoising from x_T
diffusion.eval()
num_samples = 16
xt = diffusion.sample_x_T((num_samples, 1, 28, 28))

for t_tensor in range(500, 0, -1):  # Reverse diffusion steps
    xt, a = diffusion.backward_one_step(xt, t_tensor)

show_images(xt, "Generated MNIST Samples")

In [None]:
show_images(diffusion.sample(xt.shape, [500, 400, 300, 200, 100, 50, 20, 10, 5, 1])[:, 0])

In [None]:
diffusion.sample((1,1,28,28)).shape