In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F


class ConvBlock(nn.Module):
    """
    Bloque básico de U-Net++:
    Conv -> BN (opcional) -> ReLU -> Conv -> BN (opcional) -> ReLU
    """
    def __init__(self, in_ch, out_ch, use_bn=True, dropout=0.0):
        super().__init__()
        layers = [
            nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1, bias=not use_bn)
        ]
        if use_bn:
            layers.append(nn.BatchNorm2d(out_ch))
        layers.append(nn.ReLU(inplace=True))

        if dropout > 0.0:
          layers.append(nn.Dropout2d(dropout))

        layers.append(
            nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1, bias=not use_bn)
        )
        if use_bn:
            layers.append(nn.BatchNorm2d(out_ch))
        layers.append(nn.ReLU(inplace=True))

        if dropout > 0.0:
          layers.append(nn.Dropout2d(dropout))

        self.block = nn.Sequential(*layers)

    def forward(self, x):
        return self.block(x)


class UNetPlusPlus(nn.Module):
    """
    U-Net++ (Nested U-Net) para segmentación 2D.

    - n_channels: canales de entrada (3 para RGB)
    - n_classes: canales de salida (1 para segmentación binaria)
    - deep_supervision: si True, devuelve promedio de varias salidas intermedias
    - use_bn: usa BatchNorm en los ConvBlocks
    - base_ch: número de filtros en el primer nivel (por defecto 64)
    """
    def __init__(
        self,
        n_channels: int = 3,
        n_classes: int = 1,
        deep_supervision: bool = False,
        use_bn: bool = True,
        base_ch: int = 64,
        dropout=0.0
    ):
        super().__init__()

        self.n_channels = n_channels
        self.n_classes = n_classes
        self.deep_supervision = deep_supervision

        # Canales en cada nivel del encoder
        nb_filter = [
            base_ch,
            base_ch * 2,
            base_ch * 4,
            base_ch * 8,
            base_ch * 16,
        ]

        # ---------------- ENCODER (X_{i,0}) ----------------
        self.conv0_0 = ConvBlock(n_channels, nb_filter[0], use_bn, dropout)
        self.conv1_0 = ConvBlock(nb_filter[0], nb_filter[1], use_bn, dropout)
        self.conv2_0 = ConvBlock(nb_filter[1], nb_filter[2], use_bn, dropout)
        self.conv3_0 = ConvBlock(nb_filter[2], nb_filter[3], use_bn, dropout)
        self.conv4_0 = ConvBlock(nb_filter[3], nb_filter[4], use_bn, dropout)

        self.pool = nn.MaxPool2d(2)

        # -------------- BLOQUES NESTED (decoder) -----------
        # X_{0,1}, X_{1,1}, X_{2,1}, X_{3,1}
        self.conv0_1 = ConvBlock(nb_filter[0] + nb_filter[1], nb_filter[0], use_bn, dropout)
        self.conv1_1 = ConvBlock(nb_filter[1] + nb_filter[2], nb_filter[1], use_bn, dropout)
        self.conv2_1 = ConvBlock(nb_filter[2] + nb_filter[3], nb_filter[2], use_bn, dropout)
        self.conv3_1 = ConvBlock(nb_filter[3] + nb_filter[4], nb_filter[3], use_bn, dropout)

        # X_{0,2}, X_{1,2}, X_{2,2}
        self.conv0_2 = ConvBlock(nb_filter[0] * 2 + nb_filter[1], nb_filter[0], use_bn, dropout)
        self.conv1_2 = ConvBlock(nb_filter[1] * 2 + nb_filter[2], nb_filter[1], use_bn, dropout)
        self.conv2_2 = ConvBlock(nb_filter[2] * 2 + nb_filter[3], nb_filter[2], use_bn, dropout)

        # X_{0,3}, X_{1,3}
        self.conv0_3 = ConvBlock(nb_filter[0] * 3 + nb_filter[1], nb_filter[0], use_bn, dropout)
        self.conv1_3 = ConvBlock(nb_filter[1] * 3 + nb_filter[2], nb_filter[1], use_bn, dropout)

        # X_{0,4}
        self.conv0_4 = ConvBlock(nb_filter[0] * 4 + nb_filter[1], nb_filter[0], use_bn, dropout)

        # Upsample (siempre x2)
        self.up = nn.Upsample(scale_factor=2, mode="bilinear", align_corners=False)

        # Convs de salida para deep supervision
        if self.deep_supervision:
            self.final1 = nn.Conv2d(nb_filter[0], n_classes, kernel_size=1)
            self.final2 = nn.Conv2d(nb_filter[0], n_classes, kernel_size=1)
            self.final3 = nn.Conv2d(nb_filter[0], n_classes, kernel_size=1)
            self.final4 = nn.Conv2d(nb_filter[0], n_classes, kernel_size=1)
        else:
            self.final = nn.Conv2d(nb_filter[0], n_classes, kernel_size=1)

    def forward(self, x):
        # ---------------- ENCODER ----------------
        x0_0 = self.conv0_0(x)              # (B, f0, H,   W)
        x1_0 = self.conv1_0(self.pool(x0_0))# (B, f1, H/2, W/2)
        x2_0 = self.conv2_0(self.pool(x1_0))# (B, f2, H/4, W/4)
        x3_0 = self.conv3_0(self.pool(x2_0))# (B, f3, H/8, W/8)
        x4_0 = self.conv4_0(self.pool(x3_0))# (B, f4, H/16,W/16)

        # ---------------- DECODER NESTED ----------------
        # Nivel j = 1
        x0_1 = self.conv0_1(torch.cat([x0_0, self.up(x1_0)], dim=1))
        x1_1 = self.conv1_1(torch.cat([x1_0, self.up(x2_0)], dim=1))
        x2_1 = self.conv2_1(torch.cat([x2_0, self.up(x3_0)], dim=1))
        x3_1 = self.conv3_1(torch.cat([x3_0, self.up(x4_0)], dim=1))

        # Nivel j = 2
        x0_2 = self.conv0_2(
            torch.cat([x0_0, x0_1, self.up(x1_1)], dim=1)
        )
        x1_2 = self.conv1_2(
            torch.cat([x1_0, x1_1, self.up(x2_1)], dim=1)
        )
        x2_2 = self.conv2_2(
            torch.cat([x2_0, x2_1, self.up(x3_1)], dim=1)
        )

        # Nivel j = 3
        x0_3 = self.conv0_3(
            torch.cat([x0_0, x0_1, x0_2, self.up(x1_2)], dim=1)
        )
        x1_3 = self.conv1_3(
            torch.cat([x1_0, x1_1, x1_2, self.up(x2_2)], dim=1)
        )

        # Nivel j = 4
        x0_4 = self.conv0_4(
            torch.cat([x0_0, x0_1, x0_2, x0_3, self.up(x1_3)], dim=1)
        )

        # ---------------- SALIDAS ----------------
        if self.deep_supervision:
            output1 = self.final1(x0_1)
            output2 = self.final2(x0_2)
            output3 = self.final3(x0_3)
            output4 = self.final4(x0_4)
            # Podés devolver la lista o el promedio
            return (output1 + output2 + output3 + output4) / 4.0
        else:
            logits = self.final(x0_4)
            return logits   # (B, n_classes, H, W) logits (sin sigmoid)


In [None]:
train_loader_unet_plus_drop, val_loader_unet_plus_drop, test_loader_unet_plus_drop, kaggle_loader_unet_plus_drop = get_seg_dataloaders(
    "data/train",
    batch_size=BATCH_SIZE, num_workers=NUM_WORKERS, seed=SEED, rgb=True,
    train_transform_img=base_img_tf,
    train_transform_mask=base_mask_tf,
    val_transform_img=base_img_tf,
    val_transform_mask=base_mask_tf,  
)

unet_plus_drop = UNetPlusPlus(
    n_channels=3,
    n_classes=1,
    deep_supervision=False,   # o True si querés experimentar
    use_bn=True,
    base_ch=64,
    dropout=0.3
).to(DEVICE)

criterion = combined_loss  # por ejemplo BCE + Dice o Focal Tversky
optimizer_unet_plus_drop = torch.optim.Adam(unet_plus_drop.parameters(), lr=1e-4, weight_decay=1e-5)
scheduler_unet_plus_drop = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer_unet_plus_drop, mode="min", factor=0.5, patience=4)

# unet_plus_drop, _, checkpoint = restaurar_modelo(unet_plus_drop, None, "models/unet_plus_plus_2.pth", DEVICE)


epoch_train_errors_unet_plus_drop, epoch_val_errors_unet_plus_drop = train(
    unet_plus_drop, optimizer_unet_plus_drop, criterion, train_loader_unet_plus_drop, val_loader_unet_plus_drop, DEVICE, scheduler=scheduler_unet_plus_drop,
    do_early_stopping=True, patience=15, epochs=80, log_every=5,
    checkpoint_path="models/unet_plus_drop_3.pth", 
    loss_ponderada=True
)