In [1]:
import torch
import torch.nn as nn
import torchvision.models
from torch.utils.data import Dataset, DataLoader
import torch.optim as optim
import torch.nn.functional as F

import albumentations as A
from albumentations.pytorch import ToTensorV2

from tqdm import tqdm
from PIL import Image
from PIL.Image import open as open_image
import cv2
import matplotlib.pyplot as plt
import numpy as np

import os
from os.path import join as path_join

from time import time

### Get the data

In [2]:
if not os.path.exists("data.zip"):
    %pip install --upgrade gdown
    import gdown
    url = 'https://drive.google.com/uc?id=10f1H2T-5W-BiqabHHtlZ4ASs19TZmg8R'
    output = 'data.zip'
    gdown.download(url, output, quiet=False)
    del gdown
    del url
    del output
    !unzip data.zip

In [3]:
def set_random_seed(seed: int):
    torch.backends.cudnn.deterministic = True
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    np.random.seed(seed)
    if "random" in globals():
        random.seed(seed)

set_random_seed(1)

### Utilities (0.5 point)

Complete dataset to load prepared images and masks. Don't forget to use augmentations.

Some of the images are 1 channels, so use `gray2rgb`.

In [4]:
IMG_HEIGHT = 224
IMG_WIDTH = 224
CHANNELS_COUNT = 3
PIXEL_MAX_VALUE = 255
WHITE_PIXEL_VALUE = PIXEL_MAX_VALUE
MASK_PIXEL_VALUE = WHITE_PIXEL_VALUE

def gray2rgb(img: np.ndarray) -> np.ndarray:
    match img.shape:
        case (_, _):
            return np.dstack([img, img, img])
        case (_, _, 3):
            return img
        case _:
            raise ValueError(f"Invalid img.shape: {img.shape}")

def get_iou(gt, pred):
    pred = pred > 0.5
    un = (gt | pred).sum()
    assert un > 0 or (gt & pred).sum() == 0
    return (gt & pred).sum() / un if un else 0

def imagenet_mean() -> tuple[float, float, float]:
    return (0.485, 0.456, 0.406)

def imagenet_std() -> tuple[float, float, float]:
    return (0.229, 0.224, 0.225)

class BirdsDataset(Dataset):
    def __init__(self, folder, is_train: None | bool = None) -> None:
        assert is_train is None or isinstance(is_train, bool)

        images_folder: str = os.path.join(folder, 'images')
        gt_folder: str = os.path.join(folder, 'gt')

        calc_manually = False
        if calc_manually:
            approx_mean = np.zeros(shape=(CHANNELS_COUNT,))
            approx_std = np.zeros(shape=(CHANNELS_COUNT,))

        img_index_to_class_name: list[str] = []
        img_index_to_file_name: list[str] = []
        for class_name in os.listdir(images_folder):
            class_images_dir: str = path_join(images_folder, class_name)
            class_gt_masks_dir: str = path_join(gt_folder, class_name)
            assert os.path.exists(class_gt_masks_dir), f"no dir with ground truth masks {class_gt_masks_dir} for the class {class_name}"
            for fname in os.listdir(class_images_dir):
                gt_mask_file_path = path_join(class_gt_masks_dir, BirdsDataset._gt_mask_filename_from_img_filename(fname))
                assert os.path.exists(gt_mask_file_path), \
                    f"no ground truth mask {gt_mask_file_path} for the file {fname} with class {class_name}"

                if calc_manually:
                    img_pixels: np.ndarray = gray2rgb(np.asarray(open_image(path_join(class_images_dir, fname)))).astype(np.float32)
                    assert len(img_pixels.shape) == 3 and img_pixels.shape[2] == CHANNELS_COUNT
                    img_pixels /= PIXEL_MAX_VALUE
                    approx_mean += img_pixels.mean(axis=(0, 1))
                    approx_std += img_pixels.std(axis=(0, 1))

                img_index_to_file_name.append(fname)

                # Note: only 1 instance of bytes sequence corresponding to the `class_name`
                #  exists in memory due to the reference semantic of the Python language
                img_index_to_class_name.append(class_name)

        if calc_manually:
            total_images = len(img_index_to_class_name)
            assert total_images == len(img_index_to_file_name) > 0
            approx_mean /= total_images
            approx_std /= total_images

            assert approx_mean.ndim == 1 and len(approx_mean) == CHANNELS_COUNT \
                and np.all((0 < approx_mean) & (approx_mean < PIXEL_MAX_VALUE))
            assert approx_std.ndim == 1 and len(approx_std) == CHANNELS_COUNT \
                and np.all((0 < approx_std) & (approx_std < PIXEL_MAX_VALUE))

        if is_train is None:
            is_train = "train" in folder
        self.transform = BirdsDataset.make_data_transformer(is_train=is_train)
        self._img_index_to_class_name = img_index_to_class_name
        self._img_index_to_file_name = img_index_to_file_name
        self._images_folder = images_folder
        self._gt_folder = gt_folder

    @staticmethod
    def preprocess_image(img: np.ndarray) -> np.ndarray:
        # assert np.issubdtype(img.dtype, np.integer)
        # assert np.all(img <= MASK_PIXEL_VALUE)
        # assert np.all(0 <= img)
        return gray2rgb(img)

    @staticmethod
    def preprocess_mask(mask: np.ndarray) -> np.ndarray:
        if mask.ndim != 2:
            mask = mask[:,:,0]
        # assert np.issubdtype(mask.dtype, np.integer)
        mask = (mask == MASK_PIXEL_VALUE).astype(np.uint8)
        # assert np.all((mask == 0) ^ (mask == 1))
        return mask

    def __getitem__(self, index: int) -> tuple[torch.Tensor, torch.Tensor]:
        img, gt_mask = self._read_img_and_gt_mask_by_index(index)
        match self.transform(image=BirdsDataset.preprocess_image(img), mask=BirdsDataset.preprocess_mask(gt_mask)):
            case {"image": t_img, "mask": t_msk}:
                return t_img, t_msk.long()
        assert False

    def __len__(self) -> int:
        return len(self._img_index_to_file_name)

    @staticmethod
    def make_data_transformer(is_train: bool = False, mean=imagenet_mean(), std=imagenet_std()):
        assert len(mean) == len(std) == CHANNELS_COUNT
        actions: list[A.BasicTransform] = [
            A.Normalize(mean=mean, std=std, max_pixel_value=PIXEL_MAX_VALUE),
        ]
        if is_train:
            actions.extend([
                A.RandomResizedCrop(size=(IMG_HEIGHT, IMG_WIDTH)),
                A.HorizontalFlip(p=0.3),
                A.RandomBrightnessContrast(brightness_limit=(-0.1, 0.1), contrast_limit=(-0.1, 0.1), p=0.3),
            ])
        else:
            actions.extend([
                A.Resize(height=IMG_HEIGHT, width=IMG_WIDTH),
            ])
        actions.extend([
            ToTensorV2(),
        ])
        return A.Compose(actions)

    def _read_img_and_gt_mask_by_index(self, i: int) -> tuple[np.ndarray, np.ndarray]:
        cnm = self._img_index_to_class_name[i]
        fnm = self._img_index_to_file_name[i]
        return (
            np.asarray(open_image(path_join(self._images_folder, cnm, fnm))),
            np.asarray(open_image(path_join(self._gt_folder, cnm, self._gt_mask_filename_from_img_filename(fnm)))),
        )

    @staticmethod
    def _gt_mask_filename_from_img_filename(img_fname: str) -> str:
        return f"{img_fname.removesuffix('jpg')}png"


