In [3]:
import os
import torch
import certifi
from torchvision import datasets, transforms

os.environ["SSL_CERT_FILE"] = certifi.where()

transform = transforms.Compose([transforms.ToTensor()])
dataset = datasets.CIFAR10(
    root="./data",
    train=True,
    transform=transform,
    download=True
)

100%|██████████| 170M/170M [00:11<00:00, 14.2MB/s] 


In [None]:
from typing import Optional
from dataclasses import dataclass


@dataclass
class DiffusionConfig:
    num_timesteps: int = 1000
    num_parallel_steps: int = 10
    beta_start: float = 0.0001
    beta_end: float = 0.02
    beta_schedule: str = "linear"


@dataclass
class DiffusionModelOutput:
    loss: Optional[torch.Tensor] = None


class DiffusionModel(torch.nn.Module):
    def __init__(self, config: DiffusionConfig):
        self.config = config

        self._batch_forward_diffusion = torch.vmap(
            self._forward_diffusion,
            in_dims=(0, 0)
        )
        self._batch_reverse_diffusion = torch.vmap(
            self._reverse_diffusion,
            in_dims=(0, 0)
        )

    @staticmethod
    def _forward_diffusion(x0, t):
        """Forward diffusion process.

        Args:
            x0 (torch.Tensor): Input images. Shape: (B, C, H, W).
            t (torch.Tensor): Timesteps. Shape: (B,).
        """
        # add noise to obtain x_t
        pass


    @staticmethod
    def _reverse_diffusion(xT, t):
        """Decoder (UNet)

        Args:
            xt (torch.Tensor): Noisy images. Shape: (T, B, C, H, W).

        Returns:
            torch.Tensor: Predicted images. Shape: (T, B, C, H, W).
        """
        pass

    def forward(self, x: torch.Tensor, y: Optional[torch.Tensor] = None, num_steps: int = 1000):
        """Forward pass of the diffusion model.

        Args:
            x (torch.Tensor): Input images. Shape: (B, C, H, W).
            y (torch.Tensor): Labels. Shape: (B,).
        """
        
        # Make inputs
        T = self.config.num_timesteps
        xt = x.unsqueeze(0).repeat(T, 1, 1, 1, 1)  
        t = torch.randint(0, T, (x.shape[0],), device=x.device)
        xt = self._batch_forward_diffusion(xt, t) # (T, B, C, H, W)


        # Make targets
        tm1 = t - 1
        xtm1 = self._batch_forward_diffusion(xt, tm1) # (T, B, C, H, W)


        # Get predictions
        xt_hat = self._batch_reverse_diffusion(self.pos_embed(xtm1), tm1)  

        return DiffusionModelOutput(loss=None)