-
Notifications
You must be signed in to change notification settings - Fork 0
/
deisabode_pipeline.py
132 lines (108 loc) · 5.36 KB
/
deisabode_pipeline.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
from typing import List, Optional, Tuple, Union
import torch
from diffusers.utils import randn_tensor
from diffusers.pipelines.pipeline_utils import DiffusionPipeline, ImagePipelineOutput
class DEISABODEPipeline(DiffusionPipeline):
r"""
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
Parameters:
unet ([`UNet2DModel`]): U-Net architecture to denoise the encoded image.
scheduler ([`SchedulerMixin`]):
A scheduler to be used in combination with `unet` to denoise the encoded image. Can be one of
[`DDPMScheduler`], or [`DDIMScheduler`].
"""
def __init__(self, unet, scheduler):
super().__init__()
self.register_modules(unet=unet, scheduler=scheduler)
# @torch.no_grad()
def __call__(
self,
batch_size: int = 1,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
num_inference_steps: int = 1000,
output_type: Optional[str] = "pil",
return_dict: bool = True,
return_intermediates: bool = False,
) -> Union[ImagePipelineOutput, Tuple]:
r"""
Args:
batch_size (`int`, *optional*, defaults to 1):
The number of images to generate.
generator (`torch.Generator`, *optional*):
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
to make generation deterministic.
num_inference_steps (`int`, *optional*, defaults to 1000):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple.
return_intermediates (`bool`, *optional*, defaults to `False`):
Whether or not to return all steps of the reverse diffusion process, i.e. the intermediates
Returns:
[`~pipelines.ImagePipelineOutput`] or `tuple`: [`~pipelines.utils.ImagePipelineOutput`] if `return_dict` is
True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images.
"""
# Sample gaussian noise to begin loop
if isinstance(self.unet.config.sample_size, int):
image_shape = (
batch_size,
self.unet.config.in_channels,
self.unet.config.sample_size,
self.unet.config.sample_size,
)
else:
image_shape = (batch_size, self.unet.config.in_channels, *self.unet.config.sample_size)
if self.device.type == "mps":
# randn does not work reproducibly on mps
image = randn_tensor(image_shape, generator=generator)
image = image.to(self.device)
else:
image = randn_tensor(image_shape, generator=generator, device=self.device)
# set step values
# this also sets Cs values for the sampler
# try doing outside pipeline, otherwise called every batch
# It happens in self.scheduler.set_timesteps()
intermediates = [ ]
earlier_outputs = []
score_abs = []
Ls = []
for idx, t in enumerate(self.progress_bar(self.scheduler.timesteps)):
# for indexing the taus which go 0->1
time_idx = len(self.scheduler.inference_taus) - 1 - idx
# 1. predict noise model_output
with torch.no_grad():
model_output = self.unet(image, t).sample
# 2. compute previous image: x_t -> x_t-1
out = self.scheduler.step(
model_output, earlier_outputs,
time_idx,
image, generator=generator
)
image, earlier_outputs = out.prev_sample, out.earlier_outputs
score_est, L = out.score_est, out.L
abs = torch.abs(score_est).mean().item()
score_abs.append(abs)
Ls.append(L)
if return_intermediates: intermediates.append(image)
# torch.save(
# {
# "score_abs": torch.tensor(score_abs),
# "Ls": torch.stack(Ls).cpu()
# }
# , "lsunchurch_Ls_score_abs.pt")
# print("saved")
if return_intermediates:
intermediates = torch.stack(intermediates, 0)
intermediates = (intermediates / 2 + 0.5).clamp(0, 1)
intermediates = intermediates.cpu().permute(0, 1, 3, 4, 2).detach().numpy()
image = (image / 2 + 0.5).clamp(0, 1)
image = image.cpu().permute(0, 2, 3, 1).detach().numpy()
if output_type == "pil":
image = self.numpy_to_pil(image)
if not return_dict:
return (image, intermediates)
return ImagePipelineOutput(images=image)