In [1]:
%%writefile train.py
# train_METRICS_FIXED.py - FIX CRÍTICO DELLE METRICHE

import os
import warnings
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, random_split
import segmentation_models_pytorch as smp
from tqdm import tqdm

warnings.filterwarnings('ignore', category=DeprecationWarning)
warnings.filterwarnings('ignore', message='.*UnsupportedFieldAttributeWarning.*')

from dataset import LaneSegmentationDataset
from augmentation import (
    get_training_augmentation,
    get_validation_augmentation,
)
from metrics import LaneMetrics

# ==================== CONFIGURAZIONE KAGGLE ====================
class Config:
    DATA_DIR = '/kaggle/input/tusimple-preprocessed/tusimple_preprocessed'
    
    IMAGES_DIR = os.path.join(DATA_DIR, 'training/frames')
    MASKS_DIR = os.path.join(DATA_DIR, 'training/lane-masks')
    
    TEST_IMAGES_DIR = os.path.join(DATA_DIR, 'test/frames')
    TEST_MASKS_DIR = os.path.join(DATA_DIR, 'test/lane-masks')
    
    TRAIN_RATIO = 0.8
    VAL_RATIO = 0.2
    
    ENCODER = 'resnet50'
    ENCODER_WEIGHTS = 'imagenet'
    CLASSES = 1
    ACTIVATION = 'sigmoid'
    
    DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
    EPOCHS = 50
    BATCH_SIZE = 16
    
    LEARNING_RATE = 1e-3
    WEIGHT_DECAY = 1e-5
    DROPOUT = 0.1
    
    NUM_WORKERS = 4
    
    MODEL_DIR = '/kaggle/working/models'
    BEST_MODEL_PATH = os.path.join(MODEL_DIR, 'best_unet_fixed.pth')
    LAST_MODEL_PATH = os.path.join(MODEL_DIR, 'last_unet_fixed.pth')

def create_model(config):
    model = smp.Unet(
        encoder_name=config.ENCODER,
        encoder_weights=config.ENCODER_WEIGHTS,
        classes=config.CLASSES,
        activation=config.ACTIVATION,
        decoder_dropout=config.DROPOUT,
    )
    return model

def create_dataloaders_with_split(config):
    print("📂 Caricamento dataset completo (training + test)...")
    
    full_dataset = LaneSegmentationDataset(
        images_dir=config.IMAGES_DIR,
        masks_dir=config.MASKS_DIR,
        test_images_dir=config.TEST_IMAGES_DIR,
        test_masks_dir=config.TEST_MASKS_DIR,
        transform=None,
        preprocessing=None,
    )
    
    total_size = len(full_dataset)
    print(f"✅ Dataset caricato: {total_size} immagini totali")
    
    train_size = int(config.TRAIN_RATIO * total_size)
    val_size = total_size - train_size
    
    print(f"\n📊 Split ratio: {config.TRAIN_RATIO*100:.1f}% training / {config.VAL_RATIO*100:.1f}% validation")
    print(f"   Train samples: {train_size}")
    print(f"   Val samples: {val_size}")
    
    train_dataset, val_dataset = random_split(
        full_dataset,
        [train_size, val_size],
        generator=torch.Generator().manual_seed(42)
    )
    
    train_dataset_augmented = TrainAugmentedDataset(
        dataset=train_dataset,
        transform=get_training_augmentation(),
    )
    
    val_dataset_augmented = ValAugmentedDataset(
        dataset=val_dataset,
        transform=get_validation_augmentation(),
    )
    
    train_loader = DataLoader(
        train_dataset_augmented,
        batch_size=config.BATCH_SIZE,
        shuffle=True,
        num_workers=config.NUM_WORKERS,
        pin_memory=True,
        prefetch_factor=2,
    )
    
    val_loader = DataLoader(
        val_dataset_augmented,
        batch_size=config.BATCH_SIZE,
        shuffle=False,
        num_workers=config.NUM_WORKERS,
        pin_memory=True,
        prefetch_factor=2,
    )
    
    print(f"✅ DataLoaders creati!")
    return train_loader, val_loader


