# Segmentation using UNet

In [1]:
import numpy as np
from tqdm.notebook import tqdm
import torch
import torch.optim
import torchvision
from datasets.carvana import Carvana
from models.unet import UNet
import matplotlib.pyplot as plt
import albumentations as A
from albumentations.pytorch import ToTensorV2
from utils import (
    load_checkpoint,
    save_checkpoint,
    get_loaders,
    calculate_accuracy,
    save_predictions_as_imgs,
)

## Training parameters

In [2]:
LEARNING_RATE = 1e-4
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
BATCH_SIZE = 16
NUM_EPOCHS = 10
NUM_WORKERS = 2
IMAGE_HEIGHT = 160  # 1280 originally
IMAGE_WIDTH = 240  # 1918 originally

print ('Using GPU' if torch.cuda.is_available() else "Using CPU")

Using GPU


## Building image transformers for data augmentation

In [3]:
train_transforms = A.Compose(
        [
            A.Resize(height=160, width=240),
            A.Rotate(limit=35, p=1.0),
            A.HorizontalFlip(p=0.5),
            A.VerticalFlip(p=0.1),
            A.Normalize(
                mean=[0.0, 0.0, 0.0],
                std=[1.0, 1.0, 1.0],
                max_pixel_value=255.0,
            ),
            ToTensorV2(),
        ],
    )

val_transforms = A.Compose(
    [
        A.Resize(height=IMAGE_HEIGHT, width=IMAGE_WIDTH),
        A.Normalize(
            mean=[0.0, 0.0, 0.0],
            std=[1.0, 1.0, 1.0],
            max_pixel_value=255.0,
        ),
        ToTensorV2(),
    ],
)


## Loading dataset

In [4]:
train_loader, validation_loader = get_loaders(dataset_path="./data/carvana", batch_size=BATCH_SIZE ,train_transform=train_transforms, val_transform=val_transforms)

## Defining model, optimizer and loss functions

In [15]:
model = UNet(in_channels=3, out_channels=1).to(DEVICE)
loss_fn = torch.nn.BCEWithLogitsLoss()
optimize_fn = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
scaler = torch.cuda.amp.GradScaler()

calculate_accuracy(validation_loader, model, device=DEVICE)



Got 7569271/35443200 with acc 21.36
Dice score: 0.34863701462745667


In [28]:
def train_fn(loader, model, optimizer, loss_fn, scaler):
    loop = tqdm(loader)

    for batch_idx, (data, targets) in enumerate(loop):
        data = data.to(device=DEVICE)
        targets = targets.float().unsqueeze(1).to(device=DEVICE)

        # forward
        with torch.cuda.amp.autocast():
            predictions = model(data)
            loss = loss_fn(predictions, targets)

        # backward
        optimizer.zero_grad()
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        # update tqdm loop
        loop.set_postfix(loss=loss.item())

In [29]:
for epoch in range(NUM_EPOCHS):
    train_fn(train_loader, model, optimize_fn, loss_fn, scaler)
    # save model
    checkpoint = {
        "state_dict": model.state_dict(),
        "optimizer":optimize_fn.state_dict(),
    }
    save_checkpoint(checkpoint)

    # check accuracy
    calculate_accuracy(validation_loader, model, device=DEVICE)

    # print some examples to a folder
    save_predictions_as_imgs(
        validation_loader, model, folder="saved_images/", device=DEVICE
    )



  0%|          | 0/261 [00:00<?, ?it/s]

Got 35137849/35443200 with acc 99.14
Dice score: 0.9797587394714355


  0%|          | 0/261 [00:00<?, ?it/s]

Got 35102682/35443200 with acc 99.04
Dice score: 0.9775564074516296


  0%|          | 0/261 [00:00<?, ?it/s]

Got 35034409/35443200 with acc 98.85
Dice score: 0.9732093214988708


  0%|          | 0/261 [00:00<?, ?it/s]

Got 35227706/35443200 with acc 99.39
Dice score: 0.9856238961219788


  0%|          | 0/261 [00:00<?, ?it/s]

Got 35224617/35443200 with acc 99.38
Dice score: 0.9854153394699097


  0%|          | 0/261 [00:00<?, ?it/s]

Got 35247228/35443200 with acc 99.45
Dice score: 0.9868634343147278


  0%|          | 0/261 [00:00<?, ?it/s]

Got 35244053/35443200 with acc 99.44
Dice score: 0.9866546988487244


  0%|          | 0/261 [00:00<?, ?it/s]

Got 35256092/35443200 with acc 99.47
Dice score: 0.9874642491340637


  0%|          | 0/261 [00:00<?, ?it/s]

Got 35262796/35443200 with acc 99.49
Dice score: 0.9879517555236816


  0%|          | 0/261 [00:00<?, ?it/s]

Got 35253680/35443200 with acc 99.47
Dice score: 0.9873597621917725
