Skip to content

Commit

Permalink
🚑 Fix dimension mismatches
Browse files Browse the repository at this point in the history
The bugs were introduced in 3d1fd4c
  • Loading branch information
francois-rozet committed Dec 10, 2020
1 parent 200460a commit fe983b6
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 11 deletions.
2 changes: 1 addition & 1 deletion piqa/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,4 @@
specific image quality assessement metric.
"""

__version__ = '1.0.2'
__version__ = '1.0.4'
20 changes: 14 additions & 6 deletions piqa/tv.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,22 @@ def tv(x: torch.Tensor, norm: str = 'L2') -> torch.Tensor:
`'L1'` | `'L2'` | `'L2_squared'`.
"""

variation = torch.cat([
x[..., :, 1:] - x[..., :, :-1],
x[..., 1:, :] - x[..., :-1, :],
], dim=-2)
w_var = x[..., :, 1:] - x[..., :, :-1]
h_var = x[..., 1:, :] - x[..., :-1, :]

tv = tensor_norm(variation, dim=(-1, -2, -3), norm=norm)
if norm in ['L2', 'L2_squared']:
w_var = w_var ** 2
h_var = h_var ** 2
else: # norm == 'L1'
w_var = w_var.abs()
h_var = h_var.abs()

return tv
var = w_var.sum(dim=(-1, -2, -3)) + h_var.sum(dim=(-1, -2, -3))

if norm == 'L2':
var = torch.sqrt(var)

return var


class TV(nn.Module):
Expand Down
10 changes: 6 additions & 4 deletions piqa/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,19 +153,21 @@ def tensor_norm(

def normalize_tensor(
x: torch.Tensor,
dim: Tuple[int, ...] = (),
norm: str = 'L2',
epsilon: float = 1e-8,
**kwargs,
) -> torch.Tensor:
r"""Returns `x` normalized.
Args:
x: An input tensor.
dim: The dimension(s) along which to normalize.
norm: Specifies the norm funcion to use:
`'L1'` | `'L2'` | `'L2_squared'`.
epsilon: A numerical stability term.
`**kwargs` are transmitted to `tensor_norm`.
"""

norm = tensor_norm(x, **kwargs)
norm = tensor_norm(x, dim=dim, keepdim=True, norm=norm)

return x / (norm + epsilon)

Expand Down

0 comments on commit fe983b6

Please sign in to comment.