In [3]:
import torch
import torch.nn as nn

In [6]:
class _DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 3, padding=1),
            nn.ReLU(inplace=True),
            nn.BatchNorm2d(out_channels),
            nn.Conv2d(out_channels, out_channels, 3, padding=1),
            nn.ReLU(inplace=True),
            nn.BatchNorm2d(out_channels)
        )

    def forward(self, x):
        return self.conv(x)

class _DownConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.down = nn.Sequential(
            nn.MaxPool2d(2),
            _DoubleConv(in_channels, out_channels)
        )

    def forward(self, x):
        return self.down(x)

class _UpSample(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.up = nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True)
        self.double_conv = _DoubleConv(in_channels, out_channels)

    def forward(self, x1, x2):
        x1 = self.up(x1)
        x = torch.cat([x1, x2], dim=1)
        x = self.double_conv(x)
        return x

class UNet(nn.Module):
    def __init__(self, depth=5, scp=4):
        super().__init__()

        self.depth = depth
        self.scp = scp

        self.down_blocks = nn.ModuleList([None for _ in range(depth)])
        self.down_blocks[0] = _DoubleConv(1, 2**scp)
        for i in range(1, depth):
            self.down_blocks[i] = _DownConv(2**(scp + i - 1), 2**(scp + i))

        self.up_blocks = nn.ModuleList([None for _ in range(depth)])
        for i in range(depth - 1):
            p1 = 2**(scp + depth - i - 1)
            p2 = 2**(scp + depth - i - 2)
            self.up_blocks[i] = _UpSample(p1 + p2, p2)
        self.up_blocks[depth - 1] = nn.Conv2d(2**scp, 1, 1)

    def forward(self, x):
        d = self.depth

        xs = [None for _ in range(2 * d)]
        xs[0] = self.down_blocks[0](x)

        for i in range(1, d):
            xs[i] = self.down_blocks[i](xs[i - 1])

        for i in range(d - 1):
            xs[d + i] = self.up_blocks[i](xs[d + i - 1], xs[d - i - 2])
        xs[2 * d - 1] = self.up_blocks[d - 1](xs[2 * d - 2])

        return torch.sigmoid(xs[2 * d - 1])

In [7]:
class DiceLoss(nn.Module):
    def __init__(self, smooth=1e-6):
        super().__init__()
        self.smooth = smooth
        
    def forward(self, prediction, target):
        prediction_f = prediction.flatten()
        target_f = target.flatten()
        
        intersection = prediction_f * target_f
        dice_coef = (2.0 * intersection.sum() + self.smooth) / (prediction_f.sum() + target_f.sum() + self.smooth)
        
        return 1 - dice_coef