/
denoising.py
67 lines (57 loc) · 2.31 KB
/
denoising.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
import torch
def compute_alpha(beta, t):
beta = torch.cat([torch.zeros(1).to(beta.device), beta], dim=0)
a = (1 - beta).cumprod(dim=0).index_select(0, t + 1).view(-1, 1, 1, 1)
return a
def generalized_steps(x, seq, model, b, **kwargs):
with torch.no_grad():
n = x.size(0)
seq_next = [-1] + list(seq[:-1])
x0_preds = []
xs = [x]
for i, j in zip(reversed(seq), reversed(seq_next)):
t = (torch.ones(n) * i).to(x.device)
next_t = (torch.ones(n) * j).to(x.device)
at = compute_alpha(b, t.long())
at_next = compute_alpha(b, next_t.long())
xt = xs[-1].to('cuda')
et = model(xt, t)
x0_t = (xt - et * (1 - at).sqrt()) / at.sqrt()
x0_preds.append(x0_t.to('cpu'))
c1 = (
kwargs.get("eta", 0) * ((1 - at / at_next) * (1 - at_next) / (1 - at)).sqrt()
)
c2 = ((1 - at_next) - c1 ** 2).sqrt()
xt_next = at_next.sqrt() * x0_t + c1 * torch.randn_like(x) + c2 * et
xs.append(xt_next.to('cpu'))
return xs, x0_preds
def ddpm_steps(x, seq, model, b, **kwargs):
with torch.no_grad():
n = x.size(0)
seq_next = [-1] + list(seq[:-1])
xs = [x]
x0_preds = []
betas = b
for i, j in zip(reversed(seq), reversed(seq_next)):
t = (torch.ones(n) * i).to(x.device)
next_t = (torch.ones(n) * j).to(x.device)
at = compute_alpha(betas, t.long())
atm1 = compute_alpha(betas, next_t.long())
beta_t = 1 - at / atm1
x = xs[-1].to('cuda')
output = model(x, t.float())
e = output
x0_from_e = (1.0 / at).sqrt() * x - (1.0 / at - 1).sqrt() * e
x0_from_e = torch.clamp(x0_from_e, -1, 1)
x0_preds.append(x0_from_e.to('cpu'))
mean_eps = (
(atm1.sqrt() * beta_t) * x0_from_e + ((1 - beta_t).sqrt() * (1 - atm1)) * x
) / (1.0 - at)
mean = mean_eps
noise = torch.randn_like(x)
mask = 1 - (t == 0).float()
mask = mask.view(-1, 1, 1, 1)
logvar = beta_t.log()
sample = mean + mask * torch.exp(0.5 * logvar) * noise
xs.append(sample.to('cpu'))
return xs, x0_preds