## 1. dist_map_transform

In [1]:
from typing import Callable, BinaryIO, Match, Pattern, Tuple, Union, Optional, cast
from typing import Any, Callable, Iterable, List, Set, Tuple, TypeVar, Union, cast
from functools import partial
from torch import Tensor
import torch
from torchvision import transforms
import numpy as np
from operator import itemgetter, mul

In [2]:

def uniq(a: Tensor) -> Set:
    return set(torch.unique(a.cpu()).numpy())

def sset(a: Tensor, sub) -> bool:
    return uniq(a).issubset(sub)

def simplex(t: Tensor, axis=1) -> bool:
    _sum = cast(Tensor, t.sum(axis).type(torch.float32))
    _ones = torch.ones_like(_sum, dtype=torch.float32)
    return torch.allclose(_sum, _ones)

def one_hot(t: Tensor, axis=1) -> bool:
    return simplex(t, axis) and sset(t, [0, 1])

def class2one_hot(seg: Tensor, K: int) -> Tensor:
    # Breaking change but otherwise can't deal with both 2d and 3d
    # if len(seg.shape) == 3:  # Only w, h, d, used by the dataloader
    #     return class2one_hot(seg.unsqueeze(dim=0), K)[0]

    assert sset(seg, list(range(K))), (uniq(seg), K)

    b, *img_shape = seg.shape  # type: Tuple[int, ...]

    device = seg.device
    res = torch.zeros((b, K, *img_shape), dtype=torch.int32, device=device).scatter_(1, seg[:, None, ...], 1)

    assert res.shape == (b, K, *img_shape)
    assert one_hot(res)

    return res


In [3]:
def gt_transform(resolution: Tuple[float, ...], K: int):
        return transforms.Compose([
                lambda img: np.array(img)[...],
                lambda nd: torch.tensor(nd, dtype=torch.int64)[None, ...],  # Add one dimension to simulate batch
                partial(class2one_hot, K=K),
                itemgetter(0)  # Then pop the element to go back to img shape
        ])


In [4]:
from scipy.ndimage import distance_transform_edt as eucl_distance

def one_hot2dist(seg: np.ndarray, resolution: Tuple[float, float, float] = None,
                 dtype=None) -> np.ndarray:
    assert one_hot(torch.tensor(seg), axis=0)
    K: int = len(seg)

    res = np.zeros_like(seg, dtype=dtype)
    for k in range(K):
        posmask = seg[k].astype(bool)

        if posmask.any():
            negmask = ~posmask
            res[k] = eucl_distance(negmask, sampling=resolution) * negmask \
                - (eucl_distance(posmask, sampling=resolution) - 1) * posmask
        # The idea is to leave blank the negative classes
        # since this is one-hot encoded, another class will supervise that pixel

    return res


In [5]:
def dist_map_transform(resolution: Tuple[float, ...], K: int):
        return transforms.Compose([
                gt_transform(resolution, K),
                lambda t: t.cpu().numpy(),
                partial(one_hot2dist, resolution=resolution),
                lambda nd: torch.tensor(nd, dtype=torch.float32)
        ])

In [6]:
BATCH_SIZE = 8
NUM_CHANNELS = 3
NUM_CLASSES = 2
HEIGHT = 256
WIDTH = 256

images = torch.randint(0, 1, (BATCH_SIZE, NUM_CHANNELS, HEIGHT, WIDTH), dtype=torch.float32)
labels = torch.randint(0, NUM_CLASSES, (BATCH_SIZE, HEIGHT, WIDTH), dtype=torch.int64)

images.shape, labels.shape

(torch.Size([8, 3, 256, 256]), torch.Size([8, 256, 256]))

In [7]:
one_hot_labels = torch.stack([class2one_hot(label.unsqueeze(0), NUM_CLASSES).squeeze(0) for label in labels])

one_hot_labels.shape

torch.Size([8, 2, 256, 256])

In [8]:
disttransform = dist_map_transform([1, 1], NUM_CLASSES)

In [9]:
dist_map_labels: Tensor = torch.stack([disttransform(label) for label in labels])

dist_map_labels.shape

torch.Size([8, 2, 256, 256])

## 2. GeneralizedDiceLoss

In [10]:
from torch import Tensor, einsum

class GeneralizedDiceLoss():
    def __init__(self, **kwargs):
        # Self.idc is used to filter out some classes of the target mask. Use fancy indexing
        self.idc: List[int] = kwargs["idc"]
        print(f"Initialized {self.__class__.__name__} with {kwargs}")

    def __call__(self, probs: Tensor, target: Tensor) -> Tensor:
        assert simplex(probs) and simplex(target)

        pc = probs[:, self.idc, ...].type(torch.float32)
        tc = target[:, self.idc, ...].type(torch.float32)

        w: Tensor = 1 / ((einsum("bkwh->bk", tc).type(torch.float32) + 1e-10) ** 2)
        intersection: Tensor = w * einsum("bkwh,bkwh->bk", pc, tc)
        union: Tensor = w * (einsum("bkwh->bk", pc) + einsum("bkwh->bk", tc))

        divided: Tensor = 1 - 2 * (einsum("bk->b", intersection) + 1e-10) / (einsum("bk->b", union) + 1e-10)

        loss = divided.mean()

        return loss


## 3. BoundaryLoss

In [11]:
class SurfaceLoss():
    def __init__(self, **kwargs):
        # Self.idc is used to filter out some classes of the target mask. Use fancy indexing
        self.idc: List[int] = kwargs["idc"]
        print(f"Initialized {self.__class__.__name__} with {kwargs}")

    def __call__(self, probs: Tensor, dist_maps: Tensor) -> Tensor:
        assert simplex(probs)
        assert not one_hot(dist_maps)

        pc = probs[:, self.idc, ...].type(torch.float32)
        dc = dist_maps[:, self.idc, ...].type(torch.float32)

        multipled = einsum("bkwh,bkwh->bkwh", pc, dc)

        loss = multipled.mean()

        return loss

BoundaryLoss = SurfaceLoss

In [12]:
dice_loss = GeneralizedDiceLoss(idc=[0, 1])
boundary_loss = BoundaryLoss(idc=[1])

Initialized GeneralizedDiceLoss with {'idc': [0, 1]}
Initialized SurfaceLoss with {'idc': [1]}


In [13]:
images.shape, labels.shape, one_hot_labels.shape, dist_map_labels.shape

(torch.Size([8, 3, 256, 256]),
 torch.Size([8, 256, 256]),
 torch.Size([8, 2, 256, 256]),
 torch.Size([8, 2, 256, 256]))

In [14]:
import torch.nn.functional as F

pred_logits = torch.rand(BATCH_SIZE, NUM_CLASSES, HEIGHT, WIDTH)
pred_probs = F.softmax(pred_logits, dim=1)
pred_logits.shape, pred_probs.shape

(torch.Size([8, 2, 256, 256]), torch.Size([8, 2, 256, 256]))

In [15]:
pred_probs.shape, one_hot_labels.shape

(torch.Size([8, 2, 256, 256]), torch.Size([8, 2, 256, 256]))

In [16]:
a = 0.01

gdl_loss = dice_loss(pred_probs, one_hot_labels)
bl_loss = boundary_loss(pred_probs, dist_map_labels)

total_loss = gdl_loss + a * bl_loss
print(f"Total loss: {total_loss:.6f}")
print(f"Dice loss: {gdl_loss:.6f}")
print(f"Boundary loss: {bl_loss:.6f}")

Total loss: 0.502377
Dice loss: 0.499884
Boundary loss: 0.249297
