-
Notifications
You must be signed in to change notification settings - Fork 6
/
continuous_time.py
319 lines (285 loc) · 11.4 KB
/
continuous_time.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
import math
from functools import partial
from typing import List, Literal
import torch
from torch import nn
from torch.cuda.amp import autocast
from torch.special import expm1
from tqdm.auto import tqdm
from . import base
def _log(t, eps=1e-20):
return torch.log(t.clamp(min=eps))
def _log_snr_schedule_linear(t: torch.Tensor) -> torch.Tensor:
return -_log(expm1(1e-4 + 10 * (t**2)))[:, None, None, None]
def _log_snr_schedule_cosine(
t: torch.Tensor,
logsnr_min: float = -15,
logsnr_max: float = 15,
) -> torch.Tensor:
t_min = math.atan(math.exp(-0.5 * logsnr_max))
t_max = math.atan(math.exp(-0.5 * logsnr_min))
return -2 * _log(torch.tan(t_min + t * (t_max - t_min)))[:, None, None, None]
def _log_snr_schedule_cosine_shifted(
t: torch.Tensor,
image_d: float,
noise_d: float,
logsnr_min: float = -15,
logsnr_max: float = 15,
) -> torch.Tensor:
log_snr = _log_snr_schedule_cosine(t, logsnr_min=logsnr_min, logsnr_max=logsnr_max)
shift = 2 * math.log(noise_d / image_d)
return log_snr + shift
def _log_snr_schedule_cosine_interpolated(
t: torch.Tensor,
image_d: float,
noise_d_low: float,
noise_d_high: float,
logsnr_min: float = -15,
logsnr_max: float = 15,
) -> torch.Tensor:
logsnr_low = _log_snr_schedule_cosine_shifted(
t, image_d, noise_d_low, logsnr_min, logsnr_max
)
logsnr_high = _log_snr_schedule_cosine_shifted(
t, image_d, noise_d_high, logsnr_min, logsnr_max
)
return t * logsnr_low + (1 - t) * logsnr_high
def _log_snr_to_alpha_sigma(log_snr: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
alpha, sigma = log_snr.sigmoid().sqrt(), (-log_snr).sigmoid().sqrt()
return alpha, sigma
class ContinuousTimeGaussianDiffusion(base.GaussianDiffusion):
"""
Continuous-time Gaussian diffusion
https://arxiv.org/pdf/2107.00630.pdf
"""
def __init__(
self,
model: nn.Module,
prediction_type: Literal["eps", "v", "x_0"] = "eps",
loss_type: Literal["l2", "l1", "huber"] | nn.Module = "l2",
noise_schedule: Literal[
"linear", "cosine", "cosine_shifted", "cosine_interpolated"
] = "cosine",
min_snr_loss_weight: bool = True,
min_snr_gamma: float = 5.0,
sampling_resolution: tuple[int, int] | None = None,
clip_sample: bool = True,
clip_sample_range: float = 1,
image_d: float = None,
noise_d_low: float = None,
noise_d_high: float = None,
):
super().__init__(
model=model,
sampling="ddpm",
prediction_type=prediction_type,
loss_type=loss_type,
num_training_steps=None,
noise_schedule=noise_schedule,
min_snr_loss_weight=min_snr_loss_weight,
min_snr_gamma=min_snr_gamma,
sampling_resolution=sampling_resolution,
clip_sample=clip_sample,
clip_sample_range=clip_sample_range,
)
self.image_d = image_d
self.noise_d_low = noise_d_low
self.noise_d_high = noise_d_high
def setup_parameters(self) -> None:
if self.noise_schedule == "linear":
self.log_snr = _log_snr_schedule_linear
elif self.noise_schedule == "cosine":
self.log_snr = _log_snr_schedule_cosine
elif self.noise_schedule == "cosine_shifted":
assert self.image_d is not None and self.noise_d_low is not None
self.log_snr = partial(
_log_snr_schedule_cosine_shifted,
image_d=self.image_d,
noise_d=self.noise_d_low,
)
elif self.noise_schedule == "cosine_interpolated":
assert (
self.image_d is not None
and self.noise_d_low is not None
and self.noise_d_high is not None
)
self.log_snr = partial(
_log_snr_schedule_cosine_interpolated,
image_d=self.image_d,
noise_d_low=self.noise_d_low,
noise_d_high=self.noise_d_high,
)
else:
raise ValueError(f"invalid beta schedule: {self.noise_schedule}")
def sample_timesteps(self, batch_size: int, device: torch.device) -> torch.Tensor:
# continuous timesteps
return torch.rand(batch_size, device=device, dtype=torch.float32)
def get_network_condition(self, steps):
return self.log_snr(steps)[:, 0, 0, 0]
def get_target(self, x_0, step_t, noise):
if self.objective == "eps":
target = noise
elif self.objective == "x_0":
target = x_0
elif self.objective == "v":
log_snr = self.log_snr(step_t)
alpha, sigma = _log_snr_to_alpha_sigma(log_snr)
target = alpha * noise - sigma * x_0
else:
raise ValueError(f"invalid objective {self.objective}")
return target
def get_loss_weight(self, steps):
log_snr = self.log_snr(steps)
snr = log_snr.exp()
clipped_snr = snr.clone()
if self.min_snr_loss_weight:
clipped_snr.clamp_(max=self.min_snr_gamma)
if self.objective == "eps":
loss_weight = clipped_snr / snr
elif self.objective == "x_0":
loss_weight = clipped_snr
elif self.objective == "v":
loss_weight = clipped_snr / (snr + 1)
else:
raise ValueError(f"invalid objective {self.objective}")
return loss_weight
@autocast(enabled=False)
def q_step_from_x_0(self, x_0, step_t, rng=None):
# forward diffusion process q(zt|x0) where 0<t<1
noise = self.randn_like(x_0, rng=rng)
log_snr = self.log_snr(step_t)
alpha, sigma = _log_snr_to_alpha_sigma(log_snr)
x_t = x_0 * alpha + noise * sigma
return x_t, noise
def q_step(self, x_s, step_t, step_s, rng=None):
# q(zt|zs) where 0<s<t<1
# cf. Appendix A of https://arxiv.org/pdf/2107.00630.pdf
log_snr_t = self.log_snr(step_t)
log_snr_s = self.log_snr(step_s)
alpha_t, sigma_t = _log_snr_to_alpha_sigma(log_snr_t)
alpha_s, sigma_s = _log_snr_to_alpha_sigma(log_snr_s)
alpha_ts = alpha_t / alpha_s
var_noise = self.randn_like(x_s, rng=rng)
mean = x_s * alpha_ts
var = sigma_t.pow(2) - alpha_ts.pow(2) * sigma_s.pow(2)
x_t = mean + var.sqrt() * var_noise
return x_t
@torch.inference_mode()
def p_step(
self,
x_t: torch.Tensor,
step_t: torch.Tensor,
step_s: torch.Tensor,
rng: List[torch.Generator] | torch.Generator | None = None,
mode: Literal["ddpm", "ddim"] = "ddpm",
eta: float = 0.0,
) -> torch.Tensor:
# reverse diffusion process p(zs|zt) where 0<s<t<1
log_snr_t = self.log_snr(step_t)
log_snr_s = self.log_snr(step_s)
alpha_t, sigma_t = _log_snr_to_alpha_sigma(log_snr_t)
alpha_s, sigma_s = _log_snr_to_alpha_sigma(log_snr_s)
prediction = self.model(x_t, log_snr_t[:, 0, 0, 0])
if self.objective == "eps":
x_0 = (x_t - sigma_t * prediction) / alpha_t
elif self.objective == "v":
x_0 = alpha_t * x_t - sigma_t * prediction
elif self.objective == "x_0":
x_0 = prediction
else:
raise ValueError(f"invalid objective {self.objective}")
if self.clip_sample:
x_0.clamp_(-self.clip_sample_range, self.clip_sample_range)
if mode == "ddpm":
c = -expm1(log_snr_t - log_snr_s)
mean = alpha_s * (x_t * (1 - c) / alpha_t + c * x_0)
var = sigma_s.pow(2) * c
var_noise = self.randn_like(x_t, rng=rng)
var_noise[step_t == 0] = 0
x_s = mean + var.sqrt() * var_noise
elif mode == "ddim":
std_dev = eta * sigma_s / sigma_t * (1 - alpha_t**2 / alpha_s**2).sqrt()
eps = (x_t - alpha_t * x_0) / sigma_t
x_s_dir = (1 - alpha_s**2 - std_dev**2).sqrt() * eps
x_s = alpha_s * x_0 + x_s_dir
if eta > 0:
var_noise = self.randn_like(x_t, rng=rng)
var_noise[step_t == 0] = 0
x_s = x_s + std_dev * var_noise
else:
raise ValueError(f"invalid mode {mode}")
return x_s
@torch.inference_mode()
def sample(
self,
batch_size: int,
num_steps: int,
progress: bool = True,
rng: list[torch.Generator] | torch.Generator | None = None,
return_all: bool = False,
mode: Literal["ddpm", "ddim"] = "ddpm",
):
x = self.randn(batch_size, *self.sampling_shape, rng=rng, device=self.device)
if return_all:
out = [x]
steps = torch.linspace(1.0, 0.0, num_steps + 1, device=self.device)
steps = steps[None].repeat_interleave(batch_size, dim=0)
tqdm_kwargs = dict(desc="sampling", leave=False, disable=not progress)
for i in tqdm(range(num_steps), **tqdm_kwargs):
step_t = steps[:, i]
step_s = steps[:, i + 1]
x = self.p_step(x, step_t, step_s, rng=rng, mode=mode)
if return_all:
out.append(x)
return torch.stack(out) if return_all else x
@torch.inference_mode()
def repaint(
self,
known: torch.Tensor,
mask: torch.Tensor,
num_steps: int,
num_resample_steps: int = 1, # "n" of the RePaint paper
jump_length: int = 1,
progress: bool = True,
rng: list[torch.Generator] | torch.Generator | None = None,
return_all: bool = False,
):
# re-implementation of RePaint (https://arxiv.org/abs/2201.09865)
assert num_resample_steps > 0
assert jump_length > 0
batch_size = known.shape[0]
x_t = self.randn(batch_size, *self.sampling_shape, rng=rng, device=self.device)
steps = torch.linspace(1, 0, num_steps + 1, device=self.device)
steps = steps[None].repeat_interleave(batch_size, dim=0)
if return_all:
out = [x_t]
for i in tqdm(
range(num_steps), desc="RePaint", leave=False, disable=not progress
):
for j in range(num_resample_steps):
step_t = steps[:, [i]]
step_s = steps[:, [i + 1]]
interp = torch.linspace(0, 1, jump_length + 1, device=self.device)
r_steps = step_t + interp[None] * (step_s - step_t)
# t->s (reverse diffusion)
x = x_t
for k in range(jump_length):
r_step_t = r_steps[:, k]
r_step_s = r_steps[:, k + 1]
known_s, _ = self.q_step_from_x_0(known, r_step_s, rng=rng)
unknown_s = self.p_step(x, r_step_t, r_step_s, rng=rng)
x = mask * known_s + (1 - mask) * unknown_s
x_s = x
if return_all:
out.append(x_s)
if (i == num_steps - 1) or (j == num_resample_steps - 1):
x_t = x
break
# s->t (forward diffusion)
x = x_s
for k in range(jump_length, 0, -1):
r_step_t = r_steps[:, k - 1]
r_step_s = r_steps[:, k]
x = self.q_step(x, r_step_t, r_step_s, rng=rng)
x_t = x
return torch.stack(out) if return_all else x_s