In [None]:
import os
os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE'

import json
import numpy as np
import cv2
import random
from pathlib import Path
from PIL import Image
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from collections import defaultdict
import time
from datetime import datetime
import warnings
warnings.filterwarnings('ignore')

print("‚úì Librer√≠as importadas correctamente")
print(f"‚úì PyTorch version: {torch.__version__}")
print(f"‚úì CUDA disponible: {torch.cuda.is_available()}")

In [None]:
def set_seed(seed=1111):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    os.environ['PYTHONHASHSEED'] = str(seed)
    print(f"‚úì Seed fijado en {seed}")

set_seed(1111)

In [None]:
class ConvBlock(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 3, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_ch, out_ch, 3, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True)
        )
    
    def forward(self, x):
        return self.conv(x)

class ASPP(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.conv1 = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 1, bias=False),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True)
        )
        self.conv2 = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 3, padding=6, dilation=6, bias=False),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True)
        )
        self.conv3 = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 3, padding=12, dilation=12, bias=False),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True)
        )
        self.conv4 = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 3, padding=18, dilation=18, bias=False),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True)
        )
        self.pool = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Conv2d(in_ch, out_ch, 1, bias=False),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True)
        )
        self.project = nn.Sequential(
            nn.Conv2d(out_ch * 5, out_ch, 1, bias=False),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True)
        )
    
    def forward(self, x):
        size = x.shape[2:]
        feat1 = self.conv1(x)
        feat2 = self.conv2(x)
        feat3 = self.conv3(x)
        feat4 = self.conv4(x)
        feat5 = F.interpolate(self.pool(x), size=size, mode='bilinear', align_corners=True)
        out = torch.cat([feat1, feat2, feat3, feat4, feat5], dim=1)
        return self.project(out)

class PyramidPooling(nn.Module):
    def __init__(self, in_ch, out_ch, sizes=(1, 2, 3, 6)):
        super().__init__()
        self.stages = nn.ModuleList([
            nn.Sequential(
                nn.AdaptiveAvgPool2d(size),
                nn.Conv2d(in_ch, out_ch, 1, bias=False),
                nn.BatchNorm2d(out_ch),
                nn.ReLU(inplace=True)
            ) for size in sizes
        ])
        
        self.bottleneck = nn.Sequential(
            nn.Conv2d(in_ch + len(sizes) * out_ch, out_ch, 3, padding=1, bias=False),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True)
        )
    
    def forward(self, x):
        h, w = x.shape[2:]
        pyramids = [x]
        for stage in self.stages:
            pyramids.append(F.interpolate(stage(x), size=(h, w), mode='bilinear', align_corners=True))
        return self.bottleneck(torch.cat(pyramids, dim=1))

print("‚úì M√≥dulos base definidos")

In [None]:
class DeepLabV3Plus(nn.Module):
    def __init__(self, in_channels=3, num_classes=3):
        super().__init__()
        self.enc1 = ConvBlock(in_channels, 64)
        self.enc2 = ConvBlock(64, 128)
        self.enc3 = ConvBlock(128, 256)
        self.enc4 = ConvBlock(256, 512)
        
        self.pool = nn.MaxPool2d(2, 2)
        self.aspp = ASPP(512, 256)
        
        self.decoder_conv1 = nn.Sequential(
            nn.Conv2d(64, 48, 1, bias=False),
            nn.BatchNorm2d(48),
            nn.ReLU(inplace=True)
        )
        
        self.decoder_conv2 = nn.Sequential(
            nn.Conv2d(256 + 48, 256, 3, padding=1, bias=False),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 256, 3, padding=1, bias=False),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True)
        )
        
        self.out = nn.Conv2d(256, num_classes, 1)
    
    def forward(self, x):
        size = x.shape[2:]
        enc1 = self.enc1(x)
        enc2 = self.enc2(self.pool(enc1))
        enc3 = self.enc3(self.pool(enc2))
        enc4 = self.enc4(self.pool(enc3))
        
        aspp_out = self.aspp(enc4)
        low_level = self.decoder_conv1(enc1)
        aspp_up = F.interpolate(aspp_out, size=low_level.shape[2:], mode='bilinear', align_corners=True)
        dec = torch.cat([aspp_up, low_level], dim=1)
        dec = self.decoder_conv2(dec)
        out = F.interpolate(dec, size=size, mode='bilinear', align_corners=True)
        
        return self.out(out)

class LinkNet(nn.Module):
    def __init__(self, in_channels=3, num_classes=3):
        super().__init__()
        self.enc1 = ConvBlock(in_channels, 64)
        self.enc2 = ConvBlock(64, 128)
        self.enc3 = ConvBlock(128, 256)
        self.enc4 = ConvBlock(256, 512)
        
        self.pool = nn.MaxPool2d(2, 2)
        
        self.dec4 = nn.Sequential(
            nn.Conv2d(512, 256, 3, stride=1, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True)
        )
        self.dec3 = nn.Sequential(
            nn.Conv2d(256, 128, 3, stride=1, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True)
        )
        self.dec2 = nn.Sequential(
            nn.Conv2d(128, 64, 3, stride=1, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True)
        )
        self.dec1 = nn.Sequential(
            nn.Conv2d(64, 64, 3, stride=1, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True)
        )
        
        self.out = nn.Conv2d(64, num_classes, 1)
    
    def forward(self, x):
        size = x.shape[2:]
        
        enc1 = self.enc1(x)
        enc2 = self.enc2(self.pool(enc1))
        enc3 = self.enc3(self.pool(enc2))
        enc4 = self.enc4(self.pool(enc3))
        
        dec4 = self.dec4(enc4)
        dec4 = F.interpolate(dec4, size=enc3.shape[2:], mode='bilinear', align_corners=True)
        dec4 = dec4 + enc3
        
        dec3 = self.dec3(dec4)
        dec3 = F.interpolate(dec3, size=enc2.shape[2:], mode='bilinear', align_corners=True)
        dec3 = dec3 + enc2
        
        dec2 = self.dec2(dec3)
        dec2 = F.interpolate(dec2, size=enc1.shape[2:], mode='bilinear', align_corners=True)
        dec2 = dec2 + enc1
        
        dec1 = self.dec1(dec2)
        out = self.out(dec1)
        
        if out.shape[2:] != size:
            out = F.interpolate(out, size=size, mode='bilinear', align_corners=True)
        
        return out

print("‚úì DeepLabV3+ y LinkNet definidos")

In [None]:
class PSPNet(nn.Module):
    def __init__(self, in_channels=3, num_classes=3):
        super().__init__()
        
        self.enc1 = ConvBlock(in_channels, 64)
        self.enc2 = ConvBlock(64, 128)
        self.enc3 = ConvBlock(128, 256)
        self.enc4 = ConvBlock(256, 512)
        
        self.pool = nn.MaxPool2d(2, 2)
        self.psp = PyramidPooling(512, 128, sizes=(1, 2, 3, 6))
        
        self.up3 = nn.ConvTranspose2d(128, 64, 2, stride=2)
        self.dec3 = ConvBlock(320, 256)
        
        self.up2 = nn.ConvTranspose2d(256, 128, 2, stride=2)
        self.dec2 = ConvBlock(256, 128)
        
        self.up1 = nn.ConvTranspose2d(128, 64, 2, stride=2)
        self.dec1 = ConvBlock(128, 64)
        
        self.out = nn.Conv2d(64, num_classes, 1)
    
    def forward(self, x):
        size = x.shape[2:]
        
        e1 = self.enc1(x)
        e2 = self.enc2(self.pool(e1))
        e3 = self.enc3(self.pool(e2))
        e4 = self.enc4(self.pool(e3))
        
        psp_out = self.psp(e4)
        
        d3 = self.up3(psp_out)
        d3 = torch.cat([d3, e3], dim=1)
        d3 = self.dec3(d3)
        
        d2 = self.up2(d3)
        d2 = torch.cat([d2, e2], dim=1)
        d2 = self.dec2(d2)
        
        d1 = self.up1(d2)
        d1 = torch.cat([d1, e1], dim=1)
        d1 = self.dec1(d1)
        
        out = self.out(d1)
        if out.shape[2:] != size:
            out = F.interpolate(out, size=size, mode='bilinear', align_corners=True)
        
        return out

class UNetPlusPlus(nn.Module):
    def __init__(self, in_channels=3, num_classes=3):
        super().__init__()
        self.pool = nn.MaxPool2d(2, 2)
        
        self.conv0_0 = ConvBlock(in_channels, 64)
        self.conv1_0 = ConvBlock(64, 128)
        self.conv2_0 = ConvBlock(128, 256)
        self.conv3_0 = ConvBlock(256, 512)
        self.conv4_0 = ConvBlock(512, 1024)
        
        self.conv0_1 = ConvBlock(64 + 128, 64)
        self.conv1_1 = ConvBlock(128 + 256, 128)
        self.conv2_1 = ConvBlock(256 + 512, 256)
        self.conv3_1 = ConvBlock(512 + 1024, 512)
        
        self.conv0_2 = ConvBlock(64 * 2 + 128, 64)
        self.conv1_2 = ConvBlock(128 * 2 + 256, 128)
        self.conv2_2 = ConvBlock(256 * 2 + 512, 256)
        
        self.conv0_3 = ConvBlock(64 * 3 + 128, 64)
        self.conv1_3 = ConvBlock(128 * 3 + 256, 128)
        
        self.conv0_4 = ConvBlock(64 * 4 + 128, 64)
        
        self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        self.out = nn.Conv2d(64, num_classes, 1)
    
    def forward(self, x):
        x0_0 = self.conv0_0(x)
        x1_0 = self.conv1_0(self.pool(x0_0))
        x0_1 = self.conv0_1(torch.cat([x0_0, self.up(x1_0)], 1))
        
        x2_0 = self.conv2_0(self.pool(x1_0))
        x1_1 = self.conv1_1(torch.cat([x1_0, self.up(x2_0)], 1))
        x0_2 = self.conv0_2(torch.cat([x0_0, x0_1, self.up(x1_1)], 1))
        
        x3_0 = self.conv3_0(self.pool(x2_0))
        x2_1 = self.conv2_1(torch.cat([x2_0, self.up(x3_0)], 1))
        x1_2 = self.conv1_2(torch.cat([x1_0, x1_1, self.up(x2_1)], 1))
        x0_3 = self.conv0_3(torch.cat([x0_0, x0_1, x0_2, self.up(x1_2)], 1))
        
        x4_0 = self.conv4_0(self.pool(x3_0))
        x3_1 = self.conv3_1(torch.cat([x3_0, self.up(x4_0)], 1))
        x2_2 = self.conv2_2(torch.cat([x2_0, x2_1, self.up(x3_1)], 1))
        x1_3 = self.conv1_3(torch.cat([x1_0, x1_1, x1_2, self.up(x2_2)], 1))
        x0_4 = self.conv0_4(torch.cat([x0_0, x0_1, x0_2, x0_3, self.up(x1_3)], 1))
        
        return self.out(x0_4)

print("‚úì PSPNet y U-Net++ definidos")

In [None]:
class SEBlock(nn.Module):
    """Squeeze-and-Excitation Block"""
    def __init__(self, channels, reduction=4):
        super().__init__()
        self.pool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Sequential(
            nn.Linear(channels, channels // reduction, bias=False),
            nn.ReLU(inplace=True),
            nn.Linear(channels // reduction, channels, bias=False),
            nn.Hardsigmoid(inplace=True)
        )
    
    def forward(self, x):
        b, c, _, _ = x.shape
        y = self.pool(x).view(b, c)
        y = self.fc(y).view(b, c, 1, 1)
        return x * y

class MobileNetV3Block(nn.Module):
    """MobileNetV3 Inverted Residual Block"""
    def __init__(self, in_channels, out_channels, kernel_size, stride, exp_size, use_se=False, use_hs=False):
        super().__init__()
        self.stride = stride
        self.use_residual = (stride == 1 and in_channels == out_channels)
        
        activation = nn.Hardswish if use_hs else nn.ReLU
        
        layers = []
        if exp_size != in_channels:
            layers.extend([
                nn.Conv2d(in_channels, exp_size, 1, bias=False),
                nn.BatchNorm2d(exp_size),
                activation(inplace=True)
            ])
        
        layers.extend([
            nn.Conv2d(exp_size, exp_size, kernel_size, stride=stride, 
                     padding=kernel_size//2, groups=exp_size, bias=False),
            nn.BatchNorm2d(exp_size),
            activation(inplace=True)
        ])
        
        if use_se:
            layers.append(SEBlock(exp_size))
        
        layers.extend([
            nn.Conv2d(exp_size, out_channels, 1, bias=False),
            nn.BatchNorm2d(out_channels)
        ])
        
        self.conv = nn.Sequential(*layers)
    
    def forward(self, x):
        if self.use_residual:
            return x + self.conv(x)
        return self.conv(x)

class MobileNetV3SmallEncoder(nn.Module):
    """MobileNetV3-Small como Encoder"""
    def __init__(self, in_channels=3):
        super().__init__()
        
        # Stem
        self.stem = nn.Sequential(
            nn.Conv2d(in_channels, 16, 3, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(16),
            nn.Hardswish(inplace=True)
        )
        
        # Stage 1
        self.stage1 = nn.Sequential(
            MobileNetV3Block(16, 16, 3, 2, 16, use_se=True, use_hs=False)
        )
        
        # Stage 2
        self.stage2 = nn.Sequential(
            MobileNetV3Block(16, 24, 3, 2, 72, use_se=False, use_hs=False),
            MobileNetV3Block(24, 24, 3, 1, 88, use_se=False, use_hs=False)
        )
        
        # Stage 3
        self.stage3 = nn.Sequential(
            MobileNetV3Block(24, 40, 5, 2, 96, use_se=True, use_hs=True),
            MobileNetV3Block(40, 40, 5, 1, 240, use_se=True, use_hs=True),
            MobileNetV3Block(40, 40, 5, 1, 240, use_se=True, use_hs=True)
        )
        
        # Stage 4
        self.stage4 = nn.Sequential(
            MobileNetV3Block(40, 48, 5, 1, 120, use_se=True, use_hs=True),
            MobileNetV3Block(48, 48, 5, 1, 144, use_se=True, use_hs=True)
        )
        
        # Stage 5
        self.stage5 = nn.Sequential(
            MobileNetV3Block(48, 96, 5, 2, 288, use_se=True, use_hs=True),
            MobileNetV3Block(96, 96, 5, 1, 576, use_se=True, use_hs=True)
        )
    
    def forward(self, x):
        x0 = self.stem(x)      # 1/2
        x1 = self.stage1(x0)   # 1/4
        x2 = self.stage2(x1)   # 1/8
        x3 = self.stage3(x2)   # 1/16
        x4 = self.stage4(x3)   # 1/16
        x5 = self.stage5(x4)   # 1/32
        
        return [x1, x2, x3, x5]  # Skip connections

class SimpleUNetDecoder(nn.Module):
    """Decoder U-Net Simplificado"""
    def __init__(self, encoder_channels=[16, 24, 40, 96]):
        super().__init__()
        
        # Bloques de upsampling
        self.up4 = nn.ConvTranspose2d(encoder_channels[3], 64, 2, stride=2)
        self.dec4 = nn.Sequential(
            nn.Conv2d(64 + encoder_channels[2], 64, 3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True)
        )
        
        self.up3 = nn.ConvTranspose2d(64, 32, 2, stride=2)
        self.dec3 = nn.Sequential(
            nn.Conv2d(32 + encoder_channels[1], 32, 3, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True)
        )
        
        self.up2 = nn.ConvTranspose2d(32, 16, 2, stride=2)
        self.dec2 = nn.Sequential(
            nn.Conv2d(16 + encoder_channels[0], 16, 3, padding=1),
            nn.BatchNorm2d(16),
            nn.ReLU(inplace=True)
        )
        
        self.up1 = nn.ConvTranspose2d(16, 16, 2, stride=2)
        self.final = nn.Sequential(
            nn.Conv2d(16, 16, 3, padding=1),
            nn.BatchNorm2d(16),
            nn.ReLU(inplace=True)
        )
    
    def forward(self, features):
        x1, x2, x3, x5 = features
        
        d4 = self.up4(x5)
        d4 = torch.cat([d4, x3], dim=1)
        d4 = self.dec4(d4)
        
        d3 = self.up3(d4)
        d3 = torch.cat([d3, x2], dim=1)
        d3 = self.dec3(d3)
        
        d2 = self.up2(d3)
        d2 = torch.cat([d2, x1], dim=1)
        d2 = self.dec2(d2)
        
        d1 = self.up1(d2)
        out = self.final(d1)
        
        return out

class ConvClassifier(nn.Module):
    """Clasificador Convolucional"""
    def __init__(self, in_channels, num_classes):
        super().__init__()
        self.classifier = nn.Sequential(
            nn.Conv2d(in_channels, in_channels, 3, padding=1),
            nn.BatchNorm2d(in_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(in_channels, num_classes, 1)
        )
    
    def forward(self, x):
        return self.classifier(x)

class MobileNetV3UNet(nn.Module):
    """MobileNetV3-Small + U-Net Simple + Clasificador CNN"""
    def __init__(self, in_channels=3, num_classes=3):
        super().__init__()
        self.encoder = MobileNetV3SmallEncoder(in_channels)
        self.decoder = SimpleUNetDecoder(encoder_channels=[16, 24, 40, 96])
        self.classifier = ConvClassifier(16, num_classes)
    
    def forward(self, x):
        input_size = x.shape[2:]
        
        # Encoder extrae caracter√≠sticas
        features = self.encoder(x)
        
        # Decoder localiza v√©rtebras
        decoded = self.decoder(features)
        
        # Clasificador predice clases
        out = self.classifier(decoded)
        
        # Ajustar a tama√±o original si es necesario
        if out.shape[2:] != input_size:
            out = F.interpolate(out, size=input_size, mode='bilinear', align_corners=True)
        
        return out

print("‚úì MobileNetV3-UNet definido")

In [None]:
class MobileViTBlock(nn.Module):
    """MobileViT Block: Conv + Transformer + Conv"""
    def __init__(self, in_channels, transformer_dim, num_heads=4, num_layers=2):
        super().__init__()
        
        # Local representation (conv)
        self.local_rep = nn.Sequential(
            nn.Conv2d(in_channels, in_channels, 3, padding=1, groups=in_channels),
            nn.BatchNorm2d(in_channels),
            nn.SiLU(inplace=True),
            nn.Conv2d(in_channels, transformer_dim, 1),
            nn.BatchNorm2d(transformer_dim),
            nn.SiLU(inplace=True)
        )
        
        # Global representation (transformer)
        self.transformer = nn.ModuleList([
            nn.TransformerEncoderLayer(
                d_model=transformer_dim,
                nhead=num_heads,
                dim_feedforward=transformer_dim * 2,
                dropout=0.0,
                activation='gelu',
                batch_first=True
            ) for _ in range(num_layers)
        ])
        
        # Fusion
        self.fusion = nn.Sequential(
            nn.Conv2d(transformer_dim, in_channels, 1),
            nn.BatchNorm2d(in_channels),
            nn.SiLU(inplace=True)
        )
    
    def forward(self, x):
        # Local representation
        local_feat = self.local_rep(x)
        B, C, H, W = local_feat.shape
        
        # Reshape para transformer
        transformer_input = local_feat.flatten(2).transpose(1, 2)  # B, H*W, C
        
        # Global representation
        global_feat = transformer_input
        for layer in self.transformer:
            global_feat = layer(global_feat)
        
        # Reshape de vuelta
        global_feat = global_feat.transpose(1, 2).reshape(B, C, H, W)
        
        # Fusion
        out = self.fusion(global_feat)
        
        return x + out  # Residual connection

class MobileViTXSmallEncoder(nn.Module):
    """MobileViT-XSmall como Encoder"""
    def __init__(self, in_channels=3):
        super().__init__()
        
        # Stem
        self.stem = nn.Sequential(
            nn.Conv2d(in_channels, 16, 3, stride=2, padding=1),
            nn.BatchNorm2d(16),
            nn.SiLU(inplace=True)
        )
        
        # Stage 1: MV2 blocks (1/4 resolution)
        self.stage1 = nn.Sequential(
            MobileNetV3Block(16, 16, 3, 1, 32, use_se=False, use_hs=False),
            MobileNetV3Block(16, 24, 3, 2, 48, use_se=False, use_hs=False)
        )
        
        # Stage 2: MV2 + MobileViT (1/8 resolution)
        self.stage2_conv = MobileNetV3Block(24, 32, 3, 2, 64, use_se=False, use_hs=False)
        self.stage2_vit = MobileViTBlock(32, transformer_dim=64, num_heads=4, num_layers=2)
        
        # Stage 3: MV2 blocks (1/16 resolution)
        self.stage3 = nn.Sequential(
            MobileNetV3Block(32, 48, 3, 2, 96, use_se=True, use_hs=True),
            MobileNetV3Block(48, 48, 3, 1, 128, use_se=True, use_hs=True)
        )
        
        # Stage 4: MV2 + MobileViT (1/32 resolution)
        self.stage4_conv = MobileNetV3Block(48, 64, 3, 2, 160, use_se=True, use_hs=True)
        self.stage4_vit = MobileViTBlock(64, transformer_dim=80, num_heads=4, num_layers=2)
        
        # Stage 5: Final conv (1/32 resolution)
        self.stage5 = nn.Sequential(
            MobileNetV3Block(64, 80, 3, 1, 256, use_se=True, use_hs=True)
        )
    
    def forward(self, x):
        # x: 256x256
        x0 = self.stem(x)           # 128x128
        
        x1 = self.stage1(x0)        # 64x64
        
        x2 = self.stage2_conv(x1)   # 32x32
        x2 = self.stage2_vit(x2)    # 32x32 con atenci√≥n
        
        x3 = self.stage3(x2)        # 16x16
        
        x4 = self.stage4_conv(x3)   # 8x8
        x4 = self.stage4_vit(x4)    # 8x8 con atenci√≥n
        
        x5 = self.stage5(x4)        # 8x8
        
        return [x1, x2, x3, x5]  # Skip connections: 64x64, 32x32, 16x16, 8x8

class SimplifiedDecoder(nn.Module):
    """Decoder Simplificado para MobileViT"""
    def __init__(self, encoder_channels=[24, 32, 48, 80]):
        super().__init__()
        
        # Upsampling progresivo con interpolaci√≥n para manejar mejor las dimensiones
        # De 8x8 a 16x16
        self.up4 = nn.Sequential(
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
            nn.Conv2d(encoder_channels[3], 48, 3, padding=1),
            nn.BatchNorm2d(48),
            nn.ReLU(inplace=True)
        )
        # Despu√©s de concatenar: 48 + 48 = 96 canales
        self.dec4 = nn.Sequential(
            nn.Conv2d(96, 48, 3, padding=1),
            nn.BatchNorm2d(48),
            nn.ReLU(inplace=True)
        )
        
        # De 16x16 a 32x32
        self.up3 = nn.Sequential(
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
            nn.Conv2d(48, 32, 3, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True)
        )
        # Despu√©s de concatenar: 32 + 32 = 64 canales
        self.dec3 = nn.Sequential(
            nn.Conv2d(64, 32, 3, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True)
        )
        
        # De 32x32 a 64x64
        self.up2 = nn.Sequential(
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
            nn.Conv2d(32, 24, 3, padding=1),
            nn.BatchNorm2d(24),
            nn.ReLU(inplace=True)
        )
        # Despu√©s de concatenar: 24 + 24 = 48 canales
        self.dec2 = nn.Sequential(
            nn.Conv2d(48, 24, 3, padding=1),
            nn.BatchNorm2d(24),
            nn.ReLU(inplace=True)
        )
        
        # De 64x64 a 128x128
        self.up1 = nn.Sequential(
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
            nn.Conv2d(24, 16, 3, padding=1),
            nn.BatchNorm2d(16),
            nn.ReLU(inplace=True)
        )
        
        # De 128x128 a 256x256
        self.final_up = nn.Sequential(
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
            nn.Conv2d(16, 16, 3, padding=1),
            nn.BatchNorm2d(16),
            nn.ReLU(inplace=True)
        )
    
    def forward(self, features):
        x1, x2, x3, x5 = features  # 64x64 (24ch), 32x32 (32ch), 16x16 (48ch), 8x8 (80ch)
        
        # 8x8 -> 16x16
        d4 = self.up4(x5)  # 80 -> 48 canales
        # Asegurar dimensiones compatibles
        if d4.shape[2:] != x3.shape[2:]:
            d4 = F.interpolate(d4, size=x3.shape[2:], mode='bilinear', align_corners=True)
        d4 = torch.cat([d4, x3], dim=1)  # 48 + 48 = 96
        d4 = self.dec4(d4)  # 96 -> 48
        
        # 16x16 -> 32x32
        d3 = self.up3(d4)  # 48 -> 32 canales
        if d3.shape[2:] != x2.shape[2:]:
            d3 = F.interpolate(d3, size=x2.shape[2:], mode='bilinear', align_corners=True)
        d3 = torch.cat([d3, x2], dim=1)  # 32 + 32 = 64
        d3 = self.dec3(d3)  # 64 -> 32
        
        # 32x32 -> 64x64
        d2 = self.up2(d3)  # 32 -> 24 canales
        if d2.shape[2:] != x1.shape[2:]:
            d2 = F.interpolate(d2, size=x1.shape[2:], mode='bilinear', align_corners=True)
        d2 = torch.cat([d2, x1], dim=1)  # 24 + 24 = 48
        d2 = self.dec2(d2)  # 48 -> 24
        
        # 64x64 -> 128x128
        d1 = self.up1(d2)  # 24 -> 16
        
        # 128x128 -> 256x256
        out = self.final_up(d1)  # 16 -> 16
        
        return out

class MobileViTXSmall(nn.Module):
    """MobileViT-XSmall + Decoder Simplificado + Clasificador CNN"""
    def __init__(self, in_channels=3, num_classes=3):
        super().__init__()
        self.encoder = MobileViTXSmallEncoder(in_channels)
        self.decoder = SimplifiedDecoder(encoder_channels=[24, 24, 48, 80])
        self.classifier = ConvClassifier(16, num_classes)
    
    def forward(self, x):
        input_size = x.shape[2:]
        
        # Encoder: convoluciones capturan detalles + atenci√≥n captura contexto
        features = self.encoder(x)
        
        # Decoder localiza v√©rtebras
        decoded = self.decoder(features)
        
        # Clasificador predice clases
        out = self.classifier(decoded)
        
        # Ajustar a tama√±o original si es necesario
        if out.shape[2:] != input_size:
            out = F.interpolate(out, size=input_size, mode='bilinear', align_corners=True)
        
        return out

print("‚úì MobileViT-XSmall definido")

In [None]:
class VertebrasDataset(Dataset):
    def __init__(self, base_path, json_filename='coco_anotaciones_actualizadas_23sep.json',
                 target_size=(256, 256)):
        
        self.base_path = Path(base_path)
        self.json_path = self.base_path / "Anotaciones v√©rtebras" / json_filename
        self.radiografias_path = self.base_path / "Radiograf√≠as"
        self.target_size = target_size
        
        with open(self.json_path, 'r', encoding='utf-8') as f:
            self.coco_data = json.load(f)
        
        self.class_names = ['Background', 'T1', 'V']
        
        self.name_to_class = {
            'F': 0,
            'background': 0,
            'T1': 1,
            'V': 2
        }
        
        self.samples = self._preparar_samples()
        
        self.transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(
                mean=[0.399637, 0.400040, 0.392532],
                std=[0.212403, 0.211738, 0.207753]
            )
        ])
        
        print(f"‚úì Clases definidas: {self.class_names}")
        print(f"‚úì Mapeo de nombres: {self.name_to_class}")
    
    def _preparar_samples(self):
        samples = []
        anns_por_imagen = defaultdict(list)
        
        for ann in self.coco_data['annotations']:
            img_id = ann.get('image_id')
            if img_id is not None:
                anns_por_imagen[img_id].append(ann)
        
        for img_info in self.coco_data['images']:
            img_id = img_info.get('id')
            if img_id not in anns_por_imagen:
                continue
            
            file_name = img_info.get('file_name') or img_info.get('toras_path', '') or ''
            if file_name.startswith('/'):
                file_name = file_name[1:]
            
            img_path = self.radiografias_path / file_name
            if not img_path.exists():
                img_path = self.radiografias_path / Path(file_name).name
            
            if img_path.exists():
                samples.append({
                    'image_id': img_id,
                    'image_path': str(img_path),
                    'annotations': anns_por_imagen[img_id]
                })
        
        return samples
    
    def _parsear_segmentacion(self, segmentation):
        poligonos = []
        if not segmentation or not isinstance(segmentation, list):
            return poligonos
        
        for seg_item in segmentation:
            if isinstance(seg_item, list) and seg_item:
                if isinstance(seg_item[0], (int, float)) and len(seg_item) >= 6:
                    coords = np.array(seg_item).reshape(-1, 2)
                    if coords.shape[0] >= 3:
                        poligonos.append(coords.astype(np.int32))
        return poligonos
    
    def _crear_mascara(self, annotations, orig_height, orig_width):
        mask = np.zeros((orig_height, orig_width), dtype=np.uint8)
        
        for ann in annotations:
            name = ann.get('name', '').strip()
            
            if name not in self.name_to_class:
                continue
            
            class_id = self.name_to_class[name]
            poligonos = self._parsear_segmentacion(ann.get('segmentation'))
            
            for poly in poligonos:
                cv2.fillPoly(mask, [poly], class_id)
            
            if not poligonos and 'bbox' in ann:
                bbox = ann['bbox']
                x, y, w, h = [int(v) for v in bbox]
                if x >= 0 and y >= 0 and w > 0 and h > 0:
                    x2, y2 = min(x+w, orig_width), min(y+h, orig_height)
                    mask[y:y2, x:x2] = class_id
        
        return mask
    
    def __len__(self):
        return len(self.samples)
    
    def __getitem__(self, idx):
        sample = self.samples[idx]
        
        image = Image.open(sample['image_path']).convert('RGB')
        orig_width, orig_height = image.size
        mask = self._crear_mascara(sample['annotations'], orig_height, orig_width)
        
        image = image.resize(self.target_size, Image.BILINEAR)
        mask = cv2.resize(mask, self.target_size, interpolation=cv2.INTER_NEAREST)
        
        image = self.transform(image)
        mask = torch.from_numpy(mask).long()
        
        return image, mask

print("‚úì Dataset definido")

In [None]:
class DiceLoss(nn.Module):
    def __init__(self, smooth=1.0):
        super().__init__()
        self.smooth = smooth
    
    def forward(self, pred, target):
        pred = F.softmax(pred, dim=1)
        target_one_hot = F.one_hot(target, num_classes=pred.shape[1]).permute(0, 3, 1, 2).float()
        
        intersection = (pred * target_one_hot).sum(dim=(2, 3))
        union = pred.sum(dim=(2, 3)) + target_one_hot.sum(dim=(2, 3))
        
        dice = (2.0 * intersection + self.smooth) / (union + self.smooth)
        return 1.0 - dice.mean()

class CombinedLoss(nn.Module):
    def __init__(self, ce_weight=0.5, dice_weight=0.5):
        super().__init__()
        self.ce_weight = ce_weight
        self.dice_weight = dice_weight
        self.ce_loss = nn.CrossEntropyLoss()
        self.dice_loss = DiceLoss()
    
    def forward(self, pred, target):
        ce = self.ce_loss(pred, target)
        dice = self.dice_loss(pred, target)
        return self.ce_weight * ce + self.dice_weight * dice

def calcular_metricas_detalladas(pred, target, num_classes, class_names):
    """Calcula m√©tricas detalladas por clase"""
    pred_classes = torch.argmax(pred, dim=1)
    metricas = {}
    
    for c in range(num_classes):
        pred_c = (pred_classes == c)
        target_c = (target == c)
        intersection = (pred_c & target_c).sum().float()
        union = (pred_c | target_c).sum().float()
        
        if union > 0:
            iou = (intersection / union).item()
            dice = (2 * intersection / (pred_c.sum() + target_c.sum())).item()
        else:
            iou = 0.0
            dice = 0.0
        
        class_name = class_names[c] if c < len(class_names) else f'class_{c}'
        metricas[f'iou_{class_name}'] = iou
        metricas[f'dice_{class_name}'] = dice
    
    ious_sin_bg = [metricas[f'iou_{class_names[c]}'] for c in range(1, num_classes)]
    dices_sin_bg = [metricas[f'dice_{class_names[c]}'] for c in range(1, num_classes)]
    
    metricas['mean_iou'] = np.mean(ious_sin_bg) if ious_sin_bg else 0.0
    metricas['mean_dice'] = np.mean(dices_sin_bg) if dices_sin_bg else 0.0
    metricas['accuracy'] = (pred_classes == target).float().mean().item()
    
    return metricas

print("‚úì Loss functions y m√©tricas definidas")

In [None]:
def train_epoch(model, dataloader, criterion, optimizer, device, num_classes, class_names):
    model.train()
    total_loss = 0
    total_metrics = defaultdict(float)
    
    for images, masks in dataloader:
        if images.size(0) == 1:
            continue
            
        images = images.to(device)
        masks = masks.to(device)
        
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, masks)
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
        metricas = calcular_metricas_detalladas(outputs, masks, num_classes, class_names)
        for k, v in metricas.items():
            total_metrics[k] += v
    
    avg_loss = total_loss / len(dataloader)
    avg_metrics = {k: v / len(dataloader) for k, v in total_metrics.items()}
    return avg_loss, avg_metrics

def validate(model, dataloader, criterion, device, num_classes, class_names):
    model.eval()
    total_loss = 0
    total_metrics = defaultdict(float)
    
    with torch.no_grad():
        for images, masks in dataloader:
            images = images.to(device)
            masks = masks.to(device)
            
            outputs = model(images)
            loss = criterion(outputs, masks)
            
            total_loss += loss.item()
            metricas = calcular_metricas_detalladas(outputs, masks, num_classes, class_names)
            for k, v in metricas.items():
                total_metrics[k] += v
    
    avg_loss = total_loss / len(dataloader)
    avg_metrics = {k: v / len(dataloader) for k, v in total_metrics.items()}
    return avg_loss, avg_metrics

print("‚úì Funciones de entrenamiento definidas")

In [None]:
class ModelTrainer:
    def __init__(self, model_name, model, train_loader, val_loader, device, 
                 num_classes, class_names, lr=0.0001, max_epochs=50):
        self.model_name = model_name
        self.model = model.to(device)
        self.train_loader = train_loader
        self.val_loader = val_loader
        self.device = device
        self.num_classes = num_classes
        self.class_names = class_names
        self.max_epochs = max_epochs
        
        self.optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-5)
        self.criterion = CombinedLoss(ce_weight=0.5, dice_weight=0.5)
        self.scheduler = optim.lr_scheduler.ReduceLROnPlateau(
            self.optimizer, mode='max', factor=0.5, patience=7
        )
        
        self.history = defaultdict(list)
        self.best_iou = 0
        self.best_dice = 0
        self.best_epoch = 0
        self.training_time = 0
        self.best_metrics_per_class = {}
    
    def train(self):
        print(f"\n{'='*70}")
        print(f"üöÄ Entrenando: {self.model_name}")
        print(f"{'='*70}")
        
        start_time = time.time()
        patience_counter = 0
        patience = 15
        
        for epoch in range(self.max_epochs):
            print(f"\nüìä Epoca {epoch+1}/{self.max_epochs}")
            print("-"*70)
            
            train_loss, train_metrics = train_epoch(
                self.model, self.train_loader, self.criterion, 
                self.optimizer, self.device, self.num_classes, self.class_names
            )
            
            val_loss, val_metrics = validate(
                self.model, self.val_loader, self.criterion,
                self.device, self.num_classes, self.class_names
            )
            
            self.history['train_loss'].append(train_loss)
            self.history['val_loss'].append(val_loss)
            for key in val_metrics:
                self.history[f'train_{key}'].append(train_metrics[key])
                self.history[f'val_{key}'].append(val_metrics[key])
            
            print(f"Train | Loss: {train_loss:.4f} | IoU: {train_metrics['mean_iou']:.4f} | "
                  f"Dice: {train_metrics['mean_dice']:.4f}")
            print(f"Val   | Loss: {val_loss:.4f} | IoU: {val_metrics['mean_iou']:.4f} | "
                  f"Dice: {val_metrics['mean_dice']:.4f}")
            
            print("\n  M√©tricas por clase (Validaci√≥n):")
            for c in range(1, self.num_classes):
                class_name = self.class_names[c]
                iou = val_metrics[f'iou_{class_name}']
                dice = val_metrics[f'dice_{class_name}']
                print(f"    {class_name:>10}: IoU={iou:.4f} | Dice={dice:.4f}")
            
            if val_metrics['mean_iou'] > self.best_iou:
                self.best_iou = val_metrics['mean_iou']
                self.best_dice = val_metrics['mean_dice']
                self.best_epoch = epoch + 1
                patience_counter = 0
                
                self.best_metrics_per_class = {
                    class_name: {
                        'iou': val_metrics[f'iou_{class_name}'],
                        'dice': val_metrics[f'dice_{class_name}']
                    }
                    for class_name in self.class_names[1:]
                }
                
                torch.save({
                    'epoch': epoch,
                    'model_state_dict': self.model.state_dict(),
                    'optimizer_state_dict': self.optimizer.state_dict(),
                    'best_iou': self.best_iou,
                    'best_dice': self.best_dice,
                    'best_metrics_per_class': self.best_metrics_per_class,
                    'history': dict(self.history)
                }, f'{self.model_name.lower().replace(" ", "_").replace("+", "plus").replace("-", "_")}_best.pth')
                print(f"‚úì üèÜ Mejor modelo guardado! IoU: {self.best_iou:.4f} | Dice: {self.best_dice:.4f}")
            else:
                patience_counter += 1
            
            self.scheduler.step(val_metrics['mean_iou'])
            
            if patience_counter >= patience:
                print(f"\n‚èπÔ∏è  Early stopping en epoca {epoch+1}")
                break
        
        self.training_time = time.time() - start_time
        print(f"\n‚è±Ô∏è  Tiempo: {self.training_time/60:.2f} min")
        print(f"üèÜ Mejor IoU: {self.best_iou:.4f} | Dice: {self.best_dice:.4f} (Epoca {self.best_epoch})")
        
        return self.history, self.best_iou
    
    def get_results(self):
        total_params = sum(p.numel() for p in self.model.parameters())
        return {
            'model_name': self.model_name,
            'best_iou': self.best_iou,
            'best_dice': self.best_dice,
            'best_epoch': self.best_epoch,
            'training_time': self.training_time,
            'total_params': total_params,
            'history': dict(self.history),
            'best_metrics_per_class': self.best_metrics_per_class
        }

print("‚úì Model Trainer definido")

In [None]:
def visualizar_predicciones_modelo(model, dataset, device, model_name, class_names, num_samples=4, seed=1111):
    """Visualiza predicciones de un modelo espec√≠fico comparadas con ground truth"""
    model.eval()
    
    colors = {
        0: [0, 0, 0],
        1: [0, 255, 0],
        2: [0, 0, 255]
    }
    
    fig, axes = plt.subplots(num_samples, 4, figsize=(18, num_samples * 4.5))
    if num_samples == 1:
        axes = axes.reshape(1, -1)
    
    np.random.seed(seed)
    indices = np.random.choice(len(dataset), min(num_samples, len(dataset)), replace=False)
    
    mean = torch.tensor([0.399637, 0.400040, 0.392532]).view(3, 1, 1)
    std = torch.tensor([0.212403, 0.211738, 0.207753]).view(3, 1, 1)
    
    with torch.no_grad():
        for i, idx in enumerate(indices):
            image, mask = dataset[idx]
            
            image_input = image.unsqueeze(0).to(device)
            pred = model(image_input).cpu().squeeze(0)
            pred_classes = torch.argmax(pred, dim=0).numpy()
            
            image_denorm = image * std + mean
            img_np = image_denorm.permute(1, 2, 0).numpy()
            img_np = np.clip(img_np, 0, 1)
            
            mask_np = mask.numpy()
            
            pred_colored = np.zeros((*pred_classes.shape, 3), dtype=np.uint8)
            mask_colored = np.zeros((*mask_np.shape, 3), dtype=np.uint8)
            
            for class_id, color in colors.items():
                pred_colored[pred_classes == class_id] = color
                mask_colored[mask_np == class_id] = color
            
            metricas = calcular_metricas_detalladas(
                pred.unsqueeze(0), 
                mask.unsqueeze(0), 
                len(colors), 
                class_names
            )
            
            axes[i, 0].imshow(img_np)
            axes[i, 0].set_title('üñºÔ∏è Imagen Original', fontsize=12, fontweight='bold')
            axes[i, 0].axis('off')
            
            axes[i, 1].imshow(mask_colored)
            axes[i, 1].set_title('‚úì Ground Truth', fontsize=12, fontweight='bold')
            axes[i, 1].axis('off')
            
            axes[i, 2].imshow(pred_colored)
            axes[i, 2].set_title(f'ü§ñ Predicci√≥n\nIoU: {metricas["mean_iou"]:.3f} | Dice: {metricas["mean_dice"]:.3f}', 
                                fontsize=12, fontweight='bold')
            axes[i, 2].axis('off')
            
            axes[i, 3].imshow(img_np)
            axes[i, 3].imshow(pred_colored / 255.0, alpha=0.6)
            axes[i, 3].set_title('üé® Overlay', fontsize=12, fontweight='bold')
            axes[i, 3].axis('off')
    
    legend_elements = [
        plt.Rectangle((0, 0), 1, 1, fc=np.array(colors[i])/255.0, 
                     edgecolor='black', linewidth=2,
                     label=class_names[i]) 
        for i in range(1, len(class_names))
    ]
    fig.legend(handles=legend_elements, loc='upper center', ncol=len(class_names)-1, 
              fontsize=12, frameon=True, fancybox=True, shadow=True)
    
    plt.suptitle(f'üîç Predicciones de {model_name}', fontsize=16, fontweight='bold', y=0.995)
    plt.tight_layout(rect=[0, 0, 1, 0.98])
    
    filename = f'pred_{model_name.lower().replace(" ", "_").replace("+", "plus").replace("-", "_")}_seed{seed}.png'
    plt.savefig(filename, dpi=200, bbox_inches='tight', facecolor='white')
    plt.close()
    print(f"   üì∏ Predicciones guardadas en: {filename}")

print("‚úì Funciones de visualizaci√≥n definidas")

In [None]:
def visualizar_comparacion_con_clases(resultados, class_names, output_file='comparacion_multimodal_actualizada.png'):
    """Visualizaci√≥n completa con todos los gr√°ficos"""
    fig = plt.figure(figsize=(24, 18))
    gs = fig.add_gridspec(5, 4, hspace=0.4, wspace=0.35)
    
    # 1. IoU promedio
    ax1 = fig.add_subplot(gs[0, :2])
    for result in resultados:
        ax1.plot(result['history']['val_mean_iou'], label=result['model_name'], linewidth=2.5, alpha=0.8)
    ax1.set_xlabel('Epoca', fontsize=13, fontweight='bold')
    ax1.set_ylabel('Validation IoU (Promedio)', fontsize=13, fontweight='bold')
    ax1.set_title('üìà Comparaci√≥n IoU Promedio', fontsize=15, fontweight='bold')
    ax1.legend(fontsize=10, loc='best')
    ax1.grid(True, alpha=0.3, linestyle='--')
    ax1.set_ylim([0, 1])
    
    # 2. Dice promedio
    ax2 = fig.add_subplot(gs[0, 2:])
    for result in resultados:
        ax2.plot(result['history']['val_mean_dice'], label=result['model_name'], linewidth=2.5, alpha=0.8)
    ax2.set_xlabel('Epoca', fontsize=13, fontweight='bold')
    ax2.set_ylabel('Validation Dice (Promedio)', fontsize=13, fontweight='bold')
    ax2.set_title('üìà Comparaci√≥n Dice Score Promedio', fontsize=15, fontweight='bold')
    ax2.legend(fontsize=10, loc='best')
    ax2.grid(True, alpha=0.3, linestyle='--')
    ax2.set_ylim([0, 1])
    
    # 3-4. IoU por clase
    vertebra_classes = class_names[1:]
    colors_per_model = plt.cm.tab10(np.linspace(0, 1, len(resultados)))
    
    for idx, vertebra in enumerate(vertebra_classes):
        ax = fig.add_subplot(gs[1, idx])
        for i, result in enumerate(resultados):
            key = f'val_iou_{vertebra}'
            if key in result['history']:
                ax.plot(result['history'][key], label=result['model_name'], 
                       linewidth=2.5, alpha=0.8, color=colors_per_model[i])
        ax.set_xlabel('Epoca', fontsize=12, fontweight='bold')
        ax.set_ylabel(f'IoU {vertebra}', fontsize=12, fontweight='bold')
        ax.set_title(f'üìä IoU - V√©rtebra {vertebra}', fontsize=13, fontweight='bold')
        ax.legend(fontsize=8, loc='best')
        ax.grid(True, alpha=0.3, linestyle='--')
        ax.set_ylim([0, 1])
    
    # Loss
    ax_loss = fig.add_subplot(gs[1, 2])
    for result in resultados:
        ax_loss.plot(result['history']['val_loss'], label=result['model_name'], linewidth=2.5, alpha=0.8)
    ax_loss.set_xlabel('Epoca', fontsize=12, fontweight='bold')
    ax_loss.set_ylabel('Loss', fontsize=12, fontweight='bold')
    ax_loss.set_title('üìâ Validation Loss', fontsize=13, fontweight='bold')
    ax_loss.legend(fontsize=8, loc='best')
    ax_loss.grid(True, alpha=0.3, linestyle='--')
    
    # Accuracy
    ax_acc = fig.add_subplot(gs[1, 3])
    for result in resultados:
        ax_acc.plot(result['history']['val_accuracy'], label=result['model_name'], linewidth=2.5, alpha=0.8)
    ax_acc.set_xlabel('Epoca', fontsize=12, fontweight='bold')
    ax_acc.set_ylabel('Accuracy', fontsize=12, fontweight='bold')
    ax_acc.set_title('üéØ Validation Accuracy', fontsize=13, fontweight='bold')
    ax_acc.legend(fontsize=8, loc='best')
    ax_acc.grid(True, alpha=0.3, linestyle='--')
    ax_acc.set_ylim([0, 1])
    
    # Ranking por IoU
    ax7 = fig.add_subplot(gs[2, 0])
    nombres = [r['model_name'] for r in resultados]
    ious = [r['best_iou'] for r in resultados]
    colors = plt.cm.RdYlGn(np.linspace(0.3, 0.9, len(nombres)))
    bars = ax7.barh(nombres, ious, color=colors, edgecolor='black', linewidth=1.5)
    ax7.set_xlabel('Best IoU (Promedio)', fontsize=11, fontweight='bold')
    ax7.set_title('üèÜ Ranking IoU General', fontsize=12, fontweight='bold')
    ax7.grid(True, alpha=0.3, axis='x')
    ax7.set_xlim([0, 1])
    for bar, iou in zip(bars, ious):
        ax7.text(bar.get_width() + 0.01, bar.get_y() + bar.get_height()/2, 
                f'{iou:.4f}', va='center', ha='left', fontsize=9, fontweight='bold')
    
    # IoU por clase (barras)
    for idx, vertebra in enumerate(vertebra_classes):
        ax = fig.add_subplot(gs[2, idx+1])
        ious_class = [r['best_metrics_per_class'][vertebra]['iou'] for r in resultados]
        bars = ax.barh(nombres, ious_class, color=colors, edgecolor='black', linewidth=1.5)
        ax.set_xlabel(f'IoU {vertebra}', fontsize=11, fontweight='bold')
        ax.set_title(f'üèÜ IoU - {vertebra}', fontsize=12, fontweight='bold')
        ax.grid(True, alpha=0.3, axis='x')
        ax.set_xlim([0, 1])
        for bar, iou in zip(bars, ious_class):
            ax.text(bar.get_width() + 0.01, bar.get_y() + bar.get_height()/2, 
                   f'{iou:.3f}', va='center', ha='left', fontsize=8, fontweight='bold')
    
    # Heatmap IoU
    ax11 = fig.add_subplot(gs[3, :2])
    data_matrix = []
    for result in resultados:
        row = [result['best_metrics_per_class'][v]['iou'] for v in vertebra_classes]
        data_matrix.append(row)
    
    im = ax11.imshow(data_matrix, cmap='RdYlGn', aspect='auto', vmin=0, vmax=1)
    ax11.set_xticks(range(len(vertebra_classes)))
    ax11.set_xticklabels(vertebra_classes, fontsize=11, fontweight='bold')
    ax11.set_yticks(range(len(nombres)))
    ax11.set_yticklabels(nombres, fontsize=11)
    ax11.set_title('üî• Heatmap IoU por Modelo y Clase', fontsize=13, fontweight='bold')
    
    for i in range(len(nombres)):
        for j in range(len(vertebra_classes)):
            text = ax11.text(j, i, f'{data_matrix[i][j]:.3f}',
                           ha="center", va="center", color="black", fontsize=10, fontweight='bold')
    
    plt.colorbar(im, ax=ax11, label='IoU Score')
    
    # Heatmap Dice
    ax12 = fig.add_subplot(gs[3, 2:])
    data_matrix_dice = []
    for result in resultados:
        row = [result['best_metrics_per_class'][v]['dice'] for v in vertebra_classes]
        data_matrix_dice.append(row)
    
    im2 = ax12.imshow(data_matrix_dice, cmap='RdYlGn', aspect='auto', vmin=0, vmax=1)
    ax12.set_xticks(range(len(vertebra_classes)))
    ax12.set_xticklabels(vertebra_classes, fontsize=11, fontweight='bold')
    ax12.set_yticks(range(len(nombres)))
    ax12.set_yticklabels(nombres, fontsize=11)
    ax12.set_title('üî• Heatmap Dice por Modelo y Clase', fontsize=13, fontweight='bold')
    
    for i in range(len(nombres)):
        for j in range(len(vertebra_classes)):
            text = ax12.text(j, i, f'{data_matrix_dice[i][j]:.3f}',
                           ha="center", va="center", color="black", fontsize=10, fontweight='bold')
    
    plt.colorbar(im2, ax=ax12, label='Dice Score')
    
    # Comparaci√≥n de Tama√±o
    ax13 = fig.add_subplot(gs[4, :2])
    params = [r['total_params'] / 1e6 for r in resultados]
    colors_size = plt.cm.viridis(np.linspace(0.2, 0.9, len(nombres)))
    bars = ax13.barh(nombres, params, color=colors_size, edgecolor='black', linewidth=1.5)
    ax13.set_xlabel('Par√°metros (Millones)', fontsize=11, fontweight='bold')
    ax13.set_title('üìä Tama√±o de Modelos (Par√°metros)', fontsize=12, fontweight='bold')
    ax13.grid(True, alpha=0.3, axis='x')
    for bar, param, result in zip(bars, params, resultados):
        size_mb = (result['total_params'] * 4) / (1024 * 1024)
        ax13.text(bar.get_width() + max(params)*0.01, bar.get_y() + bar.get_height()/2, 
                 f'{param:.2f}M ({size_mb:.1f}MB)', 
                 va='center', ha='left', fontsize=9, fontweight='bold')
    
    # Eficiencia: IoU vs Tama√±o
    ax14 = fig.add_subplot(gs[4, 2:])
    params_plot = [r['total_params'] / 1e6 for r in resultados]
    ious_plot = [r['best_iou'] for r in resultados]
    scatter = ax14.scatter(params_plot, ious_plot, c=range(len(resultados)), 
                          s=300, cmap='tab10', edgecolor='black', linewidth=2, alpha=0.8)
    
    for i, (p, iou, nombre) in enumerate(zip(params_plot, ious_plot, nombres)):
        ax14.annotate(nombre, (p, iou), 
                     xytext=(10, 5), textcoords='offset points',
                     fontsize=9, fontweight='bold',
                     bbox=dict(boxstyle='round,pad=0.5', fc='yellow', alpha=0.3))
    
    ax14.set_xlabel('Par√°metros (Millones)', fontsize=11, fontweight='bold')
    ax14.set_ylabel('Best IoU', fontsize=11, fontweight='bold')
    ax14.set_title('‚öñÔ∏è Eficiencia: IoU vs Tama√±o del Modelo', fontsize=12, fontweight='bold')
    ax14.grid(True, alpha=0.3, linestyle='--')
    ax14.set_ylim([min(ious_plot)*0.95, max(ious_plot)*1.05])
    
    plt.suptitle('üî¨ COMPARACI√ìN COMPLETA: MODELOS CL√ÅSICOS + NUEVOS MODELOS LIGEROS', 
                 fontsize=18, fontweight='bold', y=0.995)
    
    plt.savefig(output_file, dpi=200, bbox_inches='tight', facecolor='white')
    plt.close()
    print(f"\nüìä Comparaci√≥n detallada guardada en: {output_file}")

print("‚úì Funci√≥n de comparaci√≥n completa definida")

In [None]:
def main():
    # CONFIGURACI√ìN
    SEED = 1111
    BASE_PATH = r"C:\Users\User\Documents\Proyectofinal"
    BATCH_SIZE = 8
    MAX_EPOCHS = 100
    LEARNING_RATE = 0.0001
    IMAGE_SIZE = 256
    NUM_CLASSES = 3
    
    print("\n" + "="*80)
    print("üî¨ AN√ÅLISIS MULTI-MODELO: CL√ÅSICOS + NUEVOS MODELOS LIGEROS")
    print("   Modelos Cl√°sicos: DeepLabV3+, LinkNet, PSPNet, U-Net++")
    print("   Nuevos Modelos: MobileNetV3-UNet, MobileViT-XSmall")
    print(f"   SEED: {SEED}")
    print("="*80)
    
    set_seed(SEED)
    
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"\nüíª Dispositivo: {device}")
    if device.type == 'cuda':
        print(f"   GPU: {torch.cuda.get_device_name(0)}")
    
    # CARGAR DATASET
    print("\nüì¶ Cargando dataset...")
    full_dataset = VertebrasDataset(BASE_PATH, target_size=(IMAGE_SIZE, IMAGE_SIZE))
    class_names = full_dataset.class_names
    
    train_size = int(0.8 * len(full_dataset))
    val_size = len(full_dataset) - train_size
    
    generator = torch.Generator().manual_seed(SEED)
    train_dataset, val_dataset = torch.utils.data.random_split(
        full_dataset, [train_size, val_size], generator=generator
    )
    
    print(f"‚úì Total: {len(full_dataset)} | Train: {len(train_dataset)} | Val: {len(val_dataset)}")
    print(f"‚úì Clases: {class_names}")
    
    def seed_worker(worker_id):
        worker_seed = SEED + worker_id
        np.random.seed(worker_seed)
        random.seed(worker_seed)
    
    g = torch.Generator()
    g.manual_seed(SEED)
    
    train_loader = DataLoader(
        train_dataset, batch_size=BATCH_SIZE, shuffle=True,
        num_workers=0, worker_init_fn=seed_worker, generator=g,
        drop_last=True
    )
    
    val_loader = DataLoader(
        val_dataset, batch_size=BATCH_SIZE, shuffle=False,
        num_workers=0, worker_init_fn=seed_worker,
        drop_last=False
    )
    
    # MODELOS: CL√ÅSICOS + NUEVOS
    modelos = {
        # Modelos Cl√°sicos
        'DeepLabV3+': DeepLabV3Plus(in_channels=3, num_classes=NUM_CLASSES),
        'LinkNet': LinkNet(in_channels=3, num_classes=NUM_CLASSES),
        'PSPNet': PSPNet(in_channels=3, num_classes=NUM_CLASSES),
        'U-Net++': UNetPlusPlus(in_channels=3, num_classes=NUM_CLASSES),
        
        # Nuevos Modelos Ligeros
        'MobileNetV3-UNet': MobileNetV3UNet(in_channels=3, num_classes=NUM_CLASSES),
        'MobileViT-XSmall': MobileViTXSmall(in_channels=3, num_classes=NUM_CLASSES)
    }
    
    # Mostrar informaci√≥n de modelos
    print(f"\n{'='*80}")
    print("üìä INFORMACI√ìN DE MODELOS")
    print(f"{'='*80}")
    for nombre, modelo in modelos.items():
        params = sum(p.numel() for p in modelo.parameters())
        size_mb = (params * 4) / (1024 * 1024)
        print(f"  {nombre:20s}: {params/1e6:>6.2f}M params | {size_mb:>6.1f} MB")
    
    resultados = []
    
    print(f"\n{'='*80}")
    print(f"üèÅ INICIANDO ENTRENAMIENTO DE {len(modelos)} MODELOS")
    print(f"{'='*80}\n")
    
    # ENTRENAR TODOS LOS MODELOS
    for idx, (nombre, modelo) in enumerate(modelos.items(), 1):
        print(f"\n{'‚ñ∂'*3} Modelo {idx}/{len(modelos)}: {nombre} {'‚óÄ'*3}")
        
        set_seed(SEED)
        
        trainer = ModelTrainer(
            model_name=nombre,
            model=modelo,
            train_loader=train_loader,
            val_loader=val_loader,
            device=device,
            num_classes=NUM_CLASSES,
            class_names=class_names,
            lr=LEARNING_RATE,
            max_epochs=MAX_EPOCHS
        )
        
        trainer.train()
        resultados.append(trainer.get_results())
        
        # Visualizar predicciones
        print(f"\nüì∏ Generando visualizaciones para {nombre}...")
        visualizar_predicciones_modelo(
            modelo, val_dataset, device, nombre, class_names, 
            num_samples=4, seed=SEED
        )
        
        if device.type == 'cuda':
            torch.cuda.empty_cache()
        
        print(f"\n{'='*80}")
    
    # ORDENAR POR IoU
    resultados = sorted(resultados, key=lambda x: x['best_iou'], reverse=True)
    
    # VISUALIZACI√ìN COMPARATIVA
    print(f"\n{'='*80}")
    print("üìä GENERANDO VISUALIZACIONES COMPARATIVAS")
    print(f"{'='*80}")
    visualizar_comparacion_con_clases(resultados, class_names)
    
    # RESUMEN DETALLADO
    print(f"\n{'='*80}")
    print("üèÜ RESUMEN DETALLADO DE RESULTADOS")
    print(f"{'='*80}\n")
    
    for i, r in enumerate(resultados, 1):
        emoji = "ü•á" if i == 1 else "ü•à" if i == 2 else "ü•â" if i == 3 else "  "
        params_m = r['total_params'] / 1e6
        size_mb = (r['total_params'] * 4) / (1024 * 1024)
        
        print(f"{emoji} {i}. {r['model_name']}")
        print(f"   IoU Promedio: {r['best_iou']:.4f} | Dice Promedio: {r['best_dice']:.4f}")
        print(f"   Par√°metros: {params_m:.2f}M | Tama√±o: {size_mb:.2f} MB")
        print(f"   M√©tricas por clase:")
        for vertebra in class_names[1:]:
            iou = r['best_metrics_per_class'][vertebra]['iou']
            dice = r['best_metrics_per_class'][vertebra]['dice']
            print(f"     {vertebra:>10}: IoU={iou:.4f} | Dice={dice:.4f}")
        print(f"   Epoca: {r['best_epoch']} | Tiempo: {r['training_time']/60:.1f} min\n")
    
    # AN√ÅLISIS DE EFICIENCIA
    print(f"\n{'='*80}")
    print("‚ö° AN√ÅLISIS DE EFICIENCIA")
    print(f"{'='*80}\n")
    
    print("Modelos Ligeros (< 5M params):")
    for r in resultados:
        if r['total_params'] < 5e6:
            params_m = r['total_params'] / 1e6
            eficiencia = r['best_iou'] / params_m * 100
            print(f"  {r['model_name']:20s}: IoU={r['best_iou']:.4f} | "
                  f"Params={params_m:.2f}M | Eficiencia={eficiencia:.2f}")
    
    print("\nModelos Pesados (>= 5M params):")
    for r in resultados:
        if r['total_params'] >= 5e6:
            params_m = r['total_params'] / 1e6
            eficiencia = r['best_iou'] / params_m * 100
            print(f"  {r['model_name']:20s}: IoU={r['best_iou']:.4f} | "
                  f"Params={params_m:.2f}M | Eficiencia={eficiencia:.2f}")
    
    # COMPARACI√ìN NUEVOS vs CL√ÅSICOS
    print(f"\n{'='*80}")
    print("üîç COMPARACI√ìN: NUEVOS MODELOS vs MEJOR MODELO CL√ÅSICO")
    print(f"{'='*80}\n")
    
    modelos_clasicos = [r for r in resultados if r['model_name'] in ['DeepLabV3+', 'LinkNet', 'PSPNet', 'U-Net++']]
    modelos_nuevos = [r for r in resultados if r['model_name'] in ['MobileNetV3-UNet', 'MobileViT-XSmall']]
    
    if modelos_clasicos and modelos_nuevos:
        mejor_clasico = max(modelos_clasicos, key=lambda x: x['best_iou'])
        
        print(f"Mejor Cl√°sico: {mejor_clasico['model_name']}")
        print(f"  IoU: {mejor_clasico['best_iou']:.4f}")
        print(f"  Params: {mejor_clasico['total_params']/1e6:.2f}M")
        print(f"  Tiempo: {mejor_clasico['training_time']/60:.1f} min\n")
        
        for nuevo in modelos_nuevos:
            print(f"{nuevo['model_name']}:")
            print(f"  IoU: {nuevo['best_iou']:.4f}")
            print(f"  Params: {nuevo['total_params']/1e6:.2f}M")
            print(f"  Tiempo: {nuevo['training_time']/60:.1f} min")
            
            reduccion_params = (1 - nuevo['total_params'] / mejor_clasico['total_params']) * 100
            diff_iou = nuevo['best_iou'] - mejor_clasico['best_iou']
            
            print(f"  Reducci√≥n de Par√°metros: {reduccion_params:.1f}%")
            print(f"  Diferencia de IoU: {diff_iou:+.4f}")
            
            if diff_iou >= -0.02:
                print(f"  ‚úÖ Logra {reduccion_params:.1f}% menos par√°metros con rendimiento similar!")
            elif diff_iou > 0:
                print(f"  üéâ SUPERA al mejor cl√°sico con {reduccion_params:.1f}% menos par√°metros!")
            print()
    
    # GUARDAR RESUMEN
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    filename = f'resultados_multimodal_seed{SEED}_{timestamp}.txt'
    
    with open(filename, 'w', encoding='utf-8') as f:
        f.write("="*80 + "\n")
        f.write("AN√ÅLISIS COMPLETO: MODELOS CL√ÅSICOS + NUEVOS MODELOS LIGEROS\n")
        f.write(f"SEED: {SEED}\n")
        f.write("="*80 + "\n\n")
        
        f.write(f"Fecha: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n")
        f.write(f"Dataset: {len(full_dataset)} im√°genes\n")
        f.write(f"Clases: {', '.join(class_names)}\n\n")
        
        f.write("MODELOS EVALUADOS:\n")
        f.write("-"*80 + "\n")
        f.write("Cl√°sicos: DeepLabV3+, LinkNet, PSPNet, U-Net++\n")
        f.write("Nuevos: MobileNetV3-UNet, MobileViT-XSmall\n\n")
        
        f.write("RESULTADOS POR MODELO:\n")
        f.write("-"*80 + "\n\n")
        
        for i, r in enumerate(resultados, 1):
            params_m = r['total_params'] / 1e6
            size_mb = (r['total_params'] * 4) / (1024 * 1024)
            
            f.write(f"{i}. {r['model_name']}:\n")
            f.write(f"   IoU Promedio: {r['best_iou']:.4f}\n")
            f.write(f"   Dice Promedio: {r['best_dice']:.4f}\n")
            f.write(f"   Par√°metros: {params_m:.2f}M ({r['total_params']:,})\n")
            f.write(f"   Tama√±o: {size_mb:.2f} MB\n")
            f.write(f"   Mejor Epoca: {r['best_epoch']}\n")
            f.write(f"   Tiempo: {r['training_time']/60:.2f} min\n\n")
            f.write(f"   M√©tricas por clase:\n")
            for vertebra in class_names[1:]:
                iou = r['best_metrics_per_class'][vertebra]['iou']
                dice = r['best_metrics_per_class'][vertebra]['dice']
                f.write(f"     {vertebra}: IoU={iou:.4f}, Dice={dice:.4f}\n")
            f.write("\n")
        
        f.write(f"{'='*80}\n")
        f.write(f"MEJOR MODELO GENERAL: {resultados[0]['model_name']}\n")
        f.write(f"IoU: {resultados[0]['best_iou']:.4f}\n")
        
        if modelos_clasicos and modelos_nuevos:
            f.write(f"\n{'='*80}\n")
            f.write("COMPARACI√ìN: NUEVOS MODELOS vs CL√ÅSICOS\n")
            f.write(f"{'='*80}\n\n")
            f.write(f"Mejor Cl√°sico: {mejor_clasico['model_name']} (IoU: {mejor_clasico['best_iou']:.4f})\n\n")
            
            for nuevo in modelos_nuevos:
                reduccion_params = (1 - nuevo['total_params'] / mejor_clasico['total_params']) * 100
                diff_iou = nuevo['best_iou'] - mejor_clasico['best_iou']
                f.write(f"{nuevo['model_name']}: IoU={nuevo['best_iou']:.4f}\n")
                f.write(f"  Reducci√≥n de Par√°metros: {reduccion_params:.1f}%\n")
                f.write(f"  Diferencia de IoU: {diff_iou:+.4f}\n\n")
        
        f.write(f"\n{'='*80}\n")
        f.write("TABLA COMPARATIVA DE EFICIENCIA\n")
        f.write(f"{'='*80}\n\n")
        f.write(f"{'Modelo':<20} {'Params':<12} {'Tama√±o':<12} {'IoU':<10} {'Efic.':<10}\n")
        f.write(f"{'-'*80}\n")
        
        for r in resultados:
            params_m = r['total_params'] / 1e6
            size_mb = (r['total_params'] * 4) / (1024 * 1024)
            eficiencia = r['best_iou'] / params_m * 100
            
            f.write(f"{r['model_name']:<20} {params_m:>6.2f}M      {size_mb:>6.1f} MB   "
                   f"{r['best_iou']:<10.4f} {eficiencia:>8.2f}\n")
    
    print(f"\nüìÑ Resumen guardado en: {filename}")
    
    # TABLA FINAL EN CONSOLA
    print(f"\n{'='*80}")
    print("üìä TABLA COMPARATIVA FINAL")
    print(f"{'='*80}")
    print(f"{'Modelo':<20} {'Params':<12} {'Tama√±o':<12} {'IoU':<10} {'Efic.':<10}")
    print(f"{'-'*80}")
    
    for r in resultados:
        params_m = r['total_params'] / 1e6
        size_mb = (r['total_params'] * 4) / (1024 * 1024)
        eficiencia = r['best_iou'] / params_m * 100
        
        print(f"{r['model_name']:<20} {params_m:>6.2f}M      {size_mb:>6.1f} MB   "
              f"{r['best_iou']:<10.4f} {eficiencia:>8.2f}")
    
    print(f"\n‚úÖ AN√ÅLISIS COMPLETADO")
    print(f"\nüìã ARCHIVOS GENERADOS:")
    print(f"   ‚úì Visualizaciones por modelo: pred_*_seed{SEED}.png")
    print(f"   ‚úì Comparaci√≥n completa: comparacion_multimodal_actualizada.png")
    print(f"   ‚úì Resumen detallado: {filename}")
    print(f"   ‚úì Checkpoints: *_best.pth")
    
    return resultados

print("‚úì Funci√≥n principal definida")

In [None]:
if __name__ == "__main__":
    try:
        resultados = main()
    except KeyboardInterrupt:
        print("\n\n‚ö†Ô∏è  Proceso interrumpido")
    except Exception as e:
        print(f"\n\n‚ùå Error: {str(e)}")
        import traceback
        traceback.print_exc()
    finally:
        if torch.cuda.is_available():
            torch.cuda.empty_cache()