-
Notifications
You must be signed in to change notification settings - Fork 18
/
tv.py
105 lines (75 loc) · 2.38 KB
/
tv.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
r"""Total Variation (TV)
This module implements the TV in PyTorch.
Wikipedia:
https://wikipedia.org/wiki/Total_variation
"""
import torch
import torch.nn as nn
from torch import Tensor
from .utils import assert_type
from .utils.functional import reduce_tensor
@torch.jit.script_if_tracing
def tv(x: Tensor, norm: str = 'L1') -> Tensor:
r"""Returns the TV of :math:`x`.
With `'L1'`,
.. math::
\text{TV}(x) = \sum_{i, j}
\left| x_{i+1, j} - x_{i, j} \right| +
\left| x_{i, j+1} - x_{i, j} \right|
Alternatively, with `'L2'`,
.. math::
\text{TV}(x) = \left( \sum_{c, i, j}
(x_{c, i+1, j} - x_{c, i, j})^2 +
(x_{c, i, j+1} - x_{c, i, j})^2 \right)^{\frac{1}{2}}
Args:
x: An input tensor, :math:`(*, C, H, W)`.
norm: Specifies the norm funcion to apply:
`'L1'`, `'L2'` or `'L2_squared'`.
Returns:
The TV tensor, :math:`(*,)`.
Example:
>>> x = torch.rand(5, 3, 256, 256)
>>> l = tv(x)
>>> l.shape
torch.Size([5])
"""
w_var = torch.diff(x, dim=-1)
h_var = torch.diff(x, dim=-2)
if norm == 'L1':
w_var = w_var.abs()
h_var = h_var.abs()
else: # norm in ['L2', 'L2_squared']
w_var = w_var ** 2
h_var = h_var ** 2
var = w_var.sum(dim=(-1, -2, -3)) + h_var.sum(dim=(-1, -2, -3))
if norm == 'L2':
var = torch.sqrt(var)
return var
class TV(nn.Module):
r"""Measures the TV of an input.
Args:
reduction: Specifies the reduction to apply to the output:
`'none'`, `'mean'` or `'sum'`.
kwargs: Keyword arguments passed to :func:`tv`.
Example:
>>> criterion = TV()
>>> x = torch.rand(5, 3, 256, 256, requires_grad=True)
>>> l = criterion(x)
>>> l.shape
torch.Size([])
>>> l.backward()
"""
def __init__(self, reduction: str = 'mean', **kwargs):
super().__init__()
self.reduction = reduction
self.kwargs = kwargs
def forward(self, x: Tensor) -> Tensor:
r"""
Args:
x: An input tensor, :math:`(N, C, H, W)`.
Returns:
The TV vector, :math:`(N,)` or :math:`()` depending on `reduction`.
"""
assert_type(x, dim_range=(4, 4))
l = tv(x, **self.kwargs)
return reduce_tensor(l, self.reduction)