# Segmentation with UNet

In this notebook we will train a simple UNet model for foreground/background segmentation.

**Goal.** The goal of this notebook is to get experience in working with pix-to-pix models and training segmentation models.

You need the following extra libraries beyond PyTorch:
* albumentations
* torchvision
* Pillow (PIL)
* (optional) segmentation_models_pytorch

In [None]:
import albumentations as A
import gc
import numpy as np
import pytorch_lightning as pl
import random
import torch
from albumentations.pytorch import ToTensorV2
from PIL import Image
from matplotlib import pyplot as plt
from torchvision.datasets import VOCSegmentation

print("Have CUDA:", torch.cuda.is_available())
print("Torch version:", torch.__version__)

IMAGE_SIZE = 256
BATCH_SIZE = 8
VIZ_IMAGES = 4

DATA_ROOT = "."
VOC_YEAR = "2012"
LABELS = ["background", "aeroplane", "bicycle", "bird", "boat", "bottle",
          "bus", "car", "cat", "chair", "cow", "diningtable", "dog", "horse",
          "motorbike", "person", "potted_plant", "sheep", "sofa", "train", "tv/monitor"]

# Helper tools.
# You can skip this block.

class Module(pl.LightningModule):
    def __init__(self, model, loss):
        super().__init__()
        self.model = model
        self.loss = loss

    def forward(self, x):
        return self.model(x)

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=0.0001)

    def training_step(self, batch):
        images, masks = batch
        predictions = self(images)
        loss = self.loss(predictions, masks)
        self.log("train/loss", loss, prog_bar=True)
        if self.global_step % 100 == 0:
            # Log sample images.
            s_pred, s_image, s_mask = predictions[:VIZ_IMAGES], images[:VIZ_IMAGES], masks[:VIZ_IMAGES]
            s_image = s_image - s_image.min()
            s_image = s_image / s_image.max()  # (B, C, H, W).
            s_mask = s_mask.unsqueeze(1).repeat(1, 3, 1, 1)  # (B, C, H, W).
            s_pred = torch.sigmoid(s_pred).repeat(1, 3, 1, 1)  # (B, C, H, W).
            s_pred_bin = (s_pred > 0.5).float()
            log_image = torch.cat([s_image, s_mask, s_pred, s_pred_bin], dim=2).permute(1, 2, 0, 3).flatten(2, 3)  # (C, 4H, BW).
            for logger in self.trainer.loggers:
                if isinstance(logger, pl.loggers.TensorBoardLogger):
                    tb_logger = logger.experiment
                    tb_logger.add_image(f"Result", log_image, self.global_step)
            self._data = [s_image, s_mask, s_pred, s_pred_bin, predictions[:4]]
        return loss

    def validation_step(self, batch):
        images, masks = batch
        predictions = self(images)
        loss = self.loss(predictions, masks)
        self.log("val/loss", loss, on_epoch=True)
        # Compute Jaccard index.
        assert predictions.shape[1] == 1
        pred_masks = predictions.squeeze(1) > 0  # (B, H, W).
        assert pred_masks.shape == masks.shape
        intersection = torch.logical_and(pred_masks, masks).sum()
        union = pred_masks.sum() + masks.sum() - intersection
        jaccard_index = intersection / union.clip(min=1)
        self.log("val/Jaccard", jaccard_index, on_epoch=True)


def show_segmentations(model, dataset):
    model.eval()
    for _ in range(VIZ_IMAGES):
        image, mask_gt = dataset[random.randint(0, len(dataset) - 1)]
        with torch.no_grad():
            predicted = model(image.unsqueeze(0))[0]  # CHW.
            predicted = torch.sigmoid(predicted[0])
            predicted_labels = predicted > 0.5
        predicted_labels = predicted_labels.cpu().numpy().astype(np.uint8)
        image = image.permute(1, 2, 0)
        image = image - image.min()
        image = image / image.max()
        mask = (predicted_labels > 0)[..., None]
        selected = image * mask + 255 * (1 - mask)

        fig, axs = plt.subplots(1, 5, figsize=(12, 5))
        axs[0].imshow(image)
        axs[1].imshow(mask_gt)
        axs[2].imshow(predicted)
        axs[3].imshow(predicted_labels)
        axs[4].imshow(selected)
        plt.show()