class TrainAugmentedDataset(torch.utils.data.Dataset):
    def __init__(self, dataset, transform=None):
        self.dataset = dataset
        self.transform = transform
    
    def __len__(self):
        return len(self.dataset)
    
    def __getitem__(self, idx):
        image, mask = self.dataset[idx]
        if self.transform:
            sample = self.transform(image=image, mask=mask)
            image = sample['image']
            mask = sample['mask']
        return image, mask


class ValAugmentedDataset(torch.utils.data.Dataset):
    def __init__(self, dataset, transform=None):
        self.dataset = dataset
        self.transform = transform
    
    def __len__(self):
        return len(self.dataset)
    
    def __getitem__(self, idx):
        image, mask = self.dataset[idx]
        if self.transform:
            sample = self.transform(image=image, mask=mask)
            image = sample['image']
            mask = sample['mask']
        return image, mask


def train_epoch(model, loader, optimizer, loss_fn, device):
    model.train()
    total_loss = 0
    
    loop = tqdm(loader, desc='Training')
    for images, masks in loop:
        images = images.to(device).float()
        masks = masks.unsqueeze(1).to(device).float()
        
        predictions = model(images)
        loss = loss_fn(predictions, masks)
        
        optimizer.zero_grad()
        loss.backward()
        
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        
        optimizer.step()
        
        total_loss += loss.item()
        loop.set_postfix(loss=loss.item())
    
    return total_loss / len(loader)


def validate_epoch(model, loader, loss_fn, device):
    """
    ✅ VERSIONE CORRETTA - Metriche calcolate per OGNI IMMAGINE
    """
    model.eval()
    
    total_loss = 0
    all_metrics = {
        'iou': [],
        'dice': [],
        'sensitivity': [],
        'specificity': [],
        'f1': [],
        'mcc': [],
        'accuracy': []
    }
    
    with torch.no_grad():
        loop = tqdm(loader, desc='Validation')
        for images, masks in loop:
            images = images.to(device).float()
            masks = masks.unsqueeze(1).to(device).float()
            
            predictions = model(images)
            loss = loss_fn(predictions, masks)
            
            # ✅ CRITICO: Conversione a probabilità e binarizzazione
            pred_prob = torch.sigmoid(predictions)
            pred_binary = (pred_prob > 0.5).float()
            
            total_loss += loss.item()
            
            # ✅ CORRETTO: Per OGNI immagine nel batch
            for batch_idx in range(pred_prob.shape[0]):
                # Estrai singola immagine e maschera dal batch
                pred_single = pred_prob[batch_idx].squeeze()      # [H, W]
                pred_bin_single = pred_binary[batch_idx].squeeze() # [H, W]
                mask_single = masks[batch_idx].squeeze()           # [H, W]
                
                # ✅ Calcola IoU CORRETTAMENTE
                intersection = (pred_bin_single * mask_single).sum()
                union = pred_bin_single.sum() + mask_single.sum() - intersection
                iou = (intersection + 1e-6) / (union + 1e-6)
                all_metrics['iou'].append(iou.item())
                
                # ✅ Calcola altre metriche
                all_metrics['dice'].append(LaneMetrics.dice_coefficient(pred_single, mask_single))
                all_metrics['sensitivity'].append(LaneMetrics.sensitivity(pred_single, mask_single))
                all_metrics['specificity'].append(LaneMetrics.specificity(pred_single, mask_single))
                all_metrics['f1'].append(LaneMetrics.f1_score(pred_single, mask_single))
                all_metrics['mcc'].append(LaneMetrics.mcc(pred_single, mask_single))
                all_metrics['accuracy'].append(LaneMetrics.pixel_accuracy(pred_single, mask_single))
            
            # ✅ Mostra metriche batch corrente
            batch_sens = sum(all_metrics['sensitivity'][-len(masks):]) / len(masks)
            batch_spec = sum(all_metrics['specificity'][-len(masks):]) / len(masks)
            loop.set_postfix(
                loss=loss.item(),
                sens=f"{batch_sens:.3f}",
                spec=f"{batch_spec:.3f}"
            )
    
    # ✅ Media finale di TUTTE le metriche
    avg_loss = total_loss / len(loader)
    
    metrics_final = {}
    for k, v in all_metrics.items():
        metrics_final[k] = sum(v) / len(v) if v else 0.0
    
    return avg_loss, metrics_final


