# Custom SegFormer Notebook

_Jupyter notebook implementing a custom SegFormer-like model for 5-channel segmentation_

## 1. Импорты

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import numpy as np
import cv2
from pathlib import Path
from typing import List


## 2. Определение PatchEmbed и TransformerBlock

In [None]:
class PatchEmbed(nn.Module):
    def __init__(self, in_channels: int, embed_dim: int, patch_size: int, stride: int, padding: int=None):
        super().__init__()
        if padding is None:
            padding = patch_size // 2
        self.proj = nn.Conv2d(in_channels, embed_dim, kernel_size=patch_size, stride=stride, padding=padding)
        self.norm = nn.LayerNorm(embed_dim)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.proj(x)
        B, C, H, W = x.shape
        x = x.flatten(2).transpose(1, 2)
        x = self.norm(x)
        x = x.transpose(1, 2).reshape(B, C, H, W)
        return x

class TransformerBlock(nn.Module):
    def __init__(self, dim: int, num_heads: int, mlp_ratio: float = 4.0, drop_path: float = 0.0):
        super().__init__()
        self.norm1 = nn.LayerNorm(dim)
        self.attn = nn.MultiheadAttention(dim, num_heads, batch_first=True)
        self.norm2 = nn.LayerNorm(dim)
        hidden_dim = int(dim * mlp_ratio)
        self.mlp = nn.Sequential(
            nn.Linear(dim, hidden_dim),
            nn.GELU(),
            nn.Linear(hidden_dim, dim)
        )
        self.drop_path = nn.Identity()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        B, C, H, W = x.shape
        x_flat = x.flatten(2).transpose(1, 2)
        x_norm = self.norm1(x_flat)
        attn_out, _ = self.attn(x_norm, x_norm, x_norm)
        x_flat = x_flat + self.drop_path(attn_out)
        x_norm2 = self.norm2(x_flat)
        mlp_out = self.mlp(x_norm2)
        x_flat = x_flat + self.drop_path(mlp_out)
        x = x_flat.transpose(1, 2).reshape(B, C, H, W)
        return x


## 3. Определение CustomSegFormer

In [None]:
class CustomSegFormer(nn.Module):
    def __init__(
        self,
        in_channels: int = 5,
        num_classes: int = 9,
        embed_dims: List[int] = [64, 128, 320, 512],
        num_heads: List[int] = [1, 2, 5, 8],
        depths: List[int] = [3, 4, 6, 3],
        mlp_ratio: float = 4.0,
        decoder_dim: int = 256,
    ):
        super().__init__()
        assert len(embed_dims) == 4 and len(num_heads) == 4 and len(depths) == 4
        self.stages = nn.ModuleList()
        in_ch = in_channels
        patch_sizes = [7, 3, 3, 3]
        strides = [4, 2, 2, 2]
        for i in range(4):
            layers = []
            layers.append(PatchEmbed(in_ch, embed_dims[i], patch_sizes[i], strides[i]))
            for _ in range(depths[i]):
                layers.append(TransformerBlock(embed_dims[i], num_heads[i], mlp_ratio))
            self.stages.append(nn.Sequential(*layers))
            in_ch = embed_dims[i]
        self.proj_convs = nn.ModuleList([
            nn.Conv2d(embed_dims[i], decoder_dim, kernel_size=1) for i in range(4)
        ])
        self.head = nn.Conv2d(decoder_dim * 4, num_classes, kernel_size=1)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        feats = []
        for stage in self.stages:
            x = stage(x)
            feats.append(x)
        H0, W0 = feats[0].shape[2:]
        proj_feats = []
        for idx, feat in enumerate(feats):
            p = self.proj_convs[idx](feat)
            if p.shape[2:] != (H0, W0):
                p = F.interpolate(p, size=(H0, W0), mode='bilinear', align_corners=False)
            proj_feats.append(p)
        x_dec = torch.cat(proj_feats, dim=1)
        x_dec = self.head(x_dec)
        scale_factor = 4 * 2 * 2 * 2
        x_dec = F.interpolate(x_dec, scale_factor=scale_factor, mode='bilinear', align_corners=False)
        return x_dec


## 4. Тестирование модели

In [None]:
model = CustomSegFormer(in_channels=5, num_classes=9)
x = torch.randn(1, 5, 512, 512)
y = model(x)
print("Output shape:", y.shape)
