-
Notifications
You must be signed in to change notification settings - Fork 18
/
fid.py
219 lines (167 loc) · 5.77 KB
/
fid.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
r"""Fréchet Inception Distance (FID)
This module implements the FID in PyTorch.
Original:
https://github.com/bioinf-jku/TTUR
Wikipedia:
https://wikipedia.org/wiki/Frechet_inception_distance
References:
| GANs Trained by a Two Time-Scale Update Rule Converge to a Local Nash Equilibrium (Heusel et al., 2017)
| https://arxiv.org/abs/1706.08500
"""
import torch
import torch.nn as nn
import torchvision
from torch import Tensor
from typing import *
from .utils import assert_type
from .utils.color import ImageNetNorm
@torch.jit.script_if_tracing
def sqrtm(sigma: Tensor) -> Tensor:
r"""Returns the square root of a positive semi-definite matrix.
.. math:: \sqrt{\Sigma} = Q \sqrt{\Lambda} Q^T
where :math:`Q \Lambda Q^T` is the eigendecomposition of :math:`\Sigma`.
Args:
sigma: A positive semi-definite matrix, :math:`(*, D, D)`.
Example:
>>> V = torch.randn(4, 4, dtype=torch.double)
>>> A = V @ V.T
>>> B = sqrtm(A @ A)
>>> torch.allclose(A, B)
True
"""
L, Q = torch.linalg.eigh(sigma)
L = L.relu().sqrt()
return Q @ (L[..., None] * Q.mT)
@torch.jit.script_if_tracing
def frechet_distance(
mu_x: Tensor,
sigma_x: Tensor,
mu_y: Tensor,
sigma_y: Tensor,
) -> Tensor:
r"""Returns the Fréchet distance between two multivariate Gaussian distributions.
.. math:: d^2 = \left\| \mu_x - \mu_y \right\|_2^2 +
\operatorname{tr} \left( \Sigma_x + \Sigma_y - 2 \sqrt{\Sigma_y^{\frac{1}{2}} \Sigma_x \Sigma_y^{\frac{1}{2}}} \right)
Wikipedia:
https://wikipedia.org/wiki/Frechet_distance
Args:
mu_x: The mean :math:`\mu_x` of the first distribution, :math:`(*, D)`.
sigma_x: The covariance :math:`\Sigma_x` of the first distribution, :math:`(*, D, D)`.
mu_y: The mean :math:`\mu_y` of the second distribution, :math:`(*, D)`.
sigma_y: The covariance :math:`\Sigma_y` of the second distribution, :math:`(*, D, D)`.
Example:
>>> mu_x = torch.arange(3).float()
>>> sigma_x = torch.eye(3)
>>> mu_y = 2 * mu_x + 1
>>> sigma_y = 2 * sigma_x + 1
>>> frechet_distance(mu_x, sigma_x, mu_y, sigma_y)
tensor(15.8710)
"""
sigma_y_12 = sqrtm(sigma_y)
a = (mu_x - mu_y).square().sum(dim=-1)
b = sigma_x.trace() + sigma_y.trace()
c = sqrtm(sigma_y_12 @ sigma_x @ sigma_y_12).trace()
return a + b - 2 * c
class InceptionV3(nn.Sequential):
r"""Pretrained Inception-v3 network.
References:
| Rethinking the Inception Architecture for Computer Vision (Szegedy et al., 2015)
| https://arxiv.org/abs/1512.00567
Args:
logits: Whether to return the class logits or the last pooling features.
Example:
>>> x = torch.randn(5, 3, 256, 256)
>>> inception = InceptionV3()
>>> logits = inception(x)
>>> logits.shape
torch.Size([5, 1000])
"""
def __init__(self, logits: bool = True):
net = torchvision.models.inception_v3(weights='DEFAULT')
layers = [
net.Conv2d_1a_3x3,
net.Conv2d_2a_3x3,
net.Conv2d_2b_3x3,
net.maxpool1,
net.Conv2d_3b_1x1,
net.Conv2d_4a_3x3,
net.maxpool2,
net.Mixed_5b,
net.Mixed_5c,
net.Mixed_5d,
net.Mixed_6a,
net.Mixed_6b,
net.Mixed_6c,
net.Mixed_6d,
net.Mixed_6e,
net.Mixed_7a,
net.Mixed_7b,
net.Mixed_7c,
net.avgpool,
nn.Flatten(-3),
]
if logits:
layers.append(net.fc)
super().__init__(*layers)
class FID(nn.Module):
r"""Measures the FID between two set of inception features.
Note:
See :meth:`FID.features` for how to get inception features.
Example:
>>> criterion = FID()
>>> x = torch.randn(1024, 256)
>>> y = torch.randn(2048, 256)
>>> l = criterion(x, y)
>>> l.shape
torch.Size([])
"""
def __init__(self):
super().__init__()
# ImageNet normalization
self.normalize = ImageNetNorm()
# Inception-v3
self.inception = InceptionV3(logits=False)
self.inception.eval()
# Disable gradients
for p in self.parameters():
p.requires_grad = False
def features(self, x: Tensor, no_grad: bool = True) -> Tensor:
r"""Returns the inception features of an input.
Tip:
If you cannot get the inception features of your input at once, for instance
because of memory limitations, you can split it in smaller batches and
concatenate the outputs afterwards.
Args:
x: An input tensor, :math:`(N, 3, H, W)`.
no_grad: Whether to disable gradients or not.
Returns:
The features, :math:`(N, 2048)`.
"""
assert_type(
x,
device=self.normalize.shift.device,
dim_range=(4, 4),
n_channels=3,
value_range=(0.0, 1.0),
)
# ImageNet normalization
x = self.normalize(x)
# Features
if no_grad:
with torch.no_grad():
return self.inception(x)
else:
return self.inception(x)
def forward(self, x: Tensor, y: Tensor) -> Tensor:
r"""
Args:
x: An input tensor, :math:`(M, D)`.
y: A target tensor, :math:`(N, D)`.
Returns:
The FID, :math:`()`.
"""
# Mean & covariance
mu_x, sigma_x = torch.mean(x, dim=0), torch.cov(x.T)
mu_y, sigma_y = torch.mean(y, dim=0), torch.cov(y.T)
# Fréchet distance
return frechet_distance(mu_x, sigma_x, mu_y, sigma_y)