In [3]:
import torch
import torch.nn as nn

In [6]:
class _DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 3, padding=1),
            nn.ReLU(inplace=True),
            nn.BatchNorm2d(out_channels),
            nn.Conv2d(out_channels, out_channels, 3, padding=1),
            nn.ReLU(inplace=True),
            nn.BatchNorm2d(out_channels)
        )

    def forward(self, x):
        return self.conv(x)

class _DownConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.down = nn.Sequential(
            nn.MaxPool2d(2),
            _DoubleConv(in_channels, out_channels)
        )

    def forward(self, x):
        return self.down(x)

class _UpSample(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.up = nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True)
        self.double_conv = _DoubleConv(in_channels, out_channels)

    def forward(self, x1, x2):
        x1 = self.up(x1)
        x = torch.cat([x1, x2], dim=1)
        x = self.double_conv(x)
        return x

class UNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.down_block1 = _DoubleConv(1, 16)
        self.down_block2 = _DownConv(16, 32)
        self.down_block3 = _DownConv(32, 64)
        self.down_block4 = _DownConv(64, 128)
        self.down_block5 = _DownConv(128, 256)
        
        self.up_block1 = _UpSample(256 + 128, 128)
        self.up_block2 = _UpSample(128 + 64, 64)
        self.up_block3 = _UpSample(64 + 32, 32)
        self.up_block4 = _UpSample(32 + 16, 16)
        self.up_block5 = nn.Conv2d(16, 1, 1)

    def forward(self, x):
        x1 = self.down_block1(x)
        x2 = self.down_block2(x1)
        x3 = self.down_block3(x2)
        x4 = self.down_block4(x3)
        x5 = self.down_block5(x4)

        x6 = self.up_block1(x5, x4)
        x7 = self.up_block2(x6, x3)
        x8 = self.up_block3(x7, x2)
        x9 = self.up_block4(x8, x1)
        x10 = self.up_block5(x9)
        
        out = torch.sigmoid(x10)
        return out

In [7]:
class DiceLoss(nn.Module):
    def __init__(self, smooth=1e-6):
        super().__init__()
        self.smooth = smooth
        
    def forward(self, prediction, target):
        prediction_f = prediction.flatten()
        target_f = target.flatten()
        
        intersection = prediction_f * target_f
        dice_coef = (2.0 * intersection.sum() + self.smooth) / (prediction_f.sum() + target_f.sum() + self.smooth)
        
        return 1 - dice_coef