def clean_memory():
    to_remove = set()
    for k, v in globals().items():
        if isinstance(v, (torch.nn.Module, pl.LightningModule)):
            to_remove.add(k)
    for k in to_remove:
        del globals()[k]
    gc.collect()


class CheckerError(Exception):
    pass


def check_conv1x1(layer_fn):
    layer = layer_fn(2, 5)
    if (not isinstance(layer, torch.nn.Sequential)) or (len(layer) != 2):
        raise CheckerError("conv1x1: Need an nn.Sequential module with 2 layers.")
    conv = layer[0]
    bn = layer[1]
    if (not isinstance(conv, torch.nn.Conv2d)) or (not isinstance(bn, torch.nn.BatchNorm2d)):
        raise CheckerError("conv1x1: The first layer must be convolution and the second layer must be batch normalization.")
    if conv.kernel_size != (1, 1):
        raise CheckerError("conv1x1: wrong kernel size")
    if conv.in_channels != 2 or conv.out_channels != 5:
        raise CheckerError("conv1x1: wrong channels")
    if conv.bias is not None:
        raise CheckerError("conv1x1: Don't use bias before batch norm.")
    x = torch.randn(3, 2, 7, 9)
    try:
        layer(x)
    except Exception:
        raise CheckerError("conv1x1: Inference error")


def check_conv3x3(layer_fn):
    layer = layer_fn(2, 5)
    if (not isinstance(layer, torch.nn.Sequential)) or (len(layer) != 2):
        raise CheckerError("conv3x3: Need an nn.Sequential module with 2 layers.")
    conv = layer[0]
    bn = layer[1]
    if (not isinstance(conv, torch.nn.Conv2d)) or (not isinstance(bn, torch.nn.BatchNorm2d)):
        raise CheckerError("conv3x3: The first layer must be convolution and the second layer must be batch normalization.")
    if conv.kernel_size != (3, 3):
        raise CheckerError("conv3x3: wrong kernel size")
    if conv.padding != (1, 1):
        raise CheckerError("conv3x3: need SAME padding")
    if conv.in_channels != 2 or conv.out_channels != 5:
        raise CheckerError("conv3x3: wrong channels")
    if conv.bias is not None:
        raise CheckerError("conv3x3: Don't use bias before batch norm.")
    x = torch.randn(3, 2, 7, 9)
    try:
        layer(x)
    except Exception:
        raise CheckerError("conv3x3: Inference error")


def check_upsampling(layer_fn):
    layer = layer_fn(scale_factor=2)
    if not isinstance(layer, torch.nn.Upsample):
        raise CheckerError("upsample: Need upsampling layer.")
    if layer.scale_factor != 2:
        raise CheckerError("upsample: Wrong scale factor.")
    x = torch.randn(3, 2, 7, 9)
    try:
        layer(x)
    except Exception:
        raise CheckerError("upsample: Inference error")


def check_bce_loss(loss_class):
    loss_computer = loss_class()
    def to_logits(probs):
        return np.log(probs) - np.log(1 - probs)
    logits = torch.from_numpy(to_logits(np.array([0.9, 0.2]))).reshape(2, 1, 1, 1).float()
    labels = torch.from_numpy(np.array([1, 0])).reshape(2, 1, 1).long()
    value = loss_computer(logits, labels).item()
    if abs(value + np.mean(np.log([0.9, 0.8]))) > 1e-5:
        raise CheckerError("Wrong BCE loss")


def check_focal_loss(loss_class):
    loss_computer = loss_class()
    def to_logits(probs):
        return np.log(probs) - np.log(1 - probs)
    p = np.array([0.9, 0.2])
    logits = torch.from_numpy(to_logits(p)).reshape(2, 1, 1, 1).float()
    labels = torch.from_numpy(np.array([1, 0])).reshape(2, 1, 1).bool()
    value = loss_computer(logits, labels).item()
    gt_1 = - 0.01 * np.log(0.9)
    gt_2 = - 0.04 * np.log(0.8)
    gt = 0.5 * (gt_1 + gt_2)
    if abs(value - gt) > 1e-5:
        raise CheckerError("Wrong Focal Loss")