class BinaryCrossEntropyWithPosWeight(nn.Module):
    """✅ BCE Loss con class weight dinamico"""
    
    def __init__(self, pos_weight=10.0):
        super().__init__()
        self.pos_weight = pos_weight
    
    def forward(self, predictions, targets):
        loss = nn.functional.binary_cross_entropy_with_logits(
            predictions,
            targets,
            pos_weight=torch.tensor(self.pos_weight)
        )
        return loss


class WeightedTverskyLoss(nn.Module):
    """✅ Tversky Loss AGGRESSIVA per class imbalance"""
    
    def __init__(self, alpha=0.1, beta=0.9, smooth=1.0):
        super().__init__()
        self.alpha = alpha
        self.beta = beta
        self.smooth = smooth
    
    def forward(self, predictions, targets):
        probs = torch.sigmoid(predictions)
        
        TP = (probs * targets).sum()
        FP = (probs * (1 - targets)).sum()
        FN = ((1 - probs) * targets).sum()
        
        tversky_index = (TP + self.smooth) / (
            TP + self.alpha * FP + self.beta * FN + self.smooth
        )
        
        return 1 - tversky_index


class CombinedLossAggressivo(nn.Module):
    """✅ Combinazione aggressiva per class imbalance"""
    
    def __init__(self, bce_weight=0.3, tversky_weight=0.7):
        super().__init__()
        self.bce_weight = bce_weight
        self.tversky_weight = tversky_weight
        
        self.bce_loss = BinaryCrossEntropyWithPosWeight(pos_weight=10.0)
        self.tversky_loss = WeightedTverskyLoss(alpha=0.1, beta=0.9)
    
    def forward(self, predictions, targets):
        bce = self.bce_loss(predictions, targets)
        tversky = self.tversky_loss(predictions, targets)
        
        return self.bce_weight * bce + self.tversky_weight * tversky


