# Segmentation using UNet

In [1]:
import numpy as np
from tqdm.notebook import tqdm
import torch
import torch.optim
import torchvision
from datasets.carvana import Carvana
from models.unet import UNet
import matplotlib.pyplot as plt
import albumentations as A
from albumentations.pytorch import ToTensorV2
from utils import (
    load_checkpoint,
    save_checkpoint,
    get_loaders,
    calculate_accuracy,
    save_predictions_as_imgs,
)

## Training parameters

In [2]:
LEARNING_RATE = 1e-4
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
BATCH_SIZE = 16
NUM_EPOCHS = 10
NUM_WORKERS = 2
IMAGE_HEIGHT = 160  # 1280 originally
IMAGE_WIDTH = 240  # 1918 originally

print ('Using GPU' if torch.cuda.is_available() else "Using CPU")

Using GPU


## Building image transformers for data augmentation

In [3]:
train_transforms = A.Compose(
        [
            A.Resize(height=160, width=240),
            A.Rotate(limit=35, p=1.0),
            A.HorizontalFlip(p=0.5),
            A.VerticalFlip(p=0.1),
            A.Normalize(
                mean=[0.0, 0.0, 0.0],
                std=[1.0, 1.0, 1.0],
                max_pixel_value=255.0,
            ),
            ToTensorV2(),
        ],
    )

val_transforms = A.Compose(
    [
        A.Resize(height=IMAGE_HEIGHT, width=IMAGE_WIDTH),
        A.Normalize(
            mean=[0.0, 0.0, 0.0],
            std=[1.0, 1.0, 1.0],
            max_pixel_value=255.0,
        ),
        ToTensorV2(),
    ],
)


## Loading dataset

In [4]:
train_loader, validation_loader = get_loaders(dataset_path="./data/carvana", batch_size=BATCH_SIZE ,train_transform=train_transforms, val_transform=val_transforms)

## Dice loss

In [5]:
class DiceLoss(torch.nn.Module):
    def __init__(self, smooth=1):
        super(DiceLoss, self).__init__()
        self.smooth = smooth

    def forward(self, predictions, targets):
        # Flatten predictions and targets
        predictions = predictions.view(-1)
        targets = targets.view(-1)

        intersection = (predictions * targets).sum()
        dice_coefficient = (2. * intersection + self.smooth) / (predictions.sum() + targets.sum() + self.smooth)

        # The Dice Loss is the complement of the Dice Coefficient
        dice_loss = 1 - dice_coefficient

        return dice_loss

## Defining model, optimizer and loss functions

In [6]:
model = UNet(in_channels=3, out_channels=1).to(DEVICE)
# loss_fn = torch.nn.BCEWithLogitsLoss()
loss_fn = DiceLoss()
optimize_fn = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
scaler = torch.cuda.amp.GradScaler()
calculate_accuracy(validation_loader, model, device=DEVICE)



Got 27984126/35443200 with acc 78.95
Dice score: 0.0


In [7]:
def train_fn(loader, model, optimizer, loss_fn, scaler):
    loop = tqdm(loader)

    for batch_idx, (data, targets) in enumerate(loop):
        data = data.to(device=DEVICE)
        targets = targets.float().unsqueeze(1).to(device=DEVICE)

        # forward
        with torch.cuda.amp.autocast():
            predictions = model(data)
            predictions = torch.sigmoid(predictions)
            loss = loss_fn(predictions, targets)

        # backward
        optimizer.zero_grad()
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        # update tqdm loop
        loop.set_postfix(loss=loss.item())

In [8]:
for epoch in range(NUM_EPOCHS):
    train_fn(train_loader, model, optimize_fn, loss_fn, scaler)
    # save model
    checkpoint = {
        "state_dict": model.state_dict(),
        "optimizer":optimize_fn.state_dict(),
    }
    save_checkpoint(checkpoint)

    # check accuracy
    calculate_accuracy(validation_loader, model, device=DEVICE)

    # print some examples to a folder
    save_predictions_as_imgs(
        validation_loader, model, folder="saved_images/", device=DEVICE
    )



  0%|          | 0/261 [00:00<?, ?it/s]



Got 34928592/35443200 with acc 98.55
Dice score: 0.9664378762245178


  0%|          | 0/261 [00:00<?, ?it/s]

Got 34887270/35443200 with acc 98.43
Dice score: 0.9639837741851807


  0%|          | 0/261 [00:00<?, ?it/s]

Got 35184566/35443200 with acc 99.27
Dice score: 0.9827830791473389


  0%|          | 0/261 [00:00<?, ?it/s]

Got 35218743/35443200 with acc 99.37
Dice score: 0.985034167766571


  0%|          | 0/261 [00:00<?, ?it/s]

Got 35245093/35443200 with acc 99.44
Dice score: 0.9868009090423584


  0%|          | 0/261 [00:00<?, ?it/s]

Got 35243991/35443200 with acc 99.44
Dice score: 0.9866747260093689


  0%|          | 0/261 [00:00<?, ?it/s]

Got 35108410/35443200 with acc 99.06
Dice score: 0.9779373407363892


  0%|          | 0/261 [00:00<?, ?it/s]

Got 35273793/35443200 with acc 99.52
Dice score: 0.988678514957428


  0%|          | 0/261 [00:00<?, ?it/s]

Got 35243573/35443200 with acc 99.44
Dice score: 0.9867320656776428


  0%|          | 0/261 [00:00<?, ?it/s]

Got 35285870/35443200 with acc 99.56
Dice score: 0.9894535541534424


In [13]:
for data, targets in train_loader:
    print(data.shape)
    print(targets.shape)
    break

AttributeError: 'list' object has no attribute 'shape'