Skip to content

Commit

Permalink
⚡️ Refine cascading arguments
Browse files Browse the repository at this point in the history
  • Loading branch information
francois-rozet committed Dec 5, 2020
1 parent c682e57 commit 3d1fd4c
Show file tree
Hide file tree
Showing 4 changed files with 68 additions and 68 deletions.
13 changes: 9 additions & 4 deletions spiq/psnr.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,18 +42,23 @@ class PSNR(nn.Module):
between an input and a target.
Args:
value_range: The value range of the inputs (usually 1. or 255).
reduction: A reduction type (`'mean'`, `'sum'` or `'none'`).
`**kwargs` are transmitted to `psnr`, with
the exception of `dim` and `keepdim`.
Call:
The input and target tensors should be of shape (N, ...).
"""

def __init__(self, value_range: float = 1., reduction: str = 'mean'):
def __init__(self, reduction: str = 'mean', **kwargs):
super().__init__()

self.value_range = value_range
self.reduce = build_reduce(reduction)
self.kwargs = {
k: v for k, v in kwargs.items()
if k not in ['dim', 'keepdim']
}

def forward(
self,
Expand All @@ -64,7 +69,7 @@ def forward(
input,
target,
dim=tuple(range(1, input.ndimension())),
value_range=self.value_range,
**self.kwargs,
)

return self.reduce(l)
83 changes: 43 additions & 40 deletions spiq/ssim.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,6 @@

from spiq.utils import build_reduce, gaussian_kernel

_SIGMA = 1.5
_K1, _K2 = 0.01, 0.03
_WEIGHTS = torch.FloatTensor([0.0448, 0.2856, 0.3001, 0.2363, 0.1333])


Expand All @@ -37,7 +35,7 @@ def create_window(window_size: int, n_channels: int) -> torch.Tensor:
n_channels: A number of channels.
"""

kernel = gaussian_kernel(window_size, _SIGMA)
kernel = gaussian_kernel(window_size, 1.5)

