In [None]:
import torch
import torch.nn as nn
from torch.nn import TransformerEncoder, TransformerEncoderLayer

class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(ConvBlock, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
        )

    def forward(self, x):
        return self.conv(x)

class TransformerBlock(nn.Module):
    def __init__(self, embed_dim, num_heads, num_layers, img_size):
        super(TransformerBlock, self).__init__()
        self.img_size = img_size
        self.flatten_dim = img_size * img_size

        encoder_layer = TransformerEncoderLayer(d_model=embed_dim, nhead=num_heads)
        self.transformer = TransformerEncoder(encoder_layer, num_layers=num_layers)

    def forward(self, x):
        batch_size, channels, height, width = x.shape
        assert height == self.img_size and width == self.img_size, "Input image size must match transformer input size."

        x = x.flatten(2).permute(2, 0, 1)  # (batch, channels, height*width) -> (height*width, batch, channels)
        x = self.transformer(x)
        x = x.permute(1, 2, 0).reshape(batch_size, channels, height, width)  # Restore shape

        return x

class UNetWithTransformer(nn.Module):
    def __init__(self, in_channels=1, out_channels=1, base_channels=64, embed_dim=256, num_heads=8, num_layers=4, img_size=128):
        super(UNetWithTransformer, self).__init__()
        
        # Encoder
        self.enc1 = ConvBlock(in_channels, base_channels)
        self.enc2 = ConvBlock(base_channels, base_channels * 2)
        self.enc3 = ConvBlock(base_channels * 2, base_channels * 4)
        self.enc4 = ConvBlock(base_channels * 4, base_channels * 8)
        
        # Transformer block
        self.transformer = TransformerBlock(embed_dim, num_heads, num_layers, img_size // 16)

        # Pooling
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)

        # Decoder
        self.up3 = nn.ConvTranspose2d(base_channels * 8, base_channels * 4, kernel_size=2, stride=2)
        self.dec3 = ConvBlock(base_channels * 8, base_channels * 4)
        
        self.up2 = nn.ConvTranspose2d(base_channels * 4, base_channels * 2, kernel_size=2, stride=2)
        self.dec2 = ConvBlock(base_channels * 4, base_channels * 2)
        
        self.up1 = nn.ConvTranspose2d(base_channels * 2, base_channels, kernel_size=2, stride=2)
        self.dec1 = ConvBlock(base_channels * 2, base_channels)
        
        self.final = nn.Conv2d(base_channels, out_channels, kernel_size=1)

    def forward(self, x):
        # Encoder
        enc1 = self.enc1(x)
        enc2 = self.enc2(self.pool(enc1))
        enc3 = self.enc3(self.pool(enc2))
        enc4 = self.enc4(self.pool(enc3))

        # Transformer
        bottleneck = self.pool(enc4)
        bottleneck = self.transformer(bottleneck)

        # Decoder
        dec3 = self.dec3(torch.cat([self.up3(bottleneck), enc3], dim=1))
        dec2 = self.dec2(torch.cat([self.up2(dec3), enc2], dim=1))
        dec1 = self.dec1(torch.cat([self.up1(dec2), enc1], dim=1))
        
        return self.final(dec1)

# Exemple d'initialisation du modèle
model = UNetWithTransformer(
    in_channels=1, 
    out_channels=1, 
    base_channels=64, 
    embed_dim=256, 
    num_heads=8, 
    num_layers=4, 
    img_size=128
)

# Affichage de la structure du modèle
print(model)


: 