# U-net Paper Replication

- Original Paper: [U-Net: Convolutional Networks for Biomedical Image Segmentation](https://arxiv.org/abs/1505.04597)

In [1]:
!nvidia-smi

Mon Apr 29 14:07:38 2024       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.67                 Driver Version: 550.67         CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  NVIDIA GeForce RTX 2060        Off |   00000000:01:00.0  On |                  N/A |
| N/A   44C    P0             20W /   80W |      83MiB /   6144MiB |     32%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
                                                

In [2]:
import torch
import torchvision
from torchvision import transforms

print(f"torch version: {torch.__version__}")
print(f"torchvision version: {torchvision.__version__}")

torch version: 2.2.2+cu121
torchvision version: 0.17.2+cu121


In [None]:
import os
import sys
from pathlib import Path

sys.path.insert(0, str(Path(os.getcwd()).parent))

input_path = Path(os.getcwd()).parent / "data/JPEGImages"
target_path = Path(os.getcwd()).parent / "data/SegmentationClass"

In [3]:
device = "cuda" if torch.cuda.is_available() else "cpu"

print(f"Device: {device}")

Device: cuda


In [1]:
IMAGE_SIZE = 512
BATCH_SIZE = 16
TRAIN_SPLIT = 0.7
NUM_WORKERS = os.cpu_count()

## 01. Data

In [None]:
import random
from torch.utils.data import DataLoader

import src.data.transforms as transforms_custom
from src.data.dataset import DatasetVOC

transform = transforms.Compose(
    [
        transforms_custom.Rescale(output_size=IMAGE_SIZE),
        transforms_custom.RandomCrop(output_size=IMAGE_SIZE),
        transforms_custom.RandomHorizontalFlip(p=0.5),
        transforms_custom.ToTensor(),
    ]
)

# TODO: Test transform?

names = random.shuffle(os.listdir(path=input_path))

train_names = names[: int(len(names) * TRAIN_SPLIT)]
test_names = names[int(len(names) * TRAIN_SPLIT) :]

train_dataset = DatasetVOC(
    names=train_names,
    input_path=input_path,
    target_path=target_path,
    transform=transform,
)

test_dataset = DatasetVOC(
    names=test_names,
    input_path=input_path,
    target_path=target_path,
    transform=transform,
)

train_dataloader = DataLoader(
    dataset=train_dataset,
    batch_size=BATCH_SIZE,
    num_workers=NUM_WORKERS,
    shuffle=True,
)

## 02. Model

In [5]:
from src.models.unet.unet import UNet
from torchinfo import summary

model = UNet()

summary(
    model,
    input_size=(1, 3, IMAGE_SIZE, IMAGE_SIZE),
    verbose=0,
    col_names=["input_size", "output_size", "num_params", "trainable"],
    col_width=20,
    row_settings=["var_names"],
)

## 03. Train

In [None]:
from src.models.train import train

train(
    model=model,
)