def check_dice_loss(loss_class):
    loss_computer = loss_class()
    def to_logits(probs):
        return np.log(probs) - np.log(1 - probs)
    p = np.array([0.9, 0.2])
    logits = torch.from_numpy(to_logits(p)).reshape(2, 1, 1, 1).float()
    labels = torch.from_numpy(np.array([1, 0])).reshape(2, 1, 1).bool()
    value = loss_computer.single_class_dice_loss(logits, labels).item()
    gt = -np.log(2 * 0.9 / (1.1 + 1))
    if abs(value - gt) > 1e-5:
        raise CheckerError("Wrong Dice Loss")

# Dataset

We use Pascal VOC dataset. Masks contain two special values: 0 for background and 255 for a countour. We will remove contour during training.

In [None]:
# Download.
VOCSegmentation(DATA_ROOT, VOC_YEAR, "train", download=True)

In [None]:
valset_raw = VOCSegmentation(DATA_ROOT, VOC_YEAR, "val")

image, mask = valset_raw[random.randint(0, len(valset_raw) - 1)]
print("Mask values:", set(np.array(mask).flatten().tolist()))

fig, axs = plt.subplots(1, 2)
axs[0].imshow(image)
axs[1].imshow(mask)
plt.show()

In [None]:
random.seed(11)
for _ in range(5):
    i = random.randint(0, len(valset_raw) - 1)
    print("Image {:05d} size: {}".format(i, valset_raw[i][0].size))

Images have different sizes, sometimes the longest side is less than 500 pixels. To train the model we must handle this problem using one of the following approaches:
1. Scale images to a fixed size (change aspect ratios when necessary).
2. Scale images to a fixed size by keeping aspect ratio (with padding).
3. Scale images to a fixed size by cropping central parts with the required aspect ratio.
4. Use the models capable of working with different image sizes.

Note:
- While UNet can work with images of different size, we can't make batch of them, and the training will be slow.
- UNet is a convolutional network and each output depends only on the neighborhood pixels.

We will apply the second option and the mirror padding.

### Data module

We will use Albumentations to augment both [images and masks](https://albumentations.ai/docs/examples/example_kaggle_salt/).

In [None]:
class ToBinary(A.DualTransform):
    """A custom Albumentations transform that converts multiclass masks to foreground / background."""
    
    def __init__(self, background_label=0):
        super().__init__(p=1)  # Always apply.
        self.background_label = background_label

    def apply(self, img, **params):
       return img
        
    def apply_to_mask(self, mask, **params):
        new_mask = mask != self.background_label
        return new_mask


class TransformedDataset(torch.utils.data.Dataset):
    """Apply Albumentations transform to the dataset."""
    def __init__(self, dataset, transform):
        super().__init__()
        self.dataset = dataset
        self.transform = transform

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, index):
        image, mask = self.dataset[index]
        transformed = self.transform(image=np.asarray(image), mask=np.asarray(mask))
        return transformed["image"], transformed["mask"]


class Data(pl.LightningDataModule):
    """Dataset class.
    
    Resize an image to the specified size keeping aspect ratio and fill
    borders with reflections.

    Args:
        size: Output image size.
    """
    def __init__(self, size, batch_size=BATCH_SIZE, num_workers=4):
        super().__init__()
        self.batch_size = batch_size
        self.num_workers = num_workers
        
        train_transform = A.Compose([
            A.LongestMaxSize(size),
            A.PadIfNeeded(size, size, position="center"),
            A.HorizontalFlip(),  # Train only.
            A.RandomResizedCrop((size, size), scale=(0.7, 1.0)),  # Train only.
            A.Normalize(),
            ToBinary(),
            ToTensorV2()
        ])
        val_transform = A.Compose([
            A.LongestMaxSize(size),
            A.PadIfNeeded(size, size, position="center"),
            A.Normalize(),
            ToBinary(),
            ToTensorV2()
        ])
        self.trainset = TransformedDataset(
            VOCSegmentation(DATA_ROOT, VOC_YEAR, "train"),
            train_transform
        )
        self.valset = TransformedDataset(
            VOCSegmentation(DATA_ROOT, VOC_YEAR, "val"),
            val_transform
        )
        
    def train_dataloader(self):
        return torch.utils.data.DataLoader(
            self.trainset,
            batch_size=self.batch_size,  # The number of images in the batch.
            num_workers=self.num_workers,  # The number of concurrent readers and preprocessors.
            drop_last=True,  # Drop the truncated last batch during training.
            pin_memory=torch.cuda.is_available(),  # Optimize CUDA data transfer.
        )

    def val_dataloader(self):
        return torch.utils.data.DataLoader(
            self.valset,
            batch_size=self.batch_size,  # The number of images in the batch.
            num_workers=self.num_workers,  # The number of concurrent readers and preprocessors.
            pin_memory=torch.cuda.is_available(),  # Optimize CUDA data transfer.
        )

