In [1]:
import torch
import torch.nn as nn
import torchvision.models
from torch.utils.data import Dataset, DataLoader
import torch.optim as optim
import torch.nn.functional as F

import albumentations as A
from albumentations.pytorch import ToTensorV2

from tqdm import tqdm
from PIL import Image
from PIL.Image import open as open_image
import cv2
import matplotlib.pyplot as plt
import numpy as np

import os
from os.path import join as path_join

from time import time

### Get the data

In [2]:
if not os.path.exists("data.zip"):
    import gdown
    url = 'https://drive.google.com/uc?id=10f1H2T-5W-BiqabHHtlZ4ASs19TZmg8R'
    output = 'data.zip'
    gdown.download(url, output, quiet=False)
    del gdown
    del url
    del output
    !unzip data.zip

In [3]:
def set_random_seed(seed: int):
    torch.backends.cudnn.deterministic = True
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    np.random.seed(seed)
    if "random" in globals():
        random.seed(seed)

set_random_seed(1)

### Utilities (0.5 point)

Complete dataset to load prepared images and masks. Don't forget to use augmentations.

Some of the images are 1 channels, so use `gray2rgb`.

In [4]:
# for class_name in os.listdir("data/train/images"):
#     class_images_dir: str = path_join("data/train/images", class_name)
#     class_gt_masks_dir: str = path_join("data/train/gt", class_name)
#     assert os.path.exists(class_gt_masks_dir), f"no dir with ground truth masks {class_gt_masks_dir} for the class {class_name}"
#     for fname in os.listdir(class_images_dir):
#         gt_mask_file_path = path_join(class_gt_masks_dir, BirdsDataset._gt_mask_filename_from_img_filename(fname))
#         assert os.path.exists(gt_mask_file_path), \
#             f"no ground truth mask {gt_mask_file_path} for the file {fname} with class {class_name}"

#         mask = np.array(open_image(gt_mask_file_path))
#         if mask.ndim != 2:
#             assert len(mask.shape) == 3 and mask.shape[-1] == 2, mask.shape
#             print(gt_mask_file_path)
#             print(path_join(class_images_dir, fname))
#             # print(mask)
#             # display(Image.fromarray(mask[:,:,0]))
#             # display(Image.fromarray(mask[:,:,1]))
#             assert mask[:,:,1].ndim == 2
#             assert 255 in np.unique(mask[:,:,1])
#             # print(mask)
#             # assert False

#         # not_gt_mask = mask != WHITE_PIXEL_VALUE
#         # if not not np.all(not_gt_mask):
#         #     # print(f"{np.unique(mask)}\n{mask.reshape(-1)}")
#         #     pass
#         # else:
#         #     pass
#         #     # print(gt_mask_file_path)
#         # mask[not_gt_mask] = 0.0
#         # mask.astype(np.float32)

In [5]:
IMG_HEIGHT = 224
IMG_WIDTH = 224
CHANNELS_COUNT = 3
PIXEL_MAX_VALUE = 255
WHITE_PIXEL_VALUE = PIXEL_MAX_VALUE
MASK_PIXEL_VALUE = WHITE_PIXEL_VALUE

def gray2rgb(img: np.ndarray) -> np.ndarray:
    match img.shape:
        case (_, _):
            return np.dstack([img, img, img])
        case (_, _, 3):
            return img
        case _:
            raise ValueError(f"Invalid img.shape: {img.shape}")

def get_iou(gt, pred):
    pred = pred > 0.5
    un = (gt | pred).sum()
    assert un > 0 or (gt & pred).sum() == 0
    return (gt & pred).sum() / un if un else 0

def imagenet_mean() -> tuple[float, float, float]:
    return (0.485, 0.456, 0.406)

def imagenet_std() -> tuple[float, float, float]:
    return (0.229, 0.224, 0.225)

