# U-net implementation

[Original video](https://youtu.be/IHq1t7NxS8k)

[Paper walkthrough](https://youtu.be/oLvmLJkmXuc)

[Paper](https://arxiv.org/abs/1505.04597)

![U-Net](https://lmb.informatik.uni-freiburg.de/people/ronneber/u-net/u-net-architecture.png)

## Get previously saved model file

## Import libraries

In [None]:
import os
import torch
import shutil
import torchvision
import numpy as np
import torch.nn as nn
import multiprocessing
import albumentations as A
import torch.optim as optim
import torchvision.transforms.functional as TF

from PIL import Image
from tqdm.notebook import tqdm
from albumentations.pytorch import ToTensorV2
from torch.utils.data import Dataset, DataLoader

## UNet model architecture

In [None]:
class DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 3, 1, 1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),

            nn.Conv2d(out_channels, out_channels, 3, 1, 1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
        )

    def forward(self, x):
        return self.conv(x)


class UNET(nn.Module):
    def __init__(self, in_channels=3, out_channels=1, features=[64, 128, 256, 512]):
        super().__init__()
        self.downs = nn.ModuleList()
        self.ups = nn.ModuleList()
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)

        # Down part of U-Net
        for feature in features:
            self.downs.append(DoubleConv(in_channels, feature))
            in_channels = feature

        # Up part of U-Net
        for feature in reversed(features):
            self.ups.append(nn.ConvTranspose2d(feature*2, feature, kernel_size=2, stride=2))
            self.ups.append(DoubleConv(feature*2, feature))

        self.bottleneck = DoubleConv(features[-1], features[-1]*2)
        self.final = nn.Conv2d(features[0], out_channels, kernel_size=1)

    def forward(self, x):
        skip_connections = []
        for down in self.downs:
            x = down(x)
            skip_connections.append(x)
            x = self.pool(x)

        x = self.bottleneck(x)
        skip_connections = skip_connections[::-1]  # reverse list

        for idx in range(0, len(self.ups), 2):
            x = self.ups[idx](x)
            skip_conn = skip_connections[idx//2]

            if x.shape != skip_conn.shape:
                x = TF.resize(x, size=skip_conn.shape[2:])

            concat_skip = torch.cat((skip_conn, x), dim=1)  # batch x channel x h x w
            x = self.ups[idx+1](concat_skip)  # double conv

        return self.final(x)


def test():
    x = torch.randn((16, 3, 299, 299))
    model = UNET(in_channels=3, out_channels=1)
    preds = model(x)
    print(preds.shape)
    assert preds.shape == (16, 1, 299, 299)


test()

torch.Size([16, 1, 299, 299])


## Dataset from Kaggle

**NOTE:** Now I cannot download data from [Carvana Image Masking Challenge](https://www.kaggle.com/competitions/carvana-image-masking-challenge). Thus, I took another dataset [BRISC 2025](https://www.kaggle.com/datasets/briscdataset/brisc2025).

BRISC is a high-quality, expert-annotated MRI dataset curated for brain tumor segmentation and classification. It addresses common limitations in existing datasets (e.g., BraTS, Figshare), including class imbalance, narrow tumor focus, and annotation inconsistencies.

<br />
<br />

But first, authorize on [Kaggle](https://www.kaggle.com/settings) and get `kaggle.json` file, which is a crucial component for interacting with the Kaggle API.

   * On [Kaggle Setting](https://www.kaggle.com/settings) web page press "Create new token" button.
   
   * Ensure `kaggle.json` file is in the location `~/.kaggle/kaggle.json` to use the API.

   * Save your personal `kaggle.json` file and use it in this Google Colab.

In [None]:
# Colab's file access feature
from google.colab import files

# Upload YOUR PERSONAL `kaggle.json` file
uploaded = files.upload()

# Retrieve uploaded file and print results
for fn in uploaded.keys():
  print(f'User uploaded file "{fn}" with length {len(uploaded[fn])} bytes')
  os.rename(fn, 'kaggle.json')

# Then copy kaggle.json into the folder where the API expects to find it.
!mkdir -p ~/.kaggle
!cp kaggle.json ~/.kaggle/
!chmod 600 ~/.kaggle/kaggle.json
!ls -hal ~/.kaggle

Saving kaggle_dzmitrypaulenka.json to kaggle_dzmitrypaulenka.json
User uploaded file "kaggle_dzmitrypaulenka.json" with length 71 bytes
total 16K
drwxr-xr-x 2 root root 4.0K Aug 12 12:41 .
drwx------ 1 root root 4.0K Aug 12 12:41 ..
-rw------- 1 root root   71 Aug 12 12:41 kaggle.json


**NOTE:** you have to to accept [Carvana Image Masking Challenge](https://www.kaggle.com/competitions/carvana-image-masking-challenge) competition rules at
https://www.kaggle.com/c/carvana-image-masking-challenge/rules

Without acception it will be an error:

    401 Client Error: Unauthorized for url: https://www.kaggle.com/api/v1/competitions/data/list/$competition_name?pagesize=20
    401 Client Error: Unauthorized for url: https://www.kaggle.com/api/v1/competitions/data/download/$competition_name/train.zip
    401 Client Error: Unauthorized for url: https://www.kaggle.com/api/v1/competitions/data/download/$competition_name/train_masks.zip

To accept the competition rules and download the dataset:
   * go to the [web page](https://www.kaggle.com/c/carvana-image-masking-challenge/rules) with rules;
   * click **"Late Submission"** button **at the bottom of the page** to accept the competition’s terms and conditions.
   * after submission approval there should be a message at the top of the page **"You have accepted the rules for this competition. Good luck!"**.

In [None]:
# competition_name = 'carvana-image-masking-challenge'

# BRISC 2025: Brain Tumor MRI Dataset for Segmentation and Classification
competition_name = 'briscdataset/brisc2025'

!kaggle datasets list -s "{competition_name}"
# !kaggle datasets files "{competition_name}"

!kaggle datasets download "{competition_name}"

ref                     title                size  lastUpdated                 downloadCount  voteCount  usabilityRating  
----------------------  -------------  ----------  --------------------------  -------------  ---------  ---------------  
briscdataset/brisc2025  🧠 BRISC 2025    260542520  2025-07-20 19:37:33.853000           1009         19  0.75             
Dataset URL: https://www.kaggle.com/datasets/briscdataset/brisc2025
License(s): Attribution 4.0 International (CC BY 4.0)
Downloading brisc2025.zip to /content
100% 248M/248M [00:00<00:00, 703MB/s] 
100% 248M/248M [00:00<00:00, 741MB/s]


In [None]:
# Extract data
import zipfile
import tarfile

def extract(fname):
    if fname.endswith(".tar.gz") or fname.endswith('.tgz'):
        ref = tarfile.open(fname, mode='r:gz')
    elif fname.endswith('.tar'):
        ref = tarfile.open(fname, mode='r:')
    elif fname.endswith('.tar.bz2') or fname.endswith('.tbz'):
        ref = tarfile.open(fname, mode='r:bz2')
    elif fname.endswith('.zip'):
        ref = zipfile.ZipFile(fname, mode='r')

    ref.extractall()
    ref.close()

extract('brisc2025.zip')

In [None]:
# # Create test dataset. Move several examples from train to test
# !mkdir -p test
# !mkdir -p test_masks

# test_files = ['00087a6bd4dc', '02159e548029', '03a857ce842d']
# for f in test_files:
#     !mv train/'$f'_??.jpg test
#     !mv train_masks/'$f'_??_mask.gif test_masks

# train = os.listdir('train')
# train_masks = os.listdir('train_masks')
# print('train size: ', len(train))
# print('train masks:', len(train_masks))

# test = os.listdir('test')
# test_masks = os.listdir('test_masks')
# print('test size:  ', len(test))
# print('test masks: ', len(test_masks))

In [None]:
images_dir = '/content/brisc2025/segmentation_task/test/images/'
masks_dir  = '/content/brisc2025/segmentation_task/test/masks/'
image_name = 'brisc2025_test_00001_gl_ax_t1.jpg'
mask_name  = 'brisc2025_test_00001_gl_ax_t1.png'

image = np.array(Image.open(images_dir + image_name))
mask = np.array(Image.open(masks_dir + mask_name), dtype=np.float32)
print(f'min: {np.min(image)}; max: {np.max(image)}; shape: {image.shape}')  # [0, 255]
print(f'min: {np.min(mask)}; max: {np.max(mask)}; shape: {mask.shape}')  # [0.0, 255.0]

min: 0; max: 255; shape: (512, 512, 3)
min: 0.0; max: 255.0; shape: (512, 512)


In [None]:
class  BriscDataset(Dataset):
    def __init__(self, images_dir, masks_dir, transform=None):
        super().__init__()
        self.images_dir = images_dir
        self.masks_dir = masks_dir
        self.transform = transform
        self.images = os.listdir(self.images_dir)

    def __len__(self):
        return len(self.images)

    def __getitem__(self, index):
        img_path = os.path.join(self.images_dir, self.images[index])
        mask_path = os.path.join(self.masks_dir, self.images[index].replace('.jpg', '.png'))
        image = np.array(Image.open(img_path))  # [0, 255]
        if len(image.shape) == 2:  # image shape is (512, 512)
            image = image[:, :, np.newaxis].repeat(3, axis=2)  # convert 2D to 3D array
        mask = np.array(Image.open(mask_path), dtype=np.float32)  # [0.0, 255.0]
        mask = mask / np.max(mask)  # normalize to [0.0, 1.0]

        if self.transform is not None:
            augmentations = self.transform(image=image, mask=mask)
            image, mask = augmentations['image'], augmentations['mask']

        return image, mask

In [None]:
# class CarvanaDataset(Dataset):
#     def __init__(self, image_dir, mask_dir, transform=None):
#         super().__init__()
#         self.image_dir = image_dir
#         self.mask_dir = mask_dir
#         self.transform = transform
#         self.images = os.listdir(self.image_dir)

#     def __len__(self):
#         return len(self.images)

#     def __getitem__(self, index):
#         img_path = os.path.join(self.image_dir, self.images[index])
#         mask_path = os.path.join(self.mask_dir, self.images[index].replace('.jpg', '_mask.gif'))
#         image = np.array(Image.open(img_path))  # [0, 255]
#         mask = np.array(Image.open(mask_path), dtype=np.float32)  # [0.0, 1.0]

#         if self.transform is not None:
#             augmentations = self.transform(image=image, mask=mask)
#             image, mask = augmentations['image'], augmentations['mask']

#         return image, mask

## Helper functions

In [None]:
def save_checkpoint(model, optimizer, filename):
    print('=> Saving checkpoint')
    checkpoint = {
        'state_dict': model.state_dict(),
        'optimizer': optimizer.state_dict(),
    }
    torch.save(checkpoint, filename)


def load_checkpoint(checkpoint_file, model, optimizer, lr):
    print('=> Loading checkpoint')
    checkpoint = torch.load(checkpoint_file, map_location=DEVICE)
    model.load_state_dict(checkpoint['state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer'])

    # Replace old learning rate from the saved model
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr



def get_loaders(train_dir, train_maskdir, test_dir, test_maskdir, batch_size,
                train_transform, test_transform,
                num_workers=multiprocessing.cpu_count(), pin_memory=True):

    train_ds = BriscDataset(images_dir=train_dir, masks_dir=train_maskdir,
                            transform=train_transform)
    train_loader = DataLoader(train_ds, batch_size=batch_size,
                              num_workers=num_workers,
                              pin_memory=pin_memory, shuffle=True)
    test_ds = BriscDataset(images_dir=test_dir, masks_dir=test_maskdir,
                           transform=test_transform)
    test_loader = DataLoader(test_ds, batch_size=batch_size,
                             num_workers=num_workers,
                             pin_memory=pin_memory, shuffle=False)

    return train_loader, test_loader


def dice_score(input, target):
    smooth = 1.
    intersection = (input * target).sum()

    return ((2. * intersection + smooth) /
            (input.sum() + target.sum() + smooth))



def dice_loss(input, target):
    return 1 - dice_score(input, target)


def check_accuracy(loader, model, device='cuda'):
    num_correct, num_pixels, d_score = 0, 0, 0

    model.eval()
    with torch.no_grad():
        for x, y in loader:
            x = x.to(device)
            y = y.to(device).unsqueeze(1)  # (N, h, w) ==> (N, 1, h, w)
            preds = torch.sigmoid(model(x))
            preds = (preds > 0.5).float()
            num_correct += (preds == y).sum()
            num_pixels += torch.numel(preds)
            d_score += dice_score(preds, y)

    print(f'Got {num_correct}/{num_pixels} '
          f'with accuracy {num_correct/num_pixels*100:.2f}%')
    print(f'Dice-score: {d_score/len(loader)*100:.2f}%')
    model.train()


def save_predictions_as_imgs(loader, model, folder='saved_images', device='cuda'):
    os.makedirs(folder, exist_ok=True)
    model.eval()
    for idx, (x, y) in enumerate(loader):
        x = x.to(device=device)
        with torch.no_grad():
            preds = torch.sigmoid(model(x))  # prediction (0, 1)
            preds = (preds > 0.5).float()
        torchvision.utils.save_image(preds, f'{folder}/pred_{idx}.jpg')
        torchvision.utils.save_image(y.unsqueeze(1), f'{folder}/{idx}.jpg')
    model.train()

## Train

In [None]:
# Hyperparameters
LEARNING_RATE = 1e-4
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
USE_CUDA = torch.cuda.is_available()  # check if CUDA is available
BATCH_SIZE = 16
NUM_EPOCHS = 10
NUM_WORKERS = multiprocessing.cpu_count()
IMAGE_HEIGHT = IMAGE_WIDTH = 512
PIN_MEMORY = True
LOAD_MODEL = True
CHECKPOINT = 'UNet.pth.tar'

TRAIN_IMG_DIR  = '/content/brisc2025/segmentation_task/train/images/'
TRAIN_MASK_DIR = '/content/brisc2025/segmentation_task/train/masks/'
TEST_IMG_DIR   = '/content/brisc2025/segmentation_task/test/images/'
TEST_MASK_DIR  = '/content/brisc2025/segmentation_task/test/masks/'


def train(loader, model, optimizer, loss_fn, scaler, epoch, num_epochs):
    # progress bar
    loop = tqdm(enumerate(loader), total=len(loader), leave=False)
    loop.set_description(f'Epoch [{epoch}/{num_epochs-1}]')

    for batch_idx, (data, targets) in loop:
        data = data.to(DEVICE)
        targets = targets.to(DEVICE).unsqueeze(1)  # (N, h, w) ==> (N, 1, h, w)

        # forward
        with torch.autocast(DEVICE, dtype=torch.float16):
            predictions = model(data)
            # loss = loss_fn(predictions, targets)
            loss = dice_loss(torch.sigmoid(predictions), targets)

        # backward
        optimizer.zero_grad()
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        # update tqdm loop
        loop.set_postfix(loss=loss.item())


def main():
    train_transform = A.Compose([
        A.Resize(height=IMAGE_HEIGHT, width=IMAGE_WIDTH),
        A.Rotate(limit=35, p=1.0),
        # A.HorizontalFlip(p=0.5),
        # A.VerticalFlip(p=0.1),  # it helps to improve Dice-score on 0.1% somehow

        # Apply 1 of the 8 possible D4 dihedral group transformations to a square-shaped input
        A.D4(p=1.0),

        A.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], max_pixel_value=255),
        ToTensorV2(),
    ])

    test_transform = A.Compose([
        A.Resize(height=IMAGE_HEIGHT, width=IMAGE_WIDTH),
        A.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], max_pixel_value=255),
        ToTensorV2(),
    ])

    model = UNET(in_channels=3, out_channels=1).to(DEVICE)
    # WithLogits, because didn't use torch.sigmoid at the end of the model output
    # use cross entropy loss for multiple classes
    loss_fn = nn.BCEWithLogitsLoss()
    optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)

    if LOAD_MODEL and os.path.exists(CHECKPOINT):
        load_checkpoint(CHECKPOINT, model, optimizer, LEARNING_RATE)

    train_loader, test_loader = get_loaders(
        TRAIN_IMG_DIR, TRAIN_MASK_DIR,
        TEST_IMG_DIR, TEST_MASK_DIR, BATCH_SIZE,
        train_transform, test_transform,
        NUM_WORKERS, PIN_MEMORY)

    scaler = torch.amp.GradScaler(enabled=USE_CUDA)

    # x, y = next(iter(train_loader))
    # print(x)
    # print(y)

    for epoch in range(NUM_EPOCHS):
        train(train_loader, model, optimizer, loss_fn, scaler, epoch, NUM_EPOCHS)

        # save model
        save_checkpoint(model, optimizer, CHECKPOINT)

        # check accuracy
        check_accuracy(test_loader, model, device=DEVICE)

        # save examples
        save_predictions_as_imgs(test_loader, model, folder='saved_images', device=DEVICE)

In [None]:
# from google.colab import drive
# drive.mount('/content/gdrive')

# copy_from = '/content/gdrive/MyDrive/Colab Notebooks/PyTorch tutorial'

# !cp -rf '$copy_from/model.pth.tar' '.'

In [None]:
if __name__ == '__main__':
    main()

  0%|          | 0/246 [00:00<?, ?it/s]

=> Saving checkpoint
Got 206416501/225443840 with acc 91.56
Dice score: 0.2975063920021057


  0%|          | 0/246 [00:00<?, ?it/s]

=> Saving checkpoint
Got 220796496/225443840 with acc 97.94
Dice score: 0.592595636844635


  0%|          | 0/246 [00:00<?, ?it/s]

=> Saving checkpoint
Got 220136171/225443840 with acc 97.65
Dice score: 0.5737591981887817


  0%|          | 0/246 [00:00<?, ?it/s]

=> Saving checkpoint
Got 222479477/225443840 with acc 98.69
Dice score: 0.7290018796920776


  0%|          | 0/246 [00:00<?, ?it/s]

=> Saving checkpoint
Got 222854798/225443840 with acc 98.85
Dice score: 0.7444464564323425


  0%|          | 0/246 [00:00<?, ?it/s]

=> Saving checkpoint
Got 222956563/225443840 with acc 98.90
Dice score: 0.762932300567627


  0%|          | 0/246 [00:00<?, ?it/s]

=> Saving checkpoint
Got 222934014/225443840 with acc 98.89
Dice score: 0.7447306513786316


  0%|          | 0/246 [00:00<?, ?it/s]

=> Saving checkpoint
Got 222877172/225443840 with acc 98.86
Dice score: 0.7659452557563782


  0%|          | 0/246 [00:00<?, ?it/s]

=> Saving checkpoint
Got 223125761/225443840 with acc 98.97
Dice score: 0.7891131043434143


  0%|          | 0/246 [00:00<?, ?it/s]

=> Saving checkpoint
Got 223083386/225443840 with acc 98.95
Dice score: 0.7698049545288086


## Results

### Carvana Image Masking Challenge

For [Carvana Image Masking Challenge](https://www.kaggle.com/competitions/carvana-image-masking-challenge)

After 6 epochs.

Got 1833077 / 1843200 with accuracy 99.45%

My Dice-score: 98.799%

**Best** Kaggle Dice-score: 99.734%

### BRISC 2025]

For [BRISC 2025](https://www.kaggle.com/datasets/briscdataset/brisc2025)

After 9 epochs.

Got accuracy 98.97%

My Dice-score: 78.91%

<br />

[Brain Tumor Segmentation with Unet](https://www.kaggle.com/code/nirmalgaud/brain-tumor-segmentation-with-unet) has Dice-score: 68.43%

[Brain Tumor Segmentation with Wavelet](https://www.kaggle.com/code/nirmalgaud/brain-tumor-segmentation-with-wavelet) has Dice-score: 78.40%

In [None]:
# from google.colab import drive
# drive.mount('/content/gdrive')

# copy_to = '/content/gdrive/MyDrive/Colab Notebooks/PyTorch tutorial'
# !zip -qr saved_images.zip saved_images/

# !cp -rf saved_images.zip     '$copy_to'
# !cp -rf $CHECKPOINT '$copy_to'