Skip to content

kambshu/diffusion_models

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

1 Commit
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

ddpm_pt — PyTorch DDPM (Ho et al., 2020) + trajectory visualization

This repo provides:

  • DDPM U-Net (time embedding, ResBlocks, optional attention)
  • Gaussian diffusion utilities (q, p, sampling)
  • Training script with AMP + EMA
  • Sampling script
  • Trajectory visualization (forward noising + reverse denoising) to MP4/GIF

Design choice: torch is installed explicitly (not in pyproject.toml)

We intentionally DO NOT list torch, torchvision, or torchaudio in pyproject.toml.

Reason: CUDA wheels for PyTorch are platform- and GPU-sensitive (and for very new GPUs, you may need a specific CUDA index and/or nightly wheels). Installing torch-related packages explicitly from the correct PyTorch index URL avoids dependency resolver issues and version mismatches.

uv.lock will pin only the non-torch Python dependencies.


1) Prerequisites

  • Python 3.11 recommended (strongest PyTorch wheel coverage)
  • NVIDIA driver installed and working (nvidia-smi should run)
  • uv installed (see uv docs)

2) Create the environment (uv venv)

From the repo root:

uv venv --python 3.11
UV_HTTP_TIMEOUT=1000 uv pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu128

3) Train the model (uv run)

From the repo root:

uv run python tools/train.py --config configs/cifar10.yaml --workdir runs/cifar10

About

No description, website, or topics provided.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors

Languages