# Road Line Segmentation (Simplified)

This notebook trains a small U-Net model for road line segmentation with a concise training loop.

In [None]:
import random
from pathlib import Path
from typing import List, Tuple

import numpy as np
from PIL import Image

import torch
from torch import nn
import torch.nn.functional as F
from torch.optim import Adam
from torch.utils.data import DataLoader, Dataset
from torchvision.transforms import functional as TF
from torchvision.transforms import InterpolationMode

In [None]:
DATA_ROOT = Path("dataset")
MODEL_DIR = Path("Model")
MODEL_DIR.mkdir(parents=True, exist_ok=True)
MODEL_PATH = MODEL_DIR / "roadline_unet_best.pth"

IMAGE_SIZE: Tuple[int, int] = (256, 256)
BATCH_SIZE = 4
NUM_EPOCHS = 10
LEARNING_RATE = 1e-3
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")


def set_seed(seed: int = 42) -> None:
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)


def list_pairs(split: str) -> List[Tuple[Path, Path]]:
    image_dir = DATA_ROOT / split / "images"
    mask_dir = DATA_ROOT / split / "masks"
    image_files = sorted(image_dir.glob("*.*"))
    mask_files = sorted(mask_dir.glob("*.*"))
    if len(image_files) != len(mask_files):
        raise ValueError(
            f"Found {len(image_files)} images but {len(mask_files)} masks for split '{split}'."
        )
    return list(zip(image_files, mask_files))

In [None]:
class RoadLineDataset(Dataset):
    def __init__(self, pairs: List[Tuple[Path, Path]], image_size: Tuple[int, int], augment: bool = False) -> None:
        self.pairs = pairs
        self.image_size = image_size
        self.augment = augment

    def __len__(self) -> int:
        return len(self.pairs)

    def __getitem__(self, index: int):
        image_path, mask_path = self.pairs[index]
        image = Image.open(image_path).convert("RGB")
        mask = Image.open(mask_path).convert("L")

        if self.augment and random.random() < 0.5:
            image = TF.hflip(image)
            mask = TF.hflip(mask)

        image = TF.resize(image, self.image_size, interpolation=InterpolationMode.BILINEAR)
        mask = TF.resize(mask, self.image_size, interpolation=InterpolationMode.NEAREST)

        image = TF.to_tensor(image)
        image = TF.normalize(image, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])

        mask_array = (np.array(mask) > 0).astype("int64")
        mask_tensor = torch.from_numpy(mask_array)
        return image, mask_tensor

In [None]:
class DoubleConv(nn.Module):
    def __init__(self, in_channels: int, out_channels: int) -> None:
        super().__init__()
        self.block = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.block(x)


