Skip to content

leemengtw/Diffusion-Models-pytorch

 
 

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

11 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Diffusion Models

This is an easy-to-understand implementation of diffusion models within 100 lines of code. Different from other implementations, this code doesn't use the lower-bound formulation for sampling and strictly follows Algorithm 1 from the DDPM paper, which makes it extremely short and easy to follow. There are two implementations: conditional and unconditional. Furthermore, the conditional code also implements Classifier-Free-Guidance (CFG) and Exponential-Moving-Average (EMA). Below you can find two explanation videos for the theory behind diffusion models and the implementation.

Qries Qries

Train a Diffusion Model on your own data:

Unconditional Training

  1. (optional) Configure Hyperparameters in ddpm.py
  2. Set path to dataset in ddpm.py
  3. python ddpm.py

Conditional Training

  1. (optional) Configure Hyperparameters in ddpm_conditional.py
  2. Set path to dataset in ddpm_conditional.py
  3. python ddpm_conditional.py

Sampling

The following examples show how to sample images using the models trained in the video. You can download the checkpoints for the models here.

Unconditional Model

    device = "cuda"
    model = UNet().to(device)
    ckpt = torch.load("unconditional_ckpt.pt")
    model.load_state_dict(ckpt)
    diffusion = Diffusion(img_size=64, device=device)
    x = diffusion.sample(model, n=16)
    plot_images(x)

Conditional Model

This model was trained on CIFAR-10 64x64 with 10 classes airplane:0, auto:1, bird:2, cat:3, deer:4, dog:5, frog:6, horse:7, ship:8, truck:9

    n = 10
    device = "cuda"
    model = UNet_conditional(num_classes=10).to(device)
    ckpt = torch.load("conditional_ema_ckpt.pt")
    model.load_state_dict(ckpt)
    diffusion = Diffusion(img_size=64, device=device)
    y = torch.Tensor([6] * n).long().to(device)
    x = diffusion.sample(model, n, y, cfg_scale=3)
    plot_images(x)

About

Pytorch implementation of Diffusion Models (https://arxiv.org/pdf/2006.11239.pdf)

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages

  • Python 100.0%