-
Notifications
You must be signed in to change notification settings - Fork 8
/
base.py
210 lines (173 loc) · 7.48 KB
/
base.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
# MIT License: Copyright (c) 2021 Lorenzo Loconte, Gennaro Gala
from typing import Optional, Tuple
import torch
from torch import nn
from torch import distributions
from deeprob.torch.base import ProbabilisticModel, DensityEstimator
from deeprob.flows.utils import DequantizeLayer, LogitLayer
class NormalizingFlow(ProbabilisticModel):
has_rsample = True
def __init__(
self,
in_features,
dequantize: bool = False,
logit: Optional[float] = None,
in_base: Optional[DensityEstimator] = None
):
"""
Initialize an abstract Normalizing Flow model.
:param in_features: The input size.
:param dequantize: Whether to apply the dequantization transformation.
:param logit: The logit factor to use. Use None to disable the logit transformation.
:param in_base: The input base distribution to use. If None, the standard Normal distribution is used.
:raises ValueError: If the number of input features is invalid.
:raises ValueError: If the logit value is invalid.
"""
if isinstance(in_features, torch.Size):
in_features = tuple(in_features)
if len(in_features) == 1:
in_features = in_features[0]
if not isinstance(in_features, int):
if not isinstance(in_features, tuple) or len(in_features) != 3:
raise ValueError("The number of input features must be either an int or a (C, H, W) tuple")
super().__init__()
self.in_features = in_features
# Build the dequantization layer
if dequantize:
self.dequantize = DequantizeLayer(in_features)
else:
self.dequantize = None
# Build the logit layer
if logit is not None:
if logit <= 0.0 or logit >= 1.0:
raise ValueError("The logit factor must be in (0, 1)")
self.logit = LogitLayer(in_features, alpha=logit)
else:
self.logit = None
# Build the base distribution, if necessary
if in_base is None:
self.in_base_loc = nn.Parameter(torch.zeros(in_features), requires_grad=False)
self.in_base_scale = nn.Parameter(torch.ones(in_features), requires_grad=False)
self.in_base = distributions.Normal(self.in_base_loc, self.in_base_scale)
else:
self.in_base = in_base
# Initialize the normalizing flow layers
self.layers = nn.ModuleList()
def train(self, mode: bool = True, base_mode: bool = True):
"""
Set the training mode.
:param mode: The training mode for the flows layers.
:param base_mode: The training mode for the in_base distribution.
:return: Itself.
"""
self.training = mode
self.layers.train(mode)
if isinstance(self.in_base, torch.nn.Module):
self.in_base.train(base_mode)
return self
def eval(self):
"""
Turn off the training mode for both the flows layers and the in_base distribution.
:return: Itself.
"""
return self.train(False, False)
def preprocess(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Preprocess the data batch before feeding it to the probabilistic model (forward mode).
:param x: The input data batch.
:return: The preprocessed data batch and the inv-log-det-jacobian.
"""
inv_log_det_jacobian = 0.0
if self.dequantize is not None:
x, ildj = self.dequantize.apply_backward(x)
inv_log_det_jacobian += ildj
if self.logit is not None:
x, ildj = self.logit.apply_backward(x)
inv_log_det_jacobian += ildj
return x, inv_log_det_jacobian
def unpreprocess(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Preprocess the data batch before feeding it to the probabilistic model (backward mode).
:param x: The input data batch.
:return: The unpreprocessed data batch and the log-det-jacobian.
"""
log_det_jacobian = 0.0
if self.logit is not None:
x, ldj = self.logit.apply_forward(x)
log_det_jacobian += ldj
if self.dequantize is not None:
x, ldj = self.dequantize.apply_forward(x)
log_det_jacobian += ldj
return x, log_det_jacobian
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Compute the log-likelihood given complete evidence.
:param x: The inputs.
:return: The log-likelihoods.
"""
# Preprocess the samples
batch_size = x.shape[0]
x, inv_log_det_jacobian = self.preprocess(x)
# Apply backward transformations
x, ildj = self.apply_backward(x)
inv_log_det_jacobian += ildj
# Compute the prior log-likelihood
base_lls = self.in_base.log_prob(x)
prior = torch.sum(base_lls.view(batch_size, -1), dim=1)
# Return the final log-likelihood
return prior + inv_log_det_jacobian
@torch.no_grad()
def sample(self, n_samples: int, y: Optional[torch.Tensor] = None) -> torch.Tensor:
# Sample from the base distribution
if isinstance(self.in_base, distributions.Distribution):
n_samples = [n_samples]
x = self.in_base.sample(n_samples)
# Apply forward transformations
x, _ = self.apply_forward(x)
# Apply reversed preprocessing transformation
x, _ = self.unpreprocess(x)
return x
def rsample(self, n_samples: int, y: Optional[torch.Tensor] = None) -> torch.Tensor:
"""
Sample some values from the modeled distribution by reparametrization.
Unlike :func:`NormalizingFlow.sample`, this method allows backpropagation.
:param n_samples: The number of samples.
:param y: The samples labels. It can be None.
:return: The samples.
"""
# Sample from the base distribution (should have rsample method)
if not self.in_base.has_rsample:
raise NotImplementedError("Base distribution must support parametrized sampling")
if isinstance(self.in_base, distributions.Distribution):
n_samples = [n_samples]
x = self.in_base.rsample(n_samples)
# Apply forward transformations
x, _ = self.apply_forward(x)
# Apply reversed preprocessing transformation
x, _ = self.unpreprocess(x)
return x
def apply_backward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Apply the backward transformation.
:param x: The inputs.
:return: The transformed samples and the backward log-det-jacobian.
"""
inv_log_det_jacobian = 0.0
for layer in self.layers:
x, ildj = layer.apply_backward(x)
inv_log_det_jacobian += ildj
return x, inv_log_det_jacobian
def apply_forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Apply the forward transformation.
:param x: the inputs.
:return: The transformed samples and the forward log-det-jacobian.
"""
log_det_jacobian = 0.0
for layer in reversed(self.layers):
x, ldj = layer.apply_forward(x)
log_det_jacobian += ldj
return x, log_det_jacobian
def loss(self, x: torch.Tensor, y: Optional[torch.Tensor] = None) -> torch.Tensor:
# Compute the loss as the average negative log-likelihood
return -torch.mean(x)