-
Notifications
You must be signed in to change notification settings - Fork 0
/
sde_lib.py
119 lines (97 loc) · 3.1 KB
/
sde_lib.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
"""
SDE library from: https://github.com/yang-song/score_sde_pytorch
"""
import abc
import torch
import numpy as np
def scan(f, init, xs):
carry = init
ys = []
for x in xs:
carry, y = f(carry, x)
ys.append(y)
return carry, torch.stack(ys)
class SDE(abc.ABC):
def __init__(self, N):
super().__init__()
self.N = N
@property
@abc.abstractmethod
def T(self):
pass
@abc.abstractmethod
def sde(self, x, t):
pass
@abc.abstractmethod
def marginal_prob(self, x, t):
pass
@abc.abstractmethod
def prior_sampling(self, shape):
pass
@abc.abstractmethod
def prior_logp(self, z):
pass
def discretize(self, x, t):
dt = 1 / self.N
drift, diffusion = self.sde(x, t)
f = drift * dt
G = diffusion * torch.sqrt(torch.tensor(dt, device=t.device))
return f, G
def reverse(self, score_fn, probability_flow=False):
N = self.N
T = self.T
sde_fn = self.sde
discretize_fn = self.discretize
class RSDE(self.__class__):
def __init__(self):
self.N = N
self.probability_flow = probability_flow
@property
def T(self):
return T
def sde(self, x, t):
drift, diffusion = sde_fn(x, t)
score = score_fn(x, t)
drift = drift - diffusion[:, None] ** 2 * score * (0.5 if self.probability_flow else 1.)
diffusion = 0. if self.probability_flow else diffusion
return drift, diffusion
def discretize(self, x, t):
f, G = discretize_fn(x, t)
rev_f = f - G[:, None] ** 2 * score_fn(x, t) * (0.5 if self.probability_flow else 1.)
rev_G = torch.zeros_like(G) if self.probability_flow else G
return rev_f, rev_G
return RSDE()
class VESDE(SDE):
def __init__(self, sigma_min=0.01, sigma_max=50, N=1000):
super().__init__(N)
self.sigma_min = sigma_min
self.sigma_max = sigma_max
self.discrete_sigmas = torch.exp(torch.linspace(np.log(self.sigma_min), np.log(self.sigma_max), N))
self.N = N
@property
def T(self):
return 1
def sde(self, x, t):
sigma = self.sigma_min * (self.sigma_max / self.sigma_min) ** t
drift = torch.zeros_like(x)
diffusion = sigma * torch.sqrt(torch.tensor(2 * (np.log(self.sigma_max) - np.log(self.sigma_min)),
device=t.device))
return drift, diffusion
def marginal_prob(self, x, t):
std = self.sigma_min * (self.sigma_max / self.sigma_min) ** t
mean = x
return mean, std
def prior_sampling(self, shape):
return torch.randn(*shape) * self.sigma_max
def prior_logp(self, z):
shape = z.shape
N = np.prod(shape[1:])
return -N / 2. * np.log(2 * np.pi * self.sigma_max ** 2) - ( torch.sum(z ** 2, dim=1) / (2 * self.sigma_max ** 2) )
def discretize(self, x, t):
timestep = (t * (self.N - 1) / self.T).long()
sigma = self.discrete_sigmas.to(t.device)[timestep]
adjacent_sigma = torch.where(timestep == 0, torch.zeros_like(t),
self.discrete_sigmas[timestep - 1].to(t.device))
f = torch.zeros_like(x)
G = torch.sqrt(sigma ** 2 - adjacent_sigma ** 2)
return f, G