class BirdsDataset(Dataset):
    def __init__(self, folder, is_train: None | bool = None) -> None:
        assert is_train is None or isinstance(is_train, bool)

        images_folder: str = os.path.join(folder, 'images')
        gt_folder: str = os.path.join(folder, 'gt')

        calc_manually = False
        if calc_manually:
            # These values are not the real mean and std, but calculating
            #  them would require iterating over images twice
            approx_mean = np.zeros(shape=(CHANNELS_COUNT,))
            approx_std = np.zeros(shape=(CHANNELS_COUNT,))

        img_index_to_class_name: list[str] = []
        img_index_to_file_name: list[str] = []
        for class_name in os.listdir(images_folder):
            class_images_dir: str = path_join(images_folder, class_name)
            class_gt_masks_dir: str = path_join(gt_folder, class_name)
            assert os.path.exists(class_gt_masks_dir), f"no dir with ground truth masks {class_gt_masks_dir} for the class {class_name}"
            for fname in os.listdir(class_images_dir):
                gt_mask_file_path = path_join(class_gt_masks_dir, BirdsDataset._gt_mask_filename_from_img_filename(fname))
                assert os.path.exists(gt_mask_file_path), \
                    f"no ground truth mask {gt_mask_file_path} for the file {fname} with class {class_name}"

                if calc_manually:
                    img_pixels: np.ndarray = gray2rgb(np.asarray(open_image(path_join(class_images_dir, fname)))).astype(np.float32)
                    assert len(img_pixels.shape) == 3 and img_pixels.shape[2] == CHANNELS_COUNT
                    img_pixels /= PIXEL_MAX_VALUE
                    approx_mean += img_pixels.mean(axis=(0, 1))
                    approx_std += img_pixels.std(axis=(0, 1))

                img_index_to_file_name.append(fname)

                # Note: only 1 instance of bytes sequence corresponding to the `class_name`
                #  exists in memory due to the reference semantic of the Python language
                img_index_to_class_name.append(class_name)

        if calc_manually:
            total_images = len(img_index_to_class_name)
            assert total_images == len(img_index_to_file_name) > 0
            approx_mean /= total_images
            approx_std /= total_images

            assert approx_mean.ndim == 1 and len(approx_mean) == CHANNELS_COUNT \
                and np.all((0 < approx_mean) & (approx_mean < 1))
            assert approx_std.ndim == 1 and len(approx_std) == CHANNELS_COUNT \
                and np.all((0 < approx_std) & (approx_std < 1))
            print(f"approx_mean = {approx_mean}")
            print(f"approx_std = {approx_std}")

        if is_train is None:
            is_train = "train" in folder
        self.transform = BirdsDataset.make_data_transformer(is_train=is_train)
        self._img_index_to_class_name = img_index_to_class_name
        self._img_index_to_file_name = img_index_to_file_name
        self._images_folder = images_folder
        self._gt_folder = gt_folder

    @staticmethod
    def preprocess_image(img: np.ndarray) -> np.ndarray:
        # assert np.issubdtype(img.dtype, np.integer), f"img.dtype = {img.dtype}"
        # assert np.all(img <= MASK_PIXEL_VALUE)
        # assert np.all(0 <= img)
        return gray2rgb(img)

    @staticmethod
    def preprocess_mask(mask: np.ndarray) -> np.ndarray:
        if mask.ndim != 2:
            mask = mask[:,:,0]
        # assert np.issubdtype(mask.dtype, np.integer), f"img.dtype = {mask.dtype}"
        mask = (mask == MASK_PIXEL_VALUE).astype(np.uint8)
        # assert np.all((mask == 0) ^ (mask == 1))
        return mask

    def __getitem__(self, index: int) -> tuple[torch.Tensor, torch.Tensor]:
        img, gt_mask = self._read_img_and_gt_mask_by_index(index)
        match self.transform(image=BirdsDataset.preprocess_image(img), mask=BirdsDataset.preprocess_mask(gt_mask)):
            case {"image": t_img, "mask": t_msk}:
                return t_img, t_msk.long()
        assert False

    def __len__(self) -> int:
        return len(self._img_index_to_file_name)

    @staticmethod
    def make_data_transformer(is_train: bool = False, mean=imagenet_mean(), std=imagenet_std()):
        assert len(mean) == len(std) == CHANNELS_COUNT
        actions: list[A.BasicTransform] = [
            A.Normalize(mean=mean, std=std, max_pixel_value=PIXEL_MAX_VALUE),
        ]
        if is_train:
            actions.extend([
                A.RandomResizedCrop(size=(IMG_HEIGHT, IMG_WIDTH)),
                A.HorizontalFlip(p=0.3),
                A.RandomBrightnessContrast(brightness_limit=(-0.1, 0.1), contrast_limit=(-0.1, 0.1), p=0.3),
            ])
        else:
            actions.extend([
                A.Resize(height=IMG_HEIGHT, width=IMG_WIDTH),
            ])
        actions.extend([
            ToTensorV2(),
        ])
        return A.Compose(actions)

    def _read_img_and_gt_mask_by_index(self, i: int) -> tuple[np.ndarray, np.ndarray]:
        cnm = self._img_index_to_class_name[i]
        fnm = self._img_index_to_file_name[i]
        return (
            np.asarray(open_image(path_join(self._images_folder, cnm, fnm))),
            np.asarray(open_image(path_join(self._gt_folder, cnm, self._gt_mask_filename_from_img_filename(fnm)))),
        )

    @staticmethod
    def _gt_mask_filename_from_img_filename(img_fname: str) -> str:
        return f"{img_fname.removesuffix('jpg')}png"


