In [53]:
import kornia as K
import math
import torch
import torch.nn as nn
import torch.nn.functional as F


In [2]:
cdt = K.contrib.ConvDistanceTransform()

In [3]:
import torch

In [103]:
t = torch.zeros((1, 3, 100, 100))
t[0, 0, 40:60, 50:70] = .01
t[0, 2, 10:60, 20:70] = .01

In [5]:
cdt(t)

tensor([[[[49.8104, 48.8733, 47.7875,  ..., 39.7925, 39.7994, 39.8733],
          [49.7140, 48.7994, 47.6844,  ..., 38.6601, 38.6844, 38.7875],
          [49.6982, 48.7925, 47.6601,  ..., 37.6982, 37.7140, 37.8104],
          ...,
          [49.6982, 48.7925, 47.6601,  ..., 37.6982, 37.7140, 37.8104],
          [49.7140, 48.7994, 47.6844,  ..., 38.6601, 38.6844, 38.7875],
          [49.8104, 48.8733, 47.7875,  ..., 39.7925, 39.7994, 39.8733]]]])

In [6]:
t = torch.Tensor(1)

In [7]:
cdt(t)

ValueError: Invalid input shape, we expect BxCxHxW. Got: torch.Size([1])

In [8]:
t = torch.Tensor([[[[3, 2,1]]]]).to(torch.int)
t.dtype

torch.int32

In [9]:
cdt(t)

tensor([[[[0., 0., 0.]]]])

In [10]:
t = torch.Tensor([[[[3, 2,1]]]])
cdt(t)

tensor([[[[0., 0., 0.]]]])

In [11]:
t = torch.Tensor([[[[3, 2,0]]]]).to(torch.int)
cdt(t)

tensor([[[[0.0000, 0.0000, 0.7285]]]])

In [12]:
t = torch.zeros((1,1,7,7))
t[0,0,0,1]=1
cdt(t)

tensor([[[[1.0000, 0.0000, 1.0000, 2.0000, 3.0000, 3.8733, 4.8104],
          [1.4142, 1.0000, 1.4142, 2.2361, 3.1623, 3.7999, 4.7158],
          [2.2361, 2.0000, 2.2361, 2.8284, 3.6056, 3.7999, 4.7158],
          [3.1623, 3.0000, 3.1623, 3.6056, 4.2426, 3.8733, 4.8104],
          [3.8733, 3.7994, 3.7931, 3.7994, 3.8733, 4.3428, 5.1447],
          [4.8104, 4.7140, 4.7000, 4.7140, 4.8104, 5.1447, 5.7546],
          [5.7875, 5.6844, 5.6639, 5.6844, 5.7875, 6.0631, 6.5530]]]])

In [None]:
# Play with value of lambda/h in DT_LogConv from Karam et als work, which is what Pham used


In [110]:
def make_cdt_kernel(
    kernel_size: int,
) -> torch.Tensor:
    # Value of h derived from the parameters used by Pham et. al in their proposal of the algorithm.
    h = -0.35
    grid_range = torch.Tensor(range(kernel_size))
    
    gridx, gridy = torch.meshgrid(grid_range, grid_range)
    gridx = gridx - math.floor(kernel_size / 2)
    gridy = gridy - math.floor(kernel_size / 2)
    
    kernel = torch.hypot(gridx, gridy) 
    kernel = torch.exp(kernel/h)
    
    # for BCHW tensors
    kernel = torch.unsqueeze(kernel, 0)
    kernel = torch.unsqueeze(kernel, 0)

    return kernel

def conv_distance_transform(
    input: torch.Tensor,
    kernel_size: int = 7
) -> torch.Tensor:
    r"""Approximate the Manhattan (city block) distance transform of images using convolutions.
    
    The value at each pixel in the output represents the distance to the nearest non-zero pixel in the input image.
    The transformation is applied independently across the channel dimension of the inputs.

    
    Args:
        input: input image tensor with shape :math:`(B,C,H,W)`.
        kernel_size: size of the convolution kernel.

    Returns:
        tensor with shape :math:`(B,C,H,W)`.

    """
    if not isinstance(input, torch.Tensor):
        raise TypeError(f"Input type is not a torch.Tensor. Got {type(input)}")

    if not len(input.shape) == 4:
        raise ValueError(f"Invalid input shape, we expect BxCxHxW. Got: {input.shape}")

    device: torch.device = input.device
    dtype: torch.dtype = input.dtype

    n_iters = math.ceil(max(input.shape[2], input.shape[3]) / math.floor(kernel_size / 2))
    kernel = make_cdt_kernel(kernel_size)

    out = torch.zeros(input.shape, dtype=torch.float32, device=device)

    # It is possible to avoid cloning the input if boundary = input, but this would require modifying the input tensor.
    boundary = input.clone().to(torch.float32)
    kernel.to(device)

    # If input images have multiple channels, view the channels in the batch dimension to match kernel shape.
    if input.shape[1] > 1:
        batch_channel_view_shape = (input.shape[0] * input.shape[1], 1, input.shape[2], input.shape[3])
        out = out.view(*batch_channel_view_shape)
        boundary = boundary.view(*batch_channel_view_shape)

    for i in range(n_iters):
        cdt = F.conv2d(boundary, kernel, padding='same')
        cdt = -0.35*torch.log(cdt)

        # We are calculating log(0) above.
        torch.nan_to_num(cdt, out=cdt, posinf=0.0)

        mask = cdt > 0
        if mask.sum() == 0:
            break
        offset =  i * kernel_size / 2
        out[mask] +=  offset + cdt[mask]
        boundary[mask] = 1

    # View channels in the channel dimension, if they were added to batch dimension during transform.
    if input.shape[1] > 1:
        out = out.view(input.shape)
    return out

In [111]:
conv_distance_transform(t, 37)

tensor([[[[51.3669, 50.4196, 49.4790,  ..., 40.5935, 40.6256, 40.7296],
          [51.0870, 50.1234, 49.1642,  ..., 39.6361, 39.6609, 39.7639],
          [50.8609, 49.8849, 48.9114,  ..., 38.6965, 38.7123, 38.8088],
          ...,
          [50.8609, 49.8849, 48.9114,  ..., 38.6965, 38.7123, 38.8088],
          [51.0870, 50.1234, 49.1642,  ..., 39.6361, 39.6609, 39.7639],
          [51.3669, 50.4196, 49.4790,  ..., 40.5935, 40.6256, 40.7296]],

         [[ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
          [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
          [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
          ...,
          [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
          [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
          [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000]],

         [[20.3088, 19.3731, 22.1112,  ..., 28.1088, 29.0954, 30.0830],
          [20.2123, 19.2993, 2