# # TRES MODELOS MEJORADOS PARA SEGMENTACI√ìN DE V√âRTEBRAS
# 
# 1. DeepLabV3++ (DeepLabV3+ con Decoder Denso tipo U-Net++)
# 2. Hybrid++ (Tu modelo original mejorado con t√©cnicas de DeepLab)
# 3. UNet++Lite (U-Net++ optimizada y ligera)
# 
# Todos con segmentaci√≥n de alta calidad estilo U-Net++

In [None]:
# Configuraci√≥n inicial
import os
os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE'

import json
import numpy as np
import cv2
import random
from pathlib import Path
from PIL import Image
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from collections import defaultdict
import time
from datetime import datetime
import warnings
warnings.filterwarnings('ignore')

In [None]:
# SEED
def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    os.environ['PYTHONHASHSEED'] = str(seed)
    print(f"‚úì Seed fijado en {seed}")

In [None]:
# M√ìDULOS COMPARTIDOS

class SpatialAttention(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(2, 1, 7, padding=3, bias=False),
            nn.BatchNorm2d(1),
            nn.Sigmoid()
        )
    
    def forward(self, x):
        avg_out = torch.mean(x, dim=1, keepdim=True)
        max_out, _ = torch.max(x, dim=1, keepdim=True)
        return x * self.conv(torch.cat([avg_out, max_out], dim=1))


class ChannelAttention(nn.Module):
    def __init__(self, channels, reduction=16):
        super().__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.max_pool = nn.AdaptiveMaxPool2d(1)
        self.fc = nn.Sequential(
            nn.Conv2d(channels, channels // reduction, 1, bias=False),
            nn.ReLU(inplace=True),
            nn.Conv2d(channels // reduction, channels, 1, bias=False)
        )
        self.sigmoid = nn.Sigmoid()
    
    def forward(self, x):
        return x * self.sigmoid(self.fc(self.avg_pool(x)) + self.fc(self.max_pool(x)))


class ASPP(nn.Module):
    """ASPP mejorado con Channel Attention"""
    def __init__(self, in_ch, out_ch, rates=[6, 12, 18]):
        super().__init__()
        self.conv1 = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 1, bias=False),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True)
        )
        
        self.atrous_blocks = nn.ModuleList([
            nn.Sequential(
                nn.Conv2d(in_ch, out_ch, 3, padding=r, dilation=r, bias=False),
                nn.BatchNorm2d(out_ch),
                nn.ReLU(inplace=True)
            ) for r in rates
        ])
        
        self.global_pool = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Conv2d(in_ch, out_ch, 1, bias=False),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True)
        )
        
        total_ch = out_ch * (len(rates) + 2)
        self.channel_attention = ChannelAttention(total_ch, reduction=16)
        
        self.project = nn.Sequential(
            nn.Conv2d(total_ch, out_ch, 1, bias=False),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),
            nn.Dropout2d(0.2)
        )
    
    def forward(self, x):
        size = x.shape[2:]
        feats = [self.conv1(x)] + [block(x) for block in self.atrous_blocks]
        feats.append(F.interpolate(self.global_pool(x), size=size, mode='bilinear', align_corners=True))
        feats = torch.cat(feats, dim=1)
        feats = self.channel_attention(feats)
        return self.project(feats)


class ConvBlock(nn.Module):
    """Bloque convolucional con BatchNorm y Dropout"""
    def __init__(self, in_ch, out_ch, dropout=0.1):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 3, padding=1, bias=False),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),
            nn.Dropout2d(dropout),
            nn.Conv2d(out_ch, out_ch, 3, padding=1, bias=False),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True)
        )
    
    def forward(self, x):
        return self.conv(x)


class AttentionGate(nn.Module):
    """Attention Gate"""
    def __init__(self, F_g, F_l, F_int):
        super().__init__()
        self.W_g = nn.Sequential(
            nn.Conv2d(F_g, F_int, 1, bias=True),
            nn.BatchNorm2d(F_int)
        )
        self.W_x = nn.Sequential(
            nn.Conv2d(F_l, F_int, 1, bias=True),
            nn.BatchNorm2d(F_int)
        )
        self.psi = nn.Sequential(
            nn.Conv2d(F_int, 1, 1, bias=True),
            nn.BatchNorm2d(1),
            nn.Sigmoid()
        )
        self.relu = nn.ReLU(inplace=True)
    
    def forward(self, g, x):
        g1 = self.W_g(g)
        x1 = self.W_x(x)
        return x * self.psi(self.relu(g1 + x1))

In [None]:
# MODELO 1: DEEPLABV3++ (DeepLabV3+ con Decoder Denso)

