-
Notifications
You must be signed in to change notification settings - Fork 952
/
README.md
356 lines (290 loc) · 11.2 KB
/
README.md
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
<img src="./images/denoising-diffusion.png" width="500px"></img>
## Denoising Diffusion Probabilistic Model, in Pytorch
Implementation of <a href="https://arxiv.org/abs/2006.11239">Denoising Diffusion Probabilistic Model</a> in Pytorch. It is a new approach to generative modeling that may <a href="https://ajolicoeur.wordpress.com/the-new-contender-to-gans-score-matching-with-langevin-sampling/">have the potential</a> to rival GANs. It uses denoising score matching to estimate the gradient of the data distribution, followed by Langevin sampling to sample from the true distribution.
This implementation was inspired by the official Tensorflow version <a href="https://github.com/hojonathanho/diffusion">here</a>
Youtube AI Educators - <a href="https://www.youtube.com/watch?v=W-O7AZNzbzQ">Yannic Kilcher</a> | <a href="https://www.youtube.com/watch?v=344w5h24-h8">AI Coffeebreak with Letitia</a> | <a href="https://www.youtube.com/watch?v=HoKDTa5jHvg">Outlier</a>
<a href="https://github.com/yiyixuxu/denoising-diffusion-flax">Flax implementation</a> from <a href="https://github.com/yiyixuxu">YiYi Xu</a>
<a href="https://huggingface.co/blog/annotated-diffusion">Annotated code</a> by Research Scientists / Engineers from <a href="https://huggingface.co/">🤗 Huggingface</a>
Update: Turns out none of the technicalities really matters at all | <a href="https://arxiv.org/abs/2208.09392">"Cold Diffusion" paper</a> | <a href="https://muse-model.github.io/">Muse</a>
<img src="./images/sample.png" width="500px"><img>
[![PyPI version](https://badge.fury.io/py/denoising-diffusion-pytorch.svg)](https://badge.fury.io/py/denoising-diffusion-pytorch)
## Install
```bash
$ pip install denoising_diffusion_pytorch
```
## Usage
```python
import torch
from denoising_diffusion_pytorch import Unet, GaussianDiffusion
model = Unet(
dim = 64,
dim_mults = (1, 2, 4, 8),
flash_attn = True
)
diffusion = GaussianDiffusion(
model,
image_size = 128,
timesteps = 1000 # number of steps
)
training_images = torch.rand(8, 3, 128, 128) # images are normalized from 0 to 1
loss = diffusion(training_images)
loss.backward()
# after a lot of training
sampled_images = diffusion.sample(batch_size = 4)
sampled_images.shape # (4, 3, 128, 128)
```
Or, if you simply want to pass in a folder name and the desired image dimensions, you can use the `Trainer` class to easily train a model.
```python
from denoising_diffusion_pytorch import Unet, GaussianDiffusion, Trainer
model = Unet(
dim = 64,
dim_mults = (1, 2, 4, 8),
flash_attn = True
)
diffusion = GaussianDiffusion(
model,
image_size = 128,
timesteps = 1000, # number of steps
sampling_timesteps = 250 # number of sampling timesteps (using ddim for faster inference [see citation for ddim paper])
)
trainer = Trainer(
diffusion,
'path/to/your/images',
train_batch_size = 32,
train_lr = 8e-5,
train_num_steps = 700000, # total training steps
gradient_accumulate_every = 2, # gradient accumulation steps
ema_decay = 0.995, # exponential moving average decay
amp = True, # turn on mixed precision
calculate_fid = True # whether to calculate fid during training
)
trainer.train()
```
Samples and model checkpoints will be logged to `./results` periodically
## Multi-GPU Training
The `Trainer` class is now equipped with <a href="https://huggingface.co/docs/accelerate/accelerator">🤗 Accelerator</a>. You can easily do multi-gpu training in two steps using their `accelerate` CLI
At the project root directory, where the training script is, run
```python
$ accelerate config
```
Then, in the same directory
```python
$ accelerate launch train.py
```
## Miscellaneous
### 1D Sequence
By popular request, a 1D Unet + Gaussian Diffusion implementation.
```python
import torch
from denoising_diffusion_pytorch import Unet1D, GaussianDiffusion1D, Trainer1D, Dataset1D
model = Unet1D(
dim = 64,
dim_mults = (1, 2, 4, 8),
channels = 32
)
diffusion = GaussianDiffusion1D(
model,
seq_length = 128,
timesteps = 1000,
objective = 'pred_v'
)
training_seq = torch.rand(64, 32, 128) # features are normalized from 0 to 1
dataset = Dataset1D(training_seq) # this is just an example, but you can formulate your own Dataset and pass it into the `Trainer1D` below
loss = diffusion(training_seq)
loss.backward()
# Or using trainer
trainer = Trainer1D(
diffusion,
dataset = dataset,
train_batch_size = 32,
train_lr = 8e-5,
train_num_steps = 700000, # total training steps
gradient_accumulate_every = 2, # gradient accumulation steps
ema_decay = 0.995, # exponential moving average decay
amp = True, # turn on mixed precision
)
trainer.train()
# after a lot of training
sampled_seq = diffusion.sample(batch_size = 4)
sampled_seq.shape # (4, 32, 128)
```
`Trainer1D` does not evaluate the generated samples in any way since the type of data is not known.
You could consider adding a suitable metric to the training loop yourself after doing an editable install of this package
`pip install -e .`.
## Citations
```bibtex
@inproceedings{NEURIPS2020_4c5bcfec,
author = {Ho, Jonathan and Jain, Ajay and Abbeel, Pieter},
booktitle = {Advances in Neural Information Processing Systems},
editor = {H. Larochelle and M. Ranzato and R. Hadsell and M.F. Balcan and H. Lin},
pages = {6840--6851},
publisher = {Curran Associates, Inc.},
title = {Denoising Diffusion Probabilistic Models},
url = {https://proceedings.neurips.cc/paper/2020/file/4c5bcfec8584af0d967f1ab10179ca4b-Paper.pdf},
volume = {33},
year = {2020}
}
```
```bibtex
@InProceedings{pmlr-v139-nichol21a,
title = {Improved Denoising Diffusion Probabilistic Models},
author = {Nichol, Alexander Quinn and Dhariwal, Prafulla},
booktitle = {Proceedings of the 38th International Conference on Machine Learning},
pages = {8162--8171},
year = {2021},
editor = {Meila, Marina and Zhang, Tong},
volume = {139},
series = {Proceedings of Machine Learning Research},
month = {18--24 Jul},
publisher = {PMLR},
pdf = {http://proceedings.mlr.press/v139/nichol21a/nichol21a.pdf},
url = {https://proceedings.mlr.press/v139/nichol21a.html},
}
```
```bibtex
@inproceedings{kingma2021on,
title = {On Density Estimation with Diffusion Models},
author = {Diederik P Kingma and Tim Salimans and Ben Poole and Jonathan Ho},
booktitle = {Advances in Neural Information Processing Systems},
editor = {A. Beygelzimer and Y. Dauphin and P. Liang and J. Wortman Vaughan},
year = {2021},
url = {https://openreview.net/forum?id=2LdBqxc1Yv}
}
```
```bibtex
@article{Karras2022ElucidatingTD,
title = {Elucidating the Design Space of Diffusion-Based Generative Models},
author = {Tero Karras and Miika Aittala and Timo Aila and Samuli Laine},
journal = {ArXiv},
year = {2022},
volume = {abs/2206.00364}
}
```
```bibtex
@article{Song2021DenoisingDI,
title = {Denoising Diffusion Implicit Models},
author = {Jiaming Song and Chenlin Meng and Stefano Ermon},
journal = {ArXiv},
year = {2021},
volume = {abs/2010.02502}
}
```
```bibtex
@misc{chen2022analog,
title = {Analog Bits: Generating Discrete Data using Diffusion Models with Self-Conditioning},
author = {Ting Chen and Ruixiang Zhang and Geoffrey Hinton},
year = {2022},
eprint = {2208.04202},
archivePrefix = {arXiv},
primaryClass = {cs.CV}
}
```
```bibtex
@article{Salimans2022ProgressiveDF,
title = {Progressive Distillation for Fast Sampling of Diffusion Models},
author = {Tim Salimans and Jonathan Ho},
journal = {ArXiv},
year = {2022},
volume = {abs/2202.00512}
}
```
```bibtex
@article{Ho2022ClassifierFreeDG,
title = {Classifier-Free Diffusion Guidance},
author = {Jonathan Ho},
journal = {ArXiv},
year = {2022},
volume = {abs/2207.12598}
}
```
```bibtex
@article{Sunkara2022NoMS,
title = {No More Strided Convolutions or Pooling: A New CNN Building Block for Low-Resolution Images and Small Objects},
author = {Raja Sunkara and Tie Luo},
journal = {ArXiv},
year = {2022},
volume = {abs/2208.03641}
}
```
```bibtex
@inproceedings{Jabri2022ScalableAC,
title = {Scalable Adaptive Computation for Iterative Generation},
author = {A. Jabri and David J. Fleet and Ting Chen},
year = {2022}
}
```
```bibtex
@article{Cheng2022DPMSolverPlusPlus,
title = {DPM-Solver++: Fast Solver for Guided Sampling of Diffusion Probabilistic Models},
author = {Cheng Lu and Yuhao Zhou and Fan Bao and Jianfei Chen and Chongxuan Li and Jun Zhu},
journal = {NeuRips 2022 Oral},
year = {2022},
volume = {abs/2211.01095}
}
```
```bibtex
@inproceedings{Hoogeboom2023simpleDE,
title = {simple diffusion: End-to-end diffusion for high resolution images},
author = {Emiel Hoogeboom and Jonathan Heek and Tim Salimans},
year = {2023}
}
```
```bibtex
@misc{https://doi.org/10.48550/arxiv.2302.01327,
doi = {10.48550/ARXIV.2302.01327},
url = {https://arxiv.org/abs/2302.01327},
author = {Kumar, Manoj and Dehghani, Mostafa and Houlsby, Neil},
title = {Dual PatchNorm},
publisher = {arXiv},
year = {2023},
copyright = {Creative Commons Attribution 4.0 International}
}
```
```bibtex
@inproceedings{Hang2023EfficientDT,
title = {Efficient Diffusion Training via Min-SNR Weighting Strategy},
author = {Tiankai Hang and Shuyang Gu and Chen Li and Jianmin Bao and Dong Chen and Han Hu and Xin Geng and Baining Guo},
year = {2023}
}
```
```bibtex
@misc{Guttenberg2023,
author = {Nicholas Guttenberg},
url = {https://www.crosslabs.org/blog/diffusion-with-offset-noise}
}
```
```bibtex
@inproceedings{Lin2023CommonDN,
title = {Common Diffusion Noise Schedules and Sample Steps are Flawed},
author = {Shanchuan Lin and Bingchen Liu and Jiashi Li and Xiao Yang},
year = {2023}
}
```
```bibtex
@inproceedings{dao2022flashattention,
title = {Flash{A}ttention: Fast and Memory-Efficient Exact Attention with {IO}-Awareness},
author = {Dao, Tri and Fu, Daniel Y. and Ermon, Stefano and Rudra, Atri and R{\'e}, Christopher},
booktitle = {Advances in Neural Information Processing Systems},
year = {2022}
}
```
```bibtex
@article{Bondarenko2023QuantizableTR,
title = {Quantizable Transformers: Removing Outliers by Helping Attention Heads Do Nothing},
author = {Yelysei Bondarenko and Markus Nagel and Tijmen Blankevoort},
journal = {ArXiv},
year = {2023},
volume = {abs/2306.12929},
url = {https://api.semanticscholar.org/CorpusID:259224568}
}
```
```bibtex
@article{Karras2023AnalyzingAI,
title = {Analyzing and Improving the Training Dynamics of Diffusion Models},
author = {Tero Karras and Miika Aittala and Jaakko Lehtinen and Janne Hellsten and Timo Aila and Samuli Laine},
journal = {ArXiv},
year = {2023},
volume = {abs/2312.02696},
url = {https://api.semanticscholar.org/CorpusID:265659032}
}
```