Skip to content

Commit

Permalink
♻️ Remove repetitive reduction code segment
Browse files Browse the repository at this point in the history
  • Loading branch information
francois-rozet committed Dec 4, 2020
1 parent 7a83262 commit c9719b2
Show file tree
Hide file tree
Showing 5 changed files with 37 additions and 37 deletions.
11 changes: 3 additions & 8 deletions spiq/lpips.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import torch.nn as nn
import torchvision.models as models

from spiq.utils import normalize_tensor, Intermediary
from spiq.utils import build_reduce, normalize_tensor, Intermediary

_SHIFT = torch.Tensor([0.485, 0.456, 0.406])
_SCALE = torch.Tensor([0.229, 0.224, 0.225])
Expand Down Expand Up @@ -85,7 +85,7 @@ def __init__(
for y in x:
y.requires_grad = False

self.reduction = reduction
self.reduce = build_reduce(reduction)

def forward(
self,
Expand All @@ -106,9 +106,4 @@ def forward(

l = torch.cat(residuals, dim=-1).sum(dim=-1)

if self.reduction == 'mean':
l = l.mean()
elif self.reduction == 'sum':
l = l.sum()

return l
return self.reduce(l)
11 changes: 4 additions & 7 deletions spiq/psnr.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
import torch
import torch.nn as nn

from spiq.utils import build_reduce

from typing import Tuple


Expand Down Expand Up @@ -51,7 +53,7 @@ def __init__(self, value_range: float = 1., reduction: str = 'mean'):
super().__init__()

self.value_range = value_range
self.reduction = reduction
self.reduce = build_reduce(reduction)

def forward(
self,
Expand All @@ -65,9 +67,4 @@ def forward(
value_range=self.value_range,
)

if self.reduction == 'mean':
l = l.mean()
elif self.reduction == 'sum':
l = l.sum()

return l
return self.reduce(l)
18 changes: 4 additions & 14 deletions spiq/ssim.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
import torch.nn as nn
import torch.nn.functional as F

from spiq.utils import gaussian_kernel
from spiq.utils import build_reduce, gaussian_kernel

_SIGMA = 1.5
_K1, _K2 = 0.01, 0.03
Expand Down Expand Up @@ -191,7 +191,7 @@ def __init__(
self.register_buffer('window', create_window(window_size, n_channels))

self.value_range = value_range
self.reduction = reduction
self.reduce = build_reduce(reduction)

def forward(
self,
Expand All @@ -205,12 +205,7 @@ def forward(
value_range=self.value_range,
)[0].mean(-1)

if self.reduction == 'mean':
l = l.mean()
elif self.reduction == 'sum':
l = l.sum()

return l
return self.reduce(l)


class MSSSIM(SSIM):
Expand Down Expand Up @@ -244,9 +239,4 @@ def forward(
weights=self.weights,
).mean(-1)

if self.reduction == 'mean':
l = l.mean()
elif self.reduction == 'sum':
l = l.sum()

return l
return self.reduce(l)
11 changes: 4 additions & 7 deletions spiq/tv.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
import torch
import torch.nn as nn

from spiq.utils import build_reduce


def tv(x: torch.Tensor, norm: str = 'L2') -> torch.Tensor:
r"""Returns the TV of `x`.
Expand Down Expand Up @@ -51,14 +53,9 @@ def __init__(self, norm: str = 'L2', reduction: str = 'mean'):
super().__init__()

self.norm = norm
self.reduction = reduction
self.reduce = build_reduce(reduction)

def forward(self, input: torch.Tensor) -> torch.Tensor:
l = tv(input, norm=self.norm)

if self.reduction == 'mean':
l = l.mean()
elif self.reduction == 'sum':
l = l.sum()

return l
return self.reduce(l)
23 changes: 22 additions & 1 deletion spiq/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,28 @@
import torch
import torch.nn as nn

from typing import List, Tuple
from typing import Callable, List, Tuple


def build_reduce(
reduction: str = 'mean',
dim: Tuple[int, ...] = (),
keepdim: bool = False,
) -> Callable[[torch.Tensor], torch.Tensor]:
r"""Returns a reduce function.
Args:
reduction: A reduction type (`'mean'`, `'sum'` or `'none'`).
dim: The dimension(s) along which to reduce.
keepdim: Whether the output tensor has `dim` retained or not.
"""

if reduction == 'mean':
return lambda x: x.mean(dim=dim, keepdim=keepdim)
elif reduction == 'sum':
return lambda x: x.sum(dim=dim, keepdim=keepdim)

return lambda x: x


def gaussian_kernel(
Expand Down

0 comments on commit c9719b2

Please sign in to comment.