# U-net Paper Replication

- Original Paper: https://arxiv.org/abs/1505.04597

In [None]:
!nvidia-smi

In [None]:
import torch
import torchvision

print(f"torch version: {torch.__version__}")
print(f"torchvision version: {torchvision.__version__}")

In [None]:
import os
import sys
from pathlib import Path

sys.path.insert(0, str(Path(os.getcwd()).parent))

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"

print(f"Device: {device}")

In [None]:
BATCH_SIZE = 8
IMAGE_SIZE = (256, 512)
NUM_WORKERS = 2

SEED = 42

## 01. Data

In [None]:
import albumentations as A
from albumentations.pytorch import ToTensorV2

from src.data.dataset import DATASET_NAME
from src.data.dataloader import get_dataloaders


transform_train = A.Compose(
    [
        A.Resize(IMAGE_SIZE[0], IMAGE_SIZE[1]),
        A.HorizontalFlip(),
        A.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.2, p=0.5),
        A.ToGray(p=0.3),
        A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
        ToTensorV2(),
    ]
)

transform_test = A.Compose(
    [
        A.Resize(IMAGE_SIZE[0], IMAGE_SIZE[1]),
        A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
        ToTensorV2(),
    ]
)

train_dataloader, test_dataloader, num_classes = get_dataloaders(
    dataset=DATASET_NAME.CITYSCAPES,
    root=Path("/home/geri/work/dataset/cityscapes"),
    batch_size=BATCH_SIZE,
    num_workers=NUM_WORKERS,
    transform_train=transform_train,
    transform_test=transform_test,
)

In [None]:
import matplotlib.pyplot as plt

img, target = next(iter(test_dataloader))

fig = plt.figure(figsize=(20, 30))

fig.add_subplot(1, 2, 1)
plt.imshow(img[0].permute(1, 2, 0))
plt.axis(False)

fig.add_subplot(1, 2, 2)
plt.imshow(target[0])
plt.axis(False)

In [None]:
import numpy as np

print(f"Image Size: {img[0].shape}\tTarget Size: {target[0].shape}")
print(f"Labels in target image: {np.unique(target[0])}")

## 02. Model

In [None]:
from src.models.unet.unet import UNet
from torchinfo import summary

model = UNet(out_channels=num_classes).to(device)

summary(
    model,
    input_size=(1, 3, IMAGE_SIZE[0], IMAGE_SIZE[1]),
    verbose=0,
    col_names=["input_size", "output_size", "num_params", "trainable"],
    col_width=20,
    row_settings=["var_names"],
)

## 03. Train

In [None]:
import src.utils.loggers as loggers

root = Path(os.getcwd()).parent
writer = loggers.create_tensorboard_writer(
    path=root / "runs",
    experiment_name="Cityscapes_Segmentation",
    model_name="U-Net",
    extra="batch-8_lr-10e-4_weight-decay-10e-7_dice-loss_lr-scheduler",
)

model_saver = loggers.ModelSaver(path=root / "checkpoints", model_name="U-Net")

In [None]:
from torch import nn

from src.models.train import train
from src.utils.losses import DiceLoss

EPOCHS = 75

loss_fn = DiceLoss()
optimizer = torch.optim.Adam(params=model.parameters(), lr=10e-4, weight_decay=10e-6)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer=optimizer, T_max=EPOCHS, eta_min=10e-5)

train(
    model=model,
    train_dataloader=train_dataloader,
    test_dataloader=test_dataloader,
    loss_fn=loss_fn,
    optimizer=optimizer,
    scheduler=scheduler,
    epochs=EPOCHS,
    device=device,
    writer=writer,
    model_saver=model_saver,
    use_amp=True,
)