In [6]:
# bd = BirdsDataset(os.path.join("data", "train"))
# assert len(bd) > 0
# x = []
# for i in range(len(bd)):
#   x.extend(np.unique(bd[i][1]).tolist())
# assert np.all(np.unique(x) == np.array([0., 1.]))
# del bd

### Architecture (1 point)
Your task for today is to build your own Unet to solve the segmentation problem.

As an encoder, you can use pre-trained on IMAGENET models(or parts) from torchvision. The decoder must be trained from scratch.
It is forbidden to use data not from the `data` folder.

I advise you to experiment with the number of blocks so as not to overfit on the training sample and get good quality on validation.

In [7]:
# B = 2
# CR = 256
# CX = 512
# N = 102
# r = torch.randn(size=(B, CR, N, N))
# x = torch.randn(size=(B, CX, N, N))
# assert torch.cat((r, x), dim=1).shape == torch.Size((B, CR + CX, N, N))
# nn.Conv2d(in_channels=CHANNELS_COUNT, out_channels=256, kernel_size=3, stride=1, padding=1)(torch.randn((1, CHANNELS_COUNT, IMG_HEIGHT, IMG_WIDTH))).shape
# nn.MaxPool2d(kernel_size=2)(torch.randn((1, 256, IMG_HEIGHT, IMG_WIDTH))).shape
# nn.MaxPool2d(kernel_size=3, stride=2, padding=1)(torch.randn((1, 256, IMG_HEIGHT, IMG_WIDTH))).shape
# from torchvision import models
# from torchvision.models.segmentation.deeplabv3 import DeepLabV3
# from torchvision.models.segmentation.fcn import FCN
# m: DeepLabV3 = models.segmentation.deeplabv3_mobilenet_v3_large()
# m: torchvision.models.resnet.ResNet = models.resnet152(weights=models.ResNet152_Weights.DEFAULT)