class DeepLabV3PlusPlus(nn.Module):
    """
    DeepLabV3++ - DeepLabV3+ con Decoder Denso tipo U-Net++
    
    Mejoras sobre DeepLabV3+ est√°ndar:
    - Decoder de 4 niveles (vs 2 en original)
    - Skip connections en todos los niveles
    - Attention Gates selectivos
    - Refinaci√≥n progresiva
    
    Par√°metros: ~14M
    """
    def __init__(self, num_classes=3):
        super().__init__()
        
        # Encoder
        self.enc1 = ConvBlock(3, 64, dropout=0.05)
        self.enc2 = ConvBlock(64, 128, dropout=0.1)
        self.enc3 = ConvBlock(128, 256, dropout=0.1)
        self.enc4 = ConvBlock(256, 512, dropout=0.15)
        self.pool = nn.MaxPool2d(2, 2)
        
        # ASPP en bottleneck (salida 256 canales)
        self.aspp = ASPP(512, 256, rates=[6, 12, 18])
        
        # Projection para enc4: reducir 512‚Üí256 para match con ASPP
        self.enc4_proj = nn.Sequential(
            nn.Conv2d(512, 256, 1, bias=False),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True)
        )
        
        # Decoder DENSO
        # Level 4: ASPP(256) + enc4_proj(256) = 512
        self.up4 = nn.ConvTranspose2d(256, 256, 2, stride=2)
        self.att4 = AttentionGate(256, 256, 128)
        self.dec4 = ConvBlock(512, 256, dropout=0.1)
        
        # Level 3: dec4(256) + enc3(256) = 512
        self.up3 = nn.ConvTranspose2d(256, 128, 2, stride=2)
        self.enc3_proj = nn.Sequential(
            nn.Conv2d(256, 128, 1, bias=False),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True)
        )
        self.att3 = AttentionGate(128, 128, 64)
        self.dec3 = ConvBlock(256, 128, dropout=0.1)
        
        # Level 2: dec3(128) + enc2(128) = 256
        self.up2 = nn.ConvTranspose2d(128, 64, 2, stride=2)
        self.enc2_proj = nn.Sequential(
            nn.Conv2d(128, 64, 1, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True)
        )
        self.spatial_att = SpatialAttention()
        self.dec2 = ConvBlock(128, 64, dropout=0.05)
        
        # Level 1: dec2(64) + enc1(64) = 128
        self.up1 = nn.ConvTranspose2d(64, 64, 2, stride=2)
        self.dec1 = ConvBlock(128, 64, dropout=0.05)
        
        # Output
        self.out = nn.Sequential(
            nn.Conv2d(64, 32, 3, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True),
            nn.Conv2d(32, num_classes, 1)
        )
    
    def forward(self, x):
        # Encoder
        enc1 = self.enc1(x)              # 64 ch, H√óW
        enc2 = self.enc2(self.pool(enc1))  # 128 ch, H/2√óW/2
        enc3 = self.enc3(self.pool(enc2))  # 256 ch, H/4√óW/4
        enc4 = self.enc4(self.pool(enc3))  # 512 ch, H/8√óW/8
        
        # ASPP bottleneck
        bottleneck = self.aspp(enc4)     # 256 ch, H/8√óW/8
        
        # Decoder denso con interpolaci√≥n para match de tama√±os
        # Level 4
        dec4 = self.up4(bottleneck)      # 256 ch, H/4√óW/4
        enc4_reduced = self.enc4_proj(enc4)  # 512‚Üí256 ch, H/8√óW/8
        # Upsample enc4_reduced para match con dec4
        enc4_reduced = F.interpolate(enc4_reduced, size=dec4.shape[2:], 
                                    mode='bilinear', align_corners=True)
        enc4_att = self.att4(dec4, enc4_reduced)
        dec4 = torch.cat([dec4, enc4_att], dim=1)  # 256+256=512
        dec4 = self.dec4(dec4)           # 256 ch
        
        # Level 3
        dec3 = self.up3(dec4)            # 128 ch, H/2√óW/2
        enc3_reduced = self.enc3_proj(enc3)  # 256‚Üí128 ch, H/4√óW/4
        # Match size
        enc3_reduced = F.interpolate(enc3_reduced, size=dec3.shape[2:], 
                                    mode='bilinear', align_corners=True)
        enc3_att = self.att3(dec3, enc3_reduced)
        dec3 = torch.cat([dec3, enc3_att], dim=1)  # 128+128=256
        dec3 = self.dec3(dec3)           # 128 ch
        
        # Level 2
        dec2 = self.up2(dec3)            # 64 ch, H√óW
        enc2_reduced = self.enc2_proj(enc2)  # 128‚Üí64 ch, H/2√óW/2
        # Match size
        enc2_reduced = F.interpolate(enc2_reduced, size=dec2.shape[2:], 
                                    mode='bilinear', align_corners=True)
        enc2_att = self.spatial_att(enc2_reduced)
        dec2 = torch.cat([dec2, enc2_att], dim=1)  # 64+64=128
        dec2 = self.dec2(dec2)           # 64 ch
        
        # Level 1
        dec1 = self.up1(dec2)            # 64 ch, 2H√ó2W
        # Match size con enc1
        dec1 = F.interpolate(dec1, size=enc1.shape[2:], 
                           mode='bilinear', align_corners=True)
        dec1 = torch.cat([dec1, enc1], dim=1)  # 64+64=128
        dec1 = self.dec1(dec1)           # 64 ch
        
        # Ensure output matches input size
        output = self.out(dec1)
        if output.shape[2:] != x.shape[2:]:
            output = F.interpolate(output, size=x.shape[2:], 
                                 mode='bilinear', align_corners=True)
        
        return output

In [None]:
# MODELO 2: HYBRID++ (Tu modelo con t√©cnicas DeepLab)

