In [3]:
"""
README: Segmentation Dataset Split, Symlink, and Training Utilities

This script provides utility functions for:
- Splitting a dataset into train/validation/test splits (with symlinks for images/masks),
- Loading segmentation datasets with PyTorch,
- Training a segmentation model and saving the loss history and checkpoints,
- Plotting training loss.

All code is fully modular, robust to data issues, and includes informative error handling.

Requirements:
- torch
- torchvision
- albumentations
- opencv-python
- numpy
- matplotlib
- tqdm

Author: Bahadir Akin Akgul
Date: 13.07.2025
"""

import os
import random
import torch
import json
import gc
import cv2
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
from torch.utils.data import Dataset, DataLoader
import albumentations as A
from albumentations.pytorch import ToTensorV2
from tqdm import tqdm
import torchvision

# 🔧 Dataset symlink creation with train/valid/test split
def create_split_with_test_symlink_dataset(original_dataset_dir, new_dataset_dir, ratio=(0.7, 0.3), seed=42):
    assert sum(ratio) == 1.0, "Total split ratio must be 1.0"
    random.seed(seed)

    all_jpgs = list(Path(original_dataset_dir, "train").glob("*.jpg")) + \
               list(Path(original_dataset_dir, "valid").glob("*.jpg"))
    random.shuffle(all_jpgs)

    total = len(all_jpgs)
    train_count = int(total * ratio[0])
    rest_count = total - train_count
    valid_count = rest_count // 2
    test_count = rest_count - valid_count

    train_imgs = all_jpgs[:train_count]
    valid_imgs = all_jpgs[train_count:train_count + valid_count]
    test_imgs = all_jpgs[train_count + valid_count:]

    for split in ["train", "valid", "test"]:
        split_dir = Path(new_dataset_dir) / split
        split_dir.mkdir(parents=True, exist_ok=True)
        for f in split_dir.glob("*"):
            f.unlink()

    def create_symlinks(img_list, split_name):
        for img_path in img_list:
            mask_path = img_path.with_name(img_path.stem + "_mask.png")
            try:
                os.symlink(img_path.resolve(), Path(new_dataset_dir) / split_name / img_path.name)
                os.symlink(mask_path.resolve(), Path(new_dataset_dir) / split_name / mask_path.name)
            except Exception as e:
                print(f"[SYMLINK ERROR] {e}")

    create_symlinks(train_imgs, "train")
    create_symlinks(valid_imgs, "valid")
    create_symlinks(test_imgs, "test")

    print(f"[✓] {new_dataset_dir} → Train: {train_count}, Valid: {valid_count}, Test: {test_count}")

# 📦 Robust Segmentation Dataset class
class SegmentationDataset(Dataset):
    def __init__(self, img_dir, transform=None):
        self.img_dir = img_dir
        self.transform = transform
        self.images = [f for f in os.listdir(img_dir) if f.endswith(".jpg")]

    def __len__(self): return len(self.images)

    def __getitem__(self, idx):
        try:
            img_name = self.images[idx]
            img_path = os.path.join(self.img_dir, img_name)
            mask_path = img_path.replace(".jpg", "_mask.png")

            image = cv2.imread(img_path)
            if image is None:
                raise FileNotFoundError(f"[IMG ERROR] Could not read file: {img_path}")
            image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

            mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
            if mask is None or image.shape[:2] != mask.shape:
                print(f"[WARN] Mask error: {mask_path}")
                mask = np.zeros((image.shape[0], image.shape[1]), dtype=np.uint8)

            if self.transform:
                transformed = self.transform(image=image, mask=mask)
                return transformed["image"], transformed["mask"].long()

            return image, torch.tensor(mask).long()
        
        except Exception as e:
            print(f"[DATA ERROR] {e}")
            return self.__getitem__((idx + 1) % len(self))  # move to the next item

# 🎯 Training function with loss recording and checkpointing
def train_model(model, train_loader, optimizer, criterion, start_epoch=0, epochs=100, checkpoint_path="checkpoint.pth", loss_log_path="train_losses.json"):
    os.makedirs(os.path.dirname(loss_log_path), exist_ok=True)
    all_losses = []
    if os.path.exists(loss_log_path):
        with open(loss_log_path, "r") as f:
            all_losses = json.load(f)

    for epoch in range(start_epoch, epochs):
        model.train()
        running_loss = 0.0
        for images, masks in tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs}"):
            if images.size(0) < 2:
                continue
            images, masks = images.to(DEVICE), masks.to(DEVICE)
            optimizer.zero_grad()
            outputs = model(images)["out"]
            loss = criterion(outputs, masks)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()

        avg_loss = running_loss / len(train_loader)
        all_losses.append(avg_loss)
        print(f"[{epoch+1}/{epochs}] Train Loss: {avg_loss:.4f}")

        with open(loss_log_path, "w") as f:
            json.dump(all_losses, f)

        checkpoint = {
            "epoch": epoch + 1,
            "model_state_dict": model.module.state_dict() if isinstance(model, torch.nn.DataParallel) else model.state_dict(),
            "optimizer_state_dict": optimizer.state_dict(),
        }
        torch.save(checkpoint, checkpoint_path)

# 📈 Loss plotting function
def plot_losses(loss_log_path="train_losses.json"):
    if os.path.exists(loss_log_path):
        with open(loss_log_path, "r") as f:
            losses = json.load(f)
        plt.figure(figsize=(10,5))
        plt.plot(range(1, len(losses)+1), losses, marker='o')
        plt.title("Training Loss per Epoch")
        plt.xlabel("Epoch")
        plt.ylabel("Loss")
        plt.grid(True)
        plt.tight_layout()
        plt.show()
    else:
        print("Loss file not found!")

# 🚀 Main
if __name__ == "__main__":
    # 📁 Create new split dataset with symlinks
    base_dataset = "YOUR_ORIGINAL_DATASET_DIR/road-tr-od-ss"
    new_dataset = "YOUR_NEW_DATASET_DIR/road-tr-od-ss-95-5"
    create_split_with_test_symlink_dataset(base_dataset, new_dataset, ratio=(0.95, 0.05))


[✓] yolo-seg-11042025/road-tr-od-ss-95-5 → Train: 7235, Valid: 190, Test: 191