def main():
    config = Config()
    
    if not os.path.exists(config.IMAGES_DIR):
        raise FileNotFoundError(f"❌ Cartella non trovata: {config.IMAGES_DIR}")
    if not os.path.exists(config.MASKS_DIR):
        raise FileNotFoundError(f"❌ Cartella non trovata: {config.MASKS_DIR}")
    
    os.makedirs(config.MODEL_DIR, exist_ok=True)
    
    print("=" * 70)
    print("🚀 TRAINING U-NET - METRICHE CORRETTE")
    print("🎯 Loss: BCE (pos_weight=10) + Weighted Tversky (alpha=0.1)")
    print("=" * 70)
    print(f"Encoder: {config.ENCODER}")
    print(f"Device: {config.DEVICE}")
    print(f"Learning rate: {config.LEARNING_RATE}")
    print(f"Epochs: {config.EPOCHS}")
    print("=" * 70)
    
    model = create_model(config)
    model.to(config.DEVICE)
    
    train_loader, val_loader = create_dataloaders_with_split(config)
    
    loss_fn = CombinedLossAggressivo(bce_weight=0.3, tversky_weight=0.7)
    
    optimizer = torch.optim.Adam(
        model.parameters(),
        lr=config.LEARNING_RATE,
        weight_decay=config.WEIGHT_DECAY
    )
    
    scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
        optimizer,
        T_0=5,
        T_mult=2,
        eta_min=1e-6
    )
    
    best_specificity = 0.0
    best_metrics = {}
    no_improve_count = 0
    patience = 15
    
    print("\n🏋️ Inizio training...\n")
    
    for epoch in range(config.EPOCHS):
        print(f"\n{'='*70}")
        print(f"Epoch {epoch + 1}/{config.EPOCHS}")
        print(f"{'='*70}")
        
        train_loss = train_epoch(model, train_loader, optimizer, loss_fn, config.DEVICE)
        val_loss, metrics = validate_epoch(model, val_loader, loss_fn, config.DEVICE)
        
        current_lr = optimizer.param_groups[0]['lr']
        scheduler.step()
        
        print(f"\n📈 RISULTATI VALIDAZIONE:")
        print(f"   Train Loss:     {train_loss:.4f}")
        print(f"   Val Loss:       {val_loss:.4f}")
        print(f"   ───────────────────────")
        print(f"   Sensitivity:    {metrics['sensitivity']:.4f} (corsie trovate)")
        print(f"   🎯 Specificity: {metrics['specificity']:.4f} ← METRICA PRINCIPALE")
        print(f"   Dice:           {metrics['dice']:.4f}")
        print(f"   F1:             {metrics['f1']:.4f}")
        print(f"   IoU:            {metrics['iou']:.4f}")
        print(f"   LR:             {current_lr:.6f}")
        
        torch.save(model.state_dict(), config.LAST_MODEL_PATH)
        
        if metrics['specificity'] > best_specificity:
            best_specificity = metrics['specificity']
            best_metrics = metrics
            no_improve_count = 0
            torch.save(model.state_dict(), config.BEST_MODEL_PATH)
            print(f"\n   ✅ NUOVO MIGLIOR MODELLO!")
            print(f"      Specificity: {best_specificity:.4f}")
            print(f"      Sensitivity: {best_metrics['sensitivity']:.4f}")
            print(f"      F1: {best_metrics['f1']:.4f}")
        else:
            no_improve_count += 1
        
        if no_improve_count >= patience:
            print(f"\n⚠️ Early stopping: nessun miglioramento per {patience} epoch")
            break
    
    print("\n" + "=" * 70)
    print("✅ TRAINING COMPLETATO!")
    print("=" * 70)
    print(f"Specificity:  {best_specificity:.4f} (pochi FP) 🎯")
    print(f"Sensitivity:  {best_metrics['sensitivity']:.4f} (corsie trovate)")
    print(f"F1:           {best_metrics['f1']:.4f}")
    print(f"Dice:         {best_metrics['dice']:.4f}")
    print(f"IoU:          {best_metrics['iou']:.4f}")
    print(f"MCC:          {best_metrics['mcc']:.4f}")
    print(f"Modello:      {config.BEST_MODEL_PATH}")
    print("=" * 70)


if __name__ == '__main__':
    main()


Writing train.py


In [2]:
%%writefile dataset.py
# dataset_COMBINED.py - COMBINA TRAINING + TEST SET

import os
import cv2
import numpy as np
from torch.utils.data import Dataset

class LaneSegmentationDataset(Dataset):
    """
    Dataset personalizzato che carica:
    - ✅ Training set (frames + lane-masks)
    - ✅ Test set (frames + lane-masks)
    
    Combina entrambi per avere più dati per il training.
    Il split 80/20 interno creerà nuovi train/val set.
    """
    
    def __init__(self, images_dir, masks_dir, test_images_dir=None, test_masks_dir=None, transform=None, preprocessing=None):
        self.images_dir = images_dir
        self.masks_dir = masks_dir
        self.transform = transform
        self.preprocessing = preprocessing
        
        # ✅ Carica TRAINING SET
        self.images_fps = sorted([
            os.path.join(images_dir, image_id)
            for image_id in os.listdir(images_dir)
            if image_id.endswith(('.jpg', '.png', '.jpeg'))
        ])
        
        self.masks_fps = sorted([
            os.path.join(masks_dir, image_id)
            for image_id in os.listdir(masks_dir)
            if image_id.endswith(('.jpg', '.png', '.jpeg'))
        ])
        
        # ✅ Carica TEST SET (se fornito)
        if test_images_dir and test_masks_dir:
            print(f"\n📂 Caricamento TEST SET da: {test_images_dir}")
            
            test_images = sorted([
                os.path.join(test_images_dir, image_id)
                for image_id in os.listdir(test_images_dir)
                if image_id.endswith(('.jpg', '.png', '.jpeg'))
            ])
            
            test_masks = sorted([
                os.path.join(test_masks_dir, image_id)
                for image_id in os.listdir(test_masks_dir)
                if image_id.endswith(('.jpg', '.png', '.jpeg'))
            ])
            
            print(f"✅ Test set immagini trovate: {len(test_images)}")
            print(f"✅ Test set maschere trovate: {len(test_masks)}")
            
            # Combina training + test set
            self.images_fps.extend(test_images)
            self.masks_fps.extend(test_masks)
            
            print(f"\n✅ TOTALE IMMAGINI (Training + Test): {len(self.images_fps)}")
        else:
            print(f"⚠️ Test set NON fornito - usa solo training set")
        
        # Verifica numero immagini == numero maschere
        assert len(self.images_fps) == len(self.masks_fps), \
            f"Numero immagini ({len(self.images_fps)}) != numero maschere ({len(self.masks_fps)})"
    
    def __len__(self):
        return len(self.images_fps)
    
    def __getitem__(self, idx):
        # Leggi immagine RGB
        image = cv2.imread(self.images_fps[idx])
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        
        # Leggi mask (grayscale)
        mask = cv2.imread(self.masks_fps[idx], cv2.IMREAD_GRAYSCALE)
        
        # ✅ Binarizza ma NON dividere per 255!
        mask = (mask > 127).astype(np.float32)
        
        # Applica augmentation (se presente)
        if self.transform:
            sample = self.transform(image=image, mask=mask)
            image = sample['image']
            mask = sample['mask']
        
        # Applica preprocessing dell'encoder (se presente)
        if self.preprocessing:
            sample = self.preprocessing(image=image, mask=mask)
            image = sample['image']
            mask = sample['mask']
        
        return image, mask