window = kernel.unsqueeze(0).unsqueeze(0)
window = window.expand(n_channels, 1, window_size, window_size)
Expand All @@ -50,80 +48,91 @@ def ssim_per_channel(
y: torch.Tensor,
window: torch.Tensor,
value_range: float = 1.,
k1: float = 0.01,
k2: float = 0.03,
) -> torch.Tensor:
r"""Returns the SSIM and the contrast sensitivity (CS)
per channel between `x` and `y`.
r"""Returns the SSIM and the contrast sensitivity per channel
between `x` and `y`.
Args:
x: An input tensor, (N, C, H, W).
y: A target tensor, (N, C, H, W).
window: A convolution window.
value_range: The value range of the inputs (usually 1. or 255).
For the remaining arguments, refer to [1].
"""

n_channels, _, window_size, _ = window.size()

c1 = (k1 * value_range) ** 2
c2 = (k2 * value_range) ** 2

# Mean (mu)
mu_x = F.conv2d(x, window, padding=0, groups=n_channels)
mu_y = F.conv2d(y, window, padding=0, groups=n_channels)

mu_x_sq = mu_x ** 2
mu_y_sq = mu_y ** 2
mu_xy = mu_x * mu_y

# Variance (sigma)
sigma_x_sq = F.conv2d(x ** 2, window, padding=0, groups=n_channels)
sigma_x_sq -= mu_x_sq
sigma_y_sq = F.conv2d(y ** 2, window, padding=0, groups=n_channels)
sigma_y_sq -= mu_y_sq
sigma_xy = F.conv2d(x * y, window, padding=0, groups=n_channels)
sigma_xy -= mu_xy

c1 = (_K1 * value_range) ** 2
c2 = (_K2 * value_range) ** 2
# Contrast sensitivity
cs = (2. * sigma_xy + c2) / (sigma_x_sq + sigma_y_sq + c2)

cs_map = (2. * sigma_xy + c2) / (sigma_x_sq + sigma_y_sq + c2)
ssim_map = (2. * mu_x * mu_y + c1) / (mu_x_sq + mu_y_sq + c1) * cs_map
# Structural similarity
ss = (2. * mu_x * mu_y + c1) / (mu_x_sq + mu_y_sq + c1) * cs

return ssim_map.mean((-1, -2)), cs_map.mean((-1, -2))
return ss.mean((-1, -2)), cs.mean((-1, -2))


def ssim(
x: torch.Tensor,
y: torch.Tensor,
window_size: int = 11,
value_range: float = 1.,
**kwargs,
) -> torch.Tensor:
r"""Returns the SSIM between `x` and `y`.
Args:
x: An input tensor, (N, C, H, W).
y: A target tensor, (N, C, H, W).
window_size: The size of the window.
value_range: The value range of the inputs (usually 1. or 255).
`**kwargs` are transmitted to `ssim_per_channel`.
"""

n_channels = x.size(1)
window = create_window(window_size, n_channels).to(x.device)

return ssim_per_channel(x, y, window, value_range)[0].mean(-1)
return ssim_per_channel(x, y, window, **kwargs)[0].mean(-1)


def msssim_per_channel(
x: torch.Tensor,
y: torch.Tensor,
window: torch.Tensor,
value_range: float = 1.,
weights: torch.Tensor = _WEIGHTS,
**kwargs,
) -> torch.Tensor:
"""Returns the MS-SSIM per channel between `x` and `y`.
Args:
x: An input tensor, (N, C, H, W).
y: A target tensor, (N, C, H, W).
window: A convolution window.
value_range: The value range of the inputs (usually 1. or 255).
weights: The weights of the scales, (M,).
`**kwargs` are transmitted to `ssim_per_channel`.
"""

weights = _WEIGHTS.to(x.device)

mcs = []

for i in range(weights.numel()):
Expand All @@ -132,37 +141,36 @@ def msssim_per_channel(
x = F.avg_pool2d(x, kernel_size=2, padding=padding)
y = F.avg_pool2d(y, kernel_size=2, padding=padding)

ssim, cs = ssim_per_channel(x, y, window, value_range)
ss, cs = ssim_per_channel(x, y, window, **kwargs)
mcs.append(torch.relu(cs))

msssim = torch.stack(mcs[:-1] + [ssim], dim=0)
msssim = msssim ** weights.view(-1, 1, 1)
msss = torch.stack(mcs[:-1] + [ss], dim=0)
msss = msss ** weights.view(-1, 1, 1)
msss = msss.prod(dim=0)

return msssim.prod(dim=0)
return msss


def msssim(
x: torch.Tensor,
y: torch.Tensor,
window_size: int = 11,
value_range: float = 1.,
weights: torch.Tensor = _WEIGHTS,
**kwargs,
) -> torch.Tensor:
r"""Returns the MS-SSIM between `x` and `y`.
Args:
x: An input tensor, (N, C, H, W).
y: A target tensor, (N, C, H, W).
window_size: The size of the window.
value_range: The value range of the inputs (usually 1. or 255).
weights: The weights of the scales, (M,).
`**kwargs` are transmitted to `msssim_per_channel`.
"""

n_channels = x.size(1)
window = create_window(window_size, n_channels).to(x.device)
weights = weights.to(x.device)

return msssim_per_channel(x, y, window, value_range, weights).mean(-1)
return msssim_per_channel(x, y, window, **kwargs).mean(-1)


class SSIM(nn.Module):
Expand All @@ -172,9 +180,10 @@ class SSIM(nn.Module):
Args:
window_size: The size of the window.
n_channels: A number of channels.
value_range: The value range of the inputs (usually 1. or 255).
reduction: A reduction type (`'mean'`, `'sum'` or `'none'`).
`**kwargs` are transmitted to `ssim_per_channel`.
Call:
The input and target tensors should be of shape (N, C, H, W).
"""
Expand All @@ -183,15 +192,15 @@ def __init__(
self,
window_size: int = 11,
n_channels: int = 3,
value_range: float = 1.,
reduction: str = 'mean',
**kwargs,
):
super().__init__()

self.register_buffer('window', create_window(window_size, n_channels))

self.value_range = value_range
self.reduce = build_reduce(reduction)
self.kwargs = kwargs

