In [5]:
import torch

input = torch.Tensor([[0, 1, 0, 0], [0, 1, 0, 0], [0, 1, 0, 0]])

In [84]:
# now run a prediction on a slice 
from unet.dataset import SegThorImagesDataset
from torch.utils.data import DataLoader

train_csv = "data/train_patient_idx_sorted.csv"
data_dir = 'data/train'

train_dataset = SegThorImagesDataset(
    patient_idx_file=train_csv,
    root_dir=data_dir,
    img_crop_size=312, 
    mask_output_size=220,
    cache_size=1
    )

train_dl = DataLoader(train_dataset, batch_size=2)



torch.Size([2, 1, 312, 312])
torch.Size([2, 220, 220])


In [None]:
dl_iter = iter(train_dl)
for i in range(550):
    X, Y = next(dl_iter)

In [121]:
print(X.shape)
print(Y.shape)

preds = torch.rand([2, 5, 220, 220])
print(f'preds shape: {preds.shape}')

# List of classes
classes = torch.tensor([1, 2, 3, 4])

# need to transform Y into a multi-channel tensor to align with preds
multi_channel_Y = torch.stack([(Y == c).int() for c in classes], dim=1)

# add one value to multi_channel_Y
multi_channel_Y[:, 1, 1, 1] = 1

print(multi_channel_Y.shape)

# Sum the boolean masks to get the count of each class
target_class_weights = torch.sum(multi_channel_Y, dim=(-1, -2))
print(f"Target class weights: {target_class_weights}")

class_weights = 1 / (target_class_weights * target_class_weights).clamp(1e-6)
print(f"Class Weights: {class_weights}")

torch.Size([2, 1, 312, 312])
torch.Size([2, 220, 220])
preds shape: torch.Size([2, 5, 220, 220])
torch.Size([2, 4, 220, 220])
Target class weights: tensor([[ 167,    1,  405, 1199],
        [ 160,    1,  254, 1212]])
Class Weights: tensor([[3.5856e-05, 1.0000e+00, 6.0966e-06, 6.9560e-07],
        [3.9062e-05, 1.0000e+00, 1.5500e-05, 6.8076e-07]])


In [122]:
# ignore the 0 class 

# intersection mask 
intersection = (preds[:, 1:, :, :] * multi_channel_Y).sum(dim=(-1, -2))
print(f"Intersection: {intersection}")

# union 
union = (preds[:, 1:, :, :] + multi_channel_Y).sum(dim=(-1, -2))

print(f"Union: {union}")

Intersection: tensor([[8.2069e+01, 4.6458e-01, 2.0316e+02, 6.2095e+02],
        [7.8675e+01, 8.1349e-02, 1.2879e+02, 6.0542e+02]])
Union: tensor([[24260.6055, 24248.4062, 24694.6797, 25433.5703],
        [24292.5195, 24104.7344, 24458.5605, 25467.1445]])


In [126]:
# average over labels 
2*( (intersection*class_weights) / (union*class_weights) ).mean(dim=1)
# average over batch 


tensor([0.0180, 0.0161])

In [None]:

# GDL weighting: the contribution of each label is corrected by the inverse of its volume
class_weight = preds.sum(-1)
w_l = 1 / (w_l * w_l).clamp(min=self.epsilon)
w_l.requires_grad = False

intersect = (input * target).sum(-1)
intersect = intersect * w_l

denominator = (input + target).sum(-1)
denominator = (denominator * w_l).clamp(min=self.epsilon)

return 2 * (intersect.sum() / denominator.sum())

In [15]:
w_l = input.sum()

w_l = 1 / (w_l * w_l).clamp(min=0.2)

In [16]:
w_l

tensor(0.1111)

In [136]:
import torch.nn as nn

class GeneralizedDiceLoss(nn.Module):
    """
    https://arxiv.org/pdf/1707.03237.pdf
    """

    def __init__(self, classes = torch.tensor([1, 2, 3, 4]), epsilon=1e-6, ):
        super(GeneralizedDiceLoss, self).__init__()
        self.classes = classes 
        # normalize the logits, since each channel is predicting its own class
        self.sigmoid = nn.Sigmoid()
        self.softmax = nn.Softmax(dim=1) # channel

        self.epsilon = epsilon

    def normalization(self, input):
        ## sigmoid to squash inputs
        ## softmax across channels to get to probabilities
        logits = self.sigmoid(input)
        proba = self.softmax(logits)
        return proba


    def forward(self, input, target):
        # get probabilities from logits
        input = self.normalization(input)

        # compute per batch item multi-label dice
        per_item_dice = self.dice(input, target)
        batch_losses = 1. - per_item_dice
        # sum across batches
        return batch_losses.sum()

    def dice(self, input, target):
        """
        input shape expects batch x class x h x w (drop the 0 class)
        target shape expects batch x h x w
        """
        # List of classes
        classes = torch.tensor([1, 2, 3, 4])
        # need to transform target into a multi-channel tensor to align with preds
        multi_channel_target = torch.stack([(target == c).int() for c in self.classes], dim=1)
        # Sum the boolean masks to get the count of each class
        target_class_weights = torch.sum(multi_channel_target, dim=(-1, -2))
        class_weights = 1 / (target_class_weights * target_class_weights).clamp(1e-6)
        class_weights.requires_grad = False

        intersection = (input * multi_channel_target).sum(dim=(-1, -2))
        union = (input + multi_channel_target).sum(dim=(-1, -2))
        # take average over labels 
        batch_dice = 2*( (intersection*class_weights) / (union*class_weights) ).mean(dim=1)
        return batch_dice




In [142]:
loss = GeneralizedDiceLoss()

In [143]:
print(Y.shape)

preds = torch.rand([2, 5, 220, 220])

print(preds.shape)

torch.Size([2, 220, 220])
torch.Size([2, 5, 220, 220])


In [150]:
loss(preds[:, 1:, :, :], Y)

tensor(1.9659)