# from torchvision import models
# # m: torchvision.models.resnet.ResNet = models.resnet152(weights=models.ResNet152_Weights.DEFAULT)
# m = models.segmentation.fcn_resnet50(weights_backbone=models.ResNet50_Weights.IMAGENET1K_V1).eval().backbone
# [s for s in dir(m) if not s.startswith('_') and ("layer" in s or "conv" in s or "relu" in s or "sampl" in s or "pool" in s)]
# B = 2
# N = 32
# M = 16
# assert m.conv1(torch.zeros((B, CHANNELS_COUNT, IMG_HEIGHT, IMG_WIDTH))).shape == torch.Size((B, 64, IMG_HEIGHT // 2, IMG_WIDTH // 2))
# assert m.bn1(torch.zeros((B, 64, IMG_HEIGHT // 2, IMG_WIDTH // 2))).shape == torch.Size((B, 64, IMG_HEIGHT // 2, IMG_WIDTH // 2))
# assert m.relu(torch.zeros((B, 64, IMG_HEIGHT // 2, IMG_WIDTH // 2))).shape == torch.Size((B, 64, IMG_HEIGHT // 2, IMG_WIDTH // 2))
# assert m.maxpool(torch.zeros((B, 64, IMG_HEIGHT // 2, IMG_WIDTH // 2))).shape == torch.Size((B, 64, IMG_HEIGHT // 4, IMG_WIDTH // 4))
# assert m.layer1(torch.zeros((B, 64, IMG_HEIGHT // 4, IMG_WIDTH // 4))).shape == torch.Size((B, 256, IMG_HEIGHT // 4, IMG_WIDTH // 4))
# assert m.layer2(torch.zeros((B, 256, IMG_HEIGHT // 4, IMG_WIDTH // 4))).shape == torch.Size((B, 512, IMG_HEIGHT // 8, IMG_WIDTH // 8))
# # assert m.layer3(torch.zeros((B, 512, IMG_HEIGHT // 8, IMG_WIDTH // 8))).shape == torch.Size((B, 1024, IMG_HEIGHT // 16, IMG_WIDTH // 16))
# assert m.layer3(torch.zeros((B, 512, IMG_HEIGHT // 8, IMG_WIDTH // 8))).shape == torch.Size((B, 1024, IMG_HEIGHT // 8, IMG_WIDTH // 8))

In [8]:
from torchvision import models