class HybridPlusPlus(nn.Module):
    """
    Hybrid++ - Tu modelo Hybrid mejorado con t√©cnicas de DeepLab
    
    Mejoras:
    - ASPP m√°s robusto con Channel Attention
    - Dilated convolutions en encoder profundo
    - Mejor refinaci√≥n en decoder
    - Spatial Attention en m√∫ltiples niveles
    
    Par√°metros: ~18M (tu modelo optimizado al m√°ximo)
    """
    def __init__(self, num_classes=3):
        super().__init__()
        
        # Encoder con dilated conv en niveles profundos
        self.enc1 = ConvBlock(3, 64, dropout=0.05)
        self.enc2 = ConvBlock(64, 128, dropout=0.1)
        self.enc3 = ConvBlock(128, 256, dropout=0.1)
        
        # Enc4 con dilated convolutions
        self.enc4 = nn.Sequential(
            nn.Conv2d(256, 512, 3, padding=2, dilation=2, bias=False),
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True),
            nn.Dropout2d(0.15),
            nn.Conv2d(512, 512, 3, padding=2, dilation=2, bias=False),
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True)
        )
        
        self.pool = nn.MaxPool2d(2, 2)
        
        # ASPP mejorado (salida 256 canales para control de par√°metros)
        self.aspp = ASPP(512, 256, rates=[6, 12, 18])
        
        # Projections para match de canales
        self.enc4_proj = nn.Sequential(
            nn.Conv2d(512, 256, 1, bias=False),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True)
        )
        self.enc3_proj = nn.Sequential(
            nn.Conv2d(256, 128, 1, bias=False),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True)
        )
        self.enc2_proj = nn.Sequential(
            nn.Conv2d(128, 64, 1, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True)
        )
        
        # Decoder con attention en todos los niveles
        self.up4 = nn.ConvTranspose2d(256, 256, 2, stride=2)
        self.att4 = AttentionGate(256, 256, 128)
        self.spatial_att4 = SpatialAttention()
        self.dec4 = ConvBlock(512, 256, dropout=0.1)
        
        self.up3 = nn.ConvTranspose2d(256, 128, 2, stride=2)
        self.att3 = AttentionGate(128, 128, 64)
        self.spatial_att3 = SpatialAttention()
        self.dec3 = ConvBlock(256, 128, dropout=0.1)
        
        self.up2 = nn.ConvTranspose2d(128, 64, 2, stride=2)
        self.spatial_att2 = SpatialAttention()
        self.dec2 = ConvBlock(128, 64, dropout=0.05)
        
        self.up1 = nn.ConvTranspose2d(64, 64, 2, stride=2)
        self.dec1 = ConvBlock(128, 64, dropout=0.05)
        
        # Output refinado
        self.out = nn.Sequential(
            nn.Conv2d(64, 32, 3, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True),
            nn.Dropout2d(0.1),
            nn.Conv2d(32, num_classes, 1)
        )
    
    def forward(self, x):
        # Encoder
        enc1 = self.enc1(x)              # 64 ch
        enc2 = self.enc2(self.pool(enc1))  # 128 ch
        enc3 = self.enc3(self.pool(enc2))  # 256 ch
        enc4 = self.enc4(self.pool(enc3))  # 512 ch
        
        # ASPP
        bottleneck = self.aspp(enc4)     # 512 ‚Üí 256 ch
        
        # Decoder con doble attention + size matching
        # Level 4: bottleneck(256) + enc4_proj(256)
        dec4 = self.up4(bottleneck)
        dec4 = F.interpolate(dec4, size=enc4.shape[2:], mode='bilinear', align_corners=True)
        enc4_reduced = self.enc4_proj(enc4)  # 512 ‚Üí 256
        enc4_att = self.att4(dec4, enc4_reduced)
        enc4_att = self.spatial_att4(enc4_att)
        dec4 = torch.cat([dec4, enc4_att], dim=1)  # 256+256=512
        dec4 = self.dec4(dec4)  # 512 ‚Üí 256
        
        # Level 3: dec4(256) + enc3_proj(128)
        dec3 = self.up3(dec4)
        dec3 = F.interpolate(dec3, size=enc3.shape[2:], mode='bilinear', align_corners=True)
        enc3_reduced = self.enc3_proj(enc3)  # 256 ‚Üí 128
        enc3_att = self.att3(dec3, enc3_reduced)
        enc3_att = self.spatial_att3(enc3_att)
        dec3 = torch.cat([dec3, enc3_att], dim=1)  # 128+128=256
        dec3 = self.dec3(dec3)  # 256 ‚Üí 128
        
        # Level 2: dec3(128) + enc2_proj(64)
        dec2 = self.up2(dec3)
        dec2 = F.interpolate(dec2, size=enc2.shape[2:], mode='bilinear', align_corners=True)
        enc2_reduced = self.enc2_proj(enc2)  # 128 ‚Üí 64
        enc2_att = self.spatial_att2(enc2_reduced)
        dec2 = torch.cat([dec2, enc2_att], dim=1)  # 64+64=128
        dec2 = self.dec2(dec2)  # 128 ‚Üí 64
        
        # Level 1: dec2(64) + enc1(64)
        dec1 = self.up1(dec2)
        dec1 = F.interpolate(dec1, size=enc1.shape[2:], mode='bilinear', align_corners=True)
        dec1 = torch.cat([dec1, enc1], dim=1)  # 64+64=128
        dec1 = self.dec1(dec1)  # 128 ‚Üí 64
        
        output = self.out(dec1)
        if output.shape[2:] != x.shape[2:]:
            output = F.interpolate(output, size=x.shape[2:], mode='bilinear', align_corners=True)
        
        return output

In [None]:
# MODELO 3: UNET++ LITE (U-Net++ Optimizada)

