Skip to content

Commit

Permalink
🚧 Start major refresh
Browse files Browse the repository at this point in the history
πŸ”₯ Drop complex module in favor of torch.complex

⚑️ Replace _jit by torch.jit.script_if_tracing

✨ Simplify LPIPS

πŸ“ Add a docstring to forward methods

πŸ“ Improve the format of references
  • Loading branch information
francois-rozet committed Apr 2, 2023
1 parent ded8a27 commit d233ef8
Show file tree
Hide file tree
Showing 13 changed files with 753 additions and 1,013 deletions.
163 changes: 79 additions & 84 deletions piqa/fsim.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,10 @@
https://www4.comp.polyu.edu.hk/~cslzhang/IQA/FSIM/FSIM.htm
References:
.. [Zhang2011] FSIM: A Feature Similarity Index for Image Quality Assessment (Zhang et al., 2011)
| FSIM: A Feature Similarity Index for Image Quality Assessment (Zhang et al., 2011)
| https://ieeexplore.ieee.org/document/5705575
.. [Kovesi1999] Image Features From Phase Congruency (Kovesi, 1999)
| Image Features From Phase Congruency (Kovesi, 1999)
"""

import math
Expand All @@ -19,8 +20,7 @@

from torch import Tensor

from .utils import _jit, assert_type, reduce_tensor
from .utils import complex as cx
from .utils import assert_type
from .utils.color import ColorConv
from .utils.functional import (
scharr_kernel,
Expand All @@ -29,36 +29,38 @@
log_gabor,
channel_conv,
l2_norm,
downsample,
reduce_tensor,
)


@_jit
@torch.jit.script_if_tracing
def fsim(
x: Tensor,
y: Tensor,
pc_x: Tensor,
pc_y: Tensor,
kernel: Tensor,
value_range: float = 1.,
value_range: float = 1.0,
t1: float = 0.85,
t2: float = 160. / (255. ** 2),
t3: float = 200. / (255. ** 2),
t4: float = 200. / (255. ** 2),
t2: float = 160 / 255 ** 2,
t3: float = 200 / 255 ** 2,
t4: float = 200 / 255 ** 2,
lmbda: float = 0.03,
) -> Tensor:
r"""Returns the FSIM between :math:`x` and :math:`y`,
without color space conversion and downsampling.
r"""Returns the FSIM between :math:`x` and :math:`y`, without color space
conversion and downsampling.
Args:
x: An input tensor, :math:`(N, 3 \text{ or } 1, H, W)`.
y: A target tensor, :math:`(N, 3 \text{ or } 1, H, W)`.
pc_x: The input phase congruency, :math:`(N, H, W)`.
pc_y: The target phase congruency, :math:`(N, H, W)`.
kernel: A gradient kernel, :math:`(2, 1, K, K)`.
value_range: The value range :math:`L` of the inputs (usually `1.` or `255`).
value_range: The value range :math:`L` of the inputs (usually 1 or 255).
Note:
For the remaining arguments, refer to [Zhang2011]_.
For the remaining arguments, refer to Zhang et al. (2011).
Returns:
The FSIM vector, :math:`(N,)`.
Expand All @@ -71,7 +73,7 @@ def fsim(
>>> pc_y = phase_congruency(y[:, :1], filters)
>>> kernel = gradient_kernel(scharr_kernel())
>>> l = fsim(x, y, pc_x, pc_y, kernel)
>>> l.size()
>>> l.shape
torch.Size([5])
"""

