In [1]:
#requires you to pip install nnunetv2 and some others
#for install details see: https://github.com/MIC-DKFZ/nnUNet
import nnunetv2
from nnunetv2.dataset_conversion.generate_dataset_json import generate_dataset_json
import os,sys
import pandas as pd
import torch
import numpy as np
from tqdm import tqdm
from torch import nn
from nnunetv2.training.loss.dice import SoftDiceLoss, MemoryEfficientSoftDiceLoss
from nnunetv2.training.loss.robust_ce_loss import RobustCrossEntropyLoss, TopKLoss
from nnunetv2.utilities.helpers import softmax_helper_dim1
from nnunetv2.training.myloss.cldice_loss.cldice import soft_cldice

#sys.path.append('/home/hvv/Documents/git_repo') #not required if in the same dir
# from nnunet_utils.utils import np2sitk, set_env_nnunet, write_envlines_nnunet, assign_trainjobs_to_gpus
# from nnunet_utils.preprocess import write_as_nnunet, nnunet_directory_structure, preprocess_data
# from nnunet_utils.run import train_single_model, nnunet_train_shell

In [9]:

class DC_CE_clDC_loss(nn.Module):
    def __init__(self, soft_dice_kwargs, 
                 ce_kwargs, cldice_kwargs, 
                 weight_ce=1, weight_dice=1, weight_cldice=1, 
                 ignore_label=None,
                 dice_class=SoftDiceLoss):
        """
        Weights for CE and Dice do not need to sum to one. You can set whatever you want.
        :param soft_dice_kwargs:
        :param ce_kwargs:
        :param aggregate:
        :param square_dice:
        :param weight_ce:
        :param weight_dice:
        """
        super(DC_CE_clDC_loss, self).__init__()
        if ignore_label is not None:
            ce_kwargs['ignore_index'] = ignore_label

        self.weight_dice = weight_dice
        self.weight_ce = weight_ce
        self.weight_cldice = weight_cldice
        self.ignore_label = ignore_label

        self.ce = RobustCrossEntropyLoss(**ce_kwargs)
        self.dc = dice_class(apply_nonlin=softmax_helper_dim1, **soft_dice_kwargs)
        self.cldc = soft_cldice(iter_=10, smooth = 1.,y_true_skel_input_channel=None,
                                exclude_background=False, **cldice_kwargs)

    def forward(self, net_output: torch.Tensor, target: torch.Tensor):
        """
        target must be b, c, x, y(, z) with c=1
        :param net_output:
        :param target:
        :return:
        """
        if self.ignore_label is not None:
            assert target.shape[1] == 1, 'ignore label is not implemented for one hot encoded target variables ' \
                                         '(DC_and_CE_loss)'
            mask = target != self.ignore_label
            # remove ignore label from target, replace with one of the known labels. It doesn't matter because we
            # ignore gradients in those areas anyway
            target_dice = torch.where(mask, target, 0)
            num_fg = mask.sum()
        else:
            target_dice = target
            mask = None

        dc_loss = self.dc(net_output, target_dice, loss_mask=mask) \
            if self.weight_dice != 0 else 0
        ce_loss = self.ce(net_output, target[:, 0]) \
            if self.weight_ce != 0 and (self.ignore_label is None or num_fg > 0) else 0
        cldc_loss = self.cldc(net_output, target_dice) \
            if self.weight_cldice != 0 else 0

        result = self.weight_ce * ce_loss + self.weight_dice * dc_loss + self.weight_cldice*cldc_loss
        return result
    
    
loss = DC_CE_clDC_loss({'batch_dice': 0,
                        'smooth': 1e-5, 'do_bg': False, 'ddp': False}, {}, {}, weight_ce=1, weight_dice=1,
                                  ignore_label=None, dice_class=MemoryEfficientSoftDiceLoss).type(torch.float32).cuda()


x = torch.ones(1,1,32,24,24).type(torch.float32).cuda()
y = torch.zeros(1,1,32,24,24).type(torch.float32).cuda()


res = loss(x,y)
res

tensor(nan, device='cuda:0')