Writing dataset.py


In [7]:
%%writefile augmentation.py
# augmentation.py - VERSIONE CON NORMALIZZAZIONE ImageNet

import albumentations as A
from albumentations.pytorch import ToTensorV2
import cv2

def get_training_augmentation():
    """
    Data augmentation per training set.
    ✅ INCLUDE: Normalizzazione ImageNet + ToTensorV2
    """
    train_transform = [
        A.Resize(256, 256),
        
        A.HorizontalFlip(p=0.5),
        A.Affine(
            scale=(0.9, 1.1),
            translate_percent=(0.0625, 0.0625),
            rotate=(-10, 10),
            p=0.5
        ),
        A.OneOf([
            A.RandomBrightnessContrast(
                brightness_limit=0.2,
                contrast_limit=0.2,
                p=1.0
            ),
            A.HueSaturationValue(
                hue_shift_limit=10,
                sat_shift_limit=20,
                val_shift_limit=20,
                p=1.0
            ),
        ], p=0.5),
        A.GaussianBlur(blur_limit=3, p=0.2),
        
        # ✅ CRITICO: Normalizzazione ImageNet
        # Mean e Std di ImageNet per RGB
        A.Normalize(
            mean=[0.485, 0.456, 0.406],
            std=[0.229, 0.224, 0.225],
            max_pixel_value=255.0,
        ),
        
        # ✅ Conversione a PyTorch format (CxHxW)
        ToTensorV2(),
    ]
    
    return A.Compose(train_transform)


def get_validation_augmentation():
    """
    Augmentation per validation set.
    ✅ INCLUDE: Normalizzazione ImageNet + ToTensorV2
    """
    val_transform = [
        A.Resize(256, 256),
        
        # ✅ CRITICO: Normalizzazione ImageNet
        A.Normalize(
            mean=[0.485, 0.456, 0.406],
            std=[0.229, 0.224, 0.225],
            max_pixel_value=255.0,
        ),
        
        # ✅ Conversione a PyTorch format (CxHxW)
        ToTensorV2(),
    ]
    return A.Compose(val_transform)


def get_preprocessing(preprocessing_fn):
    """
    Preprocessing specifico dell'encoder.
    ⚠️ NON usare questo - lo facciamo direttamente con Normalize
    """
    _transform = [
        A.Normalize(
            mean=[0.485, 0.456, 0.406],
            std=[0.229, 0.224, 0.225],
            max_pixel_value=255.0,
        ),
        ToTensorV2(),
    ]
    return A.Compose(_transform)


