In [1]:
import numpy as np
import torch
import torch.nn as nn
from pathlib import *

torch.__version__

'1.0.0.dev20181128'

In [2]:
def printWithDec(v,title=None,d=2): 
    with np.printoptions(precision=d, suppress=True): 
        if title is None : print(v.numpy())
        else: print(f"{title}:", v.numpy())

In [3]:
#Example
#let bs,c,width,height be batchsize, number og classes, width, height of the image. 
#x_ predictions for all images and classes.
#y: the groundtrouth is a mask of the class og each pixel (ie a compact representation of one-hot-encoding)
#w: is the weight of each class in the loss function

bs,c,width,height =   4,  6 ,2    , 5
x = torch.randn(      bs, c, width, height) 
y = torch.randint(c, (bs,    width, height) ) 
w = torch.rand(c)
w = w/w.sum() #normalize
eps = 1e-6
smooth=0
print(f"Size of x, y, w: {x.size()}, {y.size()}, {w.size()}")

Size of x, y, w: torch.Size([4, 6, 2, 5]), torch.Size([4, 2, 5]), torch.Size([6])


In [6]:
def dice(x, y, smooth=0., l1norm=True):
    eps = 1e-6
    
    bs, nc, cols, rows = x.size()
    xp = x.permute(0, 2, 3, 1)
    xp = xp.contiguous().view(-1, nc).softmax(dim=1)
    
    #make one hot encoding of ground truth
    yp    = y.view(-1)
    ix    = torch.arange(len(yp))
    yhot  = torch.zeros_like(xp)
    yhot[ix,yp[ix]] = 1.
    
    intersection = 2.*(xp*yhot).sum()
    normalize    = (xp + yhot).sum() if l1norm else (xp**2 + yhot**2).sum()
    
    return 1- (intersection+smooth)/(normalize+smooth+eps)

dice(x,y)

tensor(0.7690)

In [None]:
printWithDec(x[0], "x[0]",2)

In [None]:
y[0]

# DICE loss simple multiclass

In [None]:
def dice_loss(input, target):
    smooth = 1.
    loss = 0.
    for c in range(n_classes):
           iflat = input[:, c ].view(-1)
           tflat = target[:, c].view(-1)
           intersection = (iflat * tflat).sum()
           
           w = class_weights[c]
           loss += w*(1 - ((2. * intersection + smooth) /
                             (iflat.sum() + tflat.sum() + smooth)))
    return loss

# Generalized dice


In [None]:
def labels_to_one_hot(ground_truth, num_classes=1):
    """
    Converts ground truth labels to one-hot, sparse tensors.
    Used extensively in segmentation losses.
    :param ground_truth: ground truth categorical labels (rank `N`)
    :param num_classes: A scalar defining the depth of the one hot dimension
        (see `depth` of `tf.one_hot`)
    :return: one-hot sparse tf tensor
        (rank `N+1`; new axis appended at the end)
    """
    # read input/output shapes
    if isinstance(num_classes, tf.Tensor):
        num_classes_tf = tf.to_int32(num_classes)
    else:
        num_classes_tf = tf.constant(num_classes, tf.int32)
    input_shape = tf.shape(ground_truth)
    output_shape = tf.concat(
        [input_shape, tf.reshape(num_classes_tf, (1,))], 0)

    if num_classes == 1:
        # need a sparse representation?
        return tf.reshape(ground_truth, output_shape)

    # squeeze the spatial shape
    ground_truth = tf.reshape(ground_truth, (-1,))
    # shape of squeezed output
    dense_shape = tf.stack([tf.shape(ground_truth)[0], num_classes_tf], 0)

    # create a rank-2 sparse tensor
    ground_truth = tf.to_int64(ground_truth)
    ids = tf.range(tf.to_int64(dense_shape[0]), dtype=tf.int64)
    ids = tf.stack([ids, ground_truth], axis=1)
    one_hot = tf.SparseTensor(
        indices=ids,
        values=tf.ones_like(ground_truth, dtype=tf.float32),
        dense_shape=tf.to_int64(dense_shape))

    # resume the spatial dims
    one_hot = tf.sparse_reshape(one_hot, output_shape)
    return one_hot


def generalised_dice_loss(prediction,
                          ground_truth,
                          weight_map=None,
                          type_weight='Square'):
    """
    Function to calculate the Generalised Dice Loss defined in
        Sudre, C. et. al. (2017) Generalised Dice overlap as a deep learning
        loss function for highly unbalanced segmentations. DLMIA 2017
    :param prediction: the logits
    :param ground_truth: the segmentation ground truth
    :param weight_map:
    :param type_weight: type of weighting allowed between labels (choice
        between Square (square of inverse of volume),
        Simple (inverse of volume) and Uniform (no weighting))
    :return: the loss
    """
    prediction = tf.cast(prediction, tf.float32)
    if len(ground_truth.shape) == len(prediction.shape):
        ground_truth = ground_truth[..., -1]
    one_hot = labels_to_one_hot(ground_truth, tf.shape(prediction)[-1])

    if weight_map is not None:
        num_classes = prediction.shape[1].value
        # weight_map_nclasses = tf.reshape( tf.tile(weight_map, [num_classes]), prediction.get_shape())
        
        weight_map_nclasses = tf.tile( tf.expand_dims(tf.reshape(weight_map, [-1]), 1), [1, num_classes])
        ref_vol = tf.sparse_reduce_sum( weight_map_nclasses * one_hot, reduction_axes=[0])

        intersect = tf.sparse_reduce_sum( weight_map_nclasses * one_hot * prediction, reduction_axes=[0])
        seg_vol = tf.reduce_sum( tf.multiply(weight_map_nclasses, prediction), 0)
    else:
        ref_vol = tf.sparse_reduce_sum(one_hot, reduction_axes=[0])
        intersect = tf.sparse_reduce_sum(one_hot * prediction, reduction_axes=[0])
        seg_vol = tf.reduce_sum(prediction, 0)
        
    if type_weight == 'Square':    weights = tf.reciprocal(tf.square(ref_vol))
    elif type_weight == 'Simple':  weights = tf.reciprocal(ref_vol)
    elif type_weight == 'Uniform': weights = tf.ones_like(ref_vol)
    else:
        raise ValueError("The variable type_weight \"{} is not defined.".format(type_weight))
        
    new_weights = tf.where(tf.is_inf(weights), tf.zeros_like(weights), weights)
    weights     = tf.where(tf.is_inf(weights), tf.ones_like(weights) * tf.reduce_max(new_weights), weights)
    
    generalised_dice_numerator =  2 * tf.reduce_sum(tf.multiply(weights, intersect))
    
    # generalised_dice_denominator = \
    #     tf.reduce_sum(tf.multiply(weights, seg_vol + ref_vol)) + 1e-6
    generalised_dice_denominator = tf.reduce_sum(tf.multiply(weights, tf.maximum(seg_vol + ref_vol, 1)))
    
    generalised_dice_score = generalised_dice_numerator / generalised_dice_denominator
    generalised_dice_score = tf.where(tf.is_nan(generalised_dice_score), 1.0, generalised_dice_score)
    return 1 - generalised_dice_score