Skip to content

Commit

Permalink
update edge loss
Browse files Browse the repository at this point in the history
  • Loading branch information
miaotianyi committed Jul 1, 2021
1 parent 0ccb14e commit e02336d
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 5 deletions.
2 changes: 2 additions & 0 deletions torchimage/contrib/__init__.py
@@ -1,3 +1,5 @@
"""
This module mostly contains algorithms contributed by the community.
These functions may be deleted or modified unexpectedly at any time.
"""
30 changes: 26 additions & 4 deletions torchimage/contrib/edge_loss.py
Expand Up @@ -9,10 +9,32 @@ def __init__(self, kernel_size, sigma, order):
self.gg = GaussianGrad(kernel_size=kernel_size, sigma=sigma, edge_order=order, same_padder="reflect")
self.loss = nn.L1Loss()

def _get_edge_tensor(self, y):
return torch.cat(self.gg.all_components(y, axes=(2, 3)), dim=1)

# y1, y2 must have nchw format
def forward(self, y1: torch.Tensor, y2: torch.Tensor):
edge_1 = torch.cat(self.gg.all_components(y1, axes=(2, 3)), dim=1)
edge_2 = torch.cat(self.gg.all_components(y2, axes=(2, 3)), dim=1)
return self.loss(edge_1, edge_2)
def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor):
edge_pred = self._get_edge_tensor(y_pred)
edge_true = self._get_edge_tensor(y_true)
return self.loss(edge_pred, edge_true)


class WeightedMul(GaussianEdgeLoss):
def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor):
edge_pred = self._get_edge_tensor(y_pred)
edge_true = self._get_edge_tensor(y_true)
return torch.mean(torch.abs(edge_true - edge_pred) * edge_true)


class WeightedDiv(GaussianEdgeLoss):
def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor):
edge_pred = self._get_edge_tensor(y_pred)
edge_true = self._get_edge_tensor(y_true)
return torch.mean(torch.abs(edge_true - edge_pred) / edge_true)


class WeightedSum(GaussianEdgeLoss):
def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor):
edge_pred = self._get_edge_tensor(y_pred)
edge_true = self._get_edge_tensor(y_true)
return (edge_true.abs().mean(dim=(2, 3)) - edge_pred.abs().mean(dim=(2, 3))).abs().mean()
2 changes: 1 addition & 1 deletion torchimage/metrics/ssim.py
Expand Up @@ -260,7 +260,7 @@ def _check_shape_large_enough(self, x: torch.Tensor, axes: tuple):
assert x.shape[a] / factor > self.blur.kernel_size[i]

def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor,
content_axes=slice(2, None), reduce_axes=slice(1, None), **kwargs):
content_axes=slice(2, None), reduce_axes=slice(1, None), full=False):
# before score computation, channel axes will be averaged first
# so the final axes are just the sample dimensions

Expand Down

0 comments on commit e02336d

Please sign in to comment.