Skip to content

Commit

Permalink
✨ Add option to disable downsampling (#12)
Browse files Browse the repository at this point in the history
  • Loading branch information
francois-rozet committed Sep 11, 2021
1 parent e606d25 commit 9fedf9f
Show file tree
Hide file tree
Showing 5 changed files with 42 additions and 22 deletions.
16 changes: 10 additions & 6 deletions piqa/fsim.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,7 @@ class FSIM(nn.Module):
Args:
chromatic: Whether to use the chromatic channels (IQ) or not.
downsample: Whether downsampling is enabled or not.
kernel: A gradient kernel, \((2, 1, K, K)\).
If `None`, use the Scharr kernel instead.
reduction: Specifies the reduction to apply to the output:
Expand All @@ -279,6 +280,7 @@ class FSIM(nn.Module):
def __init__(
self,
chromatic: bool = True,
downsample: bool = True,
kernel: torch.Tensor = None,
reduction: str = 'mean',
**kwargs,
Expand All @@ -293,6 +295,7 @@ def __init__(
self.register_buffer('filters', torch.zeros((0, 0, 0, 0)))

self.convert = ColorConv('RGB', 'YIQ' if chromatic else 'Y')
self.downsample = downsample
self.reduction = reduction
self.value_range = kwargs.get('value_range', 1.)
self.kwargs = kwargs
Expand All @@ -314,19 +317,20 @@ def forward(
)

# Downsample
_, _, h, w = input.size()
M = round(min(h, w) / 256)
if self.downsample:
_, _, h, w = input.size()
M = round(min(h, w) / 256)

if M > 1:
input = F.avg_pool2d(input, kernel_size=M, ceil_mode=True)
target = F.avg_pool2d(target, kernel_size=M, ceil_mode=True)
if M > 1:
input = F.avg_pool2d(input, kernel_size=M, ceil_mode=True)
target = F.avg_pool2d(target, kernel_size=M, ceil_mode=True)

# RGB to Y(IQ)
input = self.convert(input)
target = self.convert(target)

# Phase congruency
if self.filters.shape[-2:] != (h, w):
if self.filters.shape[-2:] != input.shape[-2:]:
self.filters = pc_filters(input)

pc_input = phase_congruency(input[:, :1], self.filters, self.value_range)
Expand Down
8 changes: 6 additions & 2 deletions piqa/gmsd.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,7 @@ class GMSD(nn.Module):
RBG to Y, the luminance color space, and downsampled by a factor 2.
Args:
downsample: Whether downsampling is enabled or not.
kernel: A gradient kernel, \((2, 1, K, K)\).
If `None`, use the Prewitt kernel instead.
reduction: Specifies the reduction to apply to the output:
Expand All @@ -191,6 +192,7 @@ class GMSD(nn.Module):

def __init__(
self,
downsample: bool = True,
kernel: torch.Tensor = None,
reduction: str = 'mean',
**kwargs,
Expand All @@ -204,6 +206,7 @@ def __init__(
self.register_buffer('kernel', kernel)

self.convert = ColorConv('RGB', 'Y')
self.downsample = downsample
self.reduction = reduction
self.value_range = kwargs.get('value_range', 1.)
self.kwargs = kwargs
Expand All @@ -225,8 +228,9 @@ def forward(
)

# Downsample
input = F.avg_pool2d(input, 2, ceil_mode=True)
target = F.avg_pool2d(target, 2, ceil_mode=True)
if self.downsample:
input = F.avg_pool2d(input, 2, ceil_mode=True)
target = F.avg_pool2d(target, 2, ceil_mode=True)

# RGB to Y
input = self.convert(input)
Expand Down
8 changes: 6 additions & 2 deletions piqa/haarpsi.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,7 @@ class HaarPSI(nn.Module):
Args:
chromatic: Whether to use the chromatic channels (IQ) or not.
downsample: Whether downsampling is enabled or not.
reduction: Specifies the reduction to apply to the output:
`'none'` | `'mean'` | `'sum'`.
Expand All @@ -149,13 +150,15 @@ class HaarPSI(nn.Module):
def __init__(
self,
chromatic: bool = True,
downsample: bool = True,
reduction: str = 'mean',
**kwargs,
):
r""""""
super().__init__()

self.convert = ColorConv('RGB', 'YIQ' if chromatic else 'Y')
self.downsample = downsample
self.reduction = reduction
self.value_range = kwargs.get('value_range', 1.)
self.kwargs = kwargs
Expand All @@ -177,8 +180,9 @@ def forward(
)

# Downsample
input = F.avg_pool2d(input, 2, ceil_mode=True)
target = F.avg_pool2d(target, 2, ceil_mode=True)
if self.downsample:
input = F.avg_pool2d(input, 2, ceil_mode=True)
target = F.avg_pool2d(target, 2, ceil_mode=True)

# RGB to Y(IQ)
input = self.convert(input)
Expand Down
14 changes: 9 additions & 5 deletions piqa/mdsi.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,7 @@ class MDSI(nn.Module):
RBG to LHM and downsampled by a factor \( \frac{\min(H, W)}{256} \).
Args:
downsample: Whether downsampling is enabled or not.
kernel: A gradient kernel, \((2, 1, K, K)\).
If `None`, use the Prewitt kernel instead.
reduction: Specifies the reduction to apply to the output:
Expand All @@ -150,6 +151,7 @@ class MDSI(nn.Module):

def __init__(
self,
downsample: bool = True,
kernel: torch.Tensor = None,
reduction: str = 'mean',
**kwargs,
Expand All @@ -163,6 +165,7 @@ def __init__(
self.register_buffer('kernel', kernel)

self.convert = ColorConv('RGB', 'LHM')
self.downsample = downsample
self.reduction = reduction
self.value_range = kwargs.get('value_range', 1.)
self.kwargs = kwargs
Expand All @@ -184,12 +187,13 @@ def forward(
)

# Downsample
_, _, h, w = input.size()
M = round(min(h, w) / 256)
if self.downsample:
_, _, h, w = input.size()
M = round(min(h, w) / 256)

if M > 1:
input = F.avg_pool2d(input, kernel_size=M, ceil_mode=True)
target = F.avg_pool2d(target, kernel_size=M, ceil_mode=True)
if M > 1:
input = F.avg_pool2d(input, kernel_size=M, ceil_mode=True)
target = F.avg_pool2d(target, kernel_size=M, ceil_mode=True)

# RGB to LHM
input = self.convert(input)
Expand Down
18 changes: 11 additions & 7 deletions piqa/vsi.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,7 @@ class VSI(nn.Module):
The visual saliency maps of the input and target are determined by `sdsp`.
Args:
downsample: Whether downsampling is enabled or not.
kernel: A gradient kernel, \((2, 1, K, K)\).
If `None`, use the Scharr kernel instead.
reduction: Specifies the reduction to apply to the output:
Expand All @@ -226,6 +227,7 @@ class VSI(nn.Module):

def __init__(
self,
downsample: bool = True,
kernel: torch.Tensor = None,
reduction: str = 'mean',
**kwargs,
Expand All @@ -240,6 +242,7 @@ def __init__(
self.register_buffer('filter', torch.zeros((0, 0)))

self.convert = ColorConv('RGB', 'LMN')
self.downsample = downsample
self.reduction = reduction
self.value_range = kwargs.get('value_range', 1.)
self.kwargs = kwargs
Expand All @@ -261,15 +264,16 @@ def forward(
)

# Downsample
_, _, h, w = input.size()
M = round(min(h, w) / 256)
if self.downsample:
_, _, h, w = input.size()
M = round(min(h, w) / 256)

if M > 1:
input = F.avg_pool2d(input, kernel_size=M, ceil_mode=True)
target = F.avg_pool2d(target, kernel_size=M, ceil_mode=True)
if M > 1:
input = F.avg_pool2d(input, kernel_size=M, ceil_mode=True)
target = F.avg_pool2d(target, kernel_size=M, ceil_mode=True)

# Visual saliancy
if self.filter.shape != (h, w):
# Visual saliency
if self.filter.shape != input.shape[-2:]:
self.filter = sdsp_filter(input)

vs_input = sdsp(input, self.filter, self.value_range)
Expand Down

0 comments on commit 9fedf9f

Please sign in to comment.