-
Notifications
You must be signed in to change notification settings - Fork 18
/
haarpsi.py
189 lines (145 loc) · 5.05 KB
/
haarpsi.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
r"""Haar Perceptual Similarity Index (HaarPSI)
This module implements the HaarPSI in PyTorch.
Original:
https://github.com/rgcda/haarpsi
Wikipedia:
https://wikipedia.org/wiki/Haar_wavelet
References:
| A Haar Wavelet-Based Perceptual Similarity Index for Image Quality Assessment (Reisenhofer et al., 2018)
| https://arxiv.org/abs/1607.06140
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from .utils import assert_type
from .utils.color import ColorConv
from .utils.functional import (
haar_kernel,
gradient_kernel,
channel_conv,
reduce_tensor,
)
@torch.jit.script_if_tracing
def haarpsi(
x: Tensor,
y: Tensor,
n_kernels: int = 3,
value_range: float = 1.0,
c: float = 30 / 255 ** 2,
alpha: float = 4.2,
) -> Tensor:
r"""Returns the HaarPSI between :math:`x` and :math:`y`, without color space
conversion.
Args:
x: An input tensor, :math:`(N, 3 \text{ or } 1, H, W)`.
y: A target tensor, :math:`(N, 3 \text{ or } 1, H, W)`.
n_kernels: The number of Haar wavelet kernels to use.
value_range: The value range :math:`L` of the inputs (usually 1 or 255).
Note:
For the remaining arguments, refer to Reisenhofer et al. (2018).
Returns:
The HaarPSI vector, :math:`(N,)`.
Example:
>>> x = torch.rand(5, 3, 256, 256)
>>> y = torch.rand(5, 3, 256, 256)
>>> l = haarpsi(x, y)
>>> l.shape
torch.Size([5])
"""
c *= value_range ** 2
# Y
y_x, y_y = x[:, :1], y[:, :1]
## Gradient similarity(ies)
gs = []
for j in range(1, n_kernels + 1):
kernel_size = int(2 ** j)
### Haar wavelet kernel
kernel = gradient_kernel(haar_kernel(kernel_size)).to(x.device)
### Haar filter (gradient)
pad = kernel_size // 2
g_x = channel_conv(y_x, kernel, padding=pad)[..., 1:, 1:].abs()
g_y = channel_conv(y_y, kernel, padding=pad)[..., 1:, 1:].abs()
if j < n_kernels:
gs.append((2 * g_x * g_y + c) / (g_x ** 2 + g_y ** 2 + c))
else:
gs.append(g_x)
gs.append(g_y)
## Local similarity(ies)
ls = torch.stack(gs[:-2], dim=-1).mean(dim=-1) # (N, 2, H, W)
## Weight(s)
w = torch.maximum(gs[-2], gs[-1]) # (N, 2, H, W)
# IQ
if x.shape[1] == 3:
iq_x, iq_y = x[:, 1:], y[:, 1:]
## Mean filter
m_x = F.avg_pool2d(iq_x, 2, stride=1, padding=1)[..., 1:, 1:].abs()
m_y = F.avg_pool2d(iq_y, 2, stride=1, padding=1)[..., 1:, 1:].abs()
## Chromatic similarity(ies)
cs = (2 * m_x * m_y + c) / (m_x ** 2 + m_y ** 2 + c)
## Local similarity(ies)
ls = torch.cat((ls, cs.mean(dim=1, keepdim=True)), dim=1) # (N, 3, H, W)
## Weight(s)
w = torch.cat((w, w.mean(dim=1, keepdim=True)), dim=1) # (N, 3, H, W)
# HaarPSI
hs = torch.sigmoid(ls * alpha)
hpsi = (hs * w).sum(dim=(-1, -2, -3)) / w.sum(dim=(-1, -2, -3))
hpsi = (torch.logit(hpsi) / alpha) ** 2
return hpsi
class HaarPSI(nn.Module):
r"""Measures the HaarPSI between an input and a target.
Before applying :func:`haarpsi`, the input and target are converted from
RBG to Y(IQ) and downsampled by a factor 2.
Args:
chromatic: Whether to use the chromatic channels (IQ) or not.
downsample: Whether downsampling is enabled or not.
reduction: Specifies the reduction to apply to the output:
`'none'`, `'mean'` or `'sum'`.
kwargs: Keyword arguments passed to :func:`haarpsi`.
Example:
>>> criterion = HaarPSI()
>>> x = torch.rand(5, 3, 256, 256, requires_grad=True)
>>> y = torch.rand(5, 3, 256, 256)
>>> l = 1 - criterion(x, y)
>>> l.shape
torch.Size([])
>>> l.backward()
"""
def __init__(
self,
chromatic: bool = True,
downsample: bool = True,
reduction: str = 'mean',
**kwargs,
):
super().__init__()
self.convert = ColorConv('RGB', 'YIQ' if chromatic else 'Y')
self.downsample = downsample
self.reduction = reduction
self.value_range = kwargs.get('value_range', 1.0)
self.kwargs = kwargs
def forward(self, x: Tensor, y: Tensor) -> Tensor:
r"""
Args:
x: An input tensor, :math:`(N, 3, H, W)`.
y: A target tensor, :math:`(N, 3, H, W)`.
Returns:
The HaarPSI vector, :math:`(N,)` or :math:`()` depending on `reduction`.
"""
assert_type(
x, y,
device=self.convert.weight.device,
dim_range=(4, 4),
n_channels=3,
value_range=(0.0, self.value_range),
)
# Downsample
if self.downsample:
x = F.avg_pool2d(x, 2, ceil_mode=True)
y = F.avg_pool2d(y, 2, ceil_mode=True)
# RGB to Y(IQ)
x = self.convert(x)
y = self.convert(y)
# HaarPSI
l = haarpsi(x, y, **self.kwargs)
return reduce_tensor(l, self.reduction)