class UNetPlusPlusLite(nn.Module):
    """
    U-Net++ Lite - Versi√≥n optimizada de U-Net++ con nested skip connections
    
    Caracter√≠sticas:
    - Nested skip pathways (caracter√≠stica distintiva de U-Net++)
    - Deep supervision (opcional)
    - Menos canales que U-Net++ original para eficiencia
    
    Par√°metros: ~22M (vs ~36M de U-Net++ original)
    """
    def __init__(self, num_classes=3, deep_supervision=False):
        super().__init__()
        self.deep_supervision = deep_supervision
        
        # Encoder
        self.enc1 = ConvBlock(3, 48, dropout=0.05)
        self.enc2 = ConvBlock(48, 96, dropout=0.1)
        self.enc3 = ConvBlock(96, 192, dropout=0.1)
        self.enc4 = ConvBlock(192, 384, dropout=0.15)
        self.pool = nn.MaxPool2d(2, 2)
        
        # Bottleneck
        self.bottleneck = ConvBlock(384, 768, dropout=0.2)
        
        # Nested skip pathways (n√∫cleo de U-Net++)
        self.up1_0 = nn.ConvTranspose2d(768, 384, 2, stride=2)
        self.conv1_0 = ConvBlock(768, 384, dropout=0.15)
        
        self.up2_0 = nn.ConvTranspose2d(384, 192, 2, stride=2)
        self.conv2_0 = ConvBlock(384, 192, dropout=0.1)
        self.up1_1 = nn.ConvTranspose2d(384, 192, 2, stride=2)
        self.conv1_1 = ConvBlock(384, 192, dropout=0.1)
        
        self.up3_0 = nn.ConvTranspose2d(192, 96, 2, stride=2)
        self.conv3_0 = ConvBlock(192, 96, dropout=0.1)
        self.up2_1 = nn.ConvTranspose2d(192, 96, 2, stride=2)
        self.conv2_1 = ConvBlock(192, 96, dropout=0.1)
        self.up1_2 = nn.ConvTranspose2d(192, 96, 2, stride=2)
        self.conv1_2 = ConvBlock(192, 96, dropout=0.1)
        
        self.up4_0 = nn.ConvTranspose2d(96, 48, 2, stride=2)
        self.conv4_0 = ConvBlock(96, 48, dropout=0.05)
        self.up3_1 = nn.ConvTranspose2d(96, 48, 2, stride=2)
        self.conv3_1 = ConvBlock(96, 48, dropout=0.05)
        self.up2_2 = nn.ConvTranspose2d(96, 48, 2, stride=2)
        self.conv2_2 = ConvBlock(96, 48, dropout=0.05)
        self.up1_3 = nn.ConvTranspose2d(96, 48, 2, stride=2)
        self.conv1_3 = ConvBlock(96, 48, dropout=0.05)
        
        # Output
        self.out = nn.Conv2d(48, num_classes, 1)
        
        # Deep supervision outputs (opcional)
        if deep_supervision:
            self.out1 = nn.Conv2d(96, num_classes, 1)
            self.out2 = nn.Conv2d(96, num_classes, 1)
            self.out3 = nn.Conv2d(384, num_classes, 1)
    
    def forward(self, x):
        # Encoder
        x1_0 = self.enc1(x)
        x2_0 = self.enc2(self.pool(x1_0))
        x3_0 = self.enc3(self.pool(x2_0))
        x4_0 = self.enc4(self.pool(x3_0))
        x5_0 = self.bottleneck(self.pool(x4_0))
        
        # Nested skip pathways
        x1_0_up = self.up1_0(x5_0)
        x1_0_cat = torch.cat([x4_0, x1_0_up], dim=1)
        x4_1 = self.conv1_0(x1_0_cat)
        
        x2_0_up = self.up2_0(x4_1)
        x2_0_cat = torch.cat([x3_0, x2_0_up], dim=1)
        x3_1 = self.conv2_0(x2_0_cat)
        
        x1_1_up = self.up1_1(x4_1)
        x1_1_cat = torch.cat([x3_0, x1_1_up], dim=1)
        x3_2 = self.conv1_1(x1_1_cat)
        
        x3_0_up = self.up3_0(x3_1)
        x3_0_cat = torch.cat([x2_0, x3_0_up], dim=1)
        x2_1 = self.conv3_0(x3_0_cat)
        
        x2_1_up = self.up2_1(x3_2)
        x2_1_cat = torch.cat([x2_0, x2_1_up], dim=1)
        x2_2 = self.conv2_1(x2_1_cat)
        
        x1_2_up = self.up1_2(x3_2)
        x1_2_cat = torch.cat([x2_0, x1_2_up], dim=1)
        x2_3 = self.conv1_2(x1_2_cat)
        
        x4_0_up = self.up4_0(x2_1)
        x4_0_cat = torch.cat([x1_0, x4_0_up], dim=1)
        x1_1 = self.conv4_0(x4_0_cat)
        
        x3_1_up = self.up3_1(x2_2)
        x3_1_cat = torch.cat([x1_0, x3_1_up], dim=1)
        x1_2 = self.conv3_1(x3_1_cat)
        
        x2_2_up = self.up2_2(x2_3)
        x2_2_cat = torch.cat([x1_0, x2_2_up], dim=1)
        x1_3 = self.conv2_2(x2_2_cat)
        
        x1_3_up = self.up1_3(x2_3)
        x1_3_cat = torch.cat([x1_0, x1_3_up], dim=1)
        x1_4 = self.conv1_3(x1_3_cat)
        
        # Output
        output = self.out(x1_4)
        
        if self.deep_supervision:
            return [output, self.out1(x1_3), self.out2(x1_2), self.out3(x4_1)]
        return output

In [None]:
# DATASET