Expand All @@ -86,26 +88,26 @@ def fsim(
s_pc = (2 * pc_x * pc_y + t1) / (pc_x ** 2 + pc_y ** 2 + t1)

# Gradient magnitude similarity
pad = kernel.size(-1) // 2
pad = kernel.shape[-1] // 2

g_x = l2_norm(channel_conv(y_x, kernel, padding=pad), dims=[1])
g_y = l2_norm(channel_conv(y_y, kernel, padding=pad), dims=[1])
g_x = l2_norm(channel_conv(y_x, kernel, padding=pad), dim=1)
g_y = l2_norm(channel_conv(y_y, kernel, padding=pad), dim=1)

s_g = (2 * g_x * g_y + t2) / (g_x ** 2 + g_y ** 2 + t2)

# Chrominance similarity
s_l = s_pc * s_g

if x.size(1) == 3:
if x.shape[1] == 3:
i_x, i_y = x[:, 1], y[:, 1]
q_x, q_y = x[:, 2], y[:, 2]

s_i = (2 * i_x * i_y + t3) / (i_x ** 2 + i_y ** 2 + t3)
s_q = (2 * q_x * q_y + t4) / (q_x ** 2 + q_y ** 2 + t4)

s_iq = s_i * s_q
s_iq = cx.complx(s_iq, torch.zeros_like(s_iq))
s_iq_lambda = cx.real(cx.pow(s_iq, lmbda))
s_iq = torch.complex(s_iq, torch.zeros_like(s_iq))
s_iq_lambda = (s_iq ** lmbda).real

s_l = s_l * s_iq_lambda

Expand All @@ -115,13 +117,13 @@ def fsim(
return fs


@_jit
@torch.jit.script_if_tracing
def pc_filters(
x: Tensor,
scales: int = 4,
orientations: int = 4,
wavelength: float = 6.,
factor: float = 2.,
wavelength: float = 6.0,
factor: float = 2.0,
sigma_f: float = 0.5978, # -log(0.55)
sigma_theta: float = 0.6545, # pi / (4 * 1.2)
) -> Tensor:
Expand All @@ -133,7 +135,7 @@ def pc_filters(
orientations: The number of orientations, :math:`S_2`.
Note:
For the remaining arguments, refer to [Kovesi1999]_.
For the remaining arguments, refer to Kovesi (1999).
Returns:
The filters tensor, :math:`(S_1, S_2, H, W)`.
Expand Down Expand Up @@ -175,12 +177,12 @@ def pc_filters(
return filters


@_jit
@torch.jit.script_if_tracing
def phase_congruency(
x: Tensor,
filters: Tensor,
value_range: float = 1.,
k: float = 2.,
value_range: float = 1.0,
k: float = 2.0,
rescale: float = 1.7,
eps: float = 1e-8,
) -> Tensor:
Expand All @@ -189,10 +191,10 @@ def phase_congruency(
Args:
x: An input tensor, :math:`(N, 1, H, W)`.
filters: The frequency domain filters, :math:`(S_1, S_2, H, W)`.
value_range: The value range :math:`L` of the input (usually `1.` or `255`).
value_range: The value range :math:`L` of the input (usually 1 or 255).
Note:
For the remaining arguments, refer to [Kovesi1999]_.
For the remaining arguments, refer to Kovesi (1999).
Returns:
The PC tensor, :math:`(N, H, W)`.
Expand All @@ -201,49 +203,47 @@ def phase_congruency(
>>> x = torch.rand(5, 1, 256, 256)
>>> filters = pc_filters(x)
>>> pc = phase_congruency(x, filters)
>>> pc.size()
>>> pc.shape
torch.Size([5, 256, 256])
"""

x = x * (255. / value_range)
x = x * (255 / value_range)

# Filters
M_hat = filters
M = fft.ifft2(M_hat)
M = cx.real(torch.view_as_real(M))
M = fft.ifft2(M_hat).real

# Even & odd (real and imaginary) responses
eo = fft.ifft2(fft.fft2(x[:, None]) * M_hat)
eo = torch.view_as_real(eo)

# Amplitude
A = cx.mod(eo)
eo = fft.ifft2(fft.fft2(x)[:, None] * M_hat)

# Expected E^2
A2 = A[:, 0] ** 2
median_A2, _ = A2.flatten(-2).median(dim=-1)
A = eo.abs()
A2 = A[:, 0].square()
median_A2 = A2.flatten(-2).median(dim=-1).values
expect_A2 = median_A2 / math.log(2)

expect_M2_hat = (M_hat[0] ** 2).mean(dim=(-1, -2))
expect_M2_hat = M_hat[0].square().mean(dim=(-1, -2))
expect_MiMj = (M[:, None] * M[None, :]).sum(dim=(0, 1, 3, 4))

expect_E2 = expect_A2 * expect_MiMj / expect_M2_hat

# Threshold
sigma_G = expect_E2.sqrt()
mu_R = sigma_G * (math.pi / 2) ** 0.5
sigma_R = sigma_G * (2 - math.pi / 2) ** 0.5
mu_R = sigma_G * math.sqrt(math.pi / 2)
sigma_R = sigma_G * math.sqrt(2 - math.pi / 2)

T = mu_R + k * sigma_R
T = T / rescale # emprirical rescaling
T = T / rescale # empirical rescaling
T = T[..., None, None]

# Phase deviation
FH = eo.sum(dim=1, keepdim=True)
phi_eo = FH / (cx.mod(FH)[..., None] + eps)
fh = eo.sum(dim=1, keepdim=True)
fh = fh / (fh.abs() + eps)

dot = eo.real * fh.real + eo.imag * fh.imag
cross = eo.real * fh.imag - eo.imag * fh.real

E = cx.dot(eo, phi_eo) - cx.dot(eo, cx.turn(phi_eo)).abs()
E = E.sum(dim=1)
E = (dot - cross.abs()).sum(dim=1)

# Phase congruency
pc = (E - T).relu().sum(dim=1) / (A.sum(dim=(1, 2)) + eps)
Expand All @@ -252,34 +252,26 @@ def phase_congruency(


class FSIM(nn.Module):
r"""Creates a criterion that measures the FSIM
between an input and a target.
r"""Measures the FSIM between an input and a target.
Before applying :func:`fsim`, the input and target are converted from
RBG to Y(IQ) and downsampled by a factor :math:`\frac{\min(H, W)}{256}`.
Before applying :func:`fsim`, the input and target are converted from RBG to Y(IQ)
and downsampled to a 256-ish resolution.
Args:
chromatic: Whether to use the chromatic channels (IQ) or not.
downsample: Whether downsampling is enabled or not.
kernel: A gradient kernel, :math:`(2, 1, K, K)`.
If `None`, use the Scharr kernel instead.
If :py:`None`, use the Scharr kernel instead.
reduction: Specifies the reduction to apply to the output:
`'none'` | `'mean'` | `'sum'`.
Note:
`**kwargs` are passed to :func:`fsim`.
Shapes:
input: :math:`(N, 3, H, W)`
target: :math:`(N, 3, H, W)`
output: :math:`(N,)` or :math:`()` depending on `reduction`
`'none'`, `'mean'` or `'sum'`.
kwargs: Keyword arguments passed to :func:`fsim`.
Example:
>>> criterion = FSIM().cuda()
>>> x = torch.rand(5, 3, 256, 256, requires_grad=True).cuda()
>>> y = torch.rand(5, 3, 256, 256).cuda()
>>> criterion = FSIM()
>>> x = torch.rand(5, 3, 256, 256, requires_grad=True)
>>> y = torch.rand(5, 3, 256, 256)
>>> l = 1 - criterion(x, y)
>>> l.size()
>>> l.shape
torch.Size([])
>>> l.backward()
"""
Expand All @@ -298,44 +290,47 @@ def __init__(
kernel = gradient_kernel(scharr_kernel())

self.register_buffer('kernel', kernel)
self.register_buffer('filters', torch.zeros((0, 0, 0, 0)))

self.convert = ColorConv('RGB', 'YIQ' if chromatic else 'Y')
self.downsample = downsample
self.reduction = reduction
self.value_range = kwargs.get('value_range', 1.)
self.value_range = kwargs.get('value_range', 1.0)
self.kwargs = kwargs

def forward(self, input: Tensor, target: Tensor) -> Tensor:
def forward(self, x: Tensor, y: Tensor) -> Tensor:
r"""
Args:
x: An input tensor, :math:`(N, 3, H, W)`.
y: A target tensor, :math:`(N, 3, H, W)`.
Returns:
The FSIM vector, :math:`(N,)` or :math:`()` depending on `reduction`.
"""

assert_type(
input, target,
x, y,
device=self.kernel.device,
dim_range=(4, 4),
n_channels=3,
value_range=(0., self.value_range),
value_range=(0.0, self.value_range),
)

# Downsample
if self.downsample:
_, _, h, w = input.size()
M = round(min(h, w) / 256)

if M > 1:
input = F.avg_pool2d(input, kernel_size=M, ceil_mode=True)
target = F.avg_pool2d(target, kernel_size=M, ceil_mode=True)
x = downsample(x, 256)
y = downsample(y, 256)

# RGB to Y(IQ)
input = self.convert(input)
target = self.convert(target)
x = self.convert(x)
y = self.convert(y)

# Phase congruency
if self.filters.shape[-2:] != input.shape[-2:]:
self.filters = pc_filters(input)
filters = pc_filters(x)

pc_input = phase_congruency(input[:, :1], self.filters, self.value_range)
pc_target = phase_congruency(target[:, :1], self.filters, self.value_range)
pc_x = phase_congruency(x[:, :1], filters, self.value_range)
pc_y = phase_congruency(y[:, :1], filters, self.value_range)

# FSIM
l = fsim(input, target, pc_input, pc_target, kernel=self.kernel, **self.kwargs)
l = fsim(x, y, pc_x, pc_y, kernel=self.kernel, **self.kwargs)

return reduce_tensor(l, self.reduction)

0 comments on commit d233ef8

Please sign in to comment.