Skip to content

Commit

Permalink
✨ Multi-Scale Structural Similarity (MS-SSIM)
Browse files Browse the repository at this point in the history
  • Loading branch information
francois-rozet committed Oct 23, 2020
1 parent c1addd2 commit be1821f
Show file tree
Hide file tree
Showing 2 changed files with 86 additions and 3 deletions.
2 changes: 1 addition & 1 deletion spiq/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__version__ = '0.0.1'

from .psnr import psnr, PSNR
from .ssim import ssim, SSIM
from .ssim import ssim, msssim, SSIM, MSSSIM
87 changes: 85 additions & 2 deletions spiq/ssim.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
r"""Structural Similarity (SSIM)
r"""Structural Similarity (SSIM) and Multi-Scale Structural Similarity (MS-SSIM)
This module implements the SSIM in PyTorch.
This module implements the SSIM and MS-SSIM in PyTorch.
Wikipedia:
https://en.wikipedia.org/wiki/Structural_similarity
Expand Down Expand Up @@ -34,6 +34,7 @@

_SIGMA = 1.5
_K1, _K2 = 0.01, 0.03
_WEIGHTS = torch.FloatTensor([0.0448, 0.2856, 0.3001, 0.2363, 0.1333])


#############
Expand Down Expand Up @@ -136,6 +137,52 @@ def ssim(x: torch.Tensor, y: torch.Tensor, window_size: int=11, value_range: flo
return ssim_per_channel(x, y, window, value_range)[0].mean(-1)


def msssim_per_channel(x: torch.Tensor, y: torch.Tensor, window: torch.Tensor, value_range: float=1., weights: torch.Tensor=_WEIGHTS) -> torch.Tensor:
"""Returns the MS-SSIM per channel between `x` and `y`.
Args:
x: input tensor, (N, C, H, W)
y: target tensor, (N, C, H, W)
window: convolution window
value_range: value range of the inputs (usually 1. or 255)
weights: weights of the scales, (N)
"""

mcs = []

for i in range(weights.numel()):
if i > 0:
padding = (x.shape[-2] % 2, x.shape[-1] % 2)
x = F.avg_pool2d(x, kernel_size=2, padding=padding)
y = F.avg_pool2d(y, kernel_size=2, padding=padding)

ssim, cs = ssim_per_channel(x, y, window, value_range)
mcs.append(torch.relu(cs))

msssim = torch.stack(mcs[:-1] + [ssim], dim=0)
msssim = msssim ** weights.view(-1, 1, 1)

return msssim.prod(dim=0)


def msssim(x: torch.Tensor, y: torch.Tensor, window_size: int=11, value_range: float=1., weights: torch.Tensor=_WEIGHTS) -> torch.Tensor:
r"""Returns the MS-SSIM between `x` and `y`.
Args:
x: input tensor of shape, (N, C, H, W)
y: target tensor of shape, (N, C, H, W)
window_size: size of the window
value_range: value range of the inputs (usually 1. or 255)
weights: weights of the scales, (N)
"""

n_channels = x.size(1)
window = create_window(window_size, n_channels).to(x.device)
weights = weights.to(x.device)

return msssim_per_channel(x, y, window, value_range, weights).mean(-1)


###########
# Classes #
###########
Expand Down Expand Up @@ -174,3 +221,39 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
return l.sum()

return l


class MSSSIM(SSIM):
r"""Creates a criterion that measures the MS-SSIM between an input and a target.
"""

def __init__(self, window_size: int=11, n_channels: int=3, value_range: float=1., weights: torch.Tensor=_WEIGHTS, reduction='mean'):
super().__init__(
window_size=window_size,
n_channels=n_channels,
value_range=value_range,
reduction=reduction
)

self.register_buffer('weights', weights)

def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
r"""
Args:
input: input tensor, (N, C, H, W)
target: target tensor, (N, C, H, W)
"""

l = msssim_per_channel(
input, target,
window=self.window,
value_range=self.value_range,
weights=self.weights
).mean(-1)

if self.reduction == 'mean':
return l.mean()
elif self.reduction == 'sum':
return l.sum()

return l

0 comments on commit be1821f

Please sign in to comment.