class VertebrasDataset(Dataset):
    def __init__(self, base_path, json_filename='coco_anotaciones_actualizadas_23sep.json',
                 target_size=(256, 256)):
        
        self.base_path = Path(base_path)
        self.json_path = self.base_path / "Anotaciones v√©rtebras" / json_filename
        self.radiografias_path = self.base_path / "Radiograf√≠as"
        self.target_size = target_size
        
        with open(self.json_path, 'r', encoding='utf-8') as f:
            self.coco_data = json.load(f)
        
        self.class_names = ['Background', 'T1', 'V']
        self.name_to_class = {'F': 0, 'background': 0, 'T1': 1, 'V': 2}
        self.samples = self._preparar_samples()
        
        self.transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(
                mean=[0.399637, 0.400040, 0.392532],
                std=[0.212403, 0.211738, 0.207753]
            )
        ])
    
    def _preparar_samples(self):
        samples = []
        anns_por_imagen = defaultdict(list)
        
        for ann in self.coco_data['annotations']:
            img_id = ann.get('image_id')
            if img_id is not None:
                anns_por_imagen[img_id].append(ann)
        
        for img_info in self.coco_data['images']:
            img_id = img_info.get('id')
            if img_id not in anns_por_imagen:
                continue
            
            file_name = img_info.get('file_name') or img_info.get('toras_path', '') or ''
            if file_name.startswith('/'):
                file_name = file_name[1:]
            
            img_path = self.radiografias_path / file_name
            if not img_path.exists():
                img_path = self.radiografias_path / Path(file_name).name
            
            if img_path.exists():
                samples.append({
                    'image_id': img_id,
                    'image_path': str(img_path),
                    'annotations': anns_por_imagen[img_id]
                })
        
        return samples
    
    def _parsear_segmentacion(self, segmentation):
        poligonos = []
        if not segmentation or not isinstance(segmentation, list):
            return poligonos
        
        for seg_item in segmentation:
            if isinstance(seg_item, list) and seg_item:
                if isinstance(seg_item[0], (int, float)) and len(seg_item) >= 6:
                    coords = np.array(seg_item).reshape(-1, 2)
                    if coords.shape[0] >= 3:
                        poligonos.append(coords.astype(np.int32))
        return poligonos
    
    def _crear_mascara(self, annotations, orig_height, orig_width):
        mask = np.zeros((orig_height, orig_width), dtype=np.uint8)
        
        for ann in annotations:
            name = ann.get('name', '').strip()
            if name not in self.name_to_class:
                continue
            
            class_id = self.name_to_class[name]
            poligonos = self._parsear_segmentacion(ann.get('segmentation'))
            
            for poly in poligonos:
                cv2.fillPoly(mask, [poly], class_id)
            
            if not poligonos and 'bbox' in ann:
                bbox = ann['bbox']
                x, y, w, h = [int(v) for v in bbox]
                if x >= 0 and y >= 0 and w > 0 and h > 0:
                    x2, y2 = min(x+w, orig_width), min(y+h, orig_height)
                    mask[y:y2, x:x2] = class_id
        
        return mask
    
    def __len__(self):
        return len(self.samples)
    
    def __getitem__(self, idx):
        sample = self.samples[idx]
        
        image = Image.open(sample['image_path']).convert('RGB')
        orig_width, orig_height = image.size
        mask = self._crear_mascara(sample['annotations'], orig_height, orig_width)
        
        image = image.resize(self.target_size, Image.BILINEAR)
        mask = cv2.resize(mask, self.target_size, interpolation=cv2.INTER_NEAREST)
        
        image = self.transform(image)
        mask = torch.from_numpy(mask).long()
        
        return image, mask

In [None]:
# LOSS Y M√âTRICAS

class DiceLoss(nn.Module):
    def __init__(self, smooth=1.0):
        super().__init__()
        self.smooth = smooth
    
    def forward(self, pred, target):
        pred = F.softmax(pred, dim=1)
        target_one_hot = F.one_hot(target, num_classes=pred.shape[1]).permute(0, 3, 1, 2).float()
        
        intersection = (pred * target_one_hot).sum(dim=(2, 3))
        union = pred.sum(dim=(2, 3)) + target_one_hot.sum(dim=(2, 3))
        
        dice = (2.0 * intersection + self.smooth) / (union + self.smooth)
        return 1.0 - dice.mean()


class FocalLoss(nn.Module):
    def __init__(self, alpha=1, gamma=2):
        super().__init__()
        self.alpha = alpha
        self.gamma = gamma
    
    def forward(self, pred, target):
        ce_loss = F.cross_entropy(pred, target, reduction='none')
        pt = torch.exp(-ce_loss)
        return (self.alpha * (1 - pt) ** self.gamma * ce_loss).mean()


class CombinedLoss(nn.Module):
    def __init__(self, focal_weight=0.3, dice_weight=0.7):
        super().__init__()
        self.focal_weight = focal_weight
        self.dice_weight = dice_weight
        self.focal_loss = FocalLoss()
        self.dice_loss = DiceLoss()
    
    def forward(self, pred, target):
        return self.focal_weight * self.focal_loss(pred, target) + self.dice_weight * self.dice_loss(pred, target)


def calcular_metricas_detalladas(pred, target, num_classes, class_names):
    pred_classes = torch.argmax(pred, dim=1)
    metricas = {}
    
    for c in range(num_classes):
        pred_c = (pred_classes == c)
        target_c = (target == c)
        intersection = (pred_c & target_c).sum().float()
        union = (pred_c | target_c).sum().float()
        
        if union > 0:
            iou = (intersection / union).item()
            dice = (2 * intersection / (pred_c.sum() + target_c.sum())).item()
        else:
            iou = 0.0
            dice = 0.0
        
        class_name = class_names[c] if c < len(class_names) else f'class_{c}'
        metricas[f'iou_{class_name}'] = iou
        metricas[f'dice_{class_name}'] = dice
    
    ious_sin_bg = [metricas[f'iou_{class_names[c]}'] for c in range(1, num_classes)]
    dices_sin_bg = [metricas[f'dice_{class_names[c]}'] for c in range(1, num_classes)]
    
    metricas['mean_iou'] = np.mean(ious_sin_bg) if ious_sin_bg else 0.0
    metricas['mean_dice'] = np.mean(dices_sin_bg) if dices_sin_bg else 0.0
    metricas['accuracy'] = (pred_classes == target).float().mean().item()
    
    return metricas


In [None]:
# FUNCIONES DE ENTRENAMIENTO

def train_epoch(model, dataloader, criterion, optimizer, device, num_classes, class_names):
    model.train()
    total_loss = 0
    total_metrics = defaultdict(float)
    
    for images, masks in dataloader:
        if images.size(0) == 1:
            continue
            
        images = images.to(device)
        masks = masks.to(device)
        
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, masks)
        loss.backward()
        
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()
        
        total_loss += loss.item()
        metricas = calcular_metricas_detalladas(outputs, masks, num_classes, class_names)
        for k, v in metricas.items():
            total_metrics[k] += v
    
    avg_loss = total_loss / len(dataloader)
    avg_metrics = {k: v / len(dataloader) for k, v in total_metrics.items()}
    return avg_loss, avg_metrics