def forward(
self,
Expand All @@ -202,7 +211,7 @@ def forward(
input,
target,
window=self.window,
value_range=self.value_range,
**self.kwargs,
)[0].mean(-1)

return self.reduce(l)
Expand All @@ -213,19 +222,14 @@ class MSSSIM(SSIM):
between an input and a target.
Args:
weights: The weights of the scales, (M,).
All arguments are inherited from `SSIM`.
All other arguments are inherited (see `SSIM`).
`**kwargs` are transmitted to `msssim_per_channel`.
Call:
The input and target tensors should be of shape (N, C, H, W).
"""

def __init__(self, weights: torch.Tensor = _WEIGHTS, *args, **kwargs):
super().__init__(*args, **kwargs)

self.register_buffer('weights', weights)

def forward(
self,
input: torch.Tensor,
Expand All @@ -235,8 +239,7 @@ def forward(
input,
target,
window=self.window,
value_range=self.value_range,
weights=self.weights,
**self.kwargs,
).mean(-1)

return self.reduce(l)
31 changes: 12 additions & 19 deletions spiq/tv.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import torch
import torch.nn as nn

from spiq.utils import build_reduce
from spiq.utils import build_reduce, tensor_norm


def tv(x: torch.Tensor, norm: str = 'L2') -> torch.Tensor:
Expand All @@ -20,42 +20,35 @@ def tv(x: torch.Tensor, norm: str = 'L2') -> torch.Tensor:
norm: A norm function name (`'L1'`, `'L2'` or `'L2_squared'`).
"""

w_var = x[..., :, 1:] - x[..., :, :-1]
h_var = x[..., 1:, :] - x[..., :-1, :]
variation = torch.cat([
x[..., :, 1:] - x[..., :, :-1],
x[..., 1:, :] - x[..., :-1, :],
], dim=-2)

if norm in ['L2', 'L2_squared']:
w_var = w_var ** 2
h_var = h_var ** 2
else: # norm == 'L1'
w_var = w_var.abs()
h_var = h_var.abs()
tv = tensor_norm(variation, dim=(-1, -2, -3), norm=norm)

score = w_var.sum(dim=(-1, -2, -3)) + h_var.sum(dim=(-1, -2, -3))

if norm == 'L2':
score = torch.sqrt(score)

return score
return tv


class TV(nn.Module):
r"""Creates a criterion that measures the TV of an input.
Args:
norm: A norm function name (`'L1'`, `'L2'` or `'L2_squared'`).
reduction: A reduction type (`'mean'`, `'sum'` or `'none'`).
`**kwargs` are transmitted to `tv`.
Call:
The input tensor should be of shape (N, C, H, W).
"""

def __init__(self, norm: str = 'L2', reduction: str = 'mean'):
def __init__(self, reduction: str = 'mean', **kwargs):
super().__init__()

self.norm = norm
self.reduce = build_reduce(reduction)
self.kwargs = kwargs

def forward(self, input: torch.Tensor) -> torch.Tensor:
l = tv(input, norm=self.norm)
l = tv(input, **self.kwargs)

return self.reduce(l)
9 changes: 4 additions & 5 deletions spiq/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,20 +97,19 @@ def tensor_norm(

def normalize_tensor(
x: torch.Tensor,
dim: Tuple[int, ...] = (),
norm: str = 'L2',
epsilon: float = 1e-8,
**kwargs,
) -> torch.Tensor:
r"""Returns `x` normalized.
Args:
x: An input tensor.
dim: The dimension(s) along which to normalize.
norm: A norm function name (`'L1'`, `'L2'` or `'L2_squared'`).
epsilon: A numerical stability term.
`**kwargs` are transmitted to `tensor_norm`.
"""

norm = tensor_norm(x, dim=dim, keepdim=True, norm=norm)

This comment has been minimized.

Copy link
@francois-rozet

francois-rozet Dec 5, 2020

Author Owner

The removal of keepdim=True creates dimension mismatches between x and norm.
The bug was fixed in 90cccb0.

norm = tensor_norm(x, **kwargs)

return x / (norm + epsilon)

Expand Down

0 comments on commit 3d1fd4c

Please sign in to comment.