This repo provides:
- DDPM U-Net (time embedding, ResBlocks, optional attention)
- Gaussian diffusion utilities (q, p, sampling)
- Training script with AMP + EMA
- Sampling script
- Trajectory visualization (forward noising + reverse denoising) to MP4/GIF
We intentionally DO NOT list torch, torchvision, or torchaudio in pyproject.toml.
Reason: CUDA wheels for PyTorch are platform- and GPU-sensitive (and for very new GPUs, you may need a specific CUDA index and/or nightly wheels). Installing torch-related packages explicitly from the correct PyTorch index URL avoids dependency resolver issues and version mismatches.
uv.lock will pin only the non-torch Python dependencies.
- Python 3.11 recommended (strongest PyTorch wheel coverage)
- NVIDIA driver installed and working (
nvidia-smishould run) uvinstalled (see uv docs)
From the repo root:
uv venv --python 3.11
UV_HTTP_TIMEOUT=1000 uv pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu128From the repo root:
uv run python tools/train.py --config configs/cifar10.yaml --workdir runs/cifar10