In [1]:
"""
README: DeepLabV3-MobileNetV3 Semantic Segmentation - Training & Evaluation Pipeline

This script provides a full pipeline for training and evaluating a semantic segmentation model
(DeepLabV3 with MobileNetV3 backbone) on a custom dataset using PyTorch and Albumentations.

Key Features:
- Custom Dataset class to load image-mask pairs from specified folders.
- Data augmentations and normalization using Albumentations.
- Model: torchvision.models.segmentation.deeplabv3_mobilenet_v3_large (customizable class count).
- Weighted CrossEntropyLoss for imbalanced datasets.
- Training loop with checkpointing and persistent tracking of training loss (JSON + plot).
- Evaluation with per-class metrics (IoU, Precision, Recall, F1) and saving results as JSON and PNG plots.
- Multi-GPU support via DataParallel.

How to use:
1. Organize your data as follows:
    YOUR_DATASET_DIR/
        train/
            image1.jpg
            image1_mask.png
            ...
        valid/
        test/
2. Set the desired number of classes and class weights.
3. Configure paths and hyperparameters as needed.
4. Run the script to train and evaluate your segmentation model.

Requirements:
- torch
- torchvision
- albumentations
- opencv-python
- numpy
- tqdm
- matplotlib

Author: Bahadir Akin Akgul
Date: 13.07.2025
"""

import os
import torch
import torchvision
import numpy as np
import matplotlib.pyplot as plt
import cv2
import json
from torch.utils.data import Dataset, DataLoader
import albumentations as A
from albumentations.pytorch import ToTensorV2
from tqdm import tqdm
import gc

# Directory setup
RESULTS_DIR = "YOUR_RESULTS_DIR"
os.makedirs(RESULTS_DIR, exist_ok=True)

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {DEVICE}, GPU count: {torch.cuda.device_count()}")

DATA_DIR = "YOUR_DATASET_DIR"
TRAIN_DIR = os.path.join(DATA_DIR, "train")
VALID_DIR = os.path.join(DATA_DIR, "valid")
TEST_DIR = os.path.join(DATA_DIR, "test")

transform = A.Compose([
    A.Resize(1024, 768),
    A.HorizontalFlip(p=0.5),
    A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
    ToTensorV2()
], additional_targets={'mask': 'mask'})

class SegmentationDataset(Dataset):
    def __init__(self, img_dir, transform=None):
        self.img_dir = img_dir
        self.transform = transform
        self.images = [f for f in os.listdir(img_dir) if f.endswith(".jpg")]

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img_name = self.images[idx]
        img_path = os.path.join(self.img_dir, img_name)
        mask_path = img_path.replace(".jpg", "_mask.png")

        image = cv2.imread(img_path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)

        if mask is None or image.shape[:2] != mask.shape:
            print(f"Warning: {mask_path} could not be loaded! Image: {image.shape}, Mask: {mask.shape if mask is not None else 'None'}")
            mask = np.zeros((image.shape[0], image.shape[1]), dtype=np.uint8)
        else:
            mask = cv2.resize(mask, (image.shape[1], image.shape[0]), interpolation=cv2.INTER_NEAREST)

        if self.transform:
            transformed = self.transform(image=image, mask=mask)
            image = transformed["image"]
            mask = transformed["mask"].long()

        return image, mask

train_dataset = SegmentationDataset(TRAIN_DIR, transform=transform)
valid_dataset = SegmentationDataset(VALID_DIR, transform=transform)
test_dataset = SegmentationDataset(TEST_DIR, transform=transform)

train_loader = DataLoader(train_dataset, batch_size=24, shuffle=True, num_workers=8, drop_last=True, pin_memory=True)
valid_loader = DataLoader(valid_dataset, batch_size=24, shuffle=False, num_workers=8, drop_last=False, pin_memory=True)
test_loader = DataLoader(test_dataset, batch_size=24, shuffle=False, num_workers=8, drop_last=False, pin_memory=True)

torch.cuda.empty_cache()
gc.collect()

NUM_CLASSES = 4
model = torchvision.models.segmentation.deeplabv3_mobilenet_v3_large(weights=None)
model.classifier[4] = torch.nn.Conv2d(256, NUM_CLASSES, kernel_size=1)
model = model.to(DEVICE)

if torch.cuda.device_count() > 1:
    print(f"Parallelizing model on {torch.cuda.device_count()} GPUs...")
    model = torch.nn.DataParallel(model)

class_weights = torch.tensor([0.2, 1.0, 1.0, 1.0]).to(DEVICE)
criterion = torch.nn.CrossEntropyLoss(weight=class_weights)
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)

def load_checkpoint(model, optimizer, checkpoint_path=os.path.join(RESULTS_DIR, "checkpoint.pth")):
    if os.path.exists(checkpoint_path):
        checkpoint = torch.load(checkpoint_path, map_location=DEVICE)
        model_to_load = model.module if isinstance(model, torch.nn.DataParallel) else model
        model_to_load.load_state_dict(checkpoint["model_state_dict"])
        optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
        start_epoch = checkpoint["epoch"]
        print(f"Checkpoint loaded. Continuing from epoch {start_epoch}.")
        return model, optimizer, start_epoch
    else:
        print("No checkpoint found. Starting training from scratch.")
        return model, optimizer, 0