data = Data(IMAGE_SIZE)
print("Train set size:", len(data.trainset))
print("Validation set size:", len(data.valset))

image, mask = data.valset[5]
print(f"Image dtype: {image.dtype}, image shape: {image.shape}")
print(f"Mask dtype: {mask.dtype}, mask shape: {mask.shape}")
fig, axs = plt.subplots(1, 2)
image = image - image.min()
image = image / image.max()
axs[0].imshow(image.permute(1, 2, 0))
axs[1].imshow(mask)
plt.show()

Not that the mask is transformed similar to the image.

# UNet

In this block we will implement a model similar to [UNet](https://arxiv.org/pdf/1505.04597.pdf).

UNet is composed of encoder and decoder. Encoder computes embedding, while decoder reconstructs the image. Decoder has a structure similar to the reversed encoder. The main differences between encoder and decoder is that encoder applies max pooling or a stride to decrease image dimensions, while decoder exploits upsampling layers or transposed convolutions. Anyway, the hidden activations of encoder and decoder layers have similar shapes, and we can append encoder activations to decoder layers inputs. This way, the information about an input image will be passed to the decoder with minimum loss, and the mask will be fine.

<img src="example-image.jpg" align="left" hspace="20" width="20%" height="20%"/> 
<img src="u-net.jpg" align="left" hspace="20" width="50%" height="50%"/> 
<img src="example-mask.jpg" align="left" hspace="20" width="20%" height="20%"/> 
<div style="clear:both;"></div>

We will use a custom implementation of UNet to speedup training and improve quality:
* Image size 256 instead of 572
* Add residual connections, as in ResNet
* Add Batch Normalization

### Assignment 1. Implement basic blocks.

Cheat sheet:

```python
torch.nn.Sequential(*args: Module)

torch.nn.Conv2d(in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True, padding_mode='zeros')

torch.nn.BatchNorm2d(num_features, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True, device=None, dtype=None)

torch.nn.ReLU(inplace=False)

torch.nn.Upsample(size=None, scale_factor=None, mode='nearest', align_corners=None, recompute_scale_factor=None)
```

In [None]:
def make_conv1x1_bn(in_channels, out_channels, stride=1):
    """Create a pair of a convolutional layer with kernel size 1 and batch normalization layer."""
    
    return torch.nn.Sequential(
        # Put your code here.
        torch.nn.Conv2d(in_channels, out_channels, 1, bias=False, stride=stride),
        torch.nn.BatchNorm2d(out_channels)
    )


def make_conv3x3_bn(in_channels, out_channels, bias=True, stride=1):
    """Create a pair of a convolutional layer with the SAME padding and batch normalization layer."""

    return torch.nn.Sequential(
        # Put your code here.
        torch.nn.Conv2d(in_channels, out_channels, 3, padding=1, bias=False, stride=stride),
        torch.nn.BatchNorm2d(out_channels)
    )


def make_upspampling(scale_factor=2):
    """Create an upsampling layer. Choose 'bilinear' upsampling mode."""
    
    # Your code starts here.
    
    layer = torch.nn.Upsample(scale_factor=scale_factor, mode="bilinear")
    
    # The end of your code.
    
    return layer

def make_activation():
    """Make a ReLU layer. It's better to compute inplace."""

    # Your code starts here.
    
    layer = torch.nn.ReLU(inplace=True)

    # The end of your code.

    return layer

check_conv1x1(make_conv1x1_bn)
check_conv3x3(make_conv3x3_bn)
check_upsampling(make_upspampling)
print("OK")

In [None]:
class ResConvBlock(torch.nn.Module):
    """A simple convolutional block with a residual connection."""
    def __init__(self, in_channels, out_channels, stride=1):
        super().__init__()
        self.main_path = torch.nn.Sequential(
            make_conv3x3_bn(in_channels, out_channels, stride=stride),
            make_activation(),
            make_conv3x3_bn(out_channels, out_channels)
        )
        self.residual_path = make_conv1x1_bn(in_channels, out_channels, stride=stride)
        self.last_relu = make_activation()
        
    def forward(self, x):
        result = self.main_path(x)
        residual = self.residual_path(x)
        result = self.last_relu(result + residual)
        return result

UNet is composed of two branches:
* input convolutions
* encoder
* decoder
* final projection

`Input convolutions` is a simple transform for image preprocessing. Encoder and decoder branches include multiple blocks. `UNetDown` class implements a computation block between two downsampling layers, starting from a strided convolution. Similarly, `UNetUp` implements a computation block between two upsampling layers, starting from upsampling. The `final projection` is a simple 1x1 convolution that extracts the required number of channels from the decoder output.

The final model structure:
```
input -> conv -> conv -> Down x 4 -> Up x 4 -> conv -> output
```

In [None]:
class UNetDown(torch.nn.Sequential):
    """A computation block between two downsampling layers, starting from a strided convolution."""
    def __init__(self, in_channels, out_channels, block):
        layers = [
            block(in_channels, out_channels, stride=2)
        ]
        super().__init__(*layers)


class UNetUp(torch.nn.Module):
    """A computation block between two upsampling layers, starting from upsampling."""
    def __init__(self, in_channels, out_channels, block):
        super().__init__()
        self.upsample = make_upspampling(scale_factor=2)
        self.layer = block(in_channels + out_channels, out_channels)
    
    def forward(self, x_down, x_up):
        x_up = self.upsample(x_up)
        x = torch.cat((x_down, x_up), axis=1)  # Concatenate along channels dim.
        result = self.layer(x)
        return result


class UNet(torch.nn.Module):
    def __init__(self, num_classes=1, num_scales=4, base_filters=64, block=ResConvBlock):
        """Create UNet.
        
        Args
            num_classes: The number of output logits.
            num_scales: The number of downsampling and upsampling layers.
            base_filters: The number of filters of the first convolution. All other channels are relative to this value.
        """
        super().__init__()
        self.input_convolutions = block(3, base_filters)
        
        layers = []
        filters = base_filters
        layers.append(UNetDown(filters, filters, block))
        for i in range(num_scales - 1):
            layers.append(UNetDown(filters, filters * 2, block))
            filters *= 2
        self.down_layers = torch.nn.Sequential(*layers)
        
        layers = []
        for i in range(num_scales - 1):
            layers.append(UNetUp(filters, filters // 2, block))
            filters //= 2
        layers.append(UNetUp(filters, filters, block))
        self.up_layers = torch.nn.Sequential(*layers)
        
        self.output_convolution = torch.nn.Conv2d(filters, num_classes, 1)
        
    def forward(self, x):
        down_results = [self.input_convolutions(x)]
        for layer in self.down_layers:
            down_results.append(layer(down_results[-1]))
        x = down_results[-1]
        for i, layer in enumerate(self.up_layers):
            x = layer(down_results[-2 - i], x)
        x = self.output_convolution(x)
        return x

unet = UNet()

print("Output shape:", unet(next(iter(data.train_dataloader()))[0]).shape)

print(unet)

# Training

### Assignment 2. Implement BCE loss for segmentation.
Cheat sheet:

```python
torch.nn.functional.binary_cross_entropy_with_logits(input, target, weight=None, size_average=None, reduce=None, reduction='mean', pos_weight=None)
```

In [None]:
class BCELoss(torch.nn.Module):       
    def __call__(self, predicted, masks):
        if predicted.shape[1] != 1:
            raise ValueError("Need binary predictions")
        predicted = predicted.squeeze(1)
    
        # predicted: float32, BHW.
        # masks: bool, BHW.
            
        # Ваш код здесь.
        
        loss = torch.nn.functional.binary_cross_entropy_with_logits(predicted, masks.float())
        
        # Конец вашего кода.
        
        return loss
        
check_bce_loss(BCELoss)
print("OK")

In [None]:
%load_ext tensorboard
%tensorboard --logdir ./lightning_logs/

In [None]:
clean_memory()

model = Module(UNet(), BCELoss())
trainer = pl.Trainer(max_epochs=10,
                     logger=pl.loggers.TensorBoardLogger("lightning_logs", name="BCE"))
trainer.fit(model, data)
show_segmentations(model, data.valset)

### Assignment 3. Implement Focal Loss

For positive class (mask = 1): $loss = -(1 - p_+)^\gamma \log p_+$

For negative class (mask = 0): $loss = -(1 - p_-)^\gamma \log p_-$

Where $p_+$ is the predicted probability and $p_+ + p_- = 1$.

Note, that the model predicts logits rather than probabilies. Use *sigmoid* activation and `average` aggregation.

In [None]:
class FocalLoss():
    def __init__(self, gamma=2):
        self.gamma = gamma

    def __call__(self, predicted, masks):
        if predicted.shape[1] != 1:
            raise ValueError("{} can't be applied to a multi-class problem".format(type(self)))
        predicted = predicted.squeeze(1)
            
        # predicted: float32, BHW.
        # masks: bool, BHW.

        # Your code starts here.
        p = torch.sigmoid(predicted)  # (B, H, W).
        np = 1 - p
        log_p = torch.nn.functional.logsigmoid(predicted)
        log_np = torch.nn.functional.logsigmoid(-predicted)
        pos_term = (np ** self.gamma) * log_p
        neg_term = (p ** self.gamma) * log_np
        loss = -torch.where(masks, pos_term, neg_term).mean()
        
        # The end of your code.
        
        return loss

check_focal_loss(FocalLoss)
print("OK")

In [None]:
%load_ext tensorboard
%tensorboard --logdir ./lightning_logs/

In [None]:
clean_memory()

model = Module(UNet(), FocalLoss())
trainer = pl.Trainer(max_epochs=10,
                     logger=pl.loggers.TensorBoardLogger("lightning_logs", name="FocalLoss"))
trainer.fit(model, data)
show_segmentations(model, data.valset)

### Assignment 4. Implement Dice Loss

$
\mathrm{DiceLoss} = -\log\frac{2\sum\limits_{i,j} x_{i,j} y_{i,j}}{\sum\limits_{i,j} x_{i, j} + y_{i,j}}
$

In [None]:
class DiceLoss():
    def __init__(self, focal_gamma=2):
        self.focal = FocalLoss(gamma=focal_gamma)
        
    def single_class_dice_loss(self, predicted, masks):
        if predicted.shape[1] != 1:
            raise ValueError("{} can't be applied to a multi-class problem".format(type(self)))
        predicted = predicted.squeeze(1)
            
        # predicted: float32, BHW.
        # masks: bool, BHW.

        # Your code starts here.
        masks = masks.float()
        p = torch.sigmoid(predicted)  # (B, H, W).
        double_intersection = 2 * (p * masks).sum()
        area_sum = p.sum() + masks.sum()
        loss = - double_intersection.clip(min=1e-6).log() + area_sum.clip(min=1e-6).log()
        
        # The end of your code.
        
        return loss

    def __call__(self, predicted, masks):
        positive = self.single_class_dice_loss(predicted, masks)
        negative = self.single_class_dice_loss(-predicted, ~masks)
        return 0.5 * (positive + negative) + self.focal(predicted, masks)

check_dice_loss(DiceLoss)
print("OK")

In [None]:
%load_ext tensorboard
%tensorboard --logdir ./lightning_logs/

In [None]:
clean_memory()

model = Module(UNet(), DiceLoss())
trainer = pl.Trainer(max_epochs=10,
                     logger=pl.loggers.TensorBoardLogger("lightning_logs", name="Dice"))
trainer.fit(model, data)
show_segmentations(model, data.valset)

# Use pretrained model (Optional)

In [None]:
import segmentation_models_pytorch as smp

### Before tuning

In [None]:
unet_pretrained = smp.Unet("resnet34", encoder_weights="imagenet", classes=1)
show_segmentations(unet_pretrained, data.valset)

In [None]:
%load_ext tensorboard
%tensorboard --logdir ./lightning_logs/

In [None]:
clean_memory()

unet_pretrained = smp.Unet("resnet34", encoder_weights="imagenet", classes=1)
model = Module(unet_pretrained, DiceLoss())
trainer = pl.Trainer(max_epochs=10,
                     logger=pl.loggers.TensorBoardLogger("lightning_logs", name="Pretrained"))
trainer.fit(model, data)
show_segmentations(unet_pretrained, data.valset)

# Summary
* UNet architecture predicts fine-grained masks
* All loss functions have similar quality in terms of Jaccard Index
* One possible reason is low class imbalance in the selected problem
* Pretraining largely affects the final quality

# Homework (optional)

Try to implement multi-class segmentation on the same dataset. Is it possible to achive a good quality without pretraining?