def get_heavy_augmentation():
    """
    Augmentation aggressiva per dataset piccoli.
    ✅ INCLUDE: Normalizzazione ImageNet + ToTensorV2
    """
    heavy_transform = [
        A.Resize(256, 256),
        A.HorizontalFlip(p=0.5),
        A.VerticalFlip(p=0.2),
        A.Affine(
            scale=(0.8, 1.2),
            translate_percent=(0.1, 0.1),
            rotate=(-15, 15),
            shear=(-5, 5),
            p=0.7
        ),
        A.OneOf([
            A.RandomBrightnessContrast(brightness_limit=0.3, contrast_limit=0.3, p=1.0),
            A.HueSaturationValue(hue_shift_limit=20, sat_shift_limit=30, val_shift_limit=30, p=1.0),
            A.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1, p=1.0),
        ], p=0.7),
        A.OneOf([
            A.GaussianBlur(blur_limit=5, p=1.0),
            A.MotionBlur(blur_limit=5, p=1.0),
        ], p=0.3),
        
        # ✅ CRITICO: Normalizzazione ImageNet
        A.Normalize(
            mean=[0.485, 0.456, 0.406],
            std=[0.229, 0.224, 0.225],
            max_pixel_value=255.0,
        ),
        
        # ✅ Conversione a PyTorch format (CxHxW)
        ToTensorV2(),
    ]
    
    return A.Compose(heavy_transform)

Overwriting augmentation.py


In [4]:
%%writefile metrics.py
# metrics_CORRECTED.py - METRICHE CORRETTE PER LANE DETECTION

import torch
import numpy as np