def train_model(model, train_loader, valid_loader, optimizer, criterion, start_epoch=0, epochs=100, checkpoint_path=os.path.join(RESULTS_DIR, "checkpoint.pth")):
    # Load train_losses if available
    train_losses = []
    loss_file = os.path.join(RESULTS_DIR, "train_losses.json")
    if os.path.exists(loss_file):
        with open(loss_file, "r") as f:
            train_losses = json.load(f)

    for epoch in range(start_epoch, epochs):
        model.train()
        running_loss = 0.0
        for images, masks in tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs}"):
            images, masks = images.to(DEVICE), masks.to(DEVICE)
            optimizer.zero_grad()
            outputs = model(images)["out"]
            loss = criterion(outputs, masks)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()

        avg_loss = running_loss / len(train_loader)
        train_losses.append(avg_loss)
        print(f"Epoch [{epoch+1}/{epochs}], Train Loss: {avg_loss:.4f}")

        # Save checkpoint
        checkpoint = {
            "epoch": epoch + 1,
            "model_state_dict": model.module.state_dict() if isinstance(model, torch.nn.DataParallel) else model.state_dict(),
            "optimizer_state_dict": optimizer.state_dict(),
        }
        torch.save(checkpoint, checkpoint_path)
        print(f"Checkpoint saved: {checkpoint_path}")

        # Save train_losses
        with open(loss_file, "w") as f:
            json.dump(train_losses, f)

    # Plot training loss
    plt.figure()
    plt.plot(range(1, 1+len(train_losses)), train_losses, marker='o')
    plt.title("Train Loss")
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.grid()
    plt.savefig(os.path.join(RESULTS_DIR, "train_loss.png"))
    plt.close()

def evaluate_model(model, dataloader, num_classes=NUM_CLASSES):
    model.eval()
    iou_per_class = [0.0] * num_classes
    precision_per_class = [0.0] * num_classes
    recall_per_class = [0.0] * num_classes
    f1_per_class = [0.0] * num_classes
    total_per_class = [0] * num_classes

    with torch.no_grad():
        for images, masks in tqdm(dataloader, desc="Evaluating"):
            images, masks = images.to(DEVICE), masks.to(DEVICE)
            outputs = model(images)["out"]
            preds = torch.argmax(outputs, dim=1)

            for cls in range(num_classes):
                pred_cls = (preds == cls).int()
                true_cls = (masks == cls).int()

                TP = (pred_cls & true_cls).sum().item()
                FP = (pred_cls & (1 - true_cls)).sum().item()
                FN = ((1 - pred_cls) & true_cls).sum().item()

                if TP + FP + FN == 0:
                    continue  # This class does not appear in this batch

                precision = TP / (TP + FP + 1e-8)
                recall = TP / (TP + FN + 1e-8)
                f1 = 2 * precision * recall / (precision + recall + 1e-8)
                iou = TP / (TP + FP + FN + 1e-8)

                precision_per_class[cls] += precision
                recall_per_class[cls] += recall
                f1_per_class[cls] += f1
                iou_per_class[cls] += iou
                total_per_class[cls] += 1

    results = {
        "precision": [p / max(n, 1) if n > 0 else None for p, n in zip(precision_per_class, total_per_class)],
        "recall": [r / max(n, 1) if n > 0 else None for r, n in zip(recall_per_class, total_per_class)],
        "f1": [f / max(n, 1) if n > 0 else None for f, n in zip(f1_per_class, total_per_class)],
        "iou": [i / max(n, 1) if n > 0 else None for i, n in zip(iou_per_class, total_per_class)],
    }

    with open(os.path.join(RESULTS_DIR, "metrics_epoch.json"), "w") as f:
        json.dump(results, f, indent=2)

    # Plot metrics for classes present in the set
    classes = [i for i, n in enumerate(total_per_class) if n > 0]
    for metric_name in ["precision", "recall", "f1", "iou"]:
        plt.figure()
        values = [results[metric_name][i] for i in classes]
        plt.bar([str(c) for c in classes], values)
        plt.title(f"{metric_name.upper()} per Class")
        plt.xlabel("Class")
        plt.ylabel(metric_name.upper())
        plt.ylim(0, 1.05)
        plt.grid(True)
        plt.savefig(os.path.join(RESULTS_DIR, f"{metric_name}_per_class.png"))
        plt.close()

    print("Evaluation results saved.")

model, optimizer, start_epoch = load_checkpoint(model, optimizer)
train_model(model, train_loader, valid_loader, optimizer, criterion, start_epoch=start_epoch, epochs=100)
evaluate_model(model, valid_loader)

model_to_save = model.module if isinstance(model, torch.nn.DataParallel) else model
torch.save(model_to_save.state_dict(), os.path.join(RESULTS_DIR, "trained_model.pth"))
print("Model saved successfully!")


Using device: cuda, GPU count: 2
2 GPU ile model paralelleştiriliyor...
Checkpoint yüklendi. Eğitim 100. epoch'tan devam ediyor...


Evaluating: 100%|██████████| 56/56 [00:23<00:00,  2.36it/s]


Değerlendirme sonuçları kaydedildi.
Model başarıyla kaydedildi!