class Down(nn.Module):
    def __init__(self, in_channels: int, out_channels: int) -> None:
        super().__init__()
        self.block = nn.Sequential(
            nn.MaxPool2d(2),
            DoubleConv(in_channels, out_channels),
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.block(x)


class Up(nn.Module):
    def __init__(self, in_channels: int, skip_channels: int, out_channels: int) -> None:
        super().__init__()
        self.up = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2)
        self.conv = DoubleConv(out_channels + skip_channels, out_channels)

    def forward(self, x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor:
        x1 = self.up(x1)
        diff_y = x2.size(2) - x1.size(2)
        diff_x = x2.size(3) - x1.size(3)
        if diff_y != 0 or diff_x != 0:
            x1 = F.pad(
                x1,
                [diff_x // 2, diff_x - diff_x // 2, diff_y // 2, diff_y - diff_y // 2],
            )
        x = torch.cat([x2, x1], dim=1)
        return self.conv(x)


class OutConv(nn.Module):
    def __init__(self, in_channels: int, out_channels: int) -> None:
        super().__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.conv(x)


class UNet(nn.Module):
    def __init__(self, in_channels: int = 3, num_classes: int = 2, base_channels: int = 32) -> None:
        super().__init__()
        self.inc = DoubleConv(in_channels, base_channels)
        self.down1 = Down(base_channels, base_channels * 2)
        self.down2 = Down(base_channels * 2, base_channels * 4)
        self.down3 = Down(base_channels * 4, base_channels * 8)
        self.down4 = Down(base_channels * 8, base_channels * 16)
        self.up1 = Up(base_channels * 16, base_channels * 8, base_channels * 8)
        self.up2 = Up(base_channels * 8, base_channels * 4, base_channels * 4)
        self.up3 = Up(base_channels * 4, base_channels * 2, base_channels * 2)
        self.up4 = Up(base_channels * 2, base_channels, base_channels)
        self.outc = OutConv(base_channels, num_classes)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x1 = self.inc(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.down4(x4)
        x = self.up1(x5, x4)
        x = self.up2(x, x3)
        x = self.up3(x, x2)
        x = self.up4(x, x1)
        return self.outc(x)

In [None]:
def compute_iou(logits: torch.Tensor, masks: torch.Tensor) -> float:
    preds = torch.argmax(logits, dim=1)
    intersection = torch.logical_and(preds == 1, masks == 1).float().sum()
    union = torch.logical_or(preds == 1, masks == 1).float().sum()
    if union == 0:
        return 0.0
    return float((intersection / union).item())


def train_one_epoch(model: nn.Module, loader: DataLoader, optimizer: Adam, criterion: nn.Module) -> Tuple[float, float]:
    model.train()
    total_loss = 0.0
    total_iou = 0.0
    total_samples = 0
    for images, masks in loader:
        images = images.to(DEVICE)
        masks = masks.to(DEVICE)

        logits = model(images)
        loss = criterion(logits, masks)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        batch_size = images.size(0)
        total_loss += loss.item() * batch_size
        total_iou += compute_iou(logits.detach(), masks) * batch_size
        total_samples += batch_size

    denom = max(total_samples, 1)
    return total_loss / denom, total_iou / denom


@torch.no_grad()
def evaluate(model: nn.Module, loader: DataLoader, criterion: nn.Module) -> Tuple[float, float]:
    model.eval()
    total_loss = 0.0
    total_iou = 0.0
    total_samples = 0
    for images, masks in loader:
        images = images.to(DEVICE)
        masks = masks.to(DEVICE)

        logits = model(images)
        loss = criterion(logits, masks)

        batch_size = images.size(0)
        total_loss += loss.item() * batch_size
        total_iou += compute_iou(logits, masks) * batch_size
        total_samples += batch_size

    denom = max(total_samples, 1)
    return total_loss / denom, total_iou / denom

In [None]:
set_seed()

train_pairs = list_pairs("train")
val_pairs = list_pairs("valid")

train_loader = DataLoader(
    RoadLineDataset(train_pairs, IMAGE_SIZE, augment=True),
    batch_size=BATCH_SIZE,
    shuffle=True,
)
val_loader = DataLoader(
    RoadLineDataset(val_pairs, IMAGE_SIZE, augment=False),
    batch_size=BATCH_SIZE,
)

model = UNet().to(DEVICE)
criterion = nn.CrossEntropyLoss()
optimizer = Adam(model.parameters(), lr=LEARNING_RATE)

In [None]:
best_val_iou = 0.0
for epoch in range(1, NUM_EPOCHS + 1):
    train_loss, train_iou = train_one_epoch(model, train_loader, optimizer, criterion)
    val_loss, val_iou = evaluate(model, val_loader, criterion)
    if val_iou > best_val_iou:
        best_val_iou = val_iou
        torch.save(model.state_dict(), MODEL_PATH)
        print(f"Saved checkpoint to {MODEL_PATH}")
    print(
        f"Epoch {epoch:02d} | train_loss={train_loss:.4f} iou={train_iou:.3f} "
        f"| val_loss={val_loss:.4f} iou={val_iou:.3f}"
    )

print(f"Best validation IoU: {best_val_iou:.3f}")
print(f"Best model checkpoint: {MODEL_PATH}")
