# Entrenamiento Unet con ConvNeXt Tiny

Este cuaderno contiene el 'Script de Rescate', combinando múltiples fases de entrenamiento en un solo flujo consolidado de 100 épocas. ConvNeXt Tiny ofrece un gran balance entre precisión y velocidad.

In [None]:
# Instalar librerías necesarias
!pip install segmentation-models-pytorch albumentations timm utils

In [None]:
import os
import cv2
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from tqdm import tqdm
from sklearn.model_selection import train_test_split
import albumentations as A
from albumentations.pytorch import ToTensorV2

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torch.cuda.amp import GradScaler
import segmentation_models_pytorch as smp

# Configuración Global
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"✅ Usando dispositivo: {DEVICE}")

# Semilla para reproducibilidad
def set_seed(seed=42):
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)

set_seed()

## Utilidades y Clases del Dataset

In [None]:
# ==========================================
# UTILIDADES Y DATASET
# ==========================================

def rle_decode(mask_rle, shape=(256, 1600)):
    '''
    Decodifica una máscara en formato RLE (Run Length Encoding) a una imagen binaria.
    '''
    if pd.isna(mask_rle) or mask_rle == '': return np.zeros(shape, dtype=np.uint8)
    s = mask_rle.split()
    starts, lengths = [np.asarray(x, dtype=int) for x in (s[0:][::2], s[1:][::2])]
    starts -= 1
    ends = starts + lengths
    img = np.zeros(shape[0] * shape[1], dtype=np.uint8)
    for lo, hi in zip(starts, ends):
        img[lo:hi] = 1
    return img.reshape(shape, order='F')

class SeverStalDataset(Dataset):
    def __init__(self, img_ids, df, img_dir, transform=None, tile_size=256):
        self.img_ids = img_ids
        self.img_dir = img_dir
        self.transform = transform
        self.tile_size = tile_size
        self.rle_dict = {}
        
        # Agrupar RLEs por imagen
        subset_df = df[df['ImageID'].isin(img_ids)]
        for img_id, group in subset_df.groupby('ImageID'):
            self.rle_dict[img_id] = group[['ClassID', 'EncodedPixels']].values.tolist()

    def __len__(self): return len(self.img_ids)

    def __getitem__(self, idx):
        img_id = self.img_ids[idx]
        path = os.path.join(self.img_dir, img_id)
        img = cv2.imread(path)
        if img is None: 
            # Placeholder si falla la carga (no debería ocurrir)
            return torch.zeros((3, self.tile_size, self.tile_size)), torch.zeros((self.tile_size, self.tile_size)).long()
        
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        h, w = img.shape[:2]

        # Crear máscara completa
        mask = np.zeros((h, w), dtype=np.uint8)
        if img_id in self.rle_dict:
            for cls, rle in self.rle_dict[img_id]:
                if pd.notna(rle):
                    m = rle_decode(rle, shape=(256, 1600)) # Shape original conocida
                    mask[m == 1] = int(cls)

        # Smart Sampling: Recortar donde hay defecto para entrenar mejor
        # Si la imagen tiene defectos, intentamos centrar el crop en uno de ellos
        use_defect = False
        if mask.max() > 0:
            if np.random.rand() < 0.8: # 80% prob de centrar en defecto
                use_defect = True

        if use_defect:
            ys, xs = np.where(mask > 0)
            if len(ys) > 0:
                center_idx = np.random.randint(len(ys))
                cy, cx = ys[center_idx], xs[center_idx]
                jitter = self.tile_size // 4
                cy += np.random.randint(-jitter, jitter)
                cx += np.random.randint(-jitter, jitter)
                y1 = np.clip(cy - self.tile_size//2, 0, h - self.tile_size)
                x1 = np.clip(cx - self.tile_size//2, 0, w - self.tile_size)
            else:
                y1 = np.random.randint(0, max(1, h - self.tile_size))
                x1 = np.random.randint(0, max(1, w - self.tile_size))
        else:
            y1 = np.random.randint(0, max(1, h - self.tile_size))
            x1 = np.random.randint(0, max(1, w - self.tile_size))

        y1, x1 = int(y1), int(x1)
        image_crop = img[y1:y1+self.tile_size, x1:x1+self.tile_size]
        mask_crop = mask[y1:y1+self.tile_size, x1:x1+self.tile_size]

        if self.transform:
            augmented = self.transform(image=image_crop, mask=mask_crop)
            image_crop = augmented['image']
            mask_crop = augmented['mask']

        return image_crop, mask_crop.long()


## Métricas y Funciones de Entrenamiento

In [None]:
# ==========================================
# MÉTRICAS Y LOOP DE ENTRENAMIENTO
# ==========================================

def compute_iou(pred, target, num_classes):
    ious = []
    pred = pred.cpu().numpy()
    target = target.cpu().numpy()
    for cls in range(num_classes):
        pred_cls = (pred == cls)
        target_cls = (target == cls)
        intersection = np.logical_and(pred_cls, target_cls).sum()
        union = np.logical_or(pred_cls, target_cls).sum()
        if union == 0: ious.append(float('nan')) 
        else: ious.append(intersection / union)
    return ious

def train_one_epoch(model, loader, criterion, optimizer, scaler, device, num_classes):
    model.train()
    total_loss = 0
    all_ious = []
    pbar = tqdm(loader, desc='Training')
    
    for images, masks in pbar:
        images = images.to(device).float()
        masks = masks.to(device).long()
        optimizer.zero_grad()
        
        with torch.amp.autocast('cuda'):
            outputs = model(images)
            loss = criterion(outputs, masks)
            
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        
        total_loss += loss.item()
        
        with torch.no_grad():
            preds = outputs.argmax(dim=1)
            batch_ious = compute_iou(preds, masks, num_classes)
            all_ious.append(batch_ious)
            
        pbar.set_postfix({'loss': f"{loss.item():.4f}"})
        
    return total_loss/len(loader), np.nanmean(np.nanmean(all_ious, axis=0)), np.nanmean(all_ious, axis=0)

def validate(model, loader, criterion, device, num_classes):
    model.eval()
    total_loss = 0
    all_ious = []
    with torch.no_grad():
        for images, masks in tqdm(loader, desc='Validation'):
            images = images.to(device).float()
            masks = masks.to(device).long()
            outputs = model(images)
            loss = criterion(outputs, masks)
            total_loss += loss.item()
            preds = outputs.argmax(dim=1)
            batch_ious = compute_iou(preds, masks, num_classes)
            all_ious.append(batch_ious)
            
    return total_loss/len(loader), np.nanmean(np.nanmean(all_ious, axis=0)), np.nanmean(all_ious, axis=0)


In [None]:
# Configuración y Carga de Datos
SEVERSTAL_ROOT = '/content/drive/MyDrive/severstal' # Ajustar path según entorno

# Transformaciones
train_trans = A.Compose([
    A.HorizontalFlip(p=0.5),
    A.VerticalFlip(p=0.5),
    A.RandomBrightnessContrast(p=0.2),
    A.Normalize(), ToTensorV2()
])

val_trans = A.Compose([
    A.Normalize(), ToTensorV2()
])

# Cargar CSV (asumiendo existencia)
try:
    df = pd.read_csv(os.path.join(SEVERSTAL_ROOT, 'train.csv'))
    # Preprocesamiento básico del CSV
    if 'ImageId_ClassId' in df.columns:
        df['ImageID'] = df['ImageId_ClassId'].apply(lambda x: x.split('_')[0])
        df['ClassID'] = df['ImageId_ClassId'].apply(lambda x: x.split('_')[1])
    
    unique_ids = df['ImageID'].unique().tolist()
    train_ids, val_ids = train_test_split(unique_ids, test_size=0.2, random_state=42)
    
    train_ds = SeverStalDataset(train_ids, df, os.path.join(SEVERSTAL_ROOT, 'train_images'), transform=train_trans)
    val_ds = SeverStalDataset(val_ids, df, os.path.join(SEVERSTAL_ROOT, 'train_images'), transform=val_trans)
    
    train_loader = DataLoader(train_ds, batch_size=16, shuffle=True, num_workers=2)
    val_loader = DataLoader(val_ds, batch_size=16, shuffle=False, num_workers=2)
    print("✅ Datos cargados correctamente")
except Exception as e:
    print(f"⚠️ Aviso: No se pudieron cargar los datos ({e}). Asegúrate de tener el dataset.")


## Definición del Modelo (ConvNeXt Tiny)
Nota: Cosine Annealing se usa para bajar el Learning Rate suavemente, permitiendo al modelo afinar pesos en los mínimos.

In [None]:
# Nota: 'tu-convnext_tiny' requiere la librería timm
model = smp.Unet(
    encoder_name='tu-convnext_tiny', 
    encoder_weights='imagenet', 
    in_channels=3, 
    classes=5
).to(DEVICE)

optimizer = optim.AdamW(model.parameters(), lr=2e-4, weight_decay=1e-2)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=100, eta_min=1e-6)
# Loss ponderada para clases difíciles
class_weights = torch.tensor([0.1, 2.0, 10.0, 2.0, 2.0]).to(DEVICE)
criterion = lambda p, t: 0.4*nn.CrossEntropyLoss(weight=class_weights)(p,t) + 0.3*smp.losses.DiceLoss('multiclass', from_logits=True)(p,t) + 0.3*smp.losses.FocalLoss('multiclass')(p,t)
scaler = GradScaler()

## Bucle de Entrenamiento Consolidado (110 Épocas)
Se han unificado las fases de 30, 40 y 30 épocas en un solo bucle robusto.

In [None]:
EPOCHS = 110
best_iou = 0.0

for epoch in range(EPOCHS):
    print(f"Epoch {epoch+1}/{EPOCHS}")
    t_loss, t_iou, _ = train_one_epoch(model, train_loader, criterion, optimizer, scaler, DEVICE, 5)
    v_loss, v_iou, v_ious = validate(model, val_loader, criterion, DEVICE, 5)
    
    scheduler.step()
    
    print(f"Train Loss: {t_loss:.4f} | Val IoU: {v_iou:.4f}")
    print(f"Detalle IoU por clase: {v_ious}")
    
    if v_iou > best_iou:
        best_iou = v_iou
        torch.save(model.state_dict(), 'convnext_best.pth')
        print("✅ Récord batido - Modelo guardado")
        
    if (epoch+1) % 10 == 0:
         torch.save(model.state_dict(), f'convnext_backup_ep{epoch+1}.pth')
