Skip to content

Commit

Permalink
Merge pull request #58 from arraiyopensource/feat/depth_smooth
Browse files Browse the repository at this point in the history
Feat/depth smooth
  • Loading branch information
edgarriba committed Jan 30, 2019
2 parents 43bd8c2 + 0da418c commit 42a1d22
Show file tree
Hide file tree
Showing 5 changed files with 138 additions and 0 deletions.
2 changes: 2 additions & 0 deletions docs/source/losses.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,5 +4,7 @@ torchgeometry.losses
.. currentmodule:: torchgeometry.losses

.. autofunction:: ssim
.. autofunction:: depth_smoothness_loss

.. autoclass:: SSIM
.. autoclass:: DepthSmoothnessLoss
1 change: 1 addition & 0 deletions mypy_files.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
torchgeometry/image/gaussian.py
torchgeometry/losses/ssim.py
torchgeometry/losses/depth_smooth.py
torchgeometry/homography_warper.py
35 changes: 35 additions & 0 deletions test/test_losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,41 @@
from common import TEST_DEVICES


class TestDepthSmoothnessLoss:
def _test_smoke(self):
image = self.image.clone()
depth = self.depth.clone()

criterion = tgm.losses.DepthSmoothnessLoss()
loss = criterion(depth, image)

# TODO: implement me
def _test_1(self):
pass

# TODO: implement me
def _test_jit(self):
pass

def _test_gradcheck(self):
image = self.image.clone()
depth = self.depth.clone()
depth = utils.tensor_to_gradcheck_var(depth) # to var
image = utils.tensor_to_gradcheck_var(image) # to var
assert gradcheck(tgm.losses.depth_smoothness_loss,
(depth, image,), raise_exception=True)

@pytest.mark.parametrize("device_type", TEST_DEVICES)
@pytest.mark.parametrize("batch_shape",
[(1, 1, 10, 16), (2, 4, 8, 15), ])
def test_run_all(self, batch_shape, device_type):
self.image = torch.rand(batch_shape).to(torch.device(device_type))
self.depth = torch.rand(batch_shape).to(torch.device(device_type))

self._test_smoke()
self._test_gradcheck()


@pytest.mark.parametrize("window_size", [5, 11])
@pytest.mark.parametrize("reduction_type", ['none', 'mean', 'sum'])
@pytest.mark.parametrize("device_type", TEST_DEVICES)
Expand Down
1 change: 1 addition & 0 deletions torchgeometry/losses/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
from .ssim import SSIM, ssim
from .depth_smooth import DepthSmoothnessLoss, depth_smoothness_loss
99 changes: 99 additions & 0 deletions torchgeometry/losses/depth_smooth.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
import torch
import torch.nn as nn
import torch.nn.functional as F

# Based on
# https://github.com/tensorflow/models/blob/master/research/struct2depth/model.py#L625-L641


class DepthSmoothnessLoss(nn.Module):
r"""Criterion that computes image-aware depth smoothness loss.
.. math::
\text{loss} = \left | \partial_x d_{ij} \right | e^{-\left \|
\partial_x I_{ij} \right \|} + \left |
\partial_y d_{ij} \right | e^{-\left \| \partial_y I_{ij} \right \|}
Shape:
- Depth: :math:`(N, 1, H, W)`
- Image: :math:`(N, 3, H, W)`
- Output: scalar
Examples::
>>> depth = torch.rand(1, 1, 4, 5)
>>> image = torch.rand(1, 3, 4, 5)
>>> smooth = tgm.losses.DepthSmoothnessLoss()
>>> loss = smooth(depth, image)
"""

def __init__(self) -> None:
super(DepthSmoothnessLoss, self).__init__()

@staticmethod
def gradient_x(img: torch.Tensor) -> torch.Tensor:
assert len(img.shape) == 4, img.shape
return img[:, :, :, :-1] - img[:, :, :, 1:]

@staticmethod
def gradient_y(img: torch.Tensor) -> torch.Tensor:
assert len(img.shape) == 4, img.shape
return img[:, :, :-1, :] - img[:, :, 1:, :]

def forward(self, depth: torch.Tensor, image: torch.Tensor) -> torch.Tensor:
if not torch.is_tensor(depth):
raise TypeError("Input depth type is not a torch.Tensor. Got {}"
.format(type(depth)))
if not torch.is_tensor(image):
raise TypeError("Input image type is not a torch.Tensor. Got {}"
.format(type(image)))
if not len(depth.shape) == 4:
raise ValueError("Invalid depth shape, we expect BxCxHxW. Got: {}"
.format(depth.shape))
if not len(image.shape) == 4:
raise ValueError("Invalid image shape, we expect BxCxHxW. Got: {}"
.format(image.shape))
if not depth.shape[-2:] == image.shape[-2:]:
raise ValueError("depth and image shapes must be the same. Got: {}"
.format(depth.shape, image.shape))
if not depth.device == image.device:
raise ValueError(
"depth and image must be in the same device. Got: {}" .format(
depth.device, image.device))
if not depth.dtype == image.dtype:
raise ValueError(
"depth and image must be in the same dtype. Got: {}" .format(
depth.dtype, image.dtype))
# compute the gradients
depth_dx: torch.Tensor = self.gradient_x(depth)
depth_dy: torch.Tensor = self.gradient_y(depth)
image_dx: torch.Tensor = self.gradient_x(image)
image_dy: torch.Tensor = self.gradient_y(image)

# compute image weights
weights_x: torch.Tensor = torch.exp(
-torch.mean(torch.abs(image_dx), dim=1, keepdim=True))
weights_y: torch.Tensor = torch.exp(
-torch.mean(torch.abs(image_dy), dim=1, keepdim=True))

# apply image weights to depth
smoothness_x: torch.Tensor = torch.abs(depth_dx * weights_x)
smoothness_y: torch.Tensor = torch.abs(depth_dy * weights_y)
return torch.mean(smoothness_x) + torch.mean(smoothness_y)


######################
# functional interface
######################


def depth_smoothness_loss(
depth: torch.Tensor,
image: torch.Tensor) -> torch.Tensor:
r"""Computes image-aware depth smoothness loss.
See :class:`~torchgeometry.losses.DepthSmoothnessLoss` for details.
"""
return DepthSmoothnessLoss()(depth, image)

0 comments on commit 42a1d22

Please sign in to comment.