def validate(model, dataloader, criterion, device, num_classes, class_names):
    model.eval()
    total_loss = 0
    total_metrics = defaultdict(float)
    
    with torch.no_grad():
        for images, masks in dataloader:
            images = images.to(device)
            masks = masks.to(device)
            
            outputs = model(images)
            loss = criterion(outputs, masks)
            
            total_loss += loss.item()
            metricas = calcular_metricas_detalladas(outputs, masks, num_classes, class_names)
            for k, v in metricas.items():
                total_metrics[k] += v
    
    avg_loss = total_loss / len(dataloader)
    avg_metrics = {k: v / len(dataloader) for k, v in total_metrics.items()}
    return avg_loss, avg_metrics

In [None]:
# VISUALIZACIONES

def visualizar_comparacion_3modelos(models, model_names, dataset, device, class_names, num_samples=3, seed=42):
    """Compara los 3 modelos lado a lado"""
    for model in models:
        model.eval()
    
    colors = {0: [0, 0, 0], 1: [0, 255, 0], 2: [0, 0, 255]}
    
    fig, axes = plt.subplots(num_samples, 5, figsize=(22, num_samples * 4))
    if num_samples == 1:
        axes = axes.reshape(1, -1)
    
    np.random.seed(seed)
    indices = np.random.choice(len(dataset), min(num_samples, len(dataset)), replace=False)
    
    mean = torch.tensor([0.399637, 0.400040, 0.392532]).view(3, 1, 1)
    std = torch.tensor([0.212403, 0.211738, 0.207753]).view(3, 1, 1)
    
    with torch.no_grad():
        for i, idx in enumerate(indices):
            image, mask = dataset[idx]
            image_input = image.unsqueeze(0).to(device)
            
            # Desnormalizar imagen
            image_denorm = image * std + mean
            img_np = image_denorm.permute(1, 2, 0).numpy()
            img_np = np.clip(img_np, 0, 1)
            
            mask_np = mask.numpy()
            mask_colored = np.zeros((*mask_np.shape, 3), dtype=np.uint8)
            for class_id, color in colors.items():
                mask_colored[mask_np == class_id] = color
            
            # Original
            axes[i, 0].imshow(img_np)
            axes[i, 0].set_title('üñºÔ∏è Original', fontsize=11, fontweight='bold')
            axes[i, 0].axis('off')
            
            # Ground Truth
            axes[i, 1].imshow(mask_colored)
            axes[i, 1].set_title('‚úì Ground Truth', fontsize=11, fontweight='bold')
            axes[i, 1].axis('off')
            
            # Predicciones de los 3 modelos
            for j, (model, name) in enumerate(zip(models, model_names)):
                pred = model(image_input).cpu().squeeze(0)
                pred_classes = torch.argmax(pred, dim=0).numpy()
                
                pred_colored = np.zeros((*pred_classes.shape, 3), dtype=np.uint8)
                for class_id, color in colors.items():
                    pred_colored[pred_classes == class_id] = color
                
                metricas = calcular_metricas_detalladas(
                    pred.unsqueeze(0), mask.unsqueeze(0), len(colors), class_names
                )
                
                axes[i, j+2].imshow(pred_colored)
                axes[i, j+2].set_title(f'{name}\nIoU: {metricas["mean_iou"]:.3f}', 
                                      fontsize=10, fontweight='bold')
                axes[i, j+2].axis('off')
    
    legend_elements = [
        plt.Rectangle((0, 0), 1, 1, fc=np.array(colors[i])/255.0, 
                     edgecolor='black', linewidth=2, label=class_names[i]) 
        for i in range(1, len(class_names))
    ]
    fig.legend(handles=legend_elements, loc='upper center', ncol=len(class_names)-1, 
              fontsize=12, frameon=True, fancybox=True, shadow=True)
    
    plt.suptitle('üîç COMPARACI√ìN: 3 MODELOS', fontsize=16, fontweight='bold', y=0.995)
    plt.tight_layout(rect=[0, 0, 1, 0.98])
    
    filename = f'comparison_3models_seed{seed}.png'
    plt.savefig(filename, dpi=200, bbox_inches='tight', facecolor='white')
    plt.close()
    print(f"üì∏ Comparaci√≥n guardada: {filename}")


def crear_tabla_final_comparativa(resultados):
    """Tabla comparativa final de los 3 modelos"""
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(18, 8))
    
    nombres = [r['name'] for r in resultados]
    ious = [r['iou'] for r in resultados]
    params = [r['params'] / 1e6 for r in resultados]
    
    # Gr√°fico 1: IoU
    colors = ['#FF6B6B', '#4ECDC4', '#45B7D1']
    bars = ax1.barh(nombres, ious, color=colors, edgecolor='black', linewidth=2)
    ax1.set_xlabel('IoU Promedio', fontsize=13, fontweight='bold')
    ax1.set_title('üèÜ Comparaci√≥n de IoU', fontsize=15, fontweight='bold')
    ax1.grid(True, alpha=0.3, axis='x')
    ax1.set_xlim([0, max(ious) * 1.1])
    
    for bar, iou in zip(bars, ious):
        ax1.text(bar.get_width() + 0.005, bar.get_y() + bar.get_height()/2, 
                f'{iou:.4f}', va='center', ha='left', fontsize=12, fontweight='bold')
    
    # Gr√°fico 2: Eficiencia
    scatter = ax2.scatter(params, ious, s=500, c=colors, edgecolor='black', 
                         linewidth=3, alpha=0.8)
    
    for i, nombre in enumerate(nombres):
        ax2.annotate(nombre, (params[i], ious[i]), 
                    xytext=(10, 10), textcoords='offset points',
                    fontsize=11, fontweight='bold',
                    bbox=dict(boxstyle='round,pad=0.5', fc='yellow', alpha=0.4))
    
    ax2.set_xlabel('Par√°metros (Millones)', fontsize=13, fontweight='bold')
    ax2.set_ylabel('IoU Promedio', fontsize=13, fontweight='bold')
    ax2.set_title('‚öñÔ∏è Eficiencia: IoU vs Tama√±o', fontsize=15, fontweight='bold')
    ax2.grid(True, alpha=0.3)
    
    plt.suptitle('üìä AN√ÅLISIS COMPARATIVO - 3 MODELOS FINALES', 
                fontsize=17, fontweight='bold')
    plt.tight_layout(rect=[0, 0, 1, 0.97])
    
    filename = 'final_comparison_3models.png'
    plt.savefig(filename, dpi=200, bbox_inches='tight', facecolor='white')
    plt.close()
    print(f"üìä Tabla final guardada: {filename}")

