# U-net implementation

[Original video](https://youtu.be/IHq1t7NxS8k)

[Paper walkthrough](https://youtu.be/oLvmLJkmXuc)

[Paper](https://arxiv.org/abs/1505.04597)

![U-Net](https://lmb.informatik.uni-freiburg.de/people/ronneber/u-net/u-net-architecture.png)

In [1]:
# Google CoLab has old version of albumentations library. Update it.
# Maybe after update restart the runtime.
!pip install -U git+https://github.com/albu/albumentations --no-cache-dir

Collecting git+https://github.com/albu/albumentations
  Cloning https://github.com/albu/albumentations to /tmp/pip-req-build-xnf5b3ar
  Running command git clone -q https://github.com/albu/albumentations /tmp/pip-req-build-xnf5b3ar
Collecting imgaug>=0.4.0
[?25l  Downloading https://files.pythonhosted.org/packages/66/b1/af3142c4a85cba6da9f4ebb5ff4e21e2616309552caca5e8acefe9840622/imgaug-0.4.0-py2.py3-none-any.whl (948kB)
[K     |████████████████████████████████| 952kB 20.4MB/s 
Building wheels for collected packages: albumentations
  Building wheel for albumentations (setup.py) ... [?25l[?25hdone
  Created wheel for albumentations: filename=albumentations-0.5.2-cp37-none-any.whl size=88144 sha256=d028f952e428d3384cc9723471571a67f91140bf1994cee245c0eb3df393349b
  Stored in directory: /tmp/pip-ephem-wheel-cache-qr5xxbn8/wheels/45/8b/e4/2837bbcf517d00732b8e394f8646f22b8723ac00993230188b
Successfully built albumentations
Installing collected packages: imgaug, albumentations
  Found exi

In [2]:
import os
import torch
import torchvision
import numpy as np
import torch.nn as nn
import multiprocessing
import albumentations as A
import torch.optim as optim
import torchvision.transforms.functional as TF

from PIL import Image
from tqdm.notebook import tqdm
from albumentations.pytorch import ToTensorV2
from torch.utils.data import Dataset, DataLoader

## Model

In [3]:
class DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 3, 1, 1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),

            nn.Conv2d(out_channels, out_channels, 3, 1, 1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
        )

    def forward(self, x):
        return self.conv(x)


class UNET(nn.Module):
    def __init__(self, in_channels=3, out_channels=1, features=[64, 128, 256, 512]):
        super().__init__()
        self.downs = nn.ModuleList()
        self.ups = nn.ModuleList()
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)

        # Down part of U-Net
        for feature in features:
            self.downs.append(DoubleConv(in_channels, feature))
            in_channels = feature
        
        # Up part of U-Net
        for feature in reversed(features):
            self.ups.append(nn.ConvTranspose2d(feature*2, feature, kernel_size=2, stride=2))
            self.ups.append(DoubleConv(feature*2, feature))

        self.bottleneck = DoubleConv(features[-1], features[-1]*2)
        self.final = nn.Conv2d(features[0], out_channels, kernel_size=1)

    def forward(self, x):
        skip_connections = []
        for down in self.downs:
            x = down(x)
            skip_connections.append(x)
            x = self.pool(x)

        x = self.bottleneck(x)
        skip_connections = skip_connections[::-1]  # reverse list

        for idx in range(0, len(self.ups), 2):
            x = self.ups[idx](x)
            skip_conn = skip_connections[idx//2]

            if x.shape != skip_conn.shape:
                x = TF.resize(x, size=skip_conn.shape[2:])

            concat_skip = torch.cat((skip_conn, x), dim=1)  # batch x channel x h x w
            x = self.ups[idx+1](concat_skip)  # double conv

        return self.final(x)


def test():
    x = torch.randn((16, 3, 299, 299))
    model = UNET(in_channels=3, out_channels=1)
    preds = model(x)
    print(preds.shape)
    assert preds.shape == (16, 1, 299, 299)


test()

torch.Size([16, 1, 299, 299])


## Dataset

In [5]:
# Colab's file access feature
from google.colab import files

# Upload `kaggle.json` file
uploaded = files.upload()

# Retrieve uploaded file and print results
for fn in uploaded.keys():
  print('User uploaded file "{name}" with length {length} bytes'.format(
      name=fn, length=len(uploaded[fn])))


# Then copy kaggle.json into the folder where the API expects to find it.
!mkdir -p ~/.kaggle
!cp kaggle.json ~/.kaggle/
!chmod 600 ~/.kaggle/kaggle.json
!ls ~/.kaggle

Saving kaggle.json to kaggle.json
User uploaded file "kaggle.json" with length 65 bytes
kaggle.json


In [6]:
# Note: you will need to accept competition rules at
#     https://www.kaggle.com/c/carvana-image-masking-challenge/rules

competition_name = 'carvana-image-masking-challenge'

!kaggle competitions files '$competition_name'

!kaggle competitions download '$competition_name' -f train.zip
!kaggle competitions download '$competition_name' -f train_masks.zip

name                        size  creationDate         
-------------------------  -----  -------------------  
test.zip                     8GB  2018-06-22 02:52:10  
train.zip                  405MB  2018-06-22 02:52:10  
29bb3ece3180_11.jpg        107KB  2018-06-22 02:52:10  
train_masks.zip             29MB  2018-06-22 02:52:10  
train_masks.csv.zip         15MB  2018-06-22 02:52:10  
train_hq.zip               804MB  2018-06-22 02:52:10  
metadata.csv.zip            81KB  2018-06-22 02:52:10  
test_hq.zip                 15GB  2018-06-22 02:52:10  
sample_submission.csv.zip  202KB  2018-06-22 02:52:10  
Downloading train.zip to /content
 97% 391M/405M [00:02<00:00, 212MB/s]
100% 405M/405M [00:02<00:00, 177MB/s]
Downloading train_masks.zip to /content
 52% 15.0M/29.1M [00:00<00:00, 155MB/s]
100% 29.1M/29.1M [00:00<00:00, 187MB/s]


In [7]:
# Extract data
import zipfile
import tarfile

def extract(fname):
    if fname.endswith(".tar.gz") or fname.endswith('.tgz'):
        ref = tarfile.open(fname, mode='r:gz')
    elif fname.endswith('.tar'):
        ref = tarfile.open(fname, mode='r:')
    elif fname.endswith('.tar.bz2') or fname.endswith('.tbz'):
        ref = tarfile.open(fname, mode='r:bz2')
    elif fname.endswith('.zip'):
        ref = zipfile.ZipFile(fname, mode='r')

    ref.extractall()
    ref.close()

extract('train.zip')
extract('train_masks.zip')

In [8]:
# Create test dataset. Move several examples from train to test
!mkdir -p test
!mkdir -p test_masks

test_files = ['00087a6bd4dc', '02159e548029', '03a857ce842d']
for f in test_files:
    !mv train/'$f'_??.jpg test
    !mv train_masks/'$f'_??_mask.gif test_masks

train = os.listdir('train')
train_masks = os.listdir('train_masks')
print('train size: ', len(train))
print('train masks:', len(train_masks))

test = os.listdir('test')
test_masks = os.listdir('test_masks')
print('test size:  ', len(test))
print('test masks: ', len(test_masks))

train size:  5040
train masks: 5040
test size:   48
test masks:  48


In [9]:
image = np.array(Image.open('test/00087a6bd4dc_01.jpg'))
mask = np.array(Image.open('test_masks/00087a6bd4dc_01_mask.gif'), dtype=np.float32)
print(np.min(image), np.max(image))  # [0, 255]
print(np.min(mask), np.max(mask))  # [0.0, 1.0]

0 255
0.0 1.0


In [10]:
class CarvanaDataset(Dataset):
    def __init__(self, image_dir, mask_dir, transform=None):
        super().__init__()
        self.image_dir = image_dir
        self.mask_dir = mask_dir
        self.transform = transform
        self.images = os.listdir(self.image_dir)

    def __len__(self):
        return len(self.images)

    def __getitem__(self, index):
        img_path = os.path.join(self.image_dir, self.images[index])
        mask_path = os.path.join(self.mask_dir, self.images[index].replace('.jpg', '_mask.gif'))
        image = np.array(Image.open(img_path))  # [0, 255]
        mask = np.array(Image.open(mask_path), dtype=np.float32)  # [0.0, 1.0]

        if self.transform is not None:
            augmentations = self.transform(image=image, mask=mask)
            image, mask = augmentations['image'], augmentations['mask']

        return image, mask

## Utils

In [11]:
def save_checkpoint(model, optimizer, filename):
    print('=> Saving checkpoint')
    checkpoint = {
        'state_dict': model.state_dict(),
        'optimizer': optimizer.state_dict(),
    }
    torch.save(checkpoint, filename)


def load_checkpoint(checkpoint_file, model, optimizer, lr):
    print('=> Loading checkpoint')
    checkpoint = torch.load(checkpoint_file, map_location=DEVICE)
    model.load_state_dict(checkpoint['state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer'])

    # Replace old learning rate from the saved model
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr



def get_loaders(train_dir, train_maskdir, test_dir, test_maskdir, batch_size,
                train_transform, test_transform,
                num_workers=multiprocessing.cpu_count(), pin_memory=True):

    train_ds = CarvanaDataset(image_dir=train_dir, mask_dir=train_maskdir,
                              transform=train_transform)
    train_loader = DataLoader(train_ds, batch_size=batch_size,
                              num_workers=num_workers,
                              pin_memory=pin_memory, shuffle=True)
    test_ds = CarvanaDataset(image_dir=test_dir, mask_dir=test_maskdir,
                             transform=test_transform)
    test_loader = DataLoader(test_ds, batch_size=batch_size,
                             num_workers=num_workers,
                             pin_memory=pin_memory, shuffle=False)

    return train_loader, test_loader


def dice_score(input, target):
    smooth = 1.
    intersection = (input * target).sum()

    return ((2. * intersection + smooth) /
            (input.sum() + target.sum() + smooth))



def dice_loss(input, target):
    return 1 - dice_score(input, target)


def check_accuracy(loader, model, device='cuda'):
    num_correct, num_pixels, d_score = 0, 0, 0

    model.eval()
    with torch.no_grad():
        for x, y in loader:
            x = x.to(device)
            y = y.to(device).unsqueeze(1)  # (N, h, w) ==> (N, 1, h, w)
            preds = torch.sigmoid(model(x))
            preds = (preds > 0.5).float()
            num_correct += (preds == y).sum()
            num_pixels += torch.numel(preds)
            d_score += dice_score(preds, y)

    print(f'Got {num_correct}/{num_pixels} '
          f'with acc {num_correct/num_pixels*100:.2f}')
    print(f'Dice score: {d_score/len(loader)}')
    model.train()


def save_predictions_as_imgs(loader, model, folder='saved_images', device='cuda'):
    os.makedirs(folder, exist_ok=True) 
    model.eval()
    for idx, (x, y) in enumerate(loader):
        x = x.to(device=device)
        with torch.no_grad():
            preds = torch.sigmoid(model(x))  # prediction (0, 1)
            preds = (preds > 0.5).float()
        torchvision.utils.save_image(preds, f'{folder}/pred_{idx}.jpg')
        torchvision.utils.save_image(y.unsqueeze(1), f'{folder}/{idx}.jpg')
    model.train()

## Train

In [18]:
# Hyperparameters
LEARNING_RATE = 1e-4
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
BATCH_SIZE = 16
NUM_EPOCHS = 3
NUM_WORKERS = multiprocessing.cpu_count()
IMAGE_HEIGHT = 160  # 1280 originally
IMAGE_WIDTH = 240  # 1918 originally
PIN_MEMORY = True
LOAD_MODEL = True
CHECKPOINT = 'model.pth.tar'
TRAIN_IMG_DIR = 'train'
TRAIN_MASK_DIR = 'train_masks'
TEST_IMG_DIR = 'test'
TEST_MASK_DIR = 'test_masks'


def train(loader, model, optimizer, loss_fn, scaler, epoch, num_epochs):
    # progress bar
    loop = tqdm(enumerate(loader), total=len(loader), leave=False)
    loop.set_description(f'Epoch [{epoch}/{num_epochs-1}]')

    for batch_idx, (data, targets) in loop:
        data = data.to(DEVICE)
        targets = targets.to(DEVICE).unsqueeze(1)  # (N, h, w) ==> (N, 1, h, w)

        # forward
        with torch.cuda.amp.autocast():
            predictions = model(data)
            # loss = loss_fn(predictions, targets)
            loss = dice_loss(torch.sigmoid(predictions), targets)

        # backward
        optimizer.zero_grad()
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        # update tqdm loop
        loop.set_postfix(loss=loss.item())


def main():
    train_transform = A.Compose([
        A.Resize(height=IMAGE_HEIGHT, width=IMAGE_WIDTH),
        A.Rotate(limit=35, p=1.0),
        A.HorizontalFlip(p=0.5),
        A.VerticalFlip(p=0.1),  # it helps to improve Dice-score on 0.1% somehow
        A.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], max_pixel_value=255),
        ToTensorV2(),
    ])
    test_transform = A.Compose([
        A.Resize(height=IMAGE_HEIGHT, width=IMAGE_WIDTH),
        A.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], max_pixel_value=255),
        ToTensorV2(),
    ])
    model = UNET(in_channels=3, out_channels=1).to(DEVICE)
    # WithLogits, because didn't use torch.sigmoid at the end of the model output
    # use cross entropy loss for multiple classes
    loss_fn = nn.BCEWithLogitsLoss()
    optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)

    if LOAD_MODEL and os.path.exists(CHECKPOINT):
        load_checkpoint(CHECKPOINT, model, optimizer, LEARNING_RATE)

    train_loader, test_loader = get_loaders(TRAIN_IMG_DIR, TRAIN_MASK_DIR,
                                            TEST_IMG_DIR, TEST_MASK_DIR, BATCH_SIZE,
                                            train_transform, test_transform,
                                            NUM_WORKERS, PIN_MEMORY)
    scaler = torch.cuda.amp.GradScaler()

    for epoch in range(NUM_EPOCHS):
        train(train_loader, model, optimizer, loss_fn, scaler, epoch, NUM_EPOCHS)

        # save model
        save_checkpoint(model, optimizer, CHECKPOINT)

        # check accuracy
        check_accuracy(test_loader, model, device=DEVICE)

        # save examples
        save_predictions_as_imgs(test_loader, model, folder='saved_images', device=DEVICE)

In [19]:
# from google.colab import drive
# drive.mount('/content/gdrive')

# copy_from = '/content/gdrive/MyDrive/Colab Notebooks/PyTorch tutorial'

# !cp -rf '$copy_from/model.pth.tar' '.'

In [22]:
# After 6 epochs
# Got 1833077/1843200 with acc 99.45
# Dice score: 98.78805875778198 %
#
# Best Kaggle Dice-score: 99.734 %


if __name__ == '__main__':
    main()

=> Loading checkpoint


HBox(children=(FloatProgress(value=0.0, max=315.0), HTML(value='')))

=> Saving checkpoint
Got 1831975/1843200 with acc 99.39
Dice score: 0.9865783452987671


HBox(children=(FloatProgress(value=0.0, max=315.0), HTML(value='')))

=> Saving checkpoint
Got 1832931/1843200 with acc 99.44
Dice score: 0.9877063632011414


HBox(children=(FloatProgress(value=0.0, max=315.0), HTML(value='')))

=> Saving checkpoint
Got 1833077/1843200 with acc 99.45
Dice score: 0.9878805875778198


In [23]:
from google.colab import drive
drive.mount('/content/gdrive')

copy_to = '/content/gdrive/MyDrive/Colab Notebooks/PyTorch tutorial'
!zip -qr saved_images.zip saved_images/

!cp -rf saved_images.zip     '$copy_to'
!cp -rf $CHECKPOINT '$copy_to'

Drive already mounted at /content/gdrive; to attempt to forcibly remount, call drive.mount("/content/gdrive", force_remount=True).