class DecoderBlock(nn.Module):
    def __init__(self, in_channels, mid_channels, out_channels):
        super().__init__()
        # self.upconv = nn.ConvTranspose2d(in_channels=in_channels, out_channels=out_channels, kernel_size=4, stride=2, padding=1)
        self.upsample = nn.UpsamplingNearest2d(scale_factor=2)
        self.conv1 = nn.Conv2d(in_channels=in_channels, out_channels=mid_channels, kernel_size=3, padding=1)
        self.relu1 = nn.ReLU()
        self.conv2 = nn.Conv2d(in_channels=mid_channels, out_channels=out_channels, kernel_size=3, padding=1)
        self.relu2 = nn.ReLU()

    def forward_with_resid(self, resid: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
        return self.forward(torch.cat((resid, self.upsample(x)), dim=1))

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.relu2(self.conv2(self.relu1(self.conv1(x))))

class Unet(nn.Module):
    def __init__(self):
        super().__init__()

        # pretrained_model_to_steal_layers_from
        enc_model: torchvision.models.resnet.ResNet = models.resnet152(weights=models.ResNet152_Weights.DEFAULT)

        B = 2
        assert enc_model.conv1(torch.zeros((B, CHANNELS_COUNT, IMG_HEIGHT, IMG_WIDTH))).shape == torch.Size((B, 64, IMG_HEIGHT // 2, IMG_WIDTH // 2))
        assert enc_model.bn1(torch.zeros((B, 64, IMG_HEIGHT // 2, IMG_WIDTH // 2))).shape == torch.Size((B, 64, IMG_HEIGHT // 2, IMG_WIDTH // 2))
        assert enc_model.relu(torch.zeros((B, 64, IMG_HEIGHT // 2, IMG_WIDTH // 2))).shape == torch.Size((B, 64, IMG_HEIGHT // 2, IMG_WIDTH // 2))

        assert enc_model.maxpool(torch.zeros((B, 64, IMG_HEIGHT // 2, IMG_WIDTH // 2))).shape == torch.Size((B, 64, IMG_HEIGHT // 4, IMG_WIDTH // 4))
        assert enc_model.layer1(torch.zeros((B, 64, IMG_HEIGHT // 4, IMG_WIDTH // 4))).shape == torch.Size((B, 256, IMG_HEIGHT // 4, IMG_WIDTH // 4))

        assert enc_model.layer2(torch.zeros((B, 256, IMG_HEIGHT // 4, IMG_WIDTH // 4))).shape == torch.Size((B, 512, IMG_HEIGHT // 8, IMG_WIDTH // 8))

        assert enc_model.layer3(torch.zeros((B, 512, IMG_HEIGHT // 8, IMG_WIDTH // 8))).shape == torch.Size((B, 1024, IMG_HEIGHT // 16, IMG_WIDTH // 16))

        self.inp = nn.Sequential(enc_model.conv1, enc_model.bn1, enc_model.relu)

        # encoder blocks
        self.encoder1 = nn.Sequential(enc_model.maxpool, enc_model.layer1)
        self.encoder2= enc_model.layer2
        self.encoder3= enc_model.layer3

        # decoder blocks
        self.decoder1 = DecoderBlock(in_channels=512 + 1024, mid_channels=1024, out_channels=512)
        self.decoder2 = DecoderBlock(in_channels=256 + 512, mid_channels=512, out_channels=256)
        self.decoder3 = DecoderBlock(in_channels=64 + 256, mid_channels=128, out_channels=64)

        self.out = nn.Sequential(
            nn.UpsamplingNearest2d(scale_factor=2),
            nn.Conv2d(in_channels=64, out_channels=2, kernel_size=1),
        )

    def forward(self, x: torch.Tensor):
        r1 = self.inp(x)
        r2 = self.encoder1(r1)
        r3 = self.encoder2(r2)
        x = self.encoder3(r3)
        x = self.decoder1.forward_with_resid(r3, x)
        x = self.decoder2.forward_with_resid(r2, x)
        x = self.decoder3.forward_with_resid(r1, x)
        return self.out(x)

    def mutable_parameters(self, recurse=True):
        return self.parameters(recurse)


In [9]:
# assert Unet()(torch.zeros(1, CHANNELS_COUNT, IMG_HEIGHT, IMG_WIDTH)).shape == torch.Size((1, 2, IMG_HEIGHT, IMG_WIDTH))

### Train script (0.5 point)

Complete the train and predict scripts.

In [10]:
# pred = torch.tensor([
#     [
#         [
#             [0, 0.5, 0.5],
#             [0, 0.5, 0.5],
#             [0.5, 0.5, 0.5],
#         ],
#         [
#             [10, 0.5, 0.5],
#             [10, 0.5, 0.5],
#             [0.5, 0.5, 0.5],
#         ],
#     ],
# ])
# target = torch.tensor([
#     [
#         [1, 0, 0],
#         [1, 0, 0],
#         [0, 0, 1],
#     ],
# ])
# nn.CrossEntropyLoss()(pred, target)

In [11]:
# get_iou(
#     gt=train_dataset[0][1].unsqueeze(0).to(DEVICE) == 1,
#     pred=model(train_dataset[0][0].unsqueeze(0).to(DEVICE)).detach().argmax(1) == 1
# )

In [12]:
def read_image_and_normalize(img_path: str) -> torch.Tensor:
    img = BirdsDataset.preprocess_image(np.asarray(open_image(img_path)))
    return BirdsDataset.make_data_transformer()(image=img)["image"]

def resize_and_normalize_mask(gt_mask: np.ndarray) -> np.ndarray:
    gt_mask = BirdsDataset.preprocess_mask(gt_mask)
    mask = BirdsDataset.make_data_transformer()(image=gt_mask, mask=gt_mask)["mask"].numpy()
    assert tuple(mask.shape) == (IMG_HEIGHT, IMG_WIDTH)
    assert np.all((mask == 0) ^ (mask == 1))
    return mask

def predict(model, img_path: str) -> np.ndarray:
    model.eval()
    model_device = next(model.parameters()).device
    with torch.no_grad():
        x = read_image_and_normalize(img_path).unsqueeze(0).to(model_device)
        assert np.all(x.cpu().numpy() <= 1)
        logits_pred = model(x).detach().squeeze(0)
        assert tuple(logits_pred.shape) == (2, IMG_HEIGHT, IMG_WIDTH)
        class_pred = logits_pred.argmax(0)
        assert tuple(class_pred.shape) == (IMG_HEIGHT, IMG_WIDTH)
        return class_pred.cpu().numpy()

def get_model(path) -> Unet:
    model = Unet()
    model.load_state_dict(torch.load(path, weights_only=False))
    model.eval()
    return model

In [13]:
def train_segmentation_model(data_path):
    # assert data_path == "data/"
    BATCH_SIZE = 8
    N_EPOCH = 15
    DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    train_dataset = BirdsDataset(data_path + 'train')
    val_dataset = BirdsDataset(data_path + 'val')
    train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
    val_dataloader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)

    model = Unet().to(DEVICE)
    optimizer = torch.optim.Adam(params=model.mutable_parameters(), lr=0.0007)
    criterion = nn.CrossEntropyLoss()
    losses_train, losses_val, ious_train, ious_val = [], [], [], []

    for epoch in (epochs_bar := tqdm(range(N_EPOCH))):
        epochs_bar.set_description(f"Epoch #{epoch}...\n", refresh=True)

        model.train()
        losses_over_batches = []
        ious_over_batches = []
        for inputs, masks in tqdm(train_dataloader):
            masks = masks.to(DEVICE, non_blocking=True)
            inputs = inputs.to(DEVICE)

            optimizer.zero_grad()
            masks_pred = model(inputs)
            loss = criterion(masks_pred, masks)

            loss.backward()
            optimizer.step()

            losses_over_batches.append(loss.detach().cpu().item())
            ious_over_batches.append(get_iou(masks == 1, masks_pred.detach().argmax(1) == 1).cpu().item())


        losses_train.append(np.mean(losses_over_batches))
        ious_train.append(np.mean(ious_over_batches))

        model.eval()
        losses_over_batches = []
        ious_over_batches = []
        with torch.no_grad():
            for inputs, masks in tqdm(val_dataloader):
                masks = masks.to(DEVICE, non_blocking=True)
                inputs = inputs.to(DEVICE)

                masks_pred = model(inputs).detach()
                loss = criterion(masks_pred, masks)

                losses_over_batches.append(loss.cpu().item())
                ious_over_batches.append(get_iou(masks == 1, masks_pred.argmax(1) == 1).cpu().item())

        losses_val.append(np.mean(losses_over_batches))
        ious_val.append(np.mean(ious_over_batches))

        torch.save(model.state_dict(), f'model_{epoch}.pth')

        print(f"Epoch: {epoch}, train loss: {losses_train[-1]}, val loss: {losses_val[-1]}, train iou: {ious_train[-1]}, val iou: {ious_val[-1]}")

In [None]:
set_random_seed(0xdeadbeef)
train_segmentation_model('data/')

Epoch #0...:   0%|          | 0/15 [00:00<?, ?it/s]

You can also experiment with models and write a small report about results. If the report will be meaningful, you will receive an extra point.

### Testing (8 points)
Your model will be tested on the new data, similar to validation, so use techniques to prevent overfitting the model.

* IoU > 0.85 — 8 points
* IoU > 0.80 — 7 points
* IoU > 0.75 — 6 points
* IoU > 0.70 — 5 points
* IoU > 0.60 — 4 points
* IoU > 0.50 — 3 points
* IoU > 0.40 — 2 points
* IoU > 0.30 — 1 points

In [139]:
def read_image_and_normalize(img_path: str) -> torch.Tensor:
    img = BirdsDataset.preprocess_image(np.asarray(open_image(img_path)))
    return BirdsDataset.make_data_transformer()(image=img)["image"]

def resize_and_normalize_mask(gt_mask: np.ndarray) -> np.ndarray:
    gt_mask = BirdsDataset.preprocess_mask(gt_mask)
    mask = BirdsDataset.make_data_transformer()(image=gt_mask, mask=gt_mask)["mask"].numpy()
    assert tuple(mask.shape) == (IMG_HEIGHT, IMG_WIDTH)
    assert np.all((mask == 0) ^ (mask == 1))
    return mask

def predict(model, img_path: str) -> np.ndarray:
    model.eval()
    model_device = next(model.parameters()).device
    with torch.no_grad():
        x = read_image_and_normalize(img_path).unsqueeze(0).to(model_device)
        assert np.all(x.cpu().numpy() <= 1)
        logits_pred = model(x).detach().squeeze(0)
        assert tuple(logits_pred.shape) == (2, IMG_HEIGHT, IMG_WIDTH)
        class_pred = logits_pred.argmax(0)
        assert tuple(class_pred.shape) == (IMG_HEIGHT, IMG_WIDTH)
        return class_pred.cpu().numpy()

def get_model(path) -> Unet:
    model = Unet()
    model.load_state_dict(torch.load(path, weights_only=False))
    model.eval()
    return model

In [99]:
model = get_model('model_14.pth').to('cuda')

In [143]:
ious, times = [], []
test_dir = 'data/val/'

for class_name in tqdm(sorted(os.listdir(os.path.join(test_dir, 'images')))):
    print('\n')
    for img_name in sorted(os.listdir(os.path.join(test_dir, 'images', class_name))):
        img_path=os.path.join(test_dir, 'images', class_name, img_name)
        t_start = time()
        pred = predict(model, img_path)
        times.append(time() - t_start)

        gt_name = img_name.replace('jpg', 'png')
        gt = np.asarray(Image.open(os.path.join(test_dir, 'gt', class_name, gt_name)), dtype = np.uint8)
        if len(gt.shape) > 2:
            gt = gt[:, :, 0]

        assert tuple(pred.shape) == (IMG_HEIGHT, IMG_WIDTH)
        assert np.all((pred == 0) ^ (pred == 1))
        print(pred.sum())
        gt = resize_and_normalize_mask(gt)
        print(gt.sum())
        iou = get_iou(gt == 1, pred > 0.5)
        ious.append(iou)

np.mean(ious), np.mean(times)

  0%|          | 0/200 [00:00<?, ?it/s]



0
3983
0
5201
0
2863
0
5499
0
3188
0
9671
0
4078


  0%|          | 1/200 [00:00<02:26,  1.35it/s]

0
5695


0
1155
0
6583
0
10002
0
22901
0
12993
0
11250


  1%|          | 2/200 [00:01<01:57,  1.69it/s]

0
1410


0
5427
0
6261
0
3935


  2%|▏         | 3/200 [00:01<01:38,  2.00it/s]

0
3601
0
6706
0
8772
0
6464


0
5796
0
187
0
2435
0
6575


  2%|▏         | 4/200 [00:02<01:33,  2.10it/s]

0
8433
0
2428
0
3953


0
10233
0
6431
0
6535
0
3582
0
8531


  2%|▎         | 5/200 [00:02<01:25,  2.29it/s]

0
1381


0
23972
0
5636


  3%|▎         | 6/200 [00:02<01:17,  2.49it/s]

0
10475
0
7449
0
13259


0
26786
0
5129
0
5801
0
6356
0
3267


  4%|▎         | 7/200 [00:03<01:14,  2.61it/s]

0
8398


0
13209
0
1616
0
5076


  4%|▍         | 8/200 [00:03<01:08,  2.81it/s]

0
10885
0
13655


0
9857
0
2406
0
7049
0
5539
0
8283


  4%|▍         | 9/200 [00:03<01:15,  2.52it/s]

0
22499
0
4767
0
5703
0
6091


0
4279
0
22073
0
9970
0
4239
0
3941
0
3794


  4%|▍         | 9/200 [00:04<01:29,  2.13it/s]


KeyboardInterrupt: 

### Compression (1 point)

Try to speed up the model in any way without losing more than 1% in iou score.
For example [torch2trt](https://github.com/NVIDIA-AI-IOT/torch2trt)

In [None]:
def get_fast_model():
    # YOUR CODE HERE
    return model

In [None]:
fast_model = get_fast_model().to('cuda')

In [None]:
ious, times = [], []
test_dir = 'data/val/'

for class_name in tqdm(sorted(os.listdir(os.path.join(test_dir, 'images')))):
    for img_name in sorted(os.listdir(os.path.join(test_dir, 'images', class_name))):

        t_start = time()
        pred = predict(fast_model, os.path.join(test_dir, 'images', class_name, img_name))
        times.append(time() - t_start)

        gt_name = img_name.replace('jpg', 'png')
        gt = np.asarray(Image.open(os.path.join(test_dir, 'gt', class_name, gt_name)), dtype = np.uint8)
        if len(gt.shape) > 2:
            gt = gt[:, :, 0]

        iou = get_iou(gt==255, pred>0.5)
        ious.append(iou)

np.mean(ious), np.mean(times)

**Bonus:** For the best iou score on test(without compression) in group you will get 1.5, 1, 0.5 extra points(for 1st, 2nd, 3rd places).