In [None]:
# MAIN - ENTRENAR Y COMPARAR 3 MODELOS

def main():
    print("\n" + "="*80)
    print("üî¨ COMPARACI√ìN: 3 MODELOS OPTIMIZADOS")
    print("   1. DeepLabV3++ (DeepLab con decoder denso)")
    print("   2. Hybrid++ (Tu modelo optimizado)")
    print("   3. U-Net++ Lite (U-Net++ eficiente)")
    print("="*80)
    
    # CONFIGURACI√ìN
    SEED = 42
    BASE_PATH = r"C:\Users\User\Documents\Proyectofinal"
    BATCH_SIZE = 8
    MAX_EPOCHS = 100
    LEARNING_RATE = 0.0001
    IMAGE_SIZE = 256
    NUM_CLASSES = 3
    
    # ¬øQu√© modelo entrenar? (puedes elegir uno o todos)
    TRAIN_MODEL = input("\n¬øQu√© modelo entrenar? (1=DeepLabV3++, 2=Hybrid++, 3=UNet++Lite, 4=Todos): ").strip()
    
    set_seed(SEED)
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"\nüíª Dispositivo: {device}")
    
    # DATASET
    print("\nüì¶ Cargando dataset...")
    full_dataset = VertebrasDataset(BASE_PATH, target_size=(IMAGE_SIZE, IMAGE_SIZE))
    class_names = full_dataset.class_names
    
    train_size = int(0.8 * len(full_dataset))
    val_size = len(full_dataset) - train_size
    
    generator = torch.Generator().manual_seed(SEED)
    train_dataset, val_dataset = torch.utils.data.random_split(
        full_dataset, [train_size, val_size], generator=generator
    )
    
    print(f"‚úì Total: {len(full_dataset)} | Train: {len(train_dataset)} | Val: {len(val_dataset)}")
    
    def seed_worker(worker_id):
        worker_seed = SEED + worker_id
        np.random.seed(worker_seed)
        random.seed(worker_seed)
    
    g = torch.Generator()
    g.manual_seed(SEED)
    
    train_loader = DataLoader(
        train_dataset, batch_size=BATCH_SIZE, shuffle=True,
        num_workers=0, worker_init_fn=seed_worker, generator=g, drop_last=True
    )
    
    val_loader = DataLoader(
        val_dataset, batch_size=BATCH_SIZE, shuffle=False,
        num_workers=0, worker_init_fn=seed_worker, drop_last=False
    )
    
    # DEFINIR MODELOS
    modelos_config = {
        '1': ('DeepLabV3++', DeepLabV3PlusPlus(NUM_CLASSES)),
        '2': ('Hybrid++', HybridPlusPlus(NUM_CLASSES)),
        '3': ('U-Net++ Lite', UNetPlusPlusLite(NUM_CLASSES, deep_supervision=False))
    }
    
    if TRAIN_MODEL == '4':
        modelos_a_entrenar = ['1', '2', '3']
    else:
        modelos_a_entrenar = [TRAIN_MODEL]
    
    resultados_finales = []
    
    # ENTRENAR MODELOS
    for modelo_id in modelos_a_entrenar:
        if modelo_id not in modelos_config:
            continue
        
        nombre_modelo, modelo = modelos_config[modelo_id]
        modelo = modelo.to(device)
        
        print("\n" + "="*80)
        print(f"üèóÔ∏è ENTRENANDO: {nombre_modelo}")
        print("="*80)
        
        total_params = sum(p.numel() for p in modelo.parameters())
        print(f"‚úì Par√°metros: {total_params/1e6:.2f}M ({total_params:,})")
        
        criterion = CombinedLoss(focal_weight=0.3, dice_weight=0.7)
        optimizer = optim.AdamW(modelo.parameters(), lr=LEARNING_RATE, weight_decay=1e-4)
        scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.5, patience=7)
        
        best_iou = 0
        best_dice = 0
        best_epoch = 0
        patience_counter = 0
        patience = 15
        
        print(f"\nüöÄ Iniciando entrenamiento...")
        start_time = time.time()
        
        for epoch in range(MAX_EPOCHS):
            train_loss, train_metrics = train_epoch(
                modelo, train_loader, criterion, optimizer, device, NUM_CLASSES, class_names
            )
            
            val_loss, val_metrics = validate(
                modelo, val_loader, criterion, device, NUM_CLASSES, class_names
            )
            
            if (epoch + 1) % 10 == 0:
                print(f"Epoca {epoch+1}/{MAX_EPOCHS} | Val IoU: {val_metrics['mean_iou']:.4f} | "
                      f"Dice: {val_metrics['mean_dice']:.4f}")
            
            if val_metrics['mean_iou'] > best_iou:
                best_iou = val_metrics['mean_iou']
                best_dice = val_metrics['mean_dice']
                best_epoch = epoch + 1
                patience_counter = 0
                
                best_metrics_per_class = {
                    class_names[c]: {
                        'iou': val_metrics[f'iou_{class_names[c]}'],
                        'dice': val_metrics[f'dice_{class_names[c]}']
                    }
                    for c in range(1, NUM_CLASSES)
                }
                
                torch.save(modelo.state_dict(), f'{nombre_modelo.replace(" ", "_").replace("+", "p")}_best.pth')
            else:
                patience_counter += 1
            
            scheduler.step(val_metrics['mean_iou'])
            
            if patience_counter >= patience:
                print(f"‚èπÔ∏è  Early stopping en epoca {epoch+1}")
                break
        
        training_time = time.time() - start_time
        
        print(f"\n‚úÖ {nombre_modelo} completado!")
        print(f"   Mejor IoU: {best_iou:.4f} | Dice: {best_dice:.4f} (Epoca {best_epoch})")
        print(f"   Tiempo: {training_time/60:.2f} min")
        
        resultados_finales.append({
            'name': nombre_modelo,
            'model': modelo,
            'iou': best_iou,
            'dice': best_dice,
            'params': total_params,
            'time': training_time,
            'epoch': best_epoch,
            'metrics_per_class': best_metrics_per_class
        })
    
    # COMPARACI√ìN FINAL
    print("\n" + "="*80)
    print("üìä RESULTADOS FINALES - COMPARACI√ìN")
    print("="*80)
    
    print(f"\n{'Modelo':<20} {'IoU':<10} {'Dice':<10} {'Params':<12} {'Tiempo':<12}")
    print(f"{'-'*64}")
    for r in resultados_finales:
        print(f"{r['name']:<20} {r['iou']:<10.4f} {r['dice']:<10.4f} "
              f"{r['params']/1e6:>6.2f}M     {r['time']/60:>6.2f} min")
    
    # Mejor modelo
    mejor = max(resultados_finales, key=lambda x: x['iou'])
    print(f"\nüèÜ GANADOR: {mejor['name']} con IoU {mejor['iou']:.4f}")
    
    # M√°s eficiente
    eficiencias = [(r['name'], r['iou'] / (r['params'] / 1e6) * 100) for r in resultados_finales]
    mas_eficiente = max(eficiencias, key=lambda x: x[1])
    print(f"‚ö° M√ÅS EFICIENTE: {mas_eficiente[0]} ({mas_eficiente[1]:.2f} eficiencia)")
    
    # VISUALIZACIONES COMPARATIVAS
    if len(resultados_finales) > 1:
        print("\nüì∏ Generando visualizaciones comparativas...")
        
        modelos = [r['model'] for r in resultados_finales]
        nombres = [r['name'] for r in resultados_finales]
        
        visualizar_comparacion_3modelos(modelos, nombres, val_dataset, device, 
                                       class_names, num_samples=4, seed=SEED)
        
        crear_tabla_final_comparativa(resultados_finales)
    
    # RESUMEN DETALLADO
    print(f"\n{'='*80}")
    print("üìã AN√ÅLISIS DETALLADO POR MODELO")
    print(f"{'='*80}")
    
    for r in resultados_finales:
        print(f"\nüî¨ {r['name']}:")
        print(f"   IoU: {r['iou']:.4f} | Dice: {r['dice']:.4f}")
        print(f"   Par√°metros: {r['params']/1e6:.2f}M")
        print(f"   Tiempo: {r['time']/60:.2f} min | Epoca: {r['epoch']}")
        print(f"   M√©tricas por clase:")
        for vertebra in class_names[1:]:
            iou = r['metrics_per_class'][vertebra]['iou']
            dice = r['metrics_per_class'][vertebra]['dice']
            print(f"      {vertebra}: IoU={iou:.4f} | Dice={dice:.4f}")
    
    # RECOMENDACIONES
    print(f"\n{'='*80}")
    print("üí° RECOMENDACIONES")
    print(f"{'='*80}")
    
    print(f"\nüéØ CU√ÅNDO USAR CADA MODELO:")
    print(f"   ‚Ä¢ DeepLabV3++: Balance entre precisi√≥n y eficiencia")
    print(f"   ‚Ä¢ Hybrid++: M√°xima precisi√≥n con atenci√≥n multi-nivel")
    print(f"   ‚Ä¢ U-Net++ Lite: Segmentaci√≥n tipo U-Net++ pero m√°s ligera")
    
    print(f"\nüìÅ ARCHIVOS GENERADOS:")
    for r in resultados_finales:
        filename = f'{r["name"].replace(" ", "_").replace("+", "p")}_best.pth'
        print(f"   ‚úì {filename}")
    if len(resultados_finales) > 1:
        print(f"   ‚úì comparison_3models_seed{SEED}.png")
        print(f"   ‚úì final_comparison_3models.png")
    
    return resultados_finales

In [None]:
# EJECUCI√ìN PRINCIPAL
if __name__ == "__main__":
    try:
        print("\n" + "üöÄ"*40)
        print("SISTEMA DE COMPARACI√ìN: 3 MODELOS OPTIMIZADOS")
        print("üöÄ"*40 + "\n")
        
        resultados = main()
        
        print("\n" + "üéâ"*40)
        print("AN√ÅLISIS COMPLETADO")
        print("üéâ"*40 + "\n")
        
    except KeyboardInterrupt:
        print("\n\n‚ö†Ô∏è  Proceso interrumpido")
    except Exception as e:
        print(f"\n\n‚ùå Error: {str(e)}")
        import traceback
        traceback.print_exc()
    finally:
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
            print("\nüßπ Memoria GPU liberada")