### Architecture (1 point)
Your task for today is to build your own Unet to solve the segmentation problem.

As an encoder, you can use pre-trained on IMAGENET models(or parts) from torchvision. The decoder must be trained from scratch.
It is forbidden to use data not from the `data` folder.

I advise you to experiment with the number of blocks so as not to overfit on the training sample and get good quality on validation.

In [5]:
from torchvision import models

class DecoderBlock(nn.Module):
    def __init__(self, in_channels, mid_channels, out_channels):
        super().__init__()
        # self.upconv = nn.ConvTranspose2d(in_channels=in_channels, out_channels=out_channels, kernel_size=4, stride=2, padding=1)
        self.upsample = nn.UpsamplingNearest2d(scale_factor=2)
        self.conv1 = nn.Conv2d(in_channels=in_channels, out_channels=mid_channels, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(num_features=self.conv1.out_channels)
        self.relu1 = nn.ReLU()
        self.conv2 = nn.Conv2d(in_channels=mid_channels, out_channels=out_channels, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(num_features=self.conv2.out_channels)
        self.relu2 = nn.ReLU()

    def forward_with_resid(self, resid: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
        return self.forward(torch.cat((resid, self.upsample(x)), dim=1))

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.relu2(self.bn2(self.conv2(self.relu1(self.bn1(self.conv1(x))))))

class Unet(nn.Module):
    def __init__(self):
        super().__init__()

        # pretrained_model_to_steal_layers_from
        enc_model: torchvision.models.resnet.ResNet = models.resnet152(weights=models.ResNet152_Weights.DEFAULT)

        B = 2
        assert enc_model.conv1(torch.zeros((B, CHANNELS_COUNT, IMG_HEIGHT, IMG_WIDTH))).shape == torch.Size((B, 64, IMG_HEIGHT // 2, IMG_WIDTH // 2))
        assert enc_model.bn1(torch.zeros((B, 64, IMG_HEIGHT // 2, IMG_WIDTH // 2))).shape == torch.Size((B, 64, IMG_HEIGHT // 2, IMG_WIDTH // 2))
        assert enc_model.relu(torch.zeros((B, 64, IMG_HEIGHT // 2, IMG_WIDTH // 2))).shape == torch.Size((B, 64, IMG_HEIGHT // 2, IMG_WIDTH // 2))

        assert enc_model.maxpool(torch.zeros((B, 64, IMG_HEIGHT // 2, IMG_WIDTH // 2))).shape == torch.Size((B, 64, IMG_HEIGHT // 4, IMG_WIDTH // 4))
        assert enc_model.layer1(torch.zeros((B, 64, IMG_HEIGHT // 4, IMG_WIDTH // 4))).shape == torch.Size((B, 256, IMG_HEIGHT // 4, IMG_WIDTH // 4))

        assert enc_model.layer2(torch.zeros((B, 256, IMG_HEIGHT // 4, IMG_WIDTH // 4))).shape == torch.Size((B, 512, IMG_HEIGHT // 8, IMG_WIDTH // 8))

        assert enc_model.layer3(torch.zeros((B, 512, IMG_HEIGHT // 8, IMG_WIDTH // 8))).shape == torch.Size((B, 1024, IMG_HEIGHT // 16, IMG_WIDTH // 16))

        self.inp = nn.Sequential(enc_model.conv1, enc_model.bn1, enc_model.relu)

        # encoder blocks
        self.encoder1 = nn.Sequential(enc_model.maxpool, enc_model.layer1)
        self.encoder2= enc_model.layer2
        self.encoder3= enc_model.layer3

        # decoder blocks
        self.decoder1 = DecoderBlock(in_channels=512 + 1024, mid_channels=1024, out_channels=512)
        self.decoder2 = DecoderBlock(in_channels=256 + 512, mid_channels=512, out_channels=256)
        self.decoder3 = DecoderBlock(in_channels=64 + 256, mid_channels=128, out_channels=64)

        self.out = nn.Sequential(
            nn.UpsamplingNearest2d(scale_factor=2),
            nn.Conv2d(in_channels=64, out_channels=2, kernel_size=1),
        )

    def forward(self, x: torch.Tensor):
        r1 = self.inp(x)
        r2 = self.encoder1(r1)
        r3 = self.encoder2(r2)
        x = self.encoder3(r3)
        x = self.decoder1.forward_with_resid(r3, x)
        x = self.decoder2.forward_with_resid(r2, x)
        x = self.decoder3.forward_with_resid(r1, x)
        return self.out(x)

    def mutable_parameters(self, recurse=True):
        return self.parameters(recurse)


In [6]:
# assert Unet()(torch.zeros(1, CHANNELS_COUNT, IMG_HEIGHT, IMG_WIDTH)).shape == torch.Size((1, 2, IMG_HEIGHT, IMG_WIDTH))

### Train script (0.5 point)

Complete the train and predict scripts.

In [7]:
def get_model(path) -> Unet:
    model = Unet()
    model.load_state_dict(torch.load(path, weights_only=False))
    model.eval()
    return model

In [8]:
def train_segmentation_model(data_path, load_from_state: int = -1):
    BATCH_SIZE = 8
    N_EPOCH = 15
    DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    train_dataset = BirdsDataset(data_path + 'train')
    val_dataset = BirdsDataset(data_path + 'val')
    train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
    val_dataloader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)

    assert N_EPOCH > load_from_state >= -1
    if load_from_state == -1:
        model = Unet()
    else:
        model = get_model(f"model_{load_from_state}.pth")
    model = model.to(DEVICE)
    optimizer = torch.optim.Adam(params=model.mutable_parameters(), lr=0.0007)
    criterion = nn.CrossEntropyLoss()
    losses_train, losses_val, ious_train, ious_val = [], [], [], []

    for epoch in (epochs_bar := tqdm(range(load_from_state + 1, N_EPOCH))):
        epochs_bar.set_description(f"Epoch #{epoch}...\n", refresh=True)

        model.train()
        losses_over_batches = []
        ious_over_batches = []
        for inputs, masks in tqdm(train_dataloader):
            masks = masks.to(DEVICE, non_blocking=True)
            inputs = inputs.to(DEVICE)

            optimizer.zero_grad()
            masks_pred = model(inputs)
            loss = criterion(masks_pred, masks)

            loss.backward()
            optimizer.step()

            losses_over_batches.append(loss.detach().cpu().item())
            ious_over_batches.append(get_iou(masks == 1, masks_pred.detach().argmax(1) == 1).cpu().item())


        losses_train.append(np.mean(losses_over_batches))
        ious_train.append(np.mean(ious_over_batches))

        model.eval()
        losses_over_batches = []
        ious_over_batches = []
        with torch.no_grad():
            for inputs, masks in tqdm(val_dataloader):
                masks = masks.to(DEVICE, non_blocking=True)
                inputs = inputs.to(DEVICE)

                masks_pred = model(inputs).detach()
                loss = criterion(masks_pred, masks)

                losses_over_batches.append(loss.cpu().item())
                ious_over_batches.append(get_iou(masks == 1, masks_pred.argmax(1) == 1).cpu().item())

        losses_val.append(np.mean(losses_over_batches))
        ious_val.append(np.mean(ious_over_batches))

        torch.save(model.state_dict(), f'model_{epoch}.pth')

        print(f"Epoch: {epoch}, train loss: {losses_train[-1]}, val loss: {losses_val[-1]}, train iou: {ious_train[-1]}, val iou: {ious_val[-1]}")

Note: in the code below training is started from the 8 epoch because google colab runtime has been disconnected several times

In [None]:
set_random_seed(0xdeadbeef)
train_segmentation_model('data/')

Epoch #0...
:   0%|          | 0/15 [00:00<?, ?it/s]
  0%|          | 0/1048 [00:00<?, ?it/s][A
  0%|          | 1/1048 [00:01<23:41,  1.36s/it][A
  0%|          | 2/1048 [00:01<14:55,  1.17it/s][A
  0%|          | 3/1048 [00:02<12:03,  1.45it/s][A
  0%|          | 4/1048 [00:02<10:40,  1.63it/s][A
  0%|          | 5/1048 [00:03<09:53,  1.76it/s][A
  1%|          | 6/1048 [00:03<09:27,  1.84it/s][A
  1%|          | 7/1048 [00:04<09:09,  1.89it/s][A
  1%|          | 8/1048 [00:04<09:00,  1.92it/s][A
  1%|          | 9/1048 [00:05<08:53,  1.95it/s][A
  1%|          | 10/1048 [00:05<08:43,  1.98it/s][A
  1%|          | 11/1048 [00:06<08:38,  2.00it/s][A
  1%|          | 12/1048 [00:06<08:35,  2.01it/s][A
  1%|          | 13/1048 [00:07<08:35,  2.01it/s][A
  1%|▏         | 14/1048 [00:07<09:26,  1.82it/s][A
  1%|▏         | 15/1048 [00:08<09:12,  1.87it/s][A
  2%|▏         | 16/1048 [00:08<08:59,  1.91it/s][A
  2%|▏         | 17/1048 [00:09<08:55,  1.93it/s][A
  2%|▏     

Epoch: 0, train loss: 0.2222982759850512, val loss: 0.10129496696489779, train iou: 0.6162584017151753, val iou: 0.7234745103527199



  0%|          | 0/1048 [00:00<?, ?it/s][A
  0%|          | 1/1048 [00:00<09:15,  1.89it/s][A
  0%|          | 2/1048 [00:01<09:11,  1.90it/s][A
  0%|          | 3/1048 [00:01<09:07,  1.91it/s][A
  0%|          | 4/1048 [00:02<09:08,  1.90it/s][A
  0%|          | 5/1048 [00:02<09:12,  1.89it/s][A
  1%|          | 6/1048 [00:03<09:23,  1.85it/s][A
  1%|          | 7/1048 [00:03<09:26,  1.84it/s][A
  1%|          | 8/1048 [00:04<09:30,  1.82it/s][A
  1%|          | 9/1048 [00:04<09:27,  1.83it/s][A
  1%|          | 10/1048 [00:05<09:21,  1.85it/s][A
  1%|          | 11/1048 [00:05<09:18,  1.86it/s][A
  1%|          | 12/1048 [00:06<09:14,  1.87it/s][A
  1%|          | 13/1048 [00:06<09:11,  1.88it/s][A
  1%|▏         | 14/1048 [00:07<09:10,  1.88it/s][A
  1%|▏         | 15/1048 [00:08<09:09,  1.88it/s][A
  2%|▏         | 16/1048 [00:08<09:07,  1.88it/s][A
  2%|▏         | 17/1048 [00:09<09:08,  1.88it/s][A
  2%|▏         | 18/1048 [00:09<09:07,  1.88it/s][A
  2%|▏    

Epoch: 1, train loss: 0.18564490762802252, val loss: 0.13446356267244977, train iou: 0.6775853195945725, val iou: 0.6577709831974723



  0%|          | 0/1048 [00:00<?, ?it/s][A
  0%|          | 1/1048 [00:00<10:48,  1.62it/s][A
  0%|          | 2/1048 [00:01<09:59,  1.74it/s][A
  0%|          | 3/1048 [00:01<09:48,  1.78it/s][A
  0%|          | 4/1048 [00:02<09:36,  1.81it/s][A
  0%|          | 5/1048 [00:02<09:35,  1.81it/s][A
  1%|          | 6/1048 [00:03<09:35,  1.81it/s][A
  1%|          | 7/1048 [00:03<09:27,  1.83it/s][A
  1%|          | 8/1048 [00:04<09:24,  1.84it/s][A
  1%|          | 9/1048 [00:04<09:21,  1.85it/s][A
  1%|          | 10/1048 [00:05<09:16,  1.86it/s][A
  1%|          | 11/1048 [00:06<09:15,  1.87it/s][A
  1%|          | 12/1048 [00:06<09:14,  1.87it/s][A
  1%|          | 13/1048 [00:07<09:11,  1.88it/s][A
  1%|▏         | 14/1048 [00:07<09:09,  1.88it/s][A
  1%|▏         | 15/1048 [00:08<09:09,  1.88it/s][A
  2%|▏         | 16/1048 [00:08<09:07,  1.89it/s][A
  2%|▏         | 17/1048 [00:09<09:06,  1.89it/s][A
  2%|▏         | 18/1048 [00:09<09:06,  1.89it/s][A
  2%|▏    

Epoch: 2, train loss: 0.17197695181114983, val loss: 0.08052747425707904, train iou: 0.7002544887363911, val iou: 0.759796343235807



  0%|          | 0/1048 [00:00<?, ?it/s][A
  0%|          | 1/1048 [00:00<10:18,  1.69it/s][A
  0%|          | 2/1048 [00:01<09:45,  1.79it/s][A
  0%|          | 3/1048 [00:01<09:32,  1.83it/s][A
  0%|          | 4/1048 [00:02<09:24,  1.85it/s][A
  0%|          | 5/1048 [00:02<09:21,  1.86it/s][A
  1%|          | 6/1048 [00:03<09:17,  1.87it/s][A
  1%|          | 7/1048 [00:03<09:26,  1.84it/s][A
  1%|          | 8/1048 [00:04<09:30,  1.82it/s][A
  1%|          | 9/1048 [00:04<09:32,  1.81it/s][A
  1%|          | 10/1048 [00:05<09:35,  1.80it/s][A
  1%|          | 11/1048 [00:06<09:39,  1.79it/s][A
  1%|          | 12/1048 [00:06<09:29,  1.82it/s][A
  1%|          | 13/1048 [00:07<09:24,  1.83it/s][A
  1%|▏         | 14/1048 [00:07<09:20,  1.85it/s][A
  1%|▏         | 15/1048 [00:08<09:14,  1.86it/s][A
  2%|▏         | 16/1048 [00:08<09:10,  1.87it/s][A
  2%|▏         | 17/1048 [00:09<09:07,  1.88it/s][A
  2%|▏         | 18/1048 [00:09<09:05,  1.89it/s][A
  2%|▏    

Epoch: 3, train loss: 0.16671928757346655, val loss: 0.09021219905381175, train iou: 0.7087113606452032, val iou: 0.7318745429881595



  0%|          | 0/1048 [00:00<?, ?it/s][A
  0%|          | 1/1048 [00:00<10:29,  1.66it/s][A
  0%|          | 2/1048 [00:01<09:56,  1.75it/s][A
  0%|          | 3/1048 [00:01<09:48,  1.78it/s][A
  0%|          | 4/1048 [00:02<09:46,  1.78it/s][A
  0%|          | 5/1048 [00:02<09:46,  1.78it/s][A
  1%|          | 6/1048 [00:03<09:40,  1.79it/s][A
  1%|          | 7/1048 [00:03<09:33,  1.82it/s][A
  1%|          | 8/1048 [00:04<09:28,  1.83it/s][A
  1%|          | 9/1048 [00:04<09:21,  1.85it/s][A
  1%|          | 10/1048 [00:05<09:16,  1.87it/s][A
  1%|          | 11/1048 [00:06<09:14,  1.87it/s][A
  1%|          | 12/1048 [00:06<09:12,  1.88it/s][A
  1%|          | 13/1048 [00:07<09:09,  1.88it/s][A
  1%|▏         | 14/1048 [00:07<09:08,  1.89it/s][A
  1%|▏         | 15/1048 [00:08<09:07,  1.89it/s][A
  2%|▏         | 16/1048 [00:08<09:06,  1.89it/s][A
  2%|▏         | 17/1048 [00:09<09:06,  1.88it/s][A
  2%|▏         | 18/1048 [00:09<09:07,  1.88it/s][A
  2%|▏    

Epoch: 4, train loss: 0.15998926445600073, val loss: 0.09822836842133918, train iou: 0.7177445998571756, val iou: 0.7177956812083721



  0%|          | 0/1048 [00:00<?, ?it/s][A
  0%|          | 1/1048 [00:00<10:37,  1.64it/s][A
  0%|          | 2/1048 [00:01<09:46,  1.78it/s][A
  0%|          | 3/1048 [00:01<09:34,  1.82it/s][A
  0%|          | 4/1048 [00:02<09:27,  1.84it/s][A
  0%|          | 5/1048 [00:02<09:22,  1.85it/s][A
  1%|          | 6/1048 [00:03<09:17,  1.87it/s][A
  1%|          | 7/1048 [00:03<09:16,  1.87it/s][A
  1%|          | 8/1048 [00:04<09:15,  1.87it/s][A
  1%|          | 9/1048 [00:04<09:13,  1.88it/s][A
  1%|          | 10/1048 [00:05<09:11,  1.88it/s][A
  1%|          | 11/1048 [00:05<09:12,  1.88it/s][A
  1%|          | 12/1048 [00:06<09:13,  1.87it/s][A
  1%|          | 13/1048 [00:06<09:10,  1.88it/s][A
  1%|▏         | 14/1048 [00:07<09:11,  1.88it/s][A
  1%|▏         | 15/1048 [00:08<09:13,  1.87it/s][A
  2%|▏         | 16/1048 [00:08<09:22,  1.84it/s][A
  2%|▏         | 17/1048 [00:09<09:27,  1.82it/s][A
  2%|▏         | 18/1048 [00:09<09:27,  1.82it/s][A
  2%|▏    

Epoch: 5, train loss: 0.14748328186941284, val loss: 0.07819546531589533, train iou: 0.736320550603266, val iou: 0.760208597406745



  0%|          | 0/1048 [00:00<?, ?it/s][A
  0%|          | 1/1048 [00:00<10:17,  1.70it/s][A
  0%|          | 2/1048 [00:01<09:39,  1.80it/s][A
  0%|          | 3/1048 [00:01<09:29,  1.83it/s][A
  0%|          | 4/1048 [00:02<09:21,  1.86it/s][A
  0%|          | 5/1048 [00:02<09:16,  1.87it/s][A
  1%|          | 6/1048 [00:03<09:13,  1.88it/s][A
  1%|          | 7/1048 [00:03<09:14,  1.88it/s][A
  1%|          | 8/1048 [00:04<09:21,  1.85it/s][A
  1%|          | 9/1048 [00:04<09:23,  1.84it/s][A
  1%|          | 10/1048 [00:05<09:22,  1.85it/s][A
  1%|          | 11/1048 [00:05<09:25,  1.83it/s][A
  1%|          | 12/1048 [00:06<09:31,  1.81it/s][A
  1%|          | 13/1048 [00:07<09:27,  1.82it/s][A
  1%|▏         | 14/1048 [00:07<09:22,  1.84it/s][A
  1%|▏         | 15/1048 [00:08<09:20,  1.84it/s][A
  2%|▏         | 16/1048 [00:08<09:20,  1.84it/s][A
  2%|▏         | 17/1048 [00:09<09:18,  1.85it/s][A
  2%|▏         | 18/1048 [00:09<09:16,  1.85it/s][A
  2%|▏    

Epoch: 6, train loss: 0.15059305427468006, val loss: 0.07645940969020805, train iou: 0.736668735571945, val iou: 0.7699663803320039



  0%|          | 0/1048 [00:00<?, ?it/s][A
  0%|          | 1/1048 [00:00<10:27,  1.67it/s][A
  0%|          | 2/1048 [00:01<09:43,  1.79it/s][A
  0%|          | 3/1048 [00:01<09:38,  1.80it/s][A
  0%|          | 4/1048 [00:02<09:37,  1.81it/s][A
  0%|          | 5/1048 [00:02<09:45,  1.78it/s][A
  1%|          | 6/1048 [00:03<09:47,  1.77it/s][A
  1%|          | 7/1048 [00:03<09:40,  1.79it/s][A
  1%|          | 8/1048 [00:04<09:33,  1.81it/s][A
  1%|          | 9/1048 [00:04<09:28,  1.83it/s][A
  1%|          | 10/1048 [00:05<09:22,  1.85it/s][A
  1%|          | 11/1048 [00:06<09:20,  1.85it/s][A
  1%|          | 12/1048 [00:06<09:18,  1.86it/s][A
  1%|          | 13/1048 [00:07<09:17,  1.86it/s][A
  1%|▏         | 14/1048 [00:07<09:17,  1.85it/s][A
  1%|▏         | 15/1048 [00:08<09:16,  1.86it/s][A
  2%|▏         | 16/1048 [00:08<09:15,  1.86it/s][A
  2%|▏         | 17/1048 [00:09<09:17,  1.85it/s][A
  2%|▏         | 18/1048 [00:09<09:15,  1.85it/s][A
  2%|▏    

Epoch: 7, train loss: 0.14460160194732646, val loss: 0.06978878425434232, train iou: 0.7421779778878425, val iou: 0.7813330524685708



  0%|          | 0/1048 [00:00<?, ?it/s][A
  0%|          | 1/1048 [00:00<11:00,  1.59it/s][A
  0%|          | 2/1048 [00:01<10:01,  1.74it/s][A
  0%|          | 3/1048 [00:01<09:40,  1.80it/s][A
  0%|          | 4/1048 [00:02<09:32,  1.82it/s][A
  0%|          | 5/1048 [00:02<09:27,  1.84it/s][A
  1%|          | 6/1048 [00:03<09:28,  1.83it/s][A
  1%|          | 7/1048 [00:03<09:25,  1.84it/s][A
  1%|          | 8/1048 [00:04<09:26,  1.84it/s][A
  1%|          | 9/1048 [00:04<09:24,  1.84it/s][A
  1%|          | 10/1048 [00:05<09:23,  1.84it/s][A
  1%|          | 11/1048 [00:06<09:19,  1.85it/s][A
  1%|          | 12/1048 [00:06<09:20,  1.85it/s][A
  1%|          | 13/1048 [00:07<09:19,  1.85it/s][A
  1%|▏         | 14/1048 [00:07<09:18,  1.85it/s][A
  1%|▏         | 15/1048 [00:08<09:16,  1.86it/s][A
  2%|▏         | 16/1048 [00:08<09:14,  1.86it/s][A
  2%|▏         | 17/1048 [00:09<09:21,  1.84it/s][A
  2%|▏         | 18/1048 [00:09<09:21,  1.83it/s][A
  2%|▏    

Epoch: 8, train loss: 0.14548932631919517, val loss: 0.07106479339894246, train iou: 0.7401619164543297, val iou: 0.7779513758353211



  0%|          | 0/1048 [00:00<?, ?it/s][A
  0%|          | 1/1048 [00:00<10:50,  1.61it/s][A
  0%|          | 2/1048 [00:01<09:54,  1.76it/s][A
  0%|          | 3/1048 [00:01<09:41,  1.80it/s][A
  0%|          | 4/1048 [00:02<09:33,  1.82it/s][A
  0%|          | 5/1048 [00:02<09:31,  1.83it/s][A
  1%|          | 6/1048 [00:03<09:27,  1.84it/s][A
  1%|          | 7/1048 [00:03<09:26,  1.84it/s][A
  1%|          | 8/1048 [00:04<09:25,  1.84it/s][A
  1%|          | 9/1048 [00:04<09:24,  1.84it/s][A
  1%|          | 10/1048 [00:05<09:34,  1.81it/s][A
  1%|          | 11/1048 [00:06<09:35,  1.80it/s][A
  1%|          | 12/1048 [00:06<09:42,  1.78it/s][A
  1%|          | 13/1048 [00:07<09:38,  1.79it/s][A
  1%|▏         | 14/1048 [00:07<09:31,  1.81it/s][A
  1%|▏         | 15/1048 [00:08<09:26,  1.82it/s][A
  2%|▏         | 16/1048 [00:08<09:25,  1.82it/s][A
  2%|▏         | 17/1048 [00:09<09:28,  1.81it/s][A
  2%|▏         | 18/1048 [00:09<09:32,  1.80it/s][A
  2%|▏    

Epoch: 9, train loss: 0.14111943864654608, val loss: 0.07155806298198347, train iou: 0.7517791201595132, val iou: 0.7706748165867545



  0%|          | 0/1048 [00:00<?, ?it/s][A
  0%|          | 1/1048 [00:00<09:39,  1.81it/s][A
  0%|          | 2/1048 [00:01<09:35,  1.82it/s][A
  0%|          | 3/1048 [00:01<09:30,  1.83it/s][A
  0%|          | 4/1048 [00:02<09:26,  1.84it/s][A
  0%|          | 5/1048 [00:02<09:29,  1.83it/s][A
  1%|          | 6/1048 [00:03<09:25,  1.84it/s][A
  1%|          | 7/1048 [00:03<09:28,  1.83it/s][A
  1%|          | 8/1048 [00:04<09:29,  1.83it/s][A
  1%|          | 9/1048 [00:04<09:28,  1.83it/s][A
  1%|          | 10/1048 [00:05<09:27,  1.83it/s][A
  1%|          | 11/1048 [00:06<09:24,  1.84it/s][A
  1%|          | 12/1048 [00:06<09:29,  1.82it/s][A
  1%|          | 13/1048 [00:07<09:32,  1.81it/s][A
  1%|▏         | 14/1048 [00:07<09:30,  1.81it/s][A
  1%|▏         | 15/1048 [00:08<09:34,  1.80it/s][A
  2%|▏         | 16/1048 [00:08<09:44,  1.77it/s][A
  2%|▏         | 17/1048 [00:09<09:37,  1.79it/s][A
  2%|▏         | 18/1048 [00:09<09:32,  1.80it/s][A
  2%|▏    

Epoch: 10, train loss: 0.13782482541287105, val loss: 0.08240202611142938, train iou: 0.7545985369196603, val iou: 0.7568928849968043



  0%|          | 0/1048 [00:00<?, ?it/s][A
  0%|          | 1/1048 [00:00<10:47,  1.62it/s][A
  0%|          | 2/1048 [00:01<10:03,  1.73it/s][A
  0%|          | 3/1048 [00:01<09:56,  1.75it/s][A
  0%|          | 4/1048 [00:02<09:48,  1.77it/s][A
  0%|          | 5/1048 [00:02<09:45,  1.78it/s][A
  1%|          | 6/1048 [00:03<09:39,  1.80it/s][A
  1%|          | 7/1048 [00:03<09:32,  1.82it/s][A
  1%|          | 8/1048 [00:04<09:27,  1.83it/s][A
  1%|          | 9/1048 [00:05<09:25,  1.84it/s][A
  1%|          | 10/1048 [00:05<09:25,  1.83it/s][A
  1%|          | 11/1048 [00:06<09:22,  1.84it/s][A
  1%|          | 12/1048 [00:06<09:24,  1.84it/s][A
  1%|          | 13/1048 [00:07<09:24,  1.83it/s][A
  1%|▏         | 14/1048 [00:07<09:22,  1.84it/s][A
  1%|▏         | 15/1048 [00:08<09:21,  1.84it/s][A
  2%|▏         | 16/1048 [00:08<09:18,  1.85it/s][A
  2%|▏         | 17/1048 [00:09<09:19,  1.84it/s][A
  2%|▏         | 18/1048 [00:09<09:17,  1.85it/s][A
  2%|▏    

Epoch: 11, train loss: 0.13361313574030312, val loss: 0.06567122115203264, train iou: 0.7613861706420666, val iou: 0.7889115387065844



  0%|          | 0/1048 [00:00<?, ?it/s][A
  0%|          | 1/1048 [00:00<10:56,  1.60it/s][A
  0%|          | 2/1048 [00:01<09:56,  1.75it/s][A
  0%|          | 3/1048 [00:01<09:43,  1.79it/s][A
  0%|          | 4/1048 [00:02<09:38,  1.81it/s][A
  0%|          | 5/1048 [00:02<09:33,  1.82it/s][A
  1%|          | 6/1048 [00:03<09:30,  1.83it/s][A
  1%|          | 7/1048 [00:03<09:26,  1.84it/s][A
  1%|          | 8/1048 [00:04<09:24,  1.84it/s][A
  1%|          | 9/1048 [00:04<09:27,  1.83it/s][A
  1%|          | 10/1048 [00:05<09:22,  1.84it/s][A
  1%|          | 11/1048 [00:06<09:20,  1.85it/s][A
  1%|          | 12/1048 [00:06<09:22,  1.84it/s][A
  1%|          | 13/1048 [00:07<09:27,  1.82it/s][A
  1%|▏         | 14/1048 [00:07<09:31,  1.81it/s][A
  1%|▏         | 15/1048 [00:08<09:31,  1.81it/s][A
  2%|▏         | 16/1048 [00:08<09:31,  1.81it/s][A
  2%|▏         | 17/1048 [00:09<09:28,  1.81it/s][A
  2%|▏         | 18/1048 [00:09<09:24,  1.82it/s][A
  2%|▏    

Epoch: 12, train loss: 0.13399486670995714, val loss: 0.06764122854325581, train iou: 0.7629052131571843, val iou: 0.7874732481485064



  0%|          | 0/1048 [00:00<?, ?it/s][A
  0%|          | 1/1048 [00:00<09:30,  1.84it/s][A
  0%|          | 2/1048 [00:01<09:31,  1.83it/s][A
  0%|          | 3/1048 [00:01<09:26,  1.84it/s][A
  0%|          | 4/1048 [00:02<09:23,  1.85it/s][A
  0%|          | 5/1048 [00:02<09:25,  1.84it/s][A
  1%|          | 6/1048 [00:03<09:25,  1.84it/s][A
  1%|          | 7/1048 [00:03<09:26,  1.84it/s][A
  1%|          | 8/1048 [00:04<09:26,  1.84it/s][A
  1%|          | 9/1048 [00:04<09:32,  1.81it/s][A
  1%|          | 10/1048 [00:05<09:28,  1.83it/s][A
  1%|          | 11/1048 [00:05<09:24,  1.84it/s][A
  1%|          | 12/1048 [00:06<09:23,  1.84it/s][A
  1%|          | 13/1048 [00:07<09:23,  1.84it/s][A
  1%|▏         | 14/1048 [00:07<09:21,  1.84it/s][A
  1%|▏         | 15/1048 [00:08<09:27,  1.82it/s][A
  2%|▏         | 16/1048 [00:08<09:35,  1.79it/s][A
  2%|▏         | 17/1048 [00:09<09:46,  1.76it/s][A
  2%|▏         | 18/1048 [00:09<09:47,  1.75it/s][A
  2%|▏    

Epoch: 13, train loss: 0.13581611958993528, val loss: 0.07370332190343602, train iou: 0.7605067540159207, val iou: 0.7787053703584454



  0%|          | 0/1048 [00:00<?, ?it/s][A
  0%|          | 1/1048 [00:00<10:51,  1.61it/s][A
  0%|          | 2/1048 [00:01<09:58,  1.75it/s][A
  0%|          | 3/1048 [00:01<09:40,  1.80it/s][A
  0%|          | 4/1048 [00:02<09:35,  1.82it/s][A
  0%|          | 5/1048 [00:02<09:28,  1.83it/s][A
  1%|          | 6/1048 [00:03<09:28,  1.83it/s][A
  1%|          | 7/1048 [00:03<09:26,  1.84it/s][A
  1%|          | 8/1048 [00:04<09:31,  1.82it/s][A
  1%|          | 9/1048 [00:04<09:36,  1.80it/s][A
  1%|          | 10/1048 [00:05<09:36,  1.80it/s][A
  1%|          | 11/1048 [00:06<09:32,  1.81it/s][A
  1%|          | 12/1048 [00:06<09:28,  1.82it/s][A
  1%|          | 13/1048 [00:07<09:25,  1.83it/s][A
  1%|▏         | 14/1048 [00:07<09:22,  1.84it/s][A
  1%|▏         | 15/1048 [00:08<09:20,  1.84it/s][A
  2%|▏         | 16/1048 [00:08<09:21,  1.84it/s][A
  2%|▏         | 17/1048 [00:09<09:19,  1.84it/s][A
  2%|▏         | 18/1048 [00:09<09:18,  1.84it/s][A
  2%|▏    

Epoch: 14, train loss: 0.13243388002453985, val loss: 0.07387564195828004, train iou: 0.7640871916904705, val iou: 0.7709351087158377





You can also experiment with models and write a small report about results. If the report will be meaningful, you will receive an extra point.

### Testing (8 points)
Your model will be tested on the new data, similar to validation, so use techniques to prevent overfitting the model.

* IoU > 0.85 — 8 points
* IoU > 0.80 — 7 points
* IoU > 0.75 — 6 points
* IoU > 0.70 — 5 points
* IoU > 0.60 — 4 points
* IoU > 0.50 — 3 points
* IoU > 0.40 — 2 points
* IoU > 0.30 — 1 points

In [9]:
def read_and_preprocess_image(img_path: str) -> torch.Tensor:
    img = BirdsDataset.preprocess_image(np.asarray(open_image(img_path)))
    return BirdsDataset.make_data_transformer()(image=img)["image"]

def resize_mask(gt_mask: np.ndarray) -> np.ndarray:
    gt_mask = BirdsDataset.preprocess_mask(gt_mask)
    mask = BirdsDataset.make_data_transformer()(image=gt_mask, mask=gt_mask)["mask"].numpy()
    assert tuple(mask.shape) == (IMG_HEIGHT, IMG_WIDTH)
    assert np.all((mask == 0) ^ (mask == 1))
    return mask

def predict(model, img_path: str) -> np.ndarray:
    model.eval()
    try:
        model_device = next(model.parameters()).device
    except (AttributeError, StopIteration):
        model_device = "cuda"

    with torch.no_grad():
        x = read_and_preprocess_image(img_path).unsqueeze(0).to(model_device)
        logits_pred = model(x).detach().squeeze(0)
        assert tuple(logits_pred.shape) == (2, IMG_HEIGHT, IMG_WIDTH)
        class_pred = logits_pred.argmax(0)
        assert tuple(class_pred.shape) == (IMG_HEIGHT, IMG_WIDTH)
        return class_pred.cpu().numpy()

In [10]:
model = get_model('model_14.pth').to('cuda')

In [11]:
ious, times = [], []
test_dir = 'data/val/'

for class_name in tqdm(sorted(os.listdir(os.path.join(test_dir, 'images')))):
    for img_name in sorted(os.listdir(os.path.join(test_dir, 'images', class_name))):
        img_path=os.path.join(test_dir, 'images', class_name, img_name)
        t_start = time()
        pred = predict(model, img_path)
        times.append(time() - t_start)

        gt_name = img_name.replace('jpg', 'png')
        gt = np.asarray(Image.open(os.path.join(test_dir, 'gt', class_name, gt_name)), dtype = np.uint8)
        if len(gt.shape) > 2:
            gt = gt[:, :, 0]

        assert tuple(pred.shape) == (IMG_HEIGHT, IMG_WIDTH)
        assert np.all((pred == 0) ^ (pred == 1))
        gt = resize_mask(gt)
        assert tuple(gt.shape) == (IMG_HEIGHT, IMG_WIDTH)
        iou = get_iou(gt == 1, pred > 0.5)
        ious.append(iou)

np.mean(ious), np.mean(times)

100%|██████████| 200/200 [00:55<00:00,  3.60it/s]


(0.7550115921417373, 0.03667253603467222)

### Compression (1 point)

Try to speed up the model in any way without losing more than 1% in iou score.
For example [torch2trt](https://github.com/NVIDIA-AI-IOT/torch2trt)

In [11]:
!sudo apt update
!sudo apt upgrade
!sudo apt install build-essential gcc g++ clang clang-tools llvm cmake make ninja-build tensorrt
!sudo apt upgrade
!dpkg-query -W tensorrt

%pip install --upgrade pip
%pip install wheel
%pip install onnxruntime
%pip cache remove "tensorrt*"
%pip install --upgrade tensorrt tensorrt-lean tensorrt-dispatch

!git clone https://github.com/NVIDIA-AI-IOT/torch2trt
%cd torch2trt
!python setup.py install
!cmake -B build . && cmake --build build --target install && ldconfig
%cd scripts
!bash build_contrib.sh
%cd ../..

[33m0% [Working][0m            Hit:1 http://security.ubuntu.com/ubuntu jammy-security InRelease
[33m0% [Connecting to archive.ubuntu.com] [Connected to cloud.r-project.org (18.239[0m                                                                               Hit:2 https://cloud.r-project.org/bin/linux/ubuntu jammy-cran40/ InRelease
[33m0% [Connecting to archive.ubuntu.com (91.189.91.81)] [Connecting to r2u.stat.il[0m                                                                               Hit:3 https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/x86_64  InRelease
Hit:4 https://ppa.launchpadcontent.net/deadsnakes/ppa/ubuntu jammy InRelease
Hit:5 https://ppa.launchpadcontent.net/graphics-drivers/ppa/ubuntu jammy InRelease
Hit:6 https://ppa.launchpadcontent.net/ubuntugis/ppa/ubuntu jammy InRelease
Hit:7 http://archive.ubuntu.com/ubuntu jammy InRelease
Hit:8 http://archive.ubuntu.com/ubuntu jammy-updates InRelease
Hit:9 http://archive.ubuntu.com/ubuntu

In [10]:
from torch2trt import torch2trt

def get_fast_model(device="cuda"):
    EXAMPLE_BATCH_SIZE = 1
    model = torch2trt(get_model('model_14.pth').to(device), [torch.randn((EXAMPLE_BATCH_SIZE, CHANNELS_COUNT, IMG_HEIGHT, IMG_WIDTH), device=device)])
    model.eval()
    if device is not None:
        model = model.to(device)
    return model


In [11]:
set_random_seed(0xbad)
fast_model = get_fast_model().to('cuda')

In [16]:
ious, times = [], []
test_dir = 'data/val/'

for class_name in tqdm(sorted(os.listdir(os.path.join(test_dir, 'images')))):
    for img_name in sorted(os.listdir(os.path.join(test_dir, 'images', class_name))):
        t_start = time()
        pred = predict(fast_model, os.path.join(test_dir, 'images', class_name, img_name))
        times.append(time() - t_start)

        gt_name = img_name.replace('jpg', 'png')
        gt = np.asarray(Image.open(os.path.join(test_dir, 'gt', class_name, gt_name)), dtype = np.uint8)
        if len(gt.shape) > 2:
            gt = gt[:, :, 0]

        gt = resize_mask(gt)
        iou = get_iou(gt == 1, pred > 0.5)
        ious.append(iou)

np.mean(ious), np.mean(times)

100%|██████████| 200/200 [00:37<00:00,  5.34it/s]


(0.7550115921417373, 0.023838715397277225)

**Bonus:** For the best iou score on test(without compression) in group you will get 1.5, 1, 0.5 extra points(for 1st, 2nd, 3rd places).