# ConvNeXt using backbone

In [None]:
import torch
import torch.nn as nn
import torchvision.models as models

# Cargar ConvNeXt como backbone
#backbone = models.convnext_base(pretrained=True).features  # Solo las características, version anteriro. 
backbone = models.convnext_base(weights=models.ConvNeXt_Base_Weights.DEFAULT).features  # only features

# Cambiar la primera capa convolucional para aceptar un canal en lugar de tres
backbone[0][0] = nn.Conv2d(1, backbone[0][0].out_channels, kernel_size=4, stride=4, padding=0)

# Definir un modelo de segmentación usando ConvNeXt como backbone
class ConvNeXtSegmentation(nn.Module):
    def __init__(self, num_classes):
        super(ConvNeXtSegmentation, self).__init__()
        self.backbone = backbone
        # Decodificador ajustado para obtener una salida de 240x240
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(1024, 512, kernel_size=2, stride=2),
            nn.GELU(),#ReLU(),
            nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2),
            nn.GELU(),
            nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2),
            nn.GELU(),
            nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.GELU(),
            nn.ConvTranspose2d(64, num_classes, kernel_size=3, stride=2, padding=1, output_padding=1)
        )

    def forward(self, x):
        x = self.backbone(x)
        x = self.decoder(x)
        return x

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# ConvNext Base with Skip connection

In [None]:
import torch
import torch.nn as nn
import torchvision.models as models

class UpBlock(nn.Module):
    def __init__(self, in_ch, skip_ch, out_ch):
        super().__init__()
        self.up = nn.ConvTranspose2d(in_ch, out_ch, kernel_size=2, stride=2)
        self.conv = nn.Sequential(
            nn.Conv2d(out_ch + skip_ch, out_ch, 3, padding=1, bias=False),
            nn.BatchNorm2d(out_ch),
            nn.GELU(),
            nn.Conv2d(out_ch, out_ch, 3, padding=1, bias=False),
            nn.BatchNorm2d(out_ch),
            nn.GELU(),
        )

    def forward(self, x, skip):
        x = self.up(x)
        if x.shape[-2:] != skip.shape[-2:]:          # corrige off-by-1
            x = torch.nn.functional.interpolate(
                x, size=skip.shape[-2:], mode="bilinear",
                align_corners=False)
        x = torch.cat([x, skip], dim=1)
        return self.conv(x)


# ---------------------------------------------------------------------
# ConvNeXt-Base como encoder + decoder tipo U-Net con skips
class ConvNeXtUNet_B(nn.Module):
    def __init__(self, num_classes=1):
        super().__init__()

        # -------------------- Encoder ----------------------------------
        convnext = models.convnext_base(
            weights=models.ConvNeXt_Base_Weights.DEFAULT).features

        # aceptar imágenes de 1 canal
        convnext[0][0] = nn.Conv2d(
            1, convnext[0][0].out_channels,
            kernel_size=4, stride=4, padding=0, bias=False)

        # separar etapas para capturar los mapas de características
        self.enc_stem   = convnext[0]   # 1/4 , 128 ch
        self.enc_stage0 = convnext[1]   # 1/4 , 128 ch
        self.down1      = convnext[2]   # ↓ 1/8
        self.enc_stage1 = convnext[3]   # 1/8 , 256 ch
        self.down2      = convnext[4]   # ↓ 1/16
        self.enc_stage2 = convnext[5]   # 1/16, 512 ch
        self.down3      = convnext[6]   # ↓ 1/32
        self.enc_stage3 = convnext[7]   # 1/32, 1024 ch

        # el LayerNorm final sólo existe si len(convnext) > 8
        self.enc_norm = convnext[8] if len(convnext) > 8 else nn.Identity()

        # -------------------- Decoder ----------------------------------
        self.up3  = UpBlock(1024, 512, 512)   # 1/32 → 1/16
        self.up2  = UpBlock(512, 256, 256)    # 1/16 → 1/8
        self.up1  = UpBlock(256, 128, 128)    # 1/8  → 1/4
        self.up0  = nn.Sequential(            # 1/4  → 1
            nn.ConvTranspose2d(128, 64, kernel_size=4, stride=4),
            nn.GELU(),
            nn.Conv2d(64, 64, 3, padding=1, bias=False),
            nn.BatchNorm2d(64),
            nn.GELU(),
        )
        self.head = nn.Conv2d(64, num_classes, kernel_size=1)

    # ------------------------------------------------------------------
    def forward(self, x):
        # ---------- Encoder -------------
        x0 = self.enc_stage0(self.enc_stem(x))      # 1/4
        x1 = self.enc_stage1(self.down1(x0))        # 1/8
        x2 = self.enc_stage2(self.down2(x1))        # 1/16
        x3 = self.enc_stage3(self.down3(x2))        # 1/32
        x3 = self.enc_norm(x3)                      # LN o identidad

        # ---------- Decoder -------------
        d2 = self.up3(x3, x2)   # 1/16
        d1 = self.up2(d2, x1)   # 1/8
        d0 = self.up1(d1, x0)   # 1/4
        d0 = self.up0(d0)       # 1
        out = self.head(d0)     # logits

        return out


# ConvNext variants

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import models

# --------------------------------------------------
# Utilidad: crea backbone ConvNeXt (cualquier variante)
# --------------------------------------------------
def get_convnext_features(variant="base", in_ch=1):
    fn  = dict(base = models.convnext_base,
               tiny = models.convnext_tiny,
               small= models.convnext_small,
               large= models.convnext_large)[variant]
    w   = getattr(models, f"ConvNeXt_{variant.capitalize()}_Weights").DEFAULT
    enc = fn(weights=w).features          # solo stages ConvNeXt
    enc[0][0] = nn.Conv2d(in_ch, enc[0][0].out_channels,
                          kernel_size=4, stride=4, padding=0)
    return enc

# --------------------------------------------------
# Bloque Reverse-Attention sencillo
# --------------------------------------------------
class ReverseAttention(nn.Module):
    def __init__(self, ch):
        super().__init__()
        self.conv1 = nn.Conv2d(ch, ch, 3, padding=1, bias=False)
        self.bn1   = nn.BatchNorm2d(ch)
        self.conv2 = nn.Conv2d(ch, ch, 3, padding=1, bias=False)
        self.bn2   = nn.BatchNorm2d(ch)

    def forward(self, x):
        m = torch.sigmoid(self.bn2(self.conv2(F.gelu(self.bn1(self.conv1(x))))))
        return x * (1.0 - m) + x               # “reverse”: suprime píxeles poco relevantes

# --------------------------------------------------
# Bloque Multi-Feature Stem (k3,k5,k7, stride=4)
# --------------------------------------------------
class MultiFeatureStem(nn.Module):
    def __init__(self, out_ch, in_ch=1):
        super().__init__()
        split = out_ch // 3
        cfg   = [split, split, out_ch - 2*split]
        ks_pd = [(3,1), (5,2), (7,3)]
        self.branches = nn.ModuleList(
            [nn.Conv2d(in_ch, cfg[i], k, stride=4, padding=p, bias=False)
             for i,(k,p) in enumerate(ks_pd)]
        )
        self.bn  = nn.BatchNorm2d(out_ch)

    def forward(self, x):
        x = torch.cat([b(x) for b in self.branches], 1)
        return F.gelu(self.bn(x))

# --------------------------------------------------
# Modelo 2 : Reverse-Attention en la salida
# --------------------------------------------------
class ConvNeXtSegmentationRA(ConvNeXtSegmentation): # ConvNeXtSegmentation hierarchy
    def __init__(self, num_classes=1):
        super().__init__(num_classes)
        # ponemos RA justo antes del último upsample (64 canales)
        self.ra = ReverseAttention(64)

        # sustituimos los dos últimos bloques del decoder
        *head, last2, last1 = self.decoder
        self.decoder = nn.Sequential(
            *head,                        # hasta 128→64
            nn.Sequential(                # RA + upsample 64→240
                self.ra,
                last2
            ),
            last1                         # 64/num_classes upsample final
        )

# --------------------------------------------------
# Modelo 3 : Multi-Feature Stem + ConvNeXt
# --------------------------------------------------

#class ConvNeXtSegmentationMF(nn.Module): # ConvNeXtSegmentationMF
#    def __init__(self, num_classes=1):
#        super().__init__()
#        # creamos stem y backbone (omitimos patch-embedding interno)
#        tmp_backbone = get_convnext_features("base", in_ch=1)
#        first_out_ch = tmp_backbone[0][0].out_channels
#        self.stem     = MultiFeatureStem(first_out_ch, in_ch=1)
#        self.backbone = nn.Sequential(*tmp_backbone[1:])  # quitamos capa 0 (ya cubierta)

#        # mismo decoder que el modelo base
#        self.decoder  = ConvNeXtSegmentation(num_classes).decoder

#    def forward(self, x):
#        x = self.stem(x)
#        x = self.backbone(x)
#        return self.decoder(x)


# ConvNet with BackBone and SE

In [None]:
#backbone = models.convnext_base(weights=models.ConvNeXt_Base_Weights.DEFAULT).features  # Only features
# Cambiar la primera capa convolucional para aceptar un canal en lugar de tres
#backbone[0][0] = nn.Conv2d(1, backbone[0][0].out_channels, kernel_size=4, stride=4, padding=0)
# Módulo de atención Squeeze-and-Excitation (SE)
class SEBlock(nn.Module):
    def __init__(self, channels, reduction=16):
        super(SEBlock, self).__init__()
        self.fc1 = nn.Linear(channels, channels // reduction)
        self.fc2 = nn.Linear(channels // reduction, channels)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        b, c, _, _ = x.size()
        y = x.mean(dim=(2, 3), keepdim=True)  # Global average pooling
        y = y.view(b, c)
        y = self.fc1(y)
        y = self.fc2(y)
        y = self.sigmoid(y).view(b, c, 1, 1)
        return x * y  # Element-wise multiplication with input

# Definir un modelo de segmentación usando ConvNeXt como backbone
class ConvNeXtSegmentation_SE(nn.Module):
    def __init__(self, num_classes):
        super(ConvNeXtSegmentation_SE, self).__init__()
        self.backbone = backbone
        
        # Agregar bloque de atención Squeeze-and-Excitation
        self.attention = SEBlock(1024)

        # Decodificador ajustado para obtener una salida de 240x240
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(1024, 512, kernel_size=2, stride=2),
            nn.GELU(),
            nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2),
            nn.GELU(),
            nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2),
            nn.GELU(),
            nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.GELU(),
            nn.ConvTranspose2d(64, num_classes, kernel_size=3, stride=2, padding=1, output_padding=1)
        )

    def forward(self, x):
        x = self.backbone(x)
        x = self.attention(x)  # Aplicar la atención después del backbone
        x = self.decoder(x)
        return x


# ConvNeXt with backbone and Reverse Attention

In [None]:
# Cargar ConvNeXt como backbone
backbone = models.convnext_base(weights=models.ConvNeXt_Base_Weights.DEFAULT).features  # Only features

# Cambiar la primera capa convolucional para aceptar un canal en lugar de tres
backbone[0][0] = nn.Conv2d(1, backbone[0][0].out_channels, kernel_size=4, stride=4, padding=0)

# Bloque de Reverse Attention
class ReverseAttention(nn.Module):
    def __init__(self, channels):
        super(ReverseAttention, self).__init__()
        self.query_conv = nn.Conv2d(channels, channels // 2, kernel_size=1)
        self.key_conv = nn.Conv2d(channels, channels // 2, kernel_size=1)
        self.value_conv = nn.Conv2d(channels, channels, kernel_size=1)
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, x):
        # Generar los mapas de consulta, clave y valor
        query = self.query_conv(x).view(x.size(0), -1, x.size(2) * x.size(3))
        key = self.key_conv(x).view(x.size(0), -1, x.size(2) * x.size(3))
        value = self.value_conv(x).view(x.size(0), -1, x.size(2) * x.size(3))
        
        # Calcular la atención "reverse" (inversa), es decir, invertimos el enfoque habitual
        attention = torch.bmm(query.permute(0, 2, 1), key)  # Producto punto entre consulta y clave
        attention = 1 / (1 + torch.exp(-attention))  # Activación sigmoid para "invertir" la atención
        
        # Aplicar la atención a los valores
        out = torch.bmm(value, attention)
        out = out.view(x.size(0), x.size(1), x.size(2), x.size(3))  # Redimensionar de vuelta
        return out

# Definir un modelo de segmentación usando ConvNeXt como backbone
class ConvNeXtSegmentation_RA(nn.Module):
    def __init__(self, num_classes):
        super(ConvNeXtSegmentation_RA, self).__init__()
        self.backbone = backbone
        
        # Agregar bloque de Reverse Attention
        self.reverse_attention = ReverseAttention(1024)

        # Decodificador ajustado para obtener una salida de 240x240
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(1024, 512, kernel_size=2, stride=2),
            nn.GELU(),
            nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2),
            nn.GELU(),
            nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2),
            nn.GELU(),
            nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.GELU(),
            nn.ConvTranspose2d(64, num_classes, kernel_size=3, stride=2, padding=1, output_padding=1)
        )

    def forward(self, x):
        x = self.backbone(x)
        x = self.reverse_attention(x)  # Aplicar el reverse attention
        x = self.decoder(x)
        return x


# ConvNextUNet

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import models


import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import models


class ConvNeXtUNet(nn.Module):
    def __init__(self, num_classes: int = 1, backbone_type: str = "base"):
        super().__init__()

        # ------------------- Backbones disponibles -------------------
        cfg = {
            "base":  dict(fn=models.convnext_base,  w=models.ConvNeXt_Base_Weights.DEFAULT,
                          ch=[128, 256, 512, 1024]),
            "large": dict(fn=models.convnext_large, w=models.ConvNeXt_Large_Weights.DEFAULT,
                          ch=[192, 384, 768, 1536]),
            "tiny":  dict(fn=models.convnext_tiny,  w=models.ConvNeXt_Tiny_Weights.DEFAULT,
                          ch=[ 96, 192, 384,  768]),
            "small": dict(fn=models.convnext_small, w=models.ConvNeXt_Small_Weights.DEFAULT,
                          ch=[ 96, 192, 384,  768]),
        }
        assert backbone_type in cfg, f"Backbone «{backbone_type}» no soportado."
        c = cfg[backbone_type]

        # ------------------- Encoder ConvNeXt -----------------------
        self.backbone = c["fn"](weights=c["w"]).features
        # 1 canal de entrada, stride=4 (512→128)
        self.backbone[0][0] = nn.Conv2d(1, c["ch"][0], kernel_size=4, stride=4, padding=0)

        # ------------------- Decoder (×2 en cada etapa) --------------
        self.up4 = self._up_block(c["ch"][3]    , c["ch"][2])       # 16 → 32
        self.up3 = self._up_block(c["ch"][2] * 2, c["ch"][1])       # 32 → 64
        self.up2 = self._up_block(c["ch"][1] * 2, c["ch"][0])       # 64 →128
        self.up1 = self._up_block(c["ch"][0] * 2, 64)               #128 →256
        self.up0 = self._up_block(64, 32)                           #256 →512

        self.final_conv = nn.Conv2d(32, num_classes, kernel_size=1)

    # ---------- Bloque upsample (ConvT2d + conv 3×3) ----------------
    @staticmethod
    def _up_block(in_c: int, out_c: int):
        return nn.Sequential(
            nn.ConvTranspose2d(in_c, out_c, kernel_size=2, stride=2),
            nn.GELU(),
            nn.Conv2d(out_c, out_c, kernel_size=3, padding=1),
            nn.GELU(),
        )

    # ---------- Asegura que dos tensores tengan mismo H×W -----------
    @staticmethod
    def _match(x, ref):
        if x.shape[2:] != ref.shape[2:]:
            x = F.interpolate(x, size=ref.shape[2:], mode="bilinear", align_corners=False)
        return x

    # ---------------------------- Forward ---------------------------
    def forward(self, x):
        skips, prev = [], None
        for layer in self.backbone:
            x = layer(x)
            # guardamos solo cuando cambia la resolución
            if prev is None or x.shape[-1] != prev:
                skips.append(x)
                prev = x.shape[-1]

        # Esperamos 4 escalas (128,64,32,16) → comprobación
        assert len(skips) >= 4, "No se capturaron suficientes escalas del encoder."

        # Decoder
        x = self.up4(skips[-1])                                  # 16→32
        x = torch.cat([self._match(x, skips[-2]), skips[-2]], 1)

        x = self.up3(x)                                          # 32→64
        x = torch.cat([self._match(x, skips[-3]), skips[-3]], 1)

        x = self.up2(x)                                          # 64→128
        x = torch.cat([self._match(x, skips[-4]), skips[-4]], 1)

        x = self.up1(x)                                          #128→256
        x = self.up0(x)                                          #256→512

        return self.final_conv(x)



In [None]:
import torch
import torch.nn as nn
from torchvision import models

# ----------------------------
# 1) Reverse-Attention module
# ----------------------------
class SkipReverseAttn(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.conv1   = nn.Conv2d(channels, channels, 3, padding=1, bias=False)
        self.bn1     = nn.BatchNorm2d(channels)
        self.act     = nn.GELU()
        self.conv2   = nn.Conv2d(channels, channels, 3, padding=1, bias=False)
        self.bn2     = nn.BatchNorm2d(channels)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        attn = self.act(self.bn1(self.conv1(x)))
        attn = self.sigmoid(self.bn2(self.conv2(attn)))
        return x * (1.0 - attn) + x


# -----------------------------------------
# 2) ConvNeXt-UNet + Reverse-Attention
# -----------------------------------------
class ConvNeXtRAUNet(nn.Module):
    def __init__(self, num_classes=1, backbone_type='base'):
        super().__init__()

        cfg = {
            'tiny' : dict(fn=models.convnext_tiny ,  w=models.ConvNeXt_Tiny_Weights.DEFAULT ,
                          ch=[ 96, 192, 384,  768]),
            'small': dict(fn=models.convnext_small,  w=models.ConvNeXt_Small_Weights.DEFAULT,
                          ch=[ 96, 192, 384,  768]),
            'base' : dict(fn=models.convnext_base ,  w=models.ConvNeXt_Base_Weights.DEFAULT ,
                          ch=[128, 256, 512, 1024]),
            'large': dict(fn=models.convnext_large,  w=models.ConvNeXt_Large_Weights.DEFAULT,
                          ch=[192, 384, 768, 1536])
        }
        assert backbone_type in cfg, f"Backbone '{backbone_type}' no soportado."
        c = cfg[backbone_type]

        # Encoder
        self.backbone = c['fn'](weights=c['w']).features
        self.backbone[0][0] = nn.Conv2d(1, c['ch'][0], kernel_size=4, stride=4, padding=0)

        # Reverse-Attention en skips
        self.ra3 = SkipReverseAttn(c['ch'][2])
        self.ra2 = SkipReverseAttn(c['ch'][1])
        self.ra1 = SkipReverseAttn(c['ch'][0])

        # Decoder
        self.up4 = self._up_block(c['ch'][3]      , c['ch'][2])
        self.up3 = self._up_block(c['ch'][2]*2    , c['ch'][1])
        self.up2 = self._up_block(c['ch'][1]*2    , c['ch'][0])
        self.up1 = self._up_block(c['ch'][0]*2    , 64)
        self.final_conv = nn.Conv2d(64, num_classes, kernel_size=1)

    def _up_block(self, in_ch, out_ch):
        return nn.Sequential(
            nn.ConvTranspose2d(in_ch, out_ch, 2, stride=2),
            nn.GELU(),
            nn.Conv2d(out_ch, out_ch, 3, padding=1),
            nn.GELU()
        )

    def forward(self, x):
        skips = []
        for layer in self.backbone:
            x = layer(x)
            skips.append(x)

        x = self.up4(skips[-1])
        x = torch.cat([x, self.ra3(skips[-2])], dim=1)

        x = self.up3(x)
        x = torch.cat([x, self.ra2(skips[-3])], dim=1)

        x = self.up2(x)
        x = torch.cat([x, self.ra1(skips[-4])], dim=1)

        x = self.up1(x)
        return self.final_conv(x)


# ConvNext UNet + MultiFeatures + Skip Reverse Attention

In [None]:


import torch
import torch.nn as nn
from torchvision import models

# --------------------------------------------------
# 1)  Multi-Feature Stem  (tres ramas: k3/k5/k7, s=4)
# --------------------------------------------------
class MultiFeatureStem(nn.Module):
    def __init__(self, *, in_ch: int, out_ch: int):
        super().__init__()
        split  = out_ch // 3
        ch_cfg = [split, split, out_ch - 2 * split]   # suma == out_ch
        ks_pad = [(3, 1), (5, 2), (7, 3)]

        self.branches = nn.ModuleList(
            nn.Conv2d(in_ch, ch_cfg[i], kernel_size=k, stride=4, padding=p, bias=False)
            for i, (k, p) in enumerate(ks_pad)
        )
        self.bn  = nn.BatchNorm2d(out_ch)
        self.act = nn.GELU()

    def forward(self, x):
        x = torch.cat([b(x) for b in self.branches], dim=1)
        return self.act(self.bn(x))

# --------------------------------------------------
# 2)  Reverse-Attention para los skips
# --------------------------------------------------
class SkipReverseAttn(nn.Module):
    def __init__(self, channels: int):
        super().__init__()
        self.conv1 = nn.Conv2d(channels, channels, 3, padding=1, bias=False)
        self.bn1   = nn.BatchNorm2d(channels)
        self.act   = nn.GELU()
        self.conv2 = nn.Conv2d(channels, channels, 3, padding=1, bias=False)
        self.bn2   = nn.BatchNorm2d(channels)
        self.sig   = nn.Sigmoid()

    def forward(self, x):
        m = self.act(self.bn1(self.conv1(x)))
        m = self.sig(self.bn2(self.conv2(m)))
        return x * (1.0 - m) + x

# --------------------------------------------------
# 3)  ConvNeXt-UNet + MF + RA   (salida H×W)
# --------------------------------------------------
class ConvNeXtSegmentationMF(nn.Module):
    def __init__(self, num_classes: int = 1, backbone_type: str = "base"):
        super().__init__()

        cfg = {
            "tiny":  dict(fn=models.convnext_tiny , w=models.ConvNeXt_Tiny_Weights.DEFAULT ,
                          ch=[ 96, 192, 384,  768]),
            "small": dict(fn=models.convnext_small, w=models.ConvNeXt_Small_Weights.DEFAULT,
                          ch=[ 96, 192, 384,  768]),
            "base":  dict(fn=models.convnext_base , w=models.ConvNeXt_Base_Weights.DEFAULT ,
                          ch=[128, 256, 512, 1024]),
            "large": dict(fn=models.convnext_large, w=models.ConvNeXt_Large_Weights.DEFAULT,
                          ch=[192, 384, 768, 1536]),
        }
        assert backbone_type in cfg, "Backbone no soportado."
        self.ch = cfg[backbone_type]["ch"]

        # ---------- Stem ----------
        self.stem = MultiFeatureStem(in_ch=1, out_ch=self.ch[0])

        # ---------- Encoder (quitamos patch-embed) ----------
        enc_full = cfg[backbone_type]["fn"](weights=cfg[backbone_type]["w"]).features
        self.encoder = nn.Sequential(*enc_full[1:])   # sin capa 0

        # ---------- Reverse-Attention ----------
        self.ra3 = SkipReverseAttn(self.ch[2])
        self.ra2 = SkipReverseAttn(self.ch[1])
        self.ra1 = SkipReverseAttn(self.ch[0])

        # ---------- Decoder (H/32 → H) ----------
        self.up4 = self._up(self.ch[3]    , self.ch[2])   # H/32 → H/16
        self.up3 = self._up(self.ch[2]*2 , self.ch[1])   # H/16 → H/8
        self.up2 = self._up(self.ch[1]*2 , self.ch[0])   # H/8  → H/4
        self.up1 = self._up(self.ch[0]*2 , 64)           # H/4  → H/2
        self.up0 = self._up(64           , 32)           # H/2  → H
        self.final_conv = nn.Conv2d(32, num_classes, 1)

    @staticmethod
    def _up(in_ch: int, out_ch: int) -> nn.Sequential:
        return nn.Sequential(
            nn.ConvTranspose2d(in_ch, out_ch, 2, stride=2),
            nn.GELU(),
            nn.Conv2d(out_ch, out_ch, 3, padding=1, bias=False),
            nn.GELU(),
        )

    # ---------- Forward ----------
    def forward(self, x):
        x = self.stem(x)                #  H/4
        skip1 = x                       # 128ch

        skip2 = skip3 = None
        for layer in self.encoder:
            x = layer(x)
            c = x.shape[1]
            if c == self.ch[1]:   # 256ch  (H/8)
                skip2 = x
            elif c == self.ch[2]: # 512ch  (H/16)
                skip3 = x

        # Bottleneck: 1024ch, H/32
        x = self.up4(x)                            # 512ch, H/16
        x = torch.cat([x, self.ra3(skip3)], dim=1) # 1024ch

        x = self.up3(x)                            # 256ch, H/8
        x = torch.cat([x, self.ra2(skip2)], dim=1) # 512ch

        x = self.up2(x)                            # 128ch, H/4
        x = torch.cat([x, self.ra1(skip1)], dim=1) # 256ch

        x = self.up1(x)                            # 64ch , H/2
        x = self.up0(x)                            # 32ch , H
        return self.final_conv(x)

# --------------------------------------------------
# 4)  Prueba de forma
# --------------------------------------------------
if __name__ == "__main__":
    model = ConvNeXtSegmentationMF(num_classes=1, backbone_type="base")
    x = torch.randn(2, 1, 512, 512)
    y = model(x)
    print("Output:", y.shape)          # torch.Size([2, 1, 512, 512])


# HyperNet model usign ConvNext and attention modules

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class AttentionBlock(nn.Module):
    """
    Módulo de atención para resaltar las características más relevantes de la imagen.
    """
    def __init__(self, in_channels):
        super(AttentionBlock, self).__init__()
        self.query_conv = nn.Conv2d(in_channels, in_channels // 8, kernel_size=1)
        self.key_conv = nn.Conv2d(in_channels, in_channels // 8, kernel_size=1)
        self.value_conv = nn.Conv2d(in_channels, in_channels, kernel_size=1)
        self.gamma = nn.Parameter(torch.zeros(1))

    def forward(self, x):
        batch_size, channels, height, width = x.size()

        # Calcular la atención
        query = self.query_conv(x).view(batch_size, -1, height * width).permute(0, 2, 1)
        key = self.key_conv(x).view(batch_size, -1, height * width)
        energy = torch.bmm(query, key)
        attention = torch.softmax(energy, dim=-1)

        # Aplicar atención a las características
        value = self.value_conv(x).view(batch_size, -1, height * width)
        out = torch.bmm(value, attention.permute(0, 2, 1))
        out = out.view(batch_size, channels, height, width)

        # Salida final con atención
        out = self.gamma * out + x
        return out


class HyperNetSegmentation(nn.Module):
    """
    Modelo de segmentación "HyperNet" que combina ConvNeXt, atención y fusión profunda de características.
    """
    def __init__(self, num_classes=1):
        super(HyperNetSegmentation, self).__init__()
        
        # Encoder con ConvNeXt (construcción de una arquitectura similar a ConvNeXt)
        self.encoder = nn.Sequential(
            nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3),
            nn.ReLU(),
            nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(256, 512, kernel_size=3, stride=2, padding=1),
            nn.ReLU(),
        )

        # Módulos de atención para resaltar las regiones relevantes
        self.attn1 = AttentionBlock(64)  # Atención para la primera capa
        self.attn2 = AttentionBlock(128)  # Atención para la segunda capa
        self.attn3 = AttentionBlock(256)  # Atención para la tercera capa
        self.attn4 = AttentionBlock(512)  # Atención para la cuarta capa

        # Decoder (fusión de características profundas)
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2),
            nn.ReLU(),
            nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2),
            nn.ReLU(),
            nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2),
            nn.ReLU(),
            nn.ConvTranspose2d(64, num_classes, kernel_size=3, stride=2, padding=1, output_padding=1)
        )
        
    def forward(self, x):
        # Pasa por el encoder
        x1 = self.encoder[0:2](x)
        x2 = self.encoder[2:4](x1)
        x3 = self.encoder[4:6](x2)
        x4 = self.encoder[6:](x3)

        # Aplicar atención
        x1 = self.attn1(x1)
        x2 = self.attn2(x2)
        x3 = self.attn3(x3)
        x4 = self.attn4(x4)

        # Fusionar características (la idea es combinar características de diferentes niveles)
        x = torch.cat([x1, x2, x3, x4], dim=1)

        # Pasar por el decoder
        x = self.decoder(x)

        # Para segmentación binaria, aplicar sigmoide en la salida
        x = torch.sigmoid(x)  # Probabilidad en el rango [0, 1]
        return x

# Definir el modelo
#model = HyperNetSegmentation(num_classes=1)  # Para segmentación binaria