class LaneMetrics:
    """Metriche specializzate per lane detection - VERSIONE CORRETTA"""
    
    # ==================== 1️⃣ DICE COEFFICIENT ====================
    @staticmethod
    def dice_coefficient(pred, target, smooth=1.0):
        """
        ✅ CORRETTO
        Formula: Dice = (2 * TP) / (2 * TP + FP + FN)
        """
        # Converti a binario (pred potrebbe già essere binario o probabilità)
        if pred.max() > 1.0:
            pred_binary = (pred > 0.5).float()
        else:
            pred_binary = pred
        
        target = target.float()
        
        intersection = (pred_binary * target).sum()
        dice = (2.0 * intersection + smooth) / (pred_binary.sum() + target.sum() + smooth)
        return dice.item()
    
    
    # ==================== 2️⃣ SENSITIVITY (RECALL) ====================
    @staticmethod
    def sensitivity(pred, target, smooth=1e-6):
        """
        ✅ CORRETTO
        Formula: Sensitivity = TP / (TP + FN)
        
        ⚠️ BUG ORIGINALE: Non convertiva pred a binario!
        """
        # ✅ FIX: Converti a binario
        if pred.max() > 1.0:
            pred_binary = (pred > 0.5).float()
        else:
            pred_binary = pred
        
        target = target.float()
        
        TP = (pred_binary * target).sum()
        FN = ((1 - pred_binary) * target).sum()
        
        sensitivity = TP / (TP + FN + smooth)
        return sensitivity.item()
    
    
    # ==================== 3️⃣ SPECIFICITY ====================
    @staticmethod
    def specificity(pred, target, smooth=1e-6):
        """
        ✅ CORRETTO
        Formula: Specificity = TN / (TN + FP)
        
        ⚠️ BUG ORIGINALE: Non convertiva pred a binario!
        """
        # ✅ FIX: Converti a binario
        if pred.max() > 1.0:
            pred_binary = (pred > 0.5).float()
        else:
            pred_binary = pred
        
        target = target.float()
        
        TN = ((1 - pred_binary) * (1 - target)).sum()
        FP = (pred_binary * (1 - target)).sum()
        
        specificity = TN / (TN + FP + smooth)
        return specificity.item()
    
    
    # ==================== 4️⃣ F1 SCORE ====================
    @staticmethod
    def f1_score(pred, target, smooth=1e-6):
        """
        ✅ CORRETTO
        Formula: F1 = 2 * (Precision * Recall) / (Precision + Recall)
        """
        # ✅ FIX: Converti a binario
        if pred.max() > 1.0:
            pred_binary = (pred > 0.5).float()
        else:
            pred_binary = pred
        
        target = target.float()
        
        TP = (pred_binary * target).sum()
        FP = (pred_binary * (1 - target)).sum()
        FN = ((1 - pred_binary) * target).sum()
        
        precision = TP / (TP + FP + smooth)
        recall = TP / (TP + FN + smooth)
        f1 = 2 * (precision * recall) / (precision + recall + smooth)
        
        return f1.item()
    
    
    # ==================== 5️⃣ MATTHEWS CORRELATION COEFFICIENT ====================
    @staticmethod
    def mcc(pred, target, smooth=1e-6):
        """
        ✅ CORRETTO
        Formula: MCC = (TP*TN - FP*FN) / sqrt((TP+FP)(TP+FN)(TN+FP)(TN+FN))
        
        ⚠️ BUG ORIGINALE: Non gestiva correttamente i tensori
        """
        # ✅ FIX: Converti a binario
        if pred.max() > 1.0:
            pred_binary = (pred > 0.5).float()
        else:
            pred_binary = pred
        
        target = target.float()
        
        TP = (pred_binary * target).sum()
        TN = ((1 - pred_binary) * (1 - target)).sum()
        FP = (pred_binary * (1 - target)).sum()
        FN = ((1 - pred_binary) * target).sum()
        
        numerator = TP * TN - FP * FN
        denominator = torch.sqrt((TP + FP) * (TP + FN) * (TN + FP) * (TN + FN) + smooth)
        
        # ✅ FIX: Gestisci il caso di denominatore zero
        if denominator == 0:
            return 0.0
        
        mcc = numerator / denominator
        return mcc.item()
    
    
    # ==================== 6️⃣ PIXEL ACCURACY ====================
    @staticmethod
    def pixel_accuracy(pred, target, smooth=1e-6):
        """
        ✅ CORRETTO
        Formula: Accuracy = (TP + TN) / Total
        """
        # ✅ FIX: Converti a binario
        if pred.max() > 1.0:
            pred_binary = (pred > 0.5).float()
        else:
            pred_binary = pred
        
        target = target.float()
        
        correct = (pred_binary == target).float().sum()
        total = target.numel()
        
        # ✅ FIX: Gestisci il caso di total zero
        if total == 0:
            return 0.0
        
        return (correct / total).item()
    
    
    # ==================== 7️⃣ IoU (Intersection over Union) ====================
    @staticmethod
    def iou(pred, target, smooth=1e-6):
        """
        ✅ AGGIUNTO - Lo standard per segmentazione
        Formula: IoU = TP / (TP + FP + FN)
        """
        # ✅ FIX: Converti a binario
        if pred.max() > 1.0:
            pred_binary = (pred > 0.5).float()
        else:
            pred_binary = pred
        
        target = target.float()
        
        intersection = (pred_binary * target).sum()
        union = pred_binary.sum() + target.sum() - intersection
        
        iou_score = (intersection + smooth) / (union + smooth)
        return iou_score.item()


# ==================== FUNZIONE DI UTILITÀ ====================

def calculate_all_metrics(pred_batch, target_batch):
    """
    ✅ CORRETTA - Calcola tutte le metriche per un batch
    
    Input:
        pred_batch: tensor [B, H, W] con probabilità [0, 1] o valori > 1
        target_batch: tensor [B, H, W] con valori binari {0, 1}
    
    Output:
        dict con tutte le metriche
    """
    metrics = {
        'iou': [],
        'dice': [],
        'sensitivity': [],
        'specificity': [],
        'f1': [],
        'mcc': [],
        'accuracy': [],
    }
    
    # ✅ FIX: Itera correttamente su ogni elemento del batch
    for batch_idx in range(pred_batch.shape[0]):
        pred_single = pred_batch[batch_idx]    # [H, W]
        target_single = target_batch[batch_idx] # [H, W]
        
        # Calcola tutte le metriche per questa immagine
        metrics['iou'].append(LaneMetrics.iou(pred_single, target_single))
        metrics['dice'].append(LaneMetrics.dice_coefficient(pred_single, target_single))
        metrics['sensitivity'].append(LaneMetrics.sensitivity(pred_single, target_single))
        metrics['specificity'].append(LaneMetrics.specificity(pred_single, target_single))
        metrics['f1'].append(LaneMetrics.f1_score(pred_single, target_single))
        metrics['mcc'].append(LaneMetrics.mcc(pred_single, target_single))
        metrics['accuracy'].append(LaneMetrics.pixel_accuracy(pred_single, target_single))
    
    # ✅ Restituisci media di tutte le metriche
    return {k: np.mean(v) if v else 0.0 for k, v in metrics.items()}


# ==================== DEBUG ====================

if __name__ == '__main__':
    print("✅ Test metriche corrette...\n")
    
    # Crea dati fittizi
    pred = torch.rand(256, 256)           # Probabilità [0, 1]
    target = torch.randint(0, 2, (256, 256)).float()  # Binario
    
    print("📊 Metriche per Lane Detection:\n")
    print(f"  IoU:              {LaneMetrics.iou(pred, target):.4f}")
    print(f"  Dice:             {LaneMetrics.dice_coefficient(pred, target):.4f}")
    print(f"  Sensitivity:      {LaneMetrics.sensitivity(pred, target):.4f} ← Corsie trovate")
    print(f"  Specificity:      {LaneMetrics.specificity(pred, target):.4f} ← Falsi positivi")
    print(f"  F1 Score:         {LaneMetrics.f1_score(pred, target):.4f}")
    print(f"  MCC:              {LaneMetrics.mcc(pred, target):.4f}")
    print(f"  Accuracy:         {LaneMetrics.pixel_accuracy(pred, target):.4f}")
    
    # Test batch
    print("\n\n📊 Test Batch:\n")
    pred_batch = torch.rand(4, 256, 256)
    target_batch = torch.randint(0, 2, (4, 256, 256)).float()
    
    all_metrics = calculate_all_metrics(pred_batch, target_batch)
    for k, v in all_metrics.items():
        print(f"  {k.upper():12s}: {v:.4f}")


Writing metrics.py


In [5]:
!pip install segmentation_models_pytorch

Collecting segmentation_models_pytorch
  Downloading segmentation_models_pytorch-0.5.0-py3-none-any.whl.metadata (17 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch>=1.8->segmentation_models_pytorch)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch>=1.8->segmentation_models_pytorch)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch>=1.8->segmentation_models_pytorch)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch>=1.8->segmentation_models_pytorch)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.4.5.8 (from torch>=1.8->segmentation_models_pytorch)
  Downloading nvidia_cublas_cu12-12.4.5.8-

In [None]:
!python train.py

🚀 TRAINING U-NET - METRICHE CORRETTE
🎯 Loss: BCE (pos_weight=10) + Weighted Tversky (alpha=0.1)
Encoder: resnet50
Device: cuda
Learning rate: 0.001
Epochs: 50
📂 Caricamento dataset completo (training + test)...

📂 Caricamento TEST SET da: /kaggle/input/tusimple-preprocessed/tusimple_preprocessed/test/frames
✅ Test set immagini trovate: 2782
✅ Test set maschere trovate: 2782

✅ TOTALE IMMAGINI (Training + Test): 6408
✅ Dataset caricato: 6408 immagini totali

📊 Split ratio: 80.0% training / 20.0% validation
   Train samples: 5126
   Val samples: 1282
✅ DataLoaders creati!

🏋️ Inizio training...


Epoch 1/50
Training: 100%|███████████████████| 321/321 [00:59<00:00,  5.42it/s, loss=0.712]
Validation: 100%|█| 81/81 [00:09<00:00,  8.28it/s, loss=0.715, sens=0.726, spec=

📈 RISULTATI VALIDAZIONE:
   Train Loss:     0.7309
   Val Loss:       0.7178
   ───────────────────────
   Sensitivity:    0.7204 (corsie trovate)
   🎯 Specificity: 0.4838 ← METRICA PRINCIPALE
   Dice:           0.1071
   F1