In [None]:
import os
import json
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim

from torch.utils.data import Dataset, DataLoader, Subset, random_split
from sklearn.model_selection import train_test_split
from sklearn.metrics import precision_score, recall_score, f1_score, accuracy_score

import albumentations as A
from albumentations.pytorch import ToTensorV2

import matplotlib.pyplot as plt
from tqdm import tqdm
import scipy.ndimage as ndimage
import cv2
import supervisely as sly
import segmentation_models_pytorch as smp

# Konfigurācija pirms palaišanas
class Config:
    IMAGE_DIR = "/mnt/c/dataset_folder/ds/img/"
    MASK_DIR = "/mnt/c/dataset_folder/ds/ann/"
    MODEL_DIR = "/mnt/c/saved_models/"
    META_PATH = os.path.join(MASK_DIR, "meta.json")
    IMG_SIZE = (256, 256)
    BATCH = 10
    EPOCHS = 50
    ENCODER = 'resnet34'
    ENCODER_WEIGHTS = 'imagenet'
    DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
    SELECTED_MODEL = 'FPN'  # Opcijas: 'Unet', 'Unet++' 'FPN', 'DeepLabV3', 'PSPNet' 'Segformer'

os.makedirs(Config.MODEL_DIR, exist_ok=True)


In [None]:
# Albumentations Transformācijas
transform = A.Compose([
    A.Resize(*Config.IMG_SIZE),
    A.Normalize(),
    ToTensorV2()
])

# Supervisely Meta informācija
with open(Config.META_PATH) as f:
    meta = sly.ProjectMeta.from_json(json.load(f))

# Dataset klase
class SegmentationDataset(Dataset):
    def __init__(self, image_dir, mask_dir, meta, transform=None):
        self.image_dir = image_dir
        self.mask_dir = mask_dir
        self.meta = meta
        self.transform = transform
        self.images = sorted([f for f in os.listdir(image_dir) if f.endswith('.png')])  # <- sorted for consistent ordering

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img_name = self.images[idx]
        img_path = os.path.join(self.image_dir, img_name)
        mask_path = os.path.join(self.mask_dir, img_name + '.json')

        image = sly.image.read(img_path)
        h, w = image.shape[:2]
        mask = np.zeros((h, w), dtype=np.uint8)

        if os.path.exists(mask_path):
            with open(mask_path) as f:
                ann = sly.Annotation.from_json(json.load(f), self.meta)
            for label in ann.labels:
                try:
                    label.draw(mask, color=255)
                except Exception as e:
                    print(f"Error drawing label on {img_name}: {e}")

        if self.transform:
            augmented = self.transform(image=image, mask=mask)
            image = augmented["image"]
            mask = augmented["mask"].unsqueeze(0).float() / 255.0  # [1, H, W]

        return image, mask

full_dataset = SegmentationDataset(Config.IMAGE_DIR, Config.MASK_DIR, meta, transform=transform)

# fiksēts apmācības un pārbaudes sadalījums
generator = torch.Generator().manual_seed(42)
train_size = int(0.8 * len(full_dataset))
val_size = len(full_dataset) - train_size
train_dataset, val_dataset = random_split(full_dataset, [train_size, val_size], generator=generator)

train_loader = DataLoader(train_dataset, batch_size=Config.BATCH, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=Config.BATCH)


In [None]:
# modeļa ielāde
def get_model(name):
    """Get the model based on the selected model name."""
    model_dict = {
        'Unet': smp.Unet,
        'Unet++': smp.UnetPlusPlus,
        'FPN': smp.FPN,
        'DeepLabV3': smp.DeepLabV3,
        'PSPNet': smp.PSPNet,
        'Segformer': smp.Segformer
    }
    return model_dict[name](
        encoder_name=Config.ENCODER,
        encoder_weights=Config.ENCODER_WEIGHTS,
        in_channels=3,
        classes=1,
        activation=None
    ).to(Config.DEVICE)

model = get_model(Config.SELECTED_MODEL)

# zaudējumu funkcija
loss_fn = nn.BCEWithLogitsLoss()

# rādītāji
def compute_metrics(preds, targets):
    preds = torch.sigmoid(preds) > 0.5
    targets = targets > 0.5

    intersection = (preds & targets).float().sum((1, 2, 3))
    union = (preds | targets).float().sum((1, 2, 3))
    iou = (intersection + 1e-6) / (union + 1e-6)
    return iou.mean().item()

# funkcijas
def train_one_epoch(model, loader, loss_fn, optimizer):
    model.train()
    epoch_loss = 0.0
    epoch_iou = 0.0

    for images, masks in tqdm(loader, desc="Training", leave=False):
        images = images.to(Config.DEVICE)
        masks = masks.to(Config.DEVICE)

        preds = model(images)
        loss = loss_fn(preds, masks)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        epoch_loss += loss.item()
        epoch_iou += compute_metrics(preds, masks)

    return epoch_loss / len(loader), epoch_iou / len(loader)

@torch.no_grad()
def evaluate(model, loader, loss_fn):
    model.eval()
    val_loss = 0.0
    val_iou = 0.0

    for images, masks in tqdm(loader, desc="Validation", leave=False):
        images = images.to(Config.DEVICE)
        masks = masks.to(Config.DEVICE)

        preds = model(images)
        loss = loss_fn(preds, masks)

        val_loss += loss.item()
        val_iou += compute_metrics(preds, masks)

    return val_loss / len(loader), val_iou / len(loader)


In [None]:

# galvenā apmācība (izlaist bloku pārbaudes gadījumā)
optimizer = optim.Adam(model.parameters(), lr=1e-4)

best_val_loss = float('inf')  
for epoch in range(Config.EPOCHS):
    print(f"\nEpoch {epoch + 1}/{Config.EPOCHS}")

    train_loss, train_iou = train_one_epoch(model, train_loader, loss_fn, optimizer)
    val_loss, val_iou = evaluate(model, val_loader, loss_fn)

    print(f"Train Loss: {train_loss:.4f}, IoU: {train_iou:.4f}")
    print(f"Val   Loss: {val_loss:.4f}, IoU: {val_iou:.4f}")

    if val_loss < best_val_loss:
        best_val_loss = val_loss
        SAVED_MODEL_NAME = f"{Config.SELECTED_MODEL}_best.pt" # saglabā modeli ar tā nosaukumu
        model_path = os.path.join(Config.MODEL_DIR, SAVED_MODEL_NAME)
        torch.save(model.state_dict(), model_path)
        print(f"✅ Saved new best model at epoch {epoch + 1}")


In [None]:
# pārbaudes metriku aprēķinu klase
class TorchSegmentationEvaluator:
    def __init__(self, model, val_dataset, threshold=0.5, min_object_size=100):
        self.model = model
        self.dataset = val_dataset
        self.threshold = threshold
        self.min_object_size = min_object_size

    def _sigmoid(self, x):
        return torch.sigmoid(x).cpu().numpy()

    def _to_numpy(self, tensor):
        return tensor.squeeze().cpu().numpy()

    def _get_connected_components(self, mask):
        labeled, num_features = ndimage.label(mask)
        objects = []
        for i in range(1, num_features + 1):
            obj_mask = (labeled == i)
            if np.sum(obj_mask) >= self.min_object_size:
                objects.append(obj_mask)
        return objects

    def calculate_iou(self, y_true, y_pred):
        intersection = np.sum(y_true * y_pred)
        union = np.sum(y_true) + np.sum(y_pred) - intersection
        return intersection / (union + 1e-6)

    def evaluate(self):
        y_true_all = []
        y_pred_all = []

        tp = fp = fn = tn = 0
        false_detections = missed_detections = good_detections = 0
        total_objects = 0

        for img, mask in tqdm(self.dataset, desc="Evaluating"):
            img = img.unsqueeze(0).to(Config.DEVICE)
            pred = self.model(img)
            pred = (torch.sigmoid(pred) > self.threshold).float().squeeze().cpu().numpy()
            true = (mask.squeeze().numpy() > self.threshold).astype(np.int32)

            y_true_all.extend(true.flatten())
            y_pred_all.extend(pred.flatten())

            true_objs = self._get_connected_components(true)
            pred_objs = self._get_connected_components(pred)

            if not true_objs:
                false_detections += len(pred_objs)
            elif not pred_objs:
                missed_detections += len(true_objs)
            else:
                for true_obj in true_objs:
                    iou_max = max(self.calculate_iou(true_obj, pred_obj) for pred_obj in pred_objs)
                    if iou_max > 0.5:
                        good_detections += 1
                    else:
                        missed_detections += 1

                false_detections += max(0, len(pred_objs) - len(true_objs))

            total_objects += len(true_objs)

        return {
            'pixel_accuracy': accuracy_score(y_true_all, y_pred_all),
            'precision': precision_score(y_true_all, y_pred_all),
            'recall': recall_score(y_true_all, y_pred_all),
            'IoU': self.calculate_iou(np.array(y_true_all), np.array(y_pred_all)),
            'Dice': f1_score(y_true_all, y_pred_all),
            'object_detection_rate': good_detections / (total_objects + 1e-6),
            'false_discovery_rate': false_detections / (false_detections + good_detections + 1e-6),
            'missed_object_rate': missed_detections / (total_objects + 1e-6),
            'total_objects': total_objects,
            'false_detections': false_detections,
            'missed_detections': missed_detections,
            'good_detections': good_detections
        }

    def visualize_samples(self, n=3, show_analysis=True):
        plt.figure(figsize=(18, 5 * n))
        for i in range(n):
            img, mask = self.dataset[i]
            pred = torch.sigmoid(self.model(img.unsqueeze(0).to(Config.DEVICE))).squeeze().cpu().detach().numpy()
            pred_bin = (pred > self.threshold).astype(np.uint8)
            true_bin = (mask.squeeze().numpy() > self.threshold).astype(np.uint8)

            plt.subplot(n, 4, i * 4 + 1)
            plt.imshow(self.denormalize(img))
            plt.title("Image")
            plt.axis("off")

            plt.subplot(n, 4, i * 4 + 2)
            plt.imshow(true_bin, cmap='gray')
            plt.title("Ground Truth")
            plt.axis("off")

            plt.subplot(n, 4, i * 4 + 3)
            plt.imshow(pred_bin, cmap='gray')
            plt.title("Prediction")
            plt.axis("off")

            if show_analysis:
                error_mask = np.zeros_like(true_bin)
                error_mask[(true_bin == 1) & (pred_bin == 0)] = 1  # FN
                error_mask[(true_bin == 0) & (pred_bin == 1)] = 2  # FP
                error_mask[(true_bin == 1) & (pred_bin == 1)] = 3  # TP

                plt.subplot(n, 4, i * 4 + 4)
                plt.imshow(error_mask, cmap='jet', vmin=0, vmax=3)
                plt.title("Errors")
                plt.axis("off")

        plt.tight_layout()
        plt.show()
# (De-normalize) Tikai attēlošanai
    def denormalize(self, tensor):
        mean = np.array([0.485, 0.456, 0.406])
        std = np.array([0.229, 0.224, 0.225])
        img = tensor.permute(1, 2, 0).cpu().numpy()  # CxHxW -> HxWxC
        img = (img * std + mean)  
        img = np.clip(img, 0, 1)
        return img


In [None]:
# ielādē modeli pārbaudei un izvada rezultātus
loaded_model = get_model('FPN') # Options: 'Unet', 'Unet++' 'FPN', 'DeepLabV3', 'PSPNet' 'Segformer'
model_path = os.path.join(Config.MODEL_DIR, f"{Config.SELECTED_MODEL}_best.pt")
loaded_model.load_state_dict(torch.load(model_path))
loaded_model.eval().to(Config.DEVICE)

evaluator = TorchSegmentationEvaluator(loaded_model, val_dataset)

metrics = evaluator.evaluate()
for k, v in metrics.items():
    print(f"{k}: {v:.4f}" if isinstance(v, float) else f"{k}: {v}")

In [None]:
# vizualizē rezultātus (50)
evaluator.visualize_samples(n=50, show_analysis=True)