In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import Conv3d, ConvTranspose3d, MaxPool3d, BatchNorm3d, ReLU, Softmax

"""Blocs de base : **RésidualBlock3D**"""

class ResidualBlock3D(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(ResidualBlock3D, self).__init__()
        self.conv1 = Conv3d(in_channels, out_channels, kernel_size=3, padding=1)
        self.bn1 = BatchNorm3d(out_channels)
        self.relu = ReLU()
        self.conv2 = Conv3d(out_channels, out_channels, kernel_size=3, padding=1)
        self.bn2 = BatchNorm3d(out_channels)

        self.skip = Conv3d(in_channels, out_channels, kernel_size=1, padding=0) if in_channels != out_channels else nn.Identity()

    def forward(self, x):
        identity = self.skip(x)
        out = self.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += identity
        return self.relu(out)

"""Bloc d'encodage : **EncoderBlock**"""

class EncoderBlock(nn.Module):
    def __init__(self, in_channels, features):
        super(EncoderBlock, self).__init__()
        self.block = ResidualBlock3D(in_channels, features)
        self.pool = MaxPool3d(kernel_size=2, stride=2)

    def forward(self, x):
        x = self.block(x)
        skip = x
        x = self.pool(x)
        return x, skip

"""**Encodeur simple pour une modalité : SingleEncoder**"""

class SingleEncoder(nn.Module):
    def __init__(self, in_channels):
        super(SingleEncoder, self).__init__()
        self.enc1 = EncoderBlock(in_channels, 16)
        self.enc2 = EncoderBlock(16, 32)
        self.enc3 = EncoderBlock(32, 64)
        self.enc4 = EncoderBlock(64, 128)
        self.enc5 = ResidualBlock3D(128, 256)

    def forward(self, x):
        x, skip1 = self.enc1(x)
        x, skip2 = self.enc2(x)
        x, skip3 = self.enc3(x)
        x, skip4 = self.enc4(x)
        x = self.enc5(x)
        return x, [skip4, skip3, skip2, skip1]

"""**Encodeur multi-modalité : MultiModalEncoder**"""

class MultiModalEncoder(nn.Module):
    def __init__(self, in_channels):
        super(MultiModalEncoder, self).__init__()
        self.encoder_flair = SingleEncoder(in_channels)
        self.encoder_t1 = SingleEncoder(in_channels)
        self.encoder_t1ce = SingleEncoder(in_channels)
        self.encoder_t2 = SingleEncoder(in_channels)

    def forward(self, flair, t1, t1ce, t2):
        flair_out, flair_skips = self.encoder_flair(flair)
        t1_out, t1_skips = self.encoder_t1(t1)
        t1ce_out, t1ce_skips = self.encoder_t1ce(t1ce)
        t2_out, t2_skips = self.encoder_t2(t2)
        return (flair_out, t1_out, t1ce_out, t2_out), [flair_skips, t1_skips, t1ce_skips, t2_skips]

"""**Transformées en ondelettes : DWT_3D et IDWT_3D**"""

def simple_dwt3d(x):
    b, c, d, h, w = x.shape
    d, h, w = d // 2, h // 2, w // 2

    return [
        x[:, :, 0::2, 0::2, 0::2],  # LLL
        x[:, :, 1::2, 0::2, 0::2],  # HLL
        x[:, :, 0::2, 1::2, 0::2],  # LHL
        x[:, :, 1::2, 1::2, 0::2],  # HHL
        x[:, :, 0::2, 0::2, 1::2],  # LLH
        x[:, :, 1::2, 0::2, 1::2],  # HLH
        x[:, :, 0::2, 1::2, 1::2],  # LHH
        x[:, :, 1::2, 1::2, 1::2],  # HHH
    ]

def simple_idwt3d(subbands):
    lll, hll, lhl, hhl, llh, hlh, lhh, hhh = subbands
    b, c, d, h, w = lll.shape
    out = torch.zeros((b, c, d*2, h*2, w*2), device=lll.device)

    out[:, :, 0::2, 0::2, 0::2] = lll
    out[:, :, 1::2, 0::2, 0::2] = hll
    out[:, :, 0::2, 1::2, 0::2] = lhl
    out[:, :, 1::2, 1::2, 0::2] = hhl
    out[:, :, 0::2, 0::2, 1::2] = llh
    out[:, :, 1::2, 0::2, 1::2] = hlh
    out[:, :, 0::2, 1::2, 1::2] = lhh
    out[:, :, 1::2, 1::2, 1::2] = hhh

    return out

def fuse_subbands(subband_lists):
    fused = []
    lll = torch.stack([bands[0] for bands in subband_lists]).mean(dim=0)
    fused.append(lll)
    for i in range(1, 8):
        high = sum(bands[i] for bands in subband_lists)
        fused.append(high)
    return fused

"""**Module de fusion par ondelettes : WaveletFusionModule**"""

class WaveletFusionModule(nn.Module):
    def forward(self, flair, t1, t1ce, t2):
        flair_bands = simple_dwt3d(flair)
        t1_bands = simple_dwt3d(t1)
        t1ce_bands = simple_dwt3d(t1ce)
        t2_bands = simple_dwt3d(t2)

        fused_bands = fuse_subbands([flair_bands, t1_bands, t1ce_bands, t2_bands])
        fused_feature = simple_idwt3d(fused_bands)

        # Injection dans chaque modalité
        flair_out = flair + fused_feature
        t1_out = t1 + fused_feature
        t1ce_out = t1ce + fused_feature
        t2_out = t2 + fused_feature

        # Concaténation des 4 modalités après injection
        return torch.cat([flair_out, t1_out, t1ce_out, t2_out], dim=1)

"""**Module d’attention contextuelle : GlobalContextAwareModule**"""

class GlobalContextAwareModule(nn.Module):
    def __init__(self, in_channels):
        super(GlobalContextAwareModule, self).__init__()
        self.conv1 = Conv3d(in_channels, in_channels, kernel_size=1)
        self.softmax = Softmax(dim=1)
        self.transform = nn.Sequential(
            Conv3d(in_channels, in_channels, kernel_size=1),
            BatchNorm3d(in_channels),
            ReLU(),
            Conv3d(in_channels, in_channels, kernel_size=1)
        )

    def forward(self, x):
        context = self.conv1(x)
        attention = self.softmax(context)
        l_context = x * attention
        l_trans = self.transform(l_context)
        return x + l_trans

"""**Bloc de décodage : DecoderBlock3D**"""

class DecoderBlock3D(nn.Module):
    def __init__(self, in_channels, skip_channels, out_channels):
        super(DecoderBlock3D, self).__init__()
        self.upconv = ConvTranspose3d(in_channels, out_channels, kernel_size=2, stride=2)
        self.conv1 = Conv3d(out_channels + skip_channels, out_channels, kernel_size=3, padding=1)
        self.bn1 = BatchNorm3d(out_channels)
        self.relu = ReLU()
        self.conv2 = Conv3d(out_channels, out_channels, kernel_size=3, padding=1)

    def forward(self, x, skip):
        x = self.upconv(x)
        x = torch.cat([x, skip], dim=1)
        x = self.relu(self.bn1(self.conv1(x)))
        return self.conv2(x)

"""**Modèle complet : BrainTumorSegmentationModel**"""

class BrainTumorSegmentationModel(nn.Module):
    def __init__(self, in_channels, num_classes):
        super(BrainTumorSegmentationModel, self).__init__()
        self.encoder = MultiModalEncoder(in_channels)

        # Wavelet Fusion Modules
        self.wfm_final = WaveletFusionModule()
        self.wfm4 = WaveletFusionModule()
        self.wfm3 = WaveletFusionModule()
        self.wfm2 = WaveletFusionModule()
        self.wfm1 = WaveletFusionModule()
        self.gcam = GlobalContextAwareModule(1024)

        # Decoder blocks
        self.dec4 = DecoderBlock3D(1024, 4 * 128, 128)  # 4x128 = 512
        self.dec3 = DecoderBlock3D(128, 4 * 64, 64)     # 4x64 = 256
        self.dec2 = DecoderBlock3D(64, 4 * 32, 32)      # 4x32 = 128
        self.dec1 = DecoderBlock3D(32, 4 * 16, 16)      # 4x16 = 64

        self.final = Conv3d(16, num_classes, kernel_size=1)

    def forward(self, flair, t1, t1ce, t2):
        # Encoders
        (f, t1_, t1ce_, t2_), skips = self.encoder(flair, t1, t1ce, t2)
        print(f"Sortie encoder : {f.shape}, {t1_.shape}, {t1ce_.shape}, {t2_.shape}")
        # Final fusion des sorties d'encodeurs
        fused_concat  = self.wfm_final(f, t1_, t1ce_, t2_)
        print(f"Sortie après fusion : {fused_concat.shape}")

        gcam_out = self.gcam(fused_concat)
        print(f"Sortie après GCAM : {gcam_out.shape}")

        # DAffichage des tailles des skip connections
        for i, skip_list in enumerate(skips):
            for j, skip in enumerate(skip_list):
                print(f"Shape skip {i+1}-{j+1}: {skip.shape}")

        # Fusion des skip connections à chaque niveau
        s4 = self.wfm4(*[x[0] for x in skips])
        s3 = self.wfm3(*[x[1] for x in skips])
        s2 = self.wfm2(*[x[2] for x in skips])
        s1 = self.wfm1(*[x[3] for x in skips])

        # Decoder avec skip connections fusionnées
        x = self.dec4(gcam_out, s4)
        print(f"Sortie après dec4 : {x.shape}")
        x = self.dec3(x, s3)
        print(f"Sortie après dec3 : {x.shape}")
        x = self.dec2(x, s2)
        print(f"Sortie après dec2 : {x.shape}")
        x = self.dec1(x, s1)
        print(f"Sortie après dec1 : {x.shape}")

        #
        return torch.sigmoid(self.final(x))

model = BrainTumorSegmentationModel(in_channels=1, num_classes=3)