Skip to content

Commit d04051e

Browse files
committed
Merge master
2 parents 6292107 + f39020b commit d04051e

File tree

12 files changed

+479
-220
lines changed

12 files changed

+479
-220
lines changed

README.md

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ import numpy as np
3434

3535
generator = torch.Generator()
3636
generator = generator.manual_seed(6694729458485568)
37+
torch_device = "cuda" if torch.cuda.is_available() else "cpu"
3738

3839
# 1. Load models
3940
scheduler = GaussianDDPMScheduler.from_config("fusing/ddpm-lsun-church")
@@ -45,20 +46,20 @@ image = scheduler.sample_noise((1, model.in_channels, model.resolution, model.re
4546
# 3. Denoise
4647
for t in reversed(range(len(scheduler))):
4748
# i) define coefficients for time step t
48-
clip_image_coeff = 1 / torch.sqrt(scheduler.get_alpha_prod(t))
49-
clip_noise_coeff = torch.sqrt(1 / scheduler.get_alpha_prod(t) - 1)
49+
clipped_image_coeff = 1 / torch.sqrt(scheduler.get_alpha_prod(t))
50+
clipped_noise_coeff = torch.sqrt(1 / scheduler.get_alpha_prod(t) - 1)
5051
image_coeff = (1 - scheduler.get_alpha_prod(t - 1)) * torch.sqrt(scheduler.get_alpha(t)) / (1 - scheduler.get_alpha_prod(t))
51-
clip_coeff = torch.sqrt(scheduler.get_alpha_prod(t - 1)) * scheduler.get_beta(t) / (1 - scheduler.get_alpha_prod(t))
52+
clipped_coeff = torch.sqrt(scheduler.get_alpha_prod(t - 1)) * scheduler.get_beta(t) / (1 - scheduler.get_alpha_prod(t))
5253

5354
# ii) predict noise residual
5455
with torch.no_grad():
5556
noise_residual = model(image, t)
5657

5758
# iii) compute predicted image from residual
5859
# See 2nd formula at https://github.com/hojonathanho/diffusion/issues/5#issue-896554416 for comparison
59-
pred_mean = clip_image_coeff * image - clip_noise_coeff * noise_residual
60+
pred_mean = clipped_image_coeff * image - clipped_noise_coeff * noise_residual
6061
pred_mean = torch.clamp(pred_mean, -1, 1)
61-
prev_image = clip_coeff * pred_mean + image_coeff * image
62+
prev_image = clipped_coeff * pred_mean + image_coeff * image
6263

6364
# iv) sample variance
6465
prev_variance = scheduler.sample_variance(t, prev_image.shape, device=torch_device, generator=generator)
@@ -83,12 +84,12 @@ image_pil.save("test.png")
8384
Example:
8485

8586
```python
86-
from modeling_ddpm import DDPM
87+
from diffusers import DiffusionPipeline
8788
import PIL.Image
8889
import numpy as np
8990

9091
# load model and scheduler
91-
ddpm = DDPM.from_pretrained("fusing/ddpm-lsun-bedroom-pipe")
92+
ddpm = DiffusionPipeline.from_pretrained("fusing/ddpm-lsun-bedroom")
9293

9394
# run pipeline in inference (sample random noise and denoise)
9495
image = ddpm()

examples/sample_loop.py

Lines changed: 0 additions & 157 deletions
This file was deleted.

models/vision/ddpm/example.py

Lines changed: 17 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,23 @@
11
#!/usr/bin/env python3
2-
import tempfile
3-
import sys
4-
2+
import os
3+
import pathlib
54
from modeling_ddpm import DDPM
5+
import PIL.Image
6+
import numpy as np
67

7-
model_id = sys.argv[1]
8+
model_ids = ["ddpm-lsun-cat", "ddpm-lsun-cat-ema", "ddpm-lsun-church-ema", "ddpm-lsun-church", "ddpm-lsun-bedroom", "ddpm-lsun-bedroom-ema", "ddpm-cifar10-ema", "ddpm-cifar10", "ddpm-celeba-hq", "ddpm-celeba-hq-ema"]
89

9-
ddpm = DDPM.from_pretrained(model_id)
10-
image = ddpm()
10+
for model_id in model_ids:
11+
path = os.path.join("/home/patrick/images/hf", model_id)
12+
pathlib.Path(path).mkdir(parents=True, exist_ok=True)
1113

12-
import PIL.Image
13-
import numpy as np
14-
image_processed = image.cpu().permute(0, 2, 3, 1)
15-
image_processed = (image_processed + 1.0) * 127.5
16-
image_processed = image_processed.numpy().astype(np.uint8)
17-
image_pil = PIL.Image.fromarray(image_processed[0])
18-
image_pil.save("test.png")
14+
ddpm = DDPM.from_pretrained("fusing/" + model_id)
15+
image = ddpm(batch_size=4)
16+
17+
image_processed = image.cpu().permute(0, 2, 3, 1)
18+
image_processed = (image_processed + 1.0) * 127.5
19+
image_processed = image_processed.numpy().astype(np.uint8)
1920

20-
import ipdb; ipdb.set_trace()
21+
for i in range(image_processed.shape[0]):
22+
image_pil = PIL.Image.fromarray(image_processed[i])
23+
image_pil.save(os.path.join(path, f"image_{i}.png"))

models/vision/ddpm/modeling_ddpm.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -23,32 +23,33 @@ class DDPM(DiffusionPipeline):
2323

2424
modeling_file = "modeling_ddpm.py"
2525

26-
def __init__(self, unet, noise_scheduler, vqvae):
26+
def __init__(self, unet, noise_scheduler):
2727
super().__init__()
2828
self.register_modules(unet=unet, noise_scheduler=noise_scheduler)
2929

30-
def __call__(self, generator=None, torch_device=None):
31-
torch_device = "cuda" if torch.cuda.is_available() else "cpu"
30+
def __call__(self, batch_size=1, generator=None, torch_device=None):
31+
if torch_device is None:
32+
torch_device = "cuda" if torch.cuda.is_available() else "cpu"
3233

3334
self.unet.to(torch_device)
3435
# 1. Sample gaussian noise
35-
image = self.noise_scheduler.sample_noise((1, self.unet.in_channels, self.unet.resolution, self.unet.resolution), device=torch_device, generator=generator)
36+
image = self.noise_scheduler.sample_noise((batch_size, self.unet.in_channels, self.unet.resolution, self.unet.resolution), device=torch_device, generator=generator)
3637
for t in tqdm.tqdm(reversed(range(len(self.noise_scheduler))), total=len(self.noise_scheduler)):
3738
# i) define coefficients for time step t
38-
clip_image_coeff = 1 / torch.sqrt(self.noise_scheduler.get_alpha_prod(t))
39-
clip_noise_coeff = torch.sqrt(1 / self.noise_scheduler.get_alpha_prod(t) - 1)
39+
clipped_image_coeff = 1 / torch.sqrt(self.noise_scheduler.get_alpha_prod(t))
40+
clipped_noise_coeff = torch.sqrt(1 / self.noise_scheduler.get_alpha_prod(t) - 1)
4041
image_coeff = (1 - self.noise_scheduler.get_alpha_prod(t - 1)) * torch.sqrt(self.noise_scheduler.get_alpha(t)) / (1 - self.noise_scheduler.get_alpha_prod(t))
41-
clip_coeff = torch.sqrt(self.noise_scheduler.get_alpha_prod(t - 1)) * self.noise_scheduler.get_beta(t) / (1 - self.noise_scheduler.get_alpha_prod(t))
42+
clipped_coeff = torch.sqrt(self.noise_scheduler.get_alpha_prod(t - 1)) * self.noise_scheduler.get_beta(t) / (1 - self.noise_scheduler.get_alpha_prod(t))
4243

4344
# ii) predict noise residual
4445
with torch.no_grad():
4546
noise_residual = self.unet(image, t)
4647

4748
# iii) compute predicted image from residual
4849
# See 2nd formula at https://github.com/hojonathanho/diffusion/issues/5#issue-896554416 for comparison
49-
pred_mean = clip_image_coeff * image - clip_noise_coeff * noise_residual
50+
pred_mean = clipped_image_coeff * image - clipped_noise_coeff * noise_residual
5051
pred_mean = torch.clamp(pred_mean, -1, 1)
51-
prev_image = clip_coeff * pred_mean + image_coeff * image
52+
prev_image = clipped_coeff * pred_mean + image_coeff * image
5253

5354
# iv) sample variance
5455
prev_variance = self.noise_scheduler.sample_variance(t, prev_image.shape, device=torch_device, generator=generator)

models/vision/ddpm/test.png

-102 KB
Binary file not shown.

src/diffusers/configuration_utils.py

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,7 @@ def save_config(self, save_directory: Union[str, os.PathLike], push_to_hub: bool
8989

9090
self.to_json_file(output_config_file)
9191
logger.info(f"ConfigMixinuration saved in {output_config_file}")
92+
9293

9394
@classmethod
9495
def get_config_dict(
@@ -182,35 +183,42 @@ def get_config_dict(
182183
logger.info(f"loading configuration file {config_file}")
183184
else:
184185
logger.info(f"loading configuration file {config_file} from cache at {resolved_config_file}")
186+
187+
return config_dict
185188

189+
@classmethod
190+
def extract_init_dict(cls, config_dict, **kwargs):
186191
expected_keys = set(dict(inspect.signature(cls.__init__).parameters).keys())
187192
expected_keys.remove("self")
188-
193+
init_dict = {}
189194
for key in expected_keys:
190195
if key in kwargs:
191196
# overwrite key
192-
config_dict[key] = kwargs.pop(key)
197+
init_dict[key] = kwargs.pop(key)
198+
elif key in config_dict:
199+
# use value from config dict
200+
init_dict[key] = config_dict.pop(key)
193201

194-
passed_keys = set(config_dict.keys())
195-
196-
unused_kwargs = kwargs
197-
for key in passed_keys - expected_keys:
198-
unused_kwargs[key] = config_dict.pop(key)
199202

203+
unused_kwargs = config_dict.update(kwargs)
204+
205+
passed_keys = set(init_dict.keys())
200206
if len(expected_keys - passed_keys) > 0:
201207
logger.warn(
202208
f"{expected_keys - passed_keys} was not found in config. Values will be initialized to default values."
203209
)
204210

205-
return config_dict, unused_kwargs
211+
return init_dict, unused_kwargs
206212

207213
@classmethod
208214
def from_config(cls, pretrained_model_name_or_path: Union[str, os.PathLike], return_unused_kwargs=False, **kwargs):
209-
config_dict, unused_kwargs = cls.get_config_dict(
215+
config_dict = cls.get_config_dict(
210216
pretrained_model_name_or_path=pretrained_model_name_or_path, **kwargs
211217
)
212218

213-
model = cls(**config_dict)
219+
init_dict, unused_kwargs = cls.extract_init_dict(config_dict, **kwargs)
220+
221+
model = cls(**init_dict)
214222

215223
if return_unused_kwargs:
216224
return model, unused_kwargs

0 commit comments

Comments
 (0)