This repository implements DDPM from scratch in Pytorch.
Time step 1000 (start of the Reverse Process):
Time step 500 (middle of the Reverse Process):
Time step 1 (final of the Reverse Process):
- Prepare the dataset. This implement takes MINST as an example.
The dataset path looks like this:
./data/MNIST/
└── raw
├── t10k-images-idx3-ubyte
├── t10k-labels-idx1-ubyte
├── train-images-idx3-ubyte
└── train-labels-idx1-ubyte
The corresponding part in the config yaml file ./config/default.yaml
is as follow:
dataset_params:
root_dir: 'data/MNIST/raw'
- Create the conda environment
You can either run the following command:
pip install -r requirements.txt
or the following command:
conda env create -f environment.yaml
After creating the environment, run conda activate ddpm-pytorch
to enter the environment.
- Training
Enter the base directory DDPM-pytorch
. Run python3 ./scripts/train.py
to conduct the training task. You can change the training relevant parameters in the file ./config/default.yaml
.
The training task (31 epochs) takes about 40 minutes on one NVIDIA RTX 4090 24G
. The checkpoint file is available now.
- Inference
For the inference task, you should first set the checkpoint name in ./config/default.yaml
. For example, if the config looks like this:
train_params:
task_name: 'default'
ckpt_name: 'model_31.pt'
Then the checkpoint path should be like:
/DDPM-pytorch/experiment/default/model_31.pt
Then, run python3 ./scripts/inference.py
to conduct the inference task.
The output will be under the folder ./experiment/$(task_name)/samples
folder, containing all the $(task_name)
is the parameter defined in ./config/default.yaml
.
- The origin code of DDPM which is in TensorFlow.
- https://github.com/explainingai-code/DDPM-Pytorch/tree/main