In [1]:
import torch
from torch import nn

import glob
import os
from tqdm import tqdm
from datetime import datetime
import json
import yaml

from argparse import ArgumentParser

from itertools import combinations, product
import warnings

from functools import partial

import torchvision
from torchvision.transforms import v2
from torchvision import tv_tensors
from torchvision import models

from torchvision.models.vision_transformer import EncoderBlock as VitEncoderBlock, MLPBlock

import types

import torch.nn.functional as F

import segmentation_models_pytorch as smp
from segmentation_models_pytorch.base import (
    ClassificationHead,
    SegmentationHead,
    SegmentationModel,
)

from segmentation_models_pytorch.base import modules as md

from segmentation_models_pytorch.encoders import get_encoder
from segmentation_models_pytorch.base.hub_mixin import supports_config_loading
from typing import Any, Dict, Optional, Union, Callable, Sequence, List, Literal

import lightning as L
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch.loggers import CSVLogger, Logger
from lightning.pytorch.utilities.rank_zero import rank_zero_only

from torchmetrics import classification
from torchmetrics import segmentation

from sklearn.model_selection import train_test_split
from sklearn import metrics

from copy import deepcopy

import einops as eo

import numpy as np

import pandas as pd

from matplotlib import pyplot as plt
%matplotlib inline

  from .autonotebook import tqdm as notebook_tqdm


In [32]:

class ComputeWeights(nn.Sequential):
    def __init__(self, in_channels, out_channels, kernel_size, padding):
        mod_lst = []
        mod_lst.append(nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, padding=padding))
        mod_lst.append(nn.Sigmoid())
        super().__init__(*mod_lst)
        

class ChannelAtt(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, padding):
        super().__init__()
        self.in2_proj = ComputeWeights(in_channels, out_channels, kernel_size, padding)

    def forward(self, in1, in2):
        weights = self.in2_proj(in2)
        bs, ch1, rows1, cols1 = in1.shape
        bs, ch2, rows2, cols2 = weights.shape
        winr = rows1//rows2
        winc = cols1//cols2

        weights = weights.view(bs, ch2, rows2, cols2, 1, 1)

        in1 = eo.rearrange(in1, 'b ch (rn wr) (cn wc) -> b ch rn cn wr wc', rn=rows2, cn=cols2, wr=winr, wc=winc)

        in1 = in1*weights

        in1 = eo.rearrange(in1, 'b ch rn cn wr wc -> b ch (rn wr) (cn wc)', rn=rows2, cn=cols2, wr=winr, wc=winc)
        return in1
in1 = torch.randn(1, 200, 96, 96)
in2 = torch.randn(1, 200, 24, 24)

chat = ChannelAtt(in_channels=200, out_channels=200, kernel_size=3, padding=1)
ret = chat(in1, in2)
ret.shape

torch.Size([1, 200, 96, 96])

In [None]:
in1 = torch.arange(0, 72).reshape(1, 2, 6, 6,)
in2 = torch.arange(0, 18).reshape(1, 2, 3, 3)



In [None]:
in1 = torch.arange(0, 72).reshape(1, 2, 6, 6,)
in2 = torch.arange(0, 18).reshape(1, 2, 3, 3)

bs, ch1, rows1, cols1 = in1.shape
bs, ch2, rows2, cols2 = in2.shape

winr = rows1//rows2
winc = cols1//cols2

in2 = in2.view(bs, ch2, rows2, cols2, 1, 1)

in1 = eo.rearrange(in1, 'b ch (rn wr) (cn wc) -> b ch rn cn wr wc', rn=rows2, cn=cols2, wr=winr, wc=winc)
ret = in1 * in2
in1[0, 0], in2[0, 0], ret[0,0]

(tensor([[[[ 0,  1],
           [ 6,  7]],
 
          [[ 2,  3],
           [ 8,  9]],
 
          [[ 4,  5],
           [10, 11]]],
 
 
         [[[12, 13],
           [18, 19]],
 
          [[14, 15],
           [20, 21]],
 
          [[16, 17],
           [22, 23]]],
 
 
         [[[24, 25],
           [30, 31]],
 
          [[26, 27],
           [32, 33]],
 
          [[28, 29],
           [34, 35]]]]),
 tensor([[[[0]],
 
          [[1]],
 
          [[2]]],
 
 
         [[[3]],
 
          [[4]],
 
          [[5]]],
 
 
         [[[6]],
 
          [[7]],
 
          [[8]]]]),
 tensor([[[[  0,   0],
           [  0,   0]],
 
          [[  2,   3],
           [  8,   9]],
 
          [[  8,  10],
           [ 20,  22]]],
 
 
         [[[ 36,  39],
           [ 54,  57]],
 
          [[ 56,  60],
           [ 80,  84]],
 
          [[ 80,  85],
           [110, 115]]],
 
 
         [[[144, 150],
           [180, 186]],
 
          [[182, 189],
           [224, 231]],
 
          [[

In [28]:
in1


tensor([[[[[[ 0,  2,  4],
            [12, 14, 16],
            [24, 26, 28]],

           [[ 1,  3,  5],
            [13, 15, 17],
            [25, 27, 29]]],


          [[[ 6,  8, 10],
            [18, 20, 22],
            [30, 32, 34]],

           [[ 7,  9, 11],
            [19, 21, 23],
            [31, 33, 35]]]],



         [[[[36, 38, 40],
            [48, 50, 52],
            [60, 62, 64]],

           [[37, 39, 41],
            [49, 51, 53],
            [61, 63, 65]]],


          [[[42, 44, 46],
            [54, 56, 58],
            [66, 68, 70]],

           [[43, 45, 47],
            [55, 57, 59],
            [67, 69, 71]]]]]])

In [11]:
any([n in 'pidoras' for n in ['huy', 'pidor']])

True

In [10]:
acc = classification.Accuracy(task='multiclass', num_classes=4)
micro_avg_acc = classification.Accuracy(task='multiclass', average='micro', num_classes=4)
macro_avg_acc = classification.Accuracy(task='multiclass', average='macro', num_classes=4)
target = torch.tensor([0, 2, 2, 3])
preds = torch.tensor([0, 2, 1, 2])
acc.update(preds, target)
micro_avg_acc.update(preds, target)
macro_avg_acc.update(preds, target)

target = torch.tensor([0, 2, 2, 3])
preds = torch.tensor([0, 2, 2, 1])
acc.update(preds, target)
micro_avg_acc.update(preds, target)
macro_avg_acc.update(preds, target)

target = torch.tensor([0, 1, 3, 3])
preds = torch.tensor([0, 2, 2, 2])
acc.update(preds, target)
micro_avg_acc.update(preds, target)
macro_avg_acc.update(preds, target)

micro_avg_acc.compute(), macro_avg_acc.compute(), acc.compute()

(tensor(0.5000), tensor(0.4375), tensor(0.5000))

# Работа с данными

In [2]:
# datasets
class SegmentationDataset(torch.utils.data.Dataset):
    def __init__(self, path_to_dataset_root:str, samples_df:pd.DataFrame, channel_indices:list, transforms:v2._transform.Transform, dtype:torch.dtype, device:torch.device):
        '''
        In:
            path_to_dataset_root - путь до корневой папки с датасетом
            samples_df - pandas.DataFrame с информацией о файлах
            channel_indices - список с номерами каналов мультиспектрального изображения
            transforms - аугментация изображений
        '''
        super().__init__()

        self.path_to_dataset_root = path_to_dataset_root
        self.samples_df = samples_df
        self.specter_bands_list = [i for i in channel_indices if isinstance(i, int)]
        self.specter_indices_names = [s for s in channel_indices if isinstance(s, str)]
        self.dtype_trasform = v2.ToDtype(dtype=dtype, scale=True)
        self.other_transforms = transforms
        self.device = device

    def __len__(self):
        return len(self.samples_df)
    @staticmethod
    def compute_spectral_index(index_name, image):
        if index_name.lower() == 'ndvi':
            b0 = image[7] # NIR, B8
            b1 = image[3] # RED, B4
            
        elif index_name.lower() == 'ndbi':
            b0 = image[10] #SWIR, B11
            b1 = image[7] #NIR, B8

        elif index_name.lower() == 'ndwi':
            b0 = image[2] #green, B3
            b1 = image[7] #NIR, B8

        elif index_name.lower() == 'ndre':
            b0 = image[7] #NIR, B8
            b1 = image[5] #Red Edge, B6
        
        with warnings.catch_warnings():
            warnings.simplefilter('ignore')
            index = (b0 - b1)/(b0 + b1)
            
        index = np.nan_to_num(index, nan=-5)

        return index    

    def __getitem__(self, idx):
        sample = self.samples_df.iloc[idx]

        file_name = sample['file_name']

        path_to_image = os.path.join(self.path_to_dataset_root, 'images', f'{file_name}.npy')
        path_to_labels = os.path.join(self.path_to_dataset_root, 'labels', f'{file_name}.npy')

        image = np.load(path_to_image)
        spectral_indices = []
        # вычисляем спектральные индексы
        if len(self.specter_indices_names) > 0:
            for sp_index_name in self.specter_indices_names:
                spectral_index = self.compute_spectral_index(sp_index_name, image)
                spectral_index = torch.as_tensor(spectral_index)
                spectral_indices.append(spectral_index.unsqueeze(0))

            spectral_indices = torch.cat(spectral_indices)
            spectral_indices = self.dtype_trasform(spectral_indices)

        image = torch.as_tensor(image[self.specter_bands_list], dtype=torch.int16)
        image = self.dtype_trasform(image)
        # добавляем спектральные индексы
        if len(self.specter_indices_names) > 0:
            image = torch.cat([image, spectral_indices], dim=0) 
        #image = np.load(path_to_image)
        # метки читаем как одноканальное изображение
        label = np.load(path_to_labels)
        label = np.where(label >= 0, label, 0)
        #label = torch.as_tensor(np.load(path_to_labels), dtype=torch.uint8).long()
        label = torch.as_tensor(label, dtype=torch.uint8).long()
        
        image = tv_tensors.Image(image, device=self.device)
        label = tv_tensors.Mask(label, device=self.device)

        transforms_dict = {'image':image, 'mask':label}
        transformed = self.other_transforms(transforms_dict)
        return transformed['image'], transformed['mask']#, image

# Описание модуля Lightning

In [3]:
def compute_pred_mask(pred):
    '''
    Определение маски классов на основе сгенерированной softmax маски
    '''
    #pred = pred.detach()
    _, pred_mask = pred.max(dim=1)
    return pred_mask#.cpu().numpy()

class CSVLoggerMetricsAndConfusion(CSVLogger):
    @rank_zero_only
    def save_confusion(self, epoch_idx, confusion_matrix, class_names, mode):
        os.makedirs(self.log_dir, exist_ok=True)
        path_to_saving_file = os.path.join(self.log_dir, f'{mode}_confusion_matrices.csv')
        if os.path.isfile(path_to_saving_file):
            # читаем матрицы ошибок
            confusion_df = pd.read_csv(path_to_saving_file)
            multiindex = pd.MultiIndex.from_arrays([confusion_df['epoch'], confusion_df['classes']])
            confusion_df = confusion_df.set_index(multiindex)
            confusion_df = confusion_df.drop(columns=['epoch', 'classes'])
        else:
            confusion_df = pd.DataFrame()

        multiindex = pd.MultiIndex.from_product([[epoch_idx], class_names], names=['epoch', 'classes'])
        epoch_confusion_df = pd.DataFrame(data=confusion_matrix, columns=class_names, index=multiindex)
        confusion_df = pd.concat([confusion_df, epoch_confusion_df])
        confusion_df.to_csv(path_to_saving_file)

class LightningSegmentationModule(L.LightningModule):
    def __init__(self, model:nn.Module, criterion:nn.Module, optimizer_cfg:dict, metrics_dict:dict, name2class_idx_dict:dict) -> None:
        '''
        Модуль Lightning для обучения сегментационной сети
        In:
            model - нейронная сеть
            criterion - функция потерь
            
            name2class_idx_dict - словарь с отображением {class_name(str): class_idx(int)}
        '''
        super().__init__()
        self.model = model
        self.criterion = criterion
        self.optimizer_cfg = optimizer_cfg
        self.metrics_dict = metrics_dict
        
        self.name2class_idx_dict = name2class_idx_dict
        # словарь, выполняющий обратное отображение class_idx в class_name
        self.class_idx2name_dict = {v:k for k, v in name2class_idx_dict.items()}
        
    def configure_optimizers(self):
        optimizer = self.optimizer_cfg['optmizer'](self.parameters(), **self.optimizer_cfg['optimizer_args'])
        ret_dict = {'optimizer': optimizer}
        if self.optimizer_cfg['lr_scheduler'] is not None:
            scheduler = self.optimizer_cfg['lr_scheduler'](optimizer, **self.optimizer_cfg['lr_scheduler_args'])
            ret_dict['lr_scheduler'] = {'scheduler': scheduler}
            ret_dict['lr_scheduler'].update(self.optimizer_cfg['lr_scheduler_params'])
        
        return ret_dict

    def compute_metrics(self, pred_labels, true_labels, mode):
        metrics_names_list = self.metrics_dict[mode].keys()
        for metric_name in metrics_names_list:
            if 'dice' in metric_name.lower():
                self.metrics_dict[mode][metric_name].update(pred_labels, true_labels)
            else:
                self.metrics_dict[mode][metric_name].update(pred_labels.reshape(-1), true_labels.reshape(-1))    
    
    def training_step(self, batch, batch_idx):
        data, true_labels = batch
        pred = self.model(data)
        loss = self.criterion(pred, true_labels)
        # вычисление сгенерированной маски
        pred_labels = compute_pred_mask(pred)
        #true_labels = true_labels.detach().cpu().numpy()
        
        self.compute_metrics(pred_labels=pred_labels, true_labels=true_labels, mode='train')

        # т.к. мы вычисляем общую ошибку на всей эпохе, то записываем в лог только значение функции потерь
        self.log('train_loss', loss, on_step=False, on_epoch=True, prog_bar=True)
        return loss
    
    def validation_step(self, batch, batch_idx):
        data, true_labels = batch
        pred = self.model(data)
        loss = self.criterion(pred, true_labels)
        pred_labels = compute_pred_mask(pred)
        self.compute_metrics(pred_labels=pred_labels, true_labels=true_labels, mode='val')
        self.log('val_loss', loss, on_step=False, on_epoch=True, prog_bar=True)
        return loss
    
    def log_metrics(self, mode):
        for metric_name, metric in self.metrics_dict[mode].items():
            metric_val = metric.compute()
            if 'confusion' in metric_name.lower():
                disp_name = f'{mode}_{metric_name}'
                class_names = [self.class_idx2name_dict[i] for i in range(len(self.class_idx2name_dict))]
                if isinstance(self.logger, CSVLoggerMetricsAndConfusion):
                    self.logger.save_confusion(
                        epoch_idx=self.current_epoch,
                        confusion_matrix=metric_val.cpu().tolist(),
                        class_names=class_names,
                        mode=mode)
            else:
                for i, value in enumerate(metric_val):
                    class_name = self.class_idx2name_dict[i]
                    disp_name = f'{mode}_{metric_name}_{class_name}'
                    self.log(disp_name, value, on_step=False, on_epoch=True, prog_bar=True)
                disp_name = f'{mode}_{metric_name}_mean'
                self.log(disp_name, metric_val.mean(), on_step=False, on_epoch=True, prog_bar=True)
            self.metrics_dict[mode][metric_name].reset()

    def on_train_epoch_end(self):
        '''
        Декодирование результатов тренировочной эпохи и запись их в лог
        '''
        self.log_metrics(mode='train')
 
    def on_validation_epoch_end(self):
        '''
        Декодирование результатов тестовой эпохи и запись их в лог
        (работает точно также, как и )
        '''
        self.log_metrics(mode='val')




# Новые функции потерь (Dice-Crossentropy)

In [4]:
class DiceCELoss(nn.Module):
    def __init__(
            self,
            ce_weight,
            ce_ignore_index,
            ce_reducion,
            ce_label_smoothing,
            dice_mode,
            dice_classes,
            dice_log_loss,
            dice_from_logits,
            dice_smooth,
            dice_ignore_index,
            dice_eps,
            losses_weight: List = [0.5, 0.5],
            is_trainable_weights: bool = False,
            weights_processing_type: str = None,
            ):
        super().__init__()
        self.dice = smp.losses.DiceLoss(
            mode=dice_mode,
            classes=dice_classes,
            log_loss=dice_log_loss,
            from_logits=dice_from_logits,
            smooth=dice_smooth,
            ignore_index=dice_ignore_index,
            eps=dice_eps
            )
        self.ce = nn.CrossEntropyLoss(
            weight=ce_weight,
            ignore_index=ce_ignore_index,
            reduction=ce_reducion,
            label_smoothing=ce_label_smoothing,
        )
        self.loss_weights = torch.tensor(losses_weight)
        if is_trainable_weights:
            self.loss_weights = nn.Parameter(self.loss_weights)
        self.weights_processing_type = weights_processing_type

    def forward(self, pred, true):
        weights = self.loss_weights
        if self.weights_processing_type == 'softmax':
            weights = weights.softmax(dim=0)
        elif self.weights_processing_type == 'sigmoid':
            weights = weights.softmax(dim=0)

        ce_loss = self.ce(pred, true) * weights[0]
        dice_loss = self.dice(pred, true) * weights[1]
        return ce_loss + dice_loss

# Адаптация UNet++

In [5]:
class UnetppDecoderBlockMod(nn.Module):
    def __init__(
        self,
        in_channels: int,
        skip_channels: int,
        out_channels: int,
        use_norm: Union[bool, str, Dict[str, Any]] = "batchnorm",
        attention_type: Optional[str] = None,
        interpolation_mode: str = "nearest",
    ):
        super().__init__()
        self.conv1 = md.Conv2dReLU(
            in_channels + skip_channels,
            out_channels,
            kernel_size=3,
            padding=1,
            use_norm=use_norm,
        )
        self.attention1 = md.Attention(
            attention_type, in_channels=in_channels + skip_channels
        )
        self.conv2 = md.Conv2dReLU(
            out_channels,
            out_channels,
            kernel_size=3,
            padding=1,
            use_norm=use_norm,
        )
        self.attention2 = md.Attention(attention_type, in_channels=out_channels)
        self.interpolation_mode = interpolation_mode

    def forward(
        self, x: torch.Tensor, skip: Optional[torch.Tensor] = None, scale_factor: float = 2.0
    ) -> torch.Tensor:
        if scale_factor != 1:
            x = F.interpolate(x, scale_factor=scale_factor, mode=self.interpolation_mode)
        if skip is not None:
            x = torch.cat([x, skip], dim=1)
            x = self.attention1(x)
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.attention2(x)
        return x


class UnetppCenterBlock(nn.Sequential):
    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        use_norm: Union[bool, str, Dict[str, Any]] = "batchnorm",
    ):
        conv1 = md.Conv2dReLU(
            in_channels,
            out_channels,
            kernel_size=3,
            padding=1,
            use_norm=use_norm,
        )
        conv2 = md.Conv2dReLU(
            out_channels,
            out_channels,
            kernel_size=3,
            padding=1,
            use_norm=use_norm,
        )
        super().__init__(conv1, conv2)


class UnetPlusPlusDecoderMod(nn.Module):
    def __init__(
        self,
        encoder_channels: Sequence[int],
        decoder_channels: Sequence[int],
        n_blocks: int = 5,
        use_norm: Union[bool, str, Dict[str, Any]] = "batchnorm",
        attention_type: Optional[str] = None,
        interpolation_mode: str = "nearest",
        center: bool = False,
    ):
        super().__init__()

        if n_blocks != len(decoder_channels):
            raise ValueError(
                f"Model depth is {n_blocks}, but you provide `decoder_channels` for {len(decoder_channels)} blocks."
            )

        # remove first skip with same spatial resolution
        encoder_channels = encoder_channels[1:]
        # reverse channels to start from head of encoder
        encoder_channels = encoder_channels[::-1]

        # computing blocks input and output channels
        head_channels = encoder_channels[0]
        self.in_channels = [head_channels] + list(decoder_channels[:-1])
        self.skip_channels = list(encoder_channels[1:]) + [0]
        self.out_channels = decoder_channels
        if center:
            self.center = UnetppCenterBlock(
                head_channels,
                head_channels,
                use_norm=use_norm,
            )
        else:
            self.center = nn.Identity()

        # combine decoder keyword arguments
        kwargs = dict(
            use_norm=use_norm,
            attention_type=attention_type,
            interpolation_mode=interpolation_mode,
        )

        blocks = {}
        for layer_idx in range(len(self.in_channels) - 1):
            for depth_idx in range(layer_idx + 1):
                if depth_idx == 0:
                    in_ch = self.in_channels[layer_idx]
                    skip_ch = self.skip_channels[layer_idx] * (layer_idx + 1)
                    out_ch = self.out_channels[layer_idx]
                else:
                    out_ch = self.skip_channels[layer_idx]
                    skip_ch = self.skip_channels[layer_idx] * (
                        layer_idx + 1 - depth_idx
                    )
                    in_ch = self.skip_channels[layer_idx - 1]
                blocks[f"x_{depth_idx}_{layer_idx}"] = UnetppDecoderBlockMod(
                    in_ch, skip_ch, out_ch, **kwargs
                )
        blocks[f"x_{0}_{len(self.in_channels) - 1}"] = UnetppDecoderBlockMod(
            self.in_channels[-1], 0, self.out_channels[-1], **kwargs
        )
        self.blocks = nn.ModuleDict(blocks)
        self.depth = len(self.in_channels) - 1

    def forward(self, features: List[torch.Tensor]) -> torch.Tensor:
        # если 
        bs, channels, img_rows, img_cols = features[0].shape
        _, _, feat1_rows, feat1_cols = features[1].shape
        #upsample_scale_factors_list = [2.0 for i in range(len(self.blocks))]
        if (img_rows, img_cols) == (feat1_rows, feat1_cols):
            output_upsample_scaling_factor = 1.0
        else:
            output_upsample_scaling_factor = 2.0

        features = features[1:]  # remove first skip with same spatial resolution
        features = features[::-1]  # reverse channels to start from head of encoder

        # start building dense connections
        dense_x = {}
        for layer_idx in range(len(self.in_channels) - 1):
            for depth_idx in range(self.depth - layer_idx):
                if layer_idx == 0:
                    output = self.blocks[f"x_{depth_idx}_{depth_idx}"](
                        features[depth_idx], features[depth_idx + 1]
                    )
                    dense_x[f"x_{depth_idx}_{depth_idx}"] = output
                else:
                    dense_l_i = depth_idx + layer_idx
                    cat_features = [
                        dense_x[f"x_{idx}_{dense_l_i}"]
                        for idx in range(depth_idx + 1, dense_l_i + 1)
                    ]
                    cat_features = torch.cat(
                        cat_features + [features[dense_l_i + 1]], dim=1
                    )
                    dense_x[f"x_{depth_idx}_{dense_l_i}"] = self.blocks[
                        f"x_{depth_idx}_{dense_l_i}"
                    ](dense_x[f"x_{depth_idx}_{dense_l_i - 1}"], cat_features)
        
        dense_x[f"x_{0}_{self.depth}"] = self.blocks[f"x_{0}_{self.depth}"](
            dense_x[f"x_{0}_{self.depth - 1}"], scale_factor=output_upsample_scaling_factor
        )
        return dense_x[f"x_{0}_{self.depth}"]

class UnetPlusPlusMod(SegmentationModel):
    """Unet++ is a fully convolution neural network for image semantic segmentation. Consist of *encoder*
    and *decoder* parts connected with *skip connections*. Encoder extract features of different spatial
    resolution (skip connections) which are used by decoder to define accurate segmentation mask. Decoder of
    Unet++ is more complex than in usual Unet.

    Args:
        encoder_name: Name of the classification model that will be used as an encoder (a.k.a backbone)
            to extract features of different spatial resolution
        encoder_depth: A number of stages used in encoder in range [3, 5]. Each stage generate features
            two times smaller in spatial dimensions than previous one (e.g. for depth 0 we will have features
            with shapes [(N, C, H, W),], for depth 1 - [(N, C, H, W), (N, C, H // 2, W // 2)] and so on).
            Default is 5
        encoder_weights: One of **None** (random initialization), **"imagenet"** (pre-training on ImageNet) and
            other pretrained weights (see table with available weights for each encoder_name)
        decoder_channels: List of integers which specify **in_channels** parameter for convolutions used in decoder.
            Length of the list should be the same as **encoder_depth**
        decoder_use_norm:     Specifies normalization between Conv2D and activation.
            Accepts the following types:
            - **True**: Defaults to `"batchnorm"`.
            - **False**: No normalization (`nn.Identity`).
            - **str**: Specifies normalization type using default parameters. Available values:
              `"batchnorm"`, `"identity"`, `"layernorm"`, `"instancenorm"`, `"inplace"`.
            - **dict**: Fully customizable normalization settings. Structure:
              ```python
              {"type": <norm_type>, **kwargs}
              ```
              where `norm_name` corresponds to normalization type (see above), and `kwargs` are passed directly to the normalization layer as defined in PyTorch documentation.

            **Example**:
            ```python
            decoder_use_norm={"type": "layernorm", "eps": 1e-2}
            ```
        decoder_attention_type: Attention module used in decoder of the model.
            Available options are **None** and **scse** (https://arxiv.org/abs/1808.08127).
        decoder_interpolation: Interpolation mode used in decoder of the model. Available options are
            **"nearest"**, **"bilinear"**, **"bicubic"**, **"area"**, **"nearest-exact"**. Default is **"nearest"**.
        in_channels: A number of input channels for the model, default is 3 (RGB images)
        classes: A number of classes for output mask (or you can think as a number of channels of output mask)
        activation: An activation function to apply after the final convolution layer.
            Available options are **"sigmoid"**, **"softmax"**, **"logsoftmax"**, **"tanh"**, **"identity"**,
            **callable** and **None**. Default is **None**.
        aux_params: Dictionary with parameters of the auxiliary output (classification head). Auxiliary output is build
            on top of encoder if **aux_params** is not **None** (default). Supported params:
                - classes (int): A number of classes
                - pooling (str): One of "max", "avg". Default is "avg"
                - dropout (float): Dropout factor in [0, 1)
                - activation (str): An activation function to apply "sigmoid"/"softmax"
                    (could be **None** to return logits)
        kwargs: Arguments passed to the encoder class ``__init__()`` function. Applies only to ``timm`` models. Keys with ``None`` values are pruned before passing.

    Returns:
        ``torch.nn.Module``: **Unet++**

    Reference:
        https://arxiv.org/abs/1807.10165

    """

    _is_torch_scriptable = False

    @supports_config_loading
    def __init__(
        self,
        encoder_name: str = "resnet34",
        encoder_depth: int = 5,
        encoder_weights: Optional[str] = "imagenet",
        decoder_use_norm: Union[bool, str, Dict[str, Any]] = "batchnorm",
        decoder_channels: Sequence[int] = (256, 128, 64, 32, 16),
        decoder_attention_type: Optional[str] = None,
        decoder_interpolation: str = "nearest",
        in_channels: int = 3,
        classes: int = 1,
        activation: Optional[Union[str, Callable]] = None,
        aux_params: Optional[dict] = None,
        **kwargs: dict[str, Any],
    ):
        super().__init__()

        if encoder_name.startswith("mit_b"):
            raise ValueError(
                "UnetPlusPlus is not support encoder_name={}".format(encoder_name)
            )

        decoder_use_batchnorm = kwargs.pop("decoder_use_batchnorm", None)
        if decoder_use_batchnorm is not None:
            warnings.warn(
                "The usage of decoder_use_batchnorm is deprecated. Please modify your code for decoder_use_norm",
                DeprecationWarning,
                stacklevel=2,
            )
            decoder_use_norm = decoder_use_batchnorm

        self.encoder = get_encoder(
            encoder_name,
            in_channels=in_channels,
            depth=encoder_depth,
            weights=encoder_weights,
            **kwargs,
        )

        self.decoder = UnetPlusPlusDecoderMod(
            encoder_channels=self.encoder.out_channels,
            decoder_channels=decoder_channels,
            n_blocks=encoder_depth,
            use_norm=decoder_use_norm,
            center=True if encoder_name.startswith("vgg") else False,
            attention_type=decoder_attention_type,
            interpolation_mode=decoder_interpolation,
        )

        self.segmentation_head = SegmentationHead(
            in_channels=decoder_channels[-1],
            out_channels=classes,
            activation=activation,
            kernel_size=3,
        )

        if aux_params is not None:
            self.classification_head = ClassificationHead(
                in_channels=self.encoder.out_channels[-1], **aux_params
            )
        else:
            self.classification_head = None

        self.name = "unetplusplus-{}".format(encoder_name)
        self.initialize()

# Адаптация MANet

In [6]:
class MFABBlockMod(smp.decoders.manet.decoder.MFABBlock):
    def forward(
        self, x: torch.Tensor, skip: Optional[torch.Tensor] = None, scale_factor=2.0,
    ) -> torch.Tensor:
        x = self.hl_conv(x)
        x = F.interpolate(x, scale_factor=scale_factor, mode=self.interpolation_mode)
        attention_hl = self.SE_hl(x)
        if skip is not None:
            attention_ll = self.SE_ll(skip)
            attention_hl = attention_hl + attention_ll
            x = x * attention_hl
            x = torch.cat([x, skip], dim=1)
        x = self.conv1(x)
        x = self.conv2(x)
        return x

class DecoderBlockMod(smp.decoders.manet.decoder.DecoderBlock):
    def forward(
        self, x: torch.Tensor, skip: Optional[torch.Tensor] = None, scale_factor=2.0,
    ) -> torch.Tensor:
        x = F.interpolate(x, scale_factor=scale_factor, mode=self.interpolation_mode)
        if skip is not None:
            x = torch.cat([x, skip], dim=1)
        x = self.conv1(x)
        x = self.conv2(x)
        return x

class MAnetDecoderMod(nn.Module):
    def __init__(
        self,
        encoder_channels: List[int],
        decoder_channels: List[int],
        n_blocks: int = 5,
        reduction: int = 16,
        use_norm: Union[bool, str, Dict[str, Any]] = "batchnorm",
        pab_channels: int = 64,
        interpolation_mode: str = "nearest",
    ):
        super().__init__()

        if n_blocks != len(decoder_channels):
            raise ValueError(
                "Model depth is {}, but you provide `decoder_channels` for {} blocks.".format(
                    n_blocks, len(decoder_channels)
                )
            )

        # remove first skip with same spatial resolution
        encoder_channels = encoder_channels[1:]

        # reverse channels to start from head of encoder
        encoder_channels = encoder_channels[::-1]

        # computing blocks input and output channels
        head_channels = encoder_channels[0]
        in_channels = [head_channels] + list(decoder_channels[:-1])
        skip_channels = list(encoder_channels[1:]) + [0]
        out_channels = decoder_channels

        self.center = smp.decoders.manet.decoder.PABBlock(head_channels, pab_channels=pab_channels)

        # combine decoder keyword arguments
        kwargs = dict(
            use_norm=use_norm, interpolation_mode=interpolation_mode
        )  # no attention type here
        blocks = [
            MFABBlockMod(in_ch, skip_ch, out_ch, reduction=reduction, **kwargs)
            if skip_ch > 0
            else DecoderBlockMod(in_ch, skip_ch, out_ch, **kwargs)
            for in_ch, skip_ch, out_ch in zip(in_channels, skip_channels, out_channels)
        ]
        # for the last we dont have skip connection -> use simple decoder block
        self.blocks = nn.ModuleList(blocks)

    def forward(self, features: List[torch.Tensor]) -> torch.Tensor:
        bs, channels, img_rows, img_cols = features[0].shape
        _, _, feat1_rows, feat1_cols = features[1].shape
        upsample_scale_factors_list = [2.0 for i in range(len(self.blocks))]
        if (img_rows, img_cols) == (feat1_rows, feat1_cols):
            upsample_scale_factors_list[-1] = 1.0

        features = features[1:]  # remove first skip with same spatial resolution
        features = features[::-1]  # reverse channels to start from head of encoder

        head = features[0]
        skips = features[1:]


        x = self.center(head)
        for i, (decoder_block, scale_factor) in enumerate(zip(self.blocks, upsample_scale_factors_list)):
            skip = skips[i] if i < len(skips) else None
            x = decoder_block(x, skip, scale_factor)

        return x
    
class MAnetMod(SegmentationModel):
    """MAnet_ :  Multi-scale Attention Net. The MA-Net can capture rich contextual dependencies based on
    the attention mechanism, using two blocks:
     - Position-wise Attention Block (PAB), which captures the spatial dependencies between pixels in a global view
     - Multi-scale Fusion Attention Block (MFAB), which  captures the channel dependencies between any feature map by
       multi-scale semantic feature fusion

    Args:
        encoder_name: Name of the classification model that will be used as an encoder (a.k.a backbone)
            to extract features of different spatial resolution
        encoder_depth: A number of stages used in encoder in range [3, 5]. Each stage generate features
            two times smaller in spatial dimensions than previous one (e.g. for depth 0 we will have features
            with shapes [(N, C, H, W),], for depth 1 - [(N, C, H, W), (N, C, H // 2, W // 2)] and so on).
            Default is 5
        encoder_weights: One of **None** (random initialization), **"imagenet"** (pre-training on ImageNet) and
            other pretrained weights (see table with available weights for each encoder_name)
        decoder_channels: List of integers which specify **in_channels** parameter for convolutions used in decoder.
            Length of the list should be the same as **encoder_depth**
        decoder_use_norm: Specifies normalization between Conv2D and activation.
            Accepts the following types:
            - **True**: Defaults to `"batchnorm"`.
            - **False**: No normalization (`nn.Identity`).
            - **str**: Specifies normalization type using default parameters. Available values:
              `"batchnorm"`, `"identity"`, `"layernorm"`, `"instancenorm"`, `"inplace"`.
            - **dict**: Fully customizable normalization settings. Structure:
              ```python
              {"type": <norm_type>, **kwargs}
              ```
              where `norm_name` corresponds to normalization type (see above), and `kwargs` are passed directly to the normalization layer as defined in PyTorch documentation.

            **Example**:
            ```python
            decoder_use_norm={"type": "layernorm", "eps": 1e-2}
            ```
        decoder_pab_channels: A number of channels for PAB module in decoder.
            Default is 64.
        decoder_interpolation: Interpolation mode used in decoder of the model. Available options are
            **"nearest"**, **"bilinear"**, **"bicubic"**, **"area"**, **"nearest-exact"**. Default is **"nearest"**.
        in_channels: A number of input channels for the model, default is 3 (RGB images)
        classes: A number of classes for output mask (or you can think as a number of channels of output mask)
        activation: An activation function to apply after the final convolution layer.
            Available options are **"sigmoid"**, **"softmax"**, **"logsoftmax"**, **"tanh"**, **"identity"**,
            **callable** and **None**. Default is **None**.
        aux_params: Dictionary with parameters of the auxiliary output (classification head). Auxiliary output is build
            on top of encoder if **aux_params** is not **None** (default). Supported params:
                - classes (int): A number of classes
                - pooling (str): One of "max", "avg". Default is "avg"
                - dropout (float): Dropout factor in [0, 1)
                - activation (str): An activation function to apply "sigmoid"/"softmax"
                    (could be **None** to return logits)
        kwargs: Arguments passed to the encoder class ``__init__()`` function. Applies only to ``timm`` models. Keys with ``None`` values are pruned before passing.

    Returns:
        ``torch.nn.Module``: **MAnet**

    .. _MAnet:
        https://ieeexplore.ieee.org/abstract/document/9201310

    """

    @supports_config_loading
    def __init__(
        self,
        encoder_name: str = "resnet34",
        encoder_depth: int = 5,
        encoder_weights: Optional[str] = "imagenet",
        decoder_use_norm: Union[bool, str, Dict[str, Any]] = "batchnorm",
        decoder_channels: Sequence[int] = (256, 128, 64, 32, 16),
        decoder_pab_channels: int = 64,
        decoder_interpolation: str = "nearest",
        in_channels: int = 3,
        classes: int = 1,
        activation: Optional[Union[str, Callable]] = None,
        aux_params: Optional[dict] = None,
        **kwargs: dict[str, Any],
    ):
        super().__init__()

        decoder_use_batchnorm = kwargs.pop("decoder_use_batchnorm", None)
        if decoder_use_batchnorm is not None:
            warnings.warn(
                "The usage of decoder_use_batchnorm is deprecated. Please modify your code for decoder_use_norm",
                DeprecationWarning,
                stacklevel=2,
            )
            decoder_use_norm = decoder_use_batchnorm

        self.encoder = get_encoder(
            encoder_name,
            in_channels=in_channels,
            depth=encoder_depth,
            weights=encoder_weights,
            **kwargs,
        )

        self.decoder = MAnetDecoderMod(
            encoder_channels=self.encoder.out_channels,
            decoder_channels=decoder_channels,
            n_blocks=encoder_depth,
            use_norm=decoder_use_norm,
            pab_channels=decoder_pab_channels,
            interpolation_mode=decoder_interpolation,
        )

        self.segmentation_head = SegmentationHead(
            in_channels=decoder_channels[-1],
            out_channels=classes,
            activation=activation,
            kernel_size=3,
        )

        if aux_params is not None:
            self.classification_head = ClassificationHead(
                in_channels=self.encoder.out_channels[-1], **aux_params
            )
        else:
            self.classification_head = None

        self.name = "manet-{}".format(encoder_name)
        self.initialize()


#model = MAnetMod()
#model.encoder.conv1.stride = 1
#ret = model(torch.randn(1, 3, 96, 96))
#ret.shape

# Адаптация FPN

In [7]:
class FPNMod(SegmentationModel):
    """FPN_ is a fully convolution neural network for image semantic segmentation.

    Args:
        encoder_name: Name of the classification model that will be used as an encoder (a.k.a backbone)
            to extract features of different spatial resolution
        encoder_depth: A number of stages used in encoder in range [3, 5]. Each stage generate features
            two times smaller in spatial dimensions than previous one (e.g. for depth 0 we will have features
            with shapes [(N, C, H, W),], for depth 1 - [(N, C, H, W), (N, C, H // 2, W // 2)] and so on).
            Default is 5
        encoder_weights: One of **None** (random initialization), **"imagenet"** (pre-training on ImageNet) and
            other pretrained weights (see table with available weights for each encoder_name)
        decoder_pyramid_channels: A number of convolution filters in Feature Pyramid of FPN_
        decoder_segmentation_channels: A number of convolution filters in segmentation blocks of FPN_
        decoder_merge_policy: Determines how to merge pyramid features inside FPN. Available options are **add**
            and **cat**
        decoder_dropout: Spatial dropout rate in range (0, 1) for feature pyramid in FPN_
        decoder_interpolation: Interpolation mode used in decoder of the model. Available options are
            **"nearest"**, **"bilinear"**, **"bicubic"**, **"area"**, **"nearest-exact"**. Default is **"nearest"**.
        in_channels: A number of input channels for the model, default is 3 (RGB images)
        classes: A number of classes for output mask (or you can think as a number of channels of output mask)
        activation: An activation function to apply after the final convolution layer.
            Available options are **"sigmoid"**, **"softmax"**, **"logsoftmax"**, **"tanh"**, **"identity"**,
            **callable** and **None**. Default is **None**.
        upsampling: Final upsampling factor. Default is 4 to preserve input-output spatial shape identity
        aux_params: Dictionary with parameters of the auxiliary output (classification head). Auxiliary output is build
            on top of encoder if **aux_params** is not **None** (default). Supported params:
                - classes (int): A number of classes
                - pooling (str): One of "max", "avg". Default is "avg"
                - dropout (float): Dropout factor in [0, 1)
                - activation (str): An activation function to apply "sigmoid"/"softmax"
                    (could be **None** to return logits)
        kwargs: Arguments passed to the encoder class ``__init__()`` function. Applies only to ``timm`` models. Keys with ``None`` values are pruned before passing.

    Returns:
        ``torch.nn.Module``: **FPN**

    .. _FPN:
        http://presentations.cocodataset.org/COCO17-Stuff-FAIR.pdf

    """

    @supports_config_loading
    def __init__(
        self,
        encoder_name: str = "resnet34",
        encoder_depth: int = 5,
        encoder_weights: Optional[str] = "imagenet",
        encoder_type:str = 'conv',
        decoder_pyramid_channels: int = 256,
        decoder_segmentation_channels: int = 128,
        decoder_merge_policy: str = "add",
        decoder_dropout: float = 0.2,
        decoder_interpolation: str = "nearest",
        in_channels: int = 3,
        classes: int = 1,
        activation: Optional[str] = None,
        upsampling: int = 4,
        aux_params: Optional[dict] = None,
        **kwargs: dict[str, Any],
    ):
        super().__init__()

        # validate input params
        if encoder_name.startswith("mit_b") and encoder_depth != 5:
            raise ValueError(
                "Encoder {} support only encoder_depth=5".format(encoder_name)
            )

        self.encoder = get_encoder(
            encoder_name,
            in_channels=in_channels,
            depth=encoder_depth,
            weights=encoder_weights,
            **kwargs,
        )

        self.decoder = FPNDecoderMod(
            encoder_channels=self.encoder.out_channels,
            encoder_depth=encoder_depth,
            pyramid_channels=decoder_pyramid_channels,
            segmentation_channels=decoder_segmentation_channels,
            dropout=decoder_dropout,
            merge_policy=decoder_merge_policy,
            interpolation_mode=decoder_interpolation,
            encoder_type=encoder_type,
        )

        self.segmentation_head = SegmentationHead(
            in_channels=self.decoder.out_channels,
            out_channels=classes,
            activation=activation,
            kernel_size=1,
            upsampling=upsampling,
        )

        if aux_params is not None:
            self.classification_head = ClassificationHead(
                in_channels=self.encoder.out_channels[-1], **aux_params
            )
        else:
            self.classification_head = None

        self.name = "fpn-{}".format(encoder_name)
        self.initialize()

class FPNModBlock(nn.Module):
    def __init__(
        self,
        pyramid_channels: int,
        skip_channels: int,
        interpolation_mode: str = "nearest",
    ):
        super().__init__()
        self.skip_conv = nn.Conv2d(skip_channels, pyramid_channels, kernel_size=1)
        self.interpolation_mode = interpolation_mode

    def forward(self, x: torch.Tensor, skip: torch.Tensor, scale_factor: float) -> torch.Tensor:
        x = F.interpolate(x, scale_factor=scale_factor, mode=self.interpolation_mode)
        if skip.size(1) != 0:
            #print(x.shape, skip.shape)
            skip = self.skip_conv(skip)
            x = x + skip
        return x

class FPNDecoderMod(nn.Module):
    def __init__(
        self,
        encoder_channels: List[int],
        encoder_depth: int = 5,
        pyramid_channels: int = 256,
        segmentation_channels: int = 128,
        dropout: float = 0.2,
        merge_policy: Literal["add", "cat"] = "add",
        interpolation_mode: str = "nearest",
        encoder_type:str = 'conv',
    ):
        super().__init__()

        self.out_channels = (
            segmentation_channels
            if merge_policy == "add"
            else segmentation_channels * 4
        )
        #print(self.out_channels)
        if encoder_depth < 3:
            raise ValueError(
                "Encoder depth for FPN decoder cannot be less than 3, got {}.".format(
                    encoder_depth
                )
            )

        encoder_channels = encoder_channels[::-1]
        encoder_channels = encoder_channels[: encoder_depth + 1]
        
        self.p6 = nn.Conv2d(encoder_channels[0], pyramid_channels, kernel_size=1)
        '''
        self.p5 = smp.decoders.fpn.decoder.FPNBlock(pyramid_channels, encoder_channels[1], interpolation_mode)
        self.p4 = smp.decoders.fpn.decoder.FPNBlock(pyramid_channels, encoder_channels[2], interpolation_mode)
        self.p3 = smp.decoders.fpn.decoder.FPNBlock(pyramid_channels, encoder_channels[3], interpolation_mode)
        self.p2 = smp.decoders.fpn.decoder.FPNBlock(pyramid_channels, encoder_channels[4], interpolation_mode)
        '''
        self.p5 = FPNModBlock(pyramid_channels, encoder_channels[1], interpolation_mode)
        self.p4 = FPNModBlock(pyramid_channels, encoder_channels[2], interpolation_mode)
        self.p3 = FPNModBlock(pyramid_channels, encoder_channels[3], interpolation_mode)
        self.p2 = FPNModBlock(pyramid_channels, encoder_channels[4], interpolation_mode)
        
        if encoder_type == 'conv':
            upsamples_list = [4, 3, 2, 1, 0]
        elif encoder_type == 'vit':
            upsamples_list = [3, 2, 1, 0, 0]


        self.seg_blocks = nn.ModuleList(
            [
                smp.decoders.fpn.decoder.SegmentationBlock(
                    pyramid_channels, segmentation_channels, n_upsamples=n_upsamples
                )
                for n_upsamples in upsamples_list
            ]
        )

        self.merge = smp.decoders.fpn.decoder.MergeBlock(merge_policy)
        self.dropout = nn.Dropout2d(p=dropout, inplace=True)

    def forward(self, features: List[torch.Tensor]) -> torch.Tensor:
        c2, c3, c4, c5, c6 = features[-5:]

        #print([c2.shape, c3.shape, c4.shape, c5.shape, c6.shape])
        #print([f.shape for f in features])
        #print()
        #print(f'c6:{c6.shape}')
        p6 = self.p6(c6)
        #print(f'p6:{p6.shape};c5:{c5.shape}')
        p5 = self.p5(p6, c5, scale_factor=2.0)
        #print(f'p5:{p5.shape};c4:{c4.shape}')
        p4 = self.p4(p5, c4, scale_factor=2.0)
        #print(f'p4:{p4.shape};c3:{c3.shape}')
        p3 = self.p3(p4, c3, scale_factor=2.0)
        #print(f'p3:{p3.shape};c2:{c2.shape}')
        p2 = self.p2(p3, c2, scale_factor=2.0)
        #print(f'p2:{p4.shape}')

        s6 = self.seg_blocks[0](p6)
        s5 = self.seg_blocks[1](p5)
        s4 = self.seg_blocks[2](p4)
        s3 = self.seg_blocks[3](p3)
        s2 = self.seg_blocks[4](p2)

        feature_pyramid = [s6, s5, s4, s3, s2]

        #print([f.shape for f in feature_pyramid])

        x = self.merge(feature_pyramid)
        x = self.dropout(x)
        
        return x

#model = FPNMod(encoder_name='mit_b0', upsampling=0, encoder_type='vit', image_size=(96, 96))
#model.encoder.patch_embed1.proj.stride=1

#model = FPNMod(encoder_name='resnet34', upsampling=0)
#model.encoder.conv1.stride=(1,1)

#ret = model(torch.randn(1, 3, 96, 96))
#ret.shape

# UNet with attention

## Attention and supplementary modules

In [None]:
def scaled_dot_product_attention(query, key, value, mask=None):
    d_k = query.size(-1)
    scores = torch.matmul(query, key.transpose(-2, -1)) / (d_k ** 0.5)
    if mask is not None:
        scores = scores.masked_fill(mask == 0, -1e9)  # Masking
    attention_weights = F.softmax(scores, dim=-1)
    output = torch.matmul(attention_weights, value)
    return output, attention_weights

class ChannelMultiHeadAttentionExpandHeads(nn.Module):
    def __init__(self, embed_dim, num_heads, kdim, vdim):
        '''
        This layer expands query, key, value vectors to dim*num_heads using Q, K, V linear layers
        input linear layers
        '''
        super().__init__()
        self.qdim = embed_dim
        self.kdim = kdim
        self.vdim = vdim

        self.num_heads = num_heads
        self.qhead_dim = embed_dim
        self.khead_dim = kdim
        self.vhead_dim = vdim
        #assert self.head_dim * num_heads == embed_dim, "embed_dim must be divisible by num_heads"

        self.wq = nn.Linear(embed_dim, embed_dim * num_heads)
        self.wk = nn.Linear(kdim, kdim * num_heads)
        self.wv = nn.Linear(vdim, vdim * num_heads)
        self.wo = nn.Linear(embed_dim * num_heads, embed_dim)

    def forward(self, query, key, value, mask=None, need_weights=False, **kwargs):
        batch_size = query.size(0)

        # 1. Linear projections for Q, K, V
        q = self.wq(query)
        k = self.wk(key)
        v = self.wv(value)

        #print(q.shape, k.shape, v.shape)

        # 2. Split into multiple heads
        q = q.view(batch_size, -1, self.num_heads, self.qhead_dim).transpose(1, 2)
        k = k.view(batch_size, -1, self.num_heads, self.khead_dim).transpose(1, 2)
        v = v.view(batch_size, -1, self.num_heads, self.vhead_dim).transpose(1, 2)

        #print(q.shape, k.shape, v.shape)

        # 3. Apply scaled dot-product attention
        if need_weights:
            x, attention_weights = scaled_dot_product_attention(q, k, v, mask)
        else:
            x = F.scaled_dot_product_attention(q, k, v, mask)
            attention_weights = None

        #print(x.shape)
        # 4. Concatenate heads and apply final linear projection
        x = x.transpose(1, 2).contiguous().view(batch_size, -1, self.qdim * self.num_heads)
        #print(x.shape)
        output = self.wo(x)
        
        return output, attention_weights
   
bs = 1
ch = 200
rows = 96
cols=96

rp = 8
cp = 8

rpn = rows//rp
cpn = cols//cp

hsi = torch.randn(bs, ch, rows, cols)
another_hsi = torch.randn(bs, 128, 16, 16)

ch_att = ChannelMultiHeadAttentionExpandHeads(embed_dim=rp*cp, kdim=rp*cp, vdim=rp*cp, num_heads=4)

axes1 = {
    'bs': bs,
    'rpn': rpn,
    'cpn': cpn,
    'rp': rp,
    'cp': cp

}
axes2 = {
    'bs': bs,
    'rpn': 16//rp,
    'cpn': 16//cp,
    'rp': rp,
    'cp': cp
}
rearr_rule = 'bs ch (rp rpn) (cp cpn) -> (rpn cpn) bs ch (rp cp)'
hsi = eo.rearrange(hsi, rearr_rule, **axes1)
another_hsi = eo.rearrange(another_hsi, rearr_rule, **axes2)
another_hsi.shape

#ret = ch_att(query=hsi[0], key=hsi[0], value=hsi[0])
#ret.shape

torch.Size([4, 1, 128, 64])

In [110]:
class VisionTransformerBlock(nn.Module):
    def __init__(
            self,
            num_heads: int,
            hidden_dim: int, # is equal to query_dim
            kdim: int,
            vdim: int,
            mlp_dim: int,
            attention_layer: Callable[..., torch.nn.Module],
            dropout:float,
            norm_layer: Callable[..., torch.nn.Module] = partial(nn.LayerNorm, eps=1e-6),
                 ):
        super().__init__()
        assert kdim==vdim, f"kdim should be equal to vdim! kdim={kdim}, vdim={vdim}"
        
        self.num_heads = num_heads
        self.kdim = kdim
        self.vdim = vdim
        self.hidden_dim = hidden_dim

        # Attention block
        self.ln_11 = norm_layer(hidden_dim)
        self.ln_12 = norm_layer(kdim)

        if attention_layer is nn.MultiheadAttention:
            
            attention_layer = partial(attention_layer, batch_first=True)
        self.self_attention = attention_layer(hidden_dim, num_heads, kdim=kdim, vdim=vdim,)
        #if kdim==vdim==hidden_dim:    
        #else:self.self_attention = attention_layer(hidden_dim, num_heads,)
        
        self.dropout = nn.Dropout(dropout)

        # MLP block
        self.ln_2 = norm_layer(hidden_dim)
        self.mlp = MLPBlock(hidden_dim, mlp_dim, dropout)

    def forward(self, target, source=None, need_weghts=False):
        x = self.ln_11(target)
        
        if source is None:
            # Self Attention
            x, weights = self.self_attention(query=x, key=x, value=x, average_attn_weights=False, need_weights=need_weghts)

        else:
            y = self.ln_12(source)
            #print(x.shape, y.shape)
            # Cross Attention
            x, weights = self.self_attention(query=x, key=y, value=y, average_attn_weights=False, need_weights=need_weghts)
        #print(x.shape)
        x = self.dropout(x)
        x = x + target

        x2 = self.ln_2(x)
        x2 = self.mlp(x2)
        return x2 + x, weights

class WindowVisionTransformer(nn.Module):
    def __init__(
            self,
            cols_in_patch:int,
            rows_in_patch:int,
            channels:int,
            # transformers block params
            num_heads:int,
            mlp_dim:int,
            dropout:float,
            layer_num:int,
            transformer_type:'str', # channels and patches are possible
            positional_encoding: Callable[..., nn.Module] = nn.Identity,
            ):
    
        super().__init__()

        self.cols_in_patch = cols_in_patch
        self.rows_in_patch = rows_in_patch
        self.transformer_type =  transformer_type
        
        if transformer_type == 'channels':
            self.seq_len = channels
            hidden_dim = cols_in_patch * rows_in_patch
            
        elif transformer_type == 'patches':
            self.seq_len = cols_in_patch * rows_in_patch
            hidden_dim = channels

        #print(self.seq_len, hidden_dim)
        
        self.positional_encoding = positional_encoding(num_embeddings=self.seq_len, embedding_dim=hidden_dim)
        # можно создать несколько трансформерных слоев
        transformer_layers_list = [
            VisionTransformerBlock(
                num_heads=num_heads,
                hidden_dim=hidden_dim,
                kdim=hidden_dim,
                vdim=hidden_dim,
                mlp_dim=mlp_dim,
                attention_layer=nn.MultiheadAttention,
                dropout=dropout)
            for i in range(layer_num)
        ]
        self.transformer_layers = nn.ModuleList(transformer_layers_list)
        
    def forward(self, x):
        bs, channels, rows, cols = x.shape
        row_patch_num = rows//self.rows_in_patch
        col_patch_num = cols//self.cols_in_patch
        # размер (bs, channels, rows, cols) преобразовываем в размер (row_patches*col_patches, bs, channels, rows_in_patch*cols_in_patch)
        # rows=row_patches*rows_in_patch, cols=col_patches*cols_in_patch
        if self.transformer_type == 'channels':
            rearrange_pattern = 'bs channels (row_patch_num rows_in_patch) (col_patch_num cols_in_patch) -> (row_patch_num col_patch_num) bs channels (rows_in_patch cols_in_patch)'
            rearrange_args = {
                'cols_in_patch':self.cols_in_patch,
                'rows_in_patch':self.rows_in_patch,
            }
            
        elif self.transformer_type == 'patches':
            rearrange_pattern = 'bs channels (row_patch_num rows_in_patch) (col_patch_num cols_in_patch) -> (row_patch_num col_patch_num) bs (rows_in_patch cols_in_patch) channels'
            rearrange_args = {
                'cols_in_patch':self.cols_in_patch,
                'rows_in_patch':self.rows_in_patch,
            }
        
        h = eo.rearrange(
            x,
            rearrange_pattern,
            **rearrange_args,
        )
        # позиционное кодирование (его может и не быть - nn.Identity)
        layer_outs = self.positional_encoding(h)
        processed_outs = []

        # итерирование по окнам 
        for i, layer_out in enumerate(layer_outs):

            for layer in self.transformer_layers:

                layer_out, layer_att_weights = layer(layer_out)
            processed_outs.append(layer_out.unsqueeze(0))
        
        processed_outs = torch.cat(processed_outs,dim=0)
        if self.transformer_type == 'channels':
            rearrange_pattern = '(row_patch_num col_patch_num) bs channels (rows_in_patch cols_in_patch) -> bs channels (row_patch_num rows_in_patch) (col_patch_num cols_in_patch)'
            rearrange_args = {
                'cols_in_patch':self.cols_in_patch,
                'rows_in_patch':self.rows_in_patch,
                'row_patch_num':row_patch_num,
                'col_patch_num':col_patch_num,
            }
        elif self.transformer_type == 'patches':
            rearrange_pattern = '(row_patch_num col_patch_num) bs (rows_in_patch cols_in_patch) channels -> bs channels (row_patch_num rows_in_patch) (col_patch_num cols_in_patch)'
            rearrange_args = {
                'cols_in_patch':self.cols_in_patch,
                'rows_in_patch':self.rows_in_patch,
                'row_patch_num':row_patch_num,
                'col_patch_num':col_patch_num,
            }
        processed_outs = eo.rearrange(
            processed_outs,
            rearrange_pattern,
            **rearrange_args,
        )
        return processed_outs


class WindowCrossAttention(nn.Module):
    def __init__(
            self,
            cols_in_patch_x:int,
            rows_in_patch_x:int,
            cols_in_patch_y:int,
            rows_in_patch_y:int,
            channels_x:int,
            channels_y:int,
            # transformers block params
            num_heads:int,
            mlp_dim:int,
            dropout:float,
            transformer_type:'str', # channels and patches are possible
            positional_encoding_x: Callable[..., nn.Module] = nn.Identity,
            positional_encoding_y: Callable[..., nn.Module] = nn.Identity,
            ):
    
        super().__init__()
        self.transformer_type =  transformer_type
        self.cols_in_patch_x = cols_in_patch_x
        self.rows_in_patch_x = rows_in_patch_x

        self.cols_in_patch_y = cols_in_patch_y
        self.rows_in_patch_y = rows_in_patch_y
        
        
        if transformer_type == 'channels':
            self.seq_len_x = channels_x
            hidden_dim_x = cols_in_patch_x * rows_in_patch_x
            self.seq_len_y = channels_y
            hidden_dim_y = cols_in_patch_y * rows_in_patch_y
            
        elif transformer_type == 'patches':
            self.seq_len_x = cols_in_patch_x * rows_in_patch_x
            hidden_dim_x = channels_x
            self.seq_len_y = cols_in_patch_y * rows_in_patch_y
            hidden_dim_y = channels_y

        #print(self.seq_len, hidden_dim)
        
        self.positional_encoding_x = positional_encoding_x(num_embeddings=self.seq_len_x, embedding_dim=hidden_dim_x)
        self.positional_encoding_y = positional_encoding_x(num_embeddings=self.seq_len_y, embedding_dim=hidden_dim_y)
        # можно создать несколько трансформерных слоев
        self.cross_attention_block = VisionTransformerBlock(
            num_heads=num_heads,
            hidden_dim=hidden_dim_x,
            kdim=hidden_dim_y,
            vdim=hidden_dim_y,
            mlp_dim=mlp_dim,
            attention_layer=nn.MultiheadAttention,
            dropout=dropout)
        
    def forward(self, x, y):
        bs, channels, rows_x, cols_x = x.shape
        
        row_patch_num_x = rows_x//self.rows_in_patch_x
        col_patch_num_x = cols_x//self.cols_in_patch_x

        bs, channels, rows_y, cols_y = y.shape
        row_patch_num_y = rows_y//self.rows_in_patch_y
        col_patch_num_y = cols_y//self.cols_in_patch_y

        

        assert row_patch_num_x == row_patch_num_y, f'number of window rows in X={row_patch_num_x} and Y={row_patch_num_y} tensors ought to coinside'
        assert col_patch_num_x == col_patch_num_y, f'number of window cols in X={col_patch_num_x} and Y={col_patch_num_y} tensors ought to coinside'

        # размер (bs, channels, rows, cols) преобразовываем в размер (row_patches*col_patches, bs, channels, rows_in_patch*cols_in_patch)
        # rows=row_patches*rows_in_patch, cols=col_patches*cols_in_patch
        if self.transformer_type == 'channels':
            rearrange_pattern = 'bs channels (row_patch_num rows_in_patch) (col_patch_num cols_in_patch) -> (row_patch_num col_patch_num) bs channels (rows_in_patch cols_in_patch)'
            rearrange_args_x = {
                'cols_in_patch':self.cols_in_patch_x,
                'rows_in_patch':self.rows_in_patch_x,
            }

            rearrange_args_y = {
                'cols_in_patch':self.cols_in_patch_y,
                'rows_in_patch':self.rows_in_patch_y,
            }
             
        elif self.transformer_type == 'patches':
            rearrange_pattern = 'bs channels (row_patch_num rows_in_patch) (col_patch_num cols_in_patch) -> (row_patch_num col_patch_num) bs (rows_in_patch cols_in_patch) channels'
            rearrange_args_x = {
                'cols_in_patch':self.cols_in_patch_x,
                'rows_in_patch':self.rows_in_patch_x,
            }
            rearrange_args_y = {
                'cols_in_patch':self.cols_in_patch_y,
                'rows_in_patch':self.rows_in_patch_y,
            }
        
        hx = eo.rearrange(
            x,
            rearrange_pattern,
            **rearrange_args_x,
        )
        hy = eo.rearrange(
            y,
            rearrange_pattern,
            **rearrange_args_y,
        )

        # позиционное кодирование (его может и не быть - nn.Identity)
        
        layer_outs_x = self.positional_encoding_x(hx)
        layer_outs_y = self.positional_encoding_y(hy)
        processed_outs = []

        # итерирование по окнам 
        for i, (layer_out_x, layer_out_y) in enumerate(zip(layer_outs_x, layer_outs_y)):
            layer_out, layer_att_weights = self.cross_attention_block(layer_out_x, layer_out_y)
            processed_outs.append(layer_out.unsqueeze(0))
        
        processed_outs = torch.cat(processed_outs,dim=0)
        if self.transformer_type == 'channels':
            rearrange_pattern = '(row_patch_num col_patch_num) bs channels (rows_in_patch cols_in_patch) -> bs channels (row_patch_num rows_in_patch) (col_patch_num cols_in_patch)'
            rearrange_args = {
                'cols_in_patch':self.cols_in_patch_x,
                'rows_in_patch':self.rows_in_patch_x,
                'row_patch_num':row_patch_num_x,
                'col_patch_num':col_patch_num_x,
            }
        elif self.transformer_type == 'patches':
            rearrange_pattern = '(row_patch_num col_patch_num) bs (rows_in_patch cols_in_patch) channels -> bs channels (row_patch_num rows_in_patch) (col_patch_num cols_in_patch)'
            rearrange_args = {
                'cols_in_patch':self.cols_in_patch_x,
                'rows_in_patch':self.rows_in_patch_x,
                'row_patch_num':row_patch_num_x,
                'col_patch_num':col_patch_num_x,
            }
        processed_outs = eo.rearrange(
            processed_outs,
            rearrange_pattern,
            **rearrange_args,
        )
        return processed_outs
    
class HyperspectralTransformer(nn.Module):
    def __init__(
            self,
            config
            ):
    
        super().__init__()
        self.config = config
        self.patch_emd = config['patch_emd']['layer'](**config['patch_emd']['params'])
        transformer_layers = {}
        for i, transformer_layer_config in enumerate(config['transformer_layers']):
            #print(transformer_layer_config['params'])
            if transformer_layer_config['layer'] == 'crossatt':
                transformer_layer_config['params']['positional_encoding_x'] = pos_enc_factory_dict[transformer_layer_config['params']['positional_encoding_x']]
                transformer_layer_config['params']['positional_encoding_y'] = pos_enc_factory_dict[transformer_layer_config['params']['positional_encoding_y']]
            else:
                transformer_layer_config['params']['positional_encoding'] = pos_enc_factory_dict[transformer_layer_config['params']['positional_encoding']]
            layer_name = transformer_layer_config['layer']
            layer_creat = transformer_factory_dict[layer_name]
            transformer_layers[f'{i}_{layer_name}'] = layer_creat(**transformer_layer_config['params'])

        self.transformer_layers = nn.ModuleDict(transformer_layers)
        #self.output_layer = config['output_layer']['layer'](**config['patch_emd']['params'])

    def forward(self, x, y: Dict[str, torch.Tensor]):
        x = self.patch_emd(x)
        results = []
        for i, (name, layer) in enumerate(self.transformer_layers.items()):
            if name in y:       
                x = layer(x, y[name])
            else:
                x = layer(x)
        return x

class EmbeddingLayer(nn.Module):
    def __init__(
            self,
            num_embeddings,
            embedding_dim,
            padding_idx=None,
            max_norm=None,
            norm_type=2.0,
            scale_grad_by_freq=False,
            sparse=False,
            _weight=None,
            _freeze=False,
            device=None,
            dtype=None
            ):
        super().__init__()
        self.embedding = nn.Embedding(
            num_embeddings,
            embedding_dim,
            padding_idx,
            max_norm,
            norm_type,
            scale_grad_by_freq,
            sparse,
            _weight,
            _freeze,
            device,
            dtype
            )
    def forward(self, x):
        #!!!!
        bs, seq_len, emb_dim = x.shape
        positions = torch.arange(0, seq_len, dtype=torch.long).to(x.device)
        pos_embeddings = self.embedding(positions)
        return x + pos_embeddings
    
class FixedSizeLearnableEmbeddings(nn.Module):
    def __init__(
            self,
            num_embeddings,
            embedding_dim,
            ):
        super().__init__()
        self.positional_encoding = nn.Parameter(torch.empty(1, num_embeddings, embedding_dim).normal_(std=0.02))

    def forward(self,x):
        return x + self.positional_encoding

class AddFeatures(nn.Module):
    def forward(self, x, y):
        return x+y

class ConcatFeatures(nn.Module):
    '''
    Implementation of concatenation. It is nececcary for various aggreagation strategies in UNet decoder
    '''
    def forward(self, *tensors, dim=1):
        return torch.cat(tensors, dim=dim)
    
class PassOneFeature(nn.Module):
    def __init__(self, passing_idx):
        super().__init__()
        self.passing_idx = passing_idx
    def forward(self, *tensors):
        return tensors[self.passing_idx]

    

pos_enc_factory_dict = {
    'fixed_embeddings': FixedSizeLearnableEmbeddings,
    'embedding_layer': EmbeddingLayer,
    'none': nn.Identity
}

transformer_factory_dict = {
    'win_mha': WindowVisionTransformer,
    'crossatt': WindowCrossAttention,
    'none': nn.Identity,
}

feature_aggregation_factory_dict ={
    'add': AddFeatures,
    'pass_one': PassOneFeature,
    'concat': ConcatFeatures,
    'crossatt': WindowCrossAttention
}


aux_transformer_config = {
    'patch_emd':{
        'layer':nn.Conv2d,
        'params':{
            'in_channels': 200,
            'out_channels': 200,
            'kernel_size': 3,
            'stride': 1,
            'padding': 1,
            'groups':200
        }
    },
    'input_transformer': {
            'layer': 'win_mha',
            'params': {
                'cols_in_patch':12,
                'rows_in_patch':12,
                'channels':200,
                'num_heads':12,
                'mlp_dim':12*12*4,
                'dropout':0.2,
                'layer_num':3,
                'transformer_type':'channels',
                'positional_encoding':'fixed_embeddings',
            },
        },
    'hsi_augmentation':{
        'layer': 'add',
        'params': {}
    },
    'intermediate_layers':{
            'layer': 'win_mha',
            'params': {
                'cols_in_patch':12,
                'rows_in_patch':12,
                'channels':200,
                'num_heads':12,
                'mlp_dim':12*12*4,
                'dropout':0.2,
                'layer_num':3,
                'transformer_type':'channels',
                'positional_encoding':'none',
            },
        },

    'output_crossatt': {
            'layer': 'crossatt',
            'params': {
                'rows_in_patch_x':12,
                'cols_in_patch_x':12,
                'rows_in_patch_y':12,
                'cols_in_patch_y':12,
                'channels_x':30,
                'channels_y':200,
                'num_heads':12,
                'mlp_dim':12*12*4,
                'dropout':0.2,
                'transformer_type':'channels',
                'positional_encoding_x':'fixed_embeddings',
                'positional_encoding_y':'none',
            },
        },
    'output_layer':{
        'layer':nn.Conv2d,
        'params':{
            'in_channels': 30,
            'out_channels': 30,
            'kernel_size': 1,
            'stride': 1,
            'padding': 0,
            
        }
    }
}
            
batch_size = 3  
seq_len1 = 14
seq_len2 = 88
hidden_size1 = 128
hidden_size2 = 256

transf_block = VisionTransformerBlock(
    num_heads=8,
    hidden_dim=hidden_size1,
    kdim=hidden_size2,
    vdim=hidden_size2,
    mlp_dim=hidden_size1*4,
    attention_layer=nn.MultiheadAttention,
    dropout=0.2
)

w_transf = WindowVisionTransformer(
    cols_in_patch=12,
    rows_in_patch=12,
    channels=200,
    num_heads=8,
    mlp_dim=16*16*4,
    dropout=0.2,
    layer_num=2,
    transformer_type='patches',
    positional_encoding=FixedSizeLearnableEmbeddings,
)

crossatt = WindowCrossAttention(
    rows_in_patch_x=12,
    cols_in_patch_x=12,
    rows_in_patch_y=12,
    cols_in_patch_y=12,
    channels_x=100,
    channels_y=200,
    num_heads=6,
    mlp_dim=6*6*4,
    dropout=0.2,
    transformer_type='channels',
    positional_encoding_x=FixedSizeLearnableEmbeddings,
    positional_encoding_y=FixedSizeLearnableEmbeddings,

)

#hs_former = HyperspectralTransformer(config)

hsi = torch.randn(1, 200, 96, 96)
hsi2 = torch.randn(1, 30, 96, 96)
#ret = w_transf(hsi)
#ret = crossatt(hsi2, hsi)
#ret.shape

aux_transformer_config['output_crossatt']['params']['positional_encoding_x'] = pos_enc_factory_dict[aux_transformer_config['output_crossatt']['params']['positional_encoding_x']]
aux_transformer_config['output_crossatt']['params']['positional_encoding_y'] = pos_enc_factory_dict[aux_transformer_config['output_crossatt']['params']['positional_encoding_y']]
aux_transformer_config['input_transformer']['params']['positional_encoding'] = pos_enc_factory_dict[aux_transformer_config['input_transformer']['params']['positional_encoding']]
aux_transformer_config['intermediate_layers']['params']['positional_encoding'] = pos_enc_factory_dict[aux_transformer_config['intermediate_layers']['params']['positional_encoding']]
#config['hsi_augmentation']['layer'] = feature_aggregation_factory_dict['hsi_augmentation']['layer']


aux_transf = nn.ModuleDict()
aux_transf['patch_emd'] = aux_transformer_config['patch_emd']['layer'](**aux_transformer_config['patch_emd']['params'])
layer_name = aux_transformer_config['input_transformer']['layer']
aux_transf['input_transformer'] = transformer_factory_dict[layer_name](**aux_transformer_config['input_transformer']['params'])
aux_transf['hsi_augmentation'] = feature_aggregation_factory_dict[aux_transformer_config['hsi_augmentation']['layer']](**aux_transformer_config['hsi_augmentation']['params'])
layer_name = aux_transformer_config['intermediate_layers']['layer']
aux_transf['intermediate_layers'] = transformer_factory_dict[layer_name](**aux_transformer_config['intermediate_layers']['params'])
layer_name = aux_transformer_config['output_crossatt']['layer']
aux_transf['output_crossatt'] = transformer_factory_dict[layer_name](**aux_transformer_config['output_crossatt']['params'])
aux_transf['output_layer'] = aux_transformer_config['output_layer']['layer'](**aux_transformer_config['output_layer']['params'])


            

x = aux_transf['patch_emd'](hsi)
x = aux_transf['input_transformer'](x)
augmented = aux_transf['hsi_augmentation'](hsi, x)
x = aux_transf['intermediate_layers'](x)
out = aux_transf['output_crossatt'](hsi2, x)
out = aux_transf['output_layer'](out)
x.shape, out.shape

(torch.Size([1, 200, 96, 96]), torch.Size([1, 30, 96, 96]))

In [109]:
config = {
    'patch_emd':{
        'layer':nn.Conv2d,
        'params':{
            'in_channels': 200,
            'out_channels': 200,
            'kernel_size': 3,
            'stride': 1,
            'padding': 1,
            'groups':200
        }
    },
    'input_transformer': {
            'layer': 'win_mha',
            'params': {
                'cols_in_patch':12,
                'rows_in_patch':12,
                'channels':200,
                'num_heads':12,
                'mlp_dim':12*12*4,
                'dropout':0.2,
                'layer_num':3,
                'transformer_type':'channels',
                'positional_encoding':'fixed_embeddings',
            },
        },
    'hsi_augmentation':{
        'layer': 'add',
        'params': {}
    },
    'intermediate_layers':{
            'layer': 'win_mha',
            'params': {
                'cols_in_patch':12,
                'rows_in_patch':12,
                'channels':200,
                'num_heads':12,
                'mlp_dim':12*12*4,
                'dropout':0.2,
                'layer_num':3,
                'transformer_type':'channels',
                'positional_encoding':'none',
            },
        },

    'output_crossatt': {
            'layer': 'crossatt',
            'params': {
                'rows_in_patch_x':12,
                'cols_in_patch_x':12,
                'rows_in_patch_y':12,
                'cols_in_patch_y':12,
                'channels_x':30,
                'channels_y':200,
                'num_heads':12,
                'mlp_dim':12*12*4,
                'dropout':0.2,
                'transformer_type':'channels',
                'positional_encoding_x':'fixed_embeddings',
                'positional_encoding_y':'none',
            },
        },
    'output_layer':{
        'layer':nn.Conv2d,
        'params':{
            'in_channels': 30,
            'out_channels': 30,
            'kernel_size': 1,
            'stride': 1,
            'padding': 0,
            
        }
    }
}

with open(r'C:\Users\mokhail\develop\MultispectralSegmentation\training_configs\models\unet_hsi.yaml') as fd:
    new_config = yaml.load(fd, yaml.Loader)

new_config['segmentation_nn']['params']['aux_transformer_config'] = config
with open(r'training_configs\models\unet_aux_tr_hsi.yaml', 'w') as fd:
    yaml.dump(new_config, fd, indent=4)

In [108]:
new_config

{'batch_size': 16,
 'device': 'cuda:0',
 'epoch_num': 300,
 'input_image_size': 96,
 'loss': {'params': {'ignore_index': -100,
   'label_smoothing': 0.15,
   'reduction': 'mean',
   'weight': None},
  'type': 'crossentropy'},
 'lr_scheduler': {'args': {'T_0': 25,
   'T_mult': 1,
   'eta_min': 0,
   'last_epoch': -1},
  'params': {'frequency': 1,
   'interval': 'epoch',
   'monitor': 'val_loss',
   'name': None,
   'strict': True},
  'type': 'cosine_warm_restarts'},
 'multispecter_bands_indices': 200,
 'name_postfix': 'HSI',
 'optimizer': {'args': {}, 'type': 'adam'},
 'path_to_dataset_root': 'C:\\Users\\mokhail\\develop\\DATA\\UAV-HSI-Crop-Dataset',
 'segmentation_nn': {'input_layer_config': {},
  'nn_architecture': 'unet',
  'params': {'activation': None,
   'aux_params': None,
   'classes': 30,
   'decoder_attention_type': None,
   'decoder_channels': (256, 128, 128, 64, 64),
   'decoder_interpolation': 'nearest',
   'decoder_use_norm': 'batchnorm',
   'encoder_depth': 5,
   'in_chan

In [None]:
class EmbeddingLayer(nn.Module):
    def __init__(
            self,
            num_embeddings,
            embedding_dim,
            padding_idx=None,
            max_norm=None,
            norm_type=2.0,
            scale_grad_by_freq=False,
            sparse=False,
            _weight=None,
            _freeze=False,
            device=None,
            dtype=None
            ):
        super().__init__()
        self.embedding = nn.Embedding(
            num_embeddings,
            embedding_dim,
            padding_idx,
            max_norm,
            norm_type,
            scale_grad_by_freq,
            sparse,
            _weight,
            _freeze,
            device,
            dtype
            )
    def forward(self, x):
        #!!!!
        bs, seq_len, emb_dim = x.shape
        positions = torch.arange(0, seq_len, dtype=torch.long).to(x.device)
        pos_embeddings = self.embedding(positions)
        return x + pos_embeddings
    
class FixedSizeLearnableEmbeddings(nn.Module):
    def __init__(
            self,
            num_embeddings,
            embedding_dim,
            ):
        super().__init__()
        self.positional_encoding = nn.Parameter(torch.empty(1, num_embeddings, embedding_dim).normal_(std=0.02))

    def forward(self,x):
        return x + self.positional_encoding

class LazyWindowVisionTransformerBlocks(nn.Module):
    def __init__(
            self,
            rows_in_win,
            cols_in_win,
            channels:int,
            layer_num:int,
            num_heads: int,
            transformer_mlp_dim:int,
            attention_dropout:float,
            dropout: float,
            positional_encoding_block: str,
            positional_encoding_block_params: Dict,
            transformer_block: Callable[..., torch.nn.Module]=VitEncoderBlock,
            ):
        super().__init__()
        assert channels % num_heads == 0, f'Channel num should be devisible by the number of MSA heads'
        self.cols_in_win = cols_in_win
        self.rows_in_win = rows_in_win
        self.seq_len = self.rows_in_win * self.cols_in_win
        self.head_dim = channels//num_heads

        #print(pos_enc_factory_dict[positional_encoding_block] is EmbeddingLayer)
        #print(pos_enc_factory_dict[positional_encoding_block] is FixedSizeLearnableEmbeddings)

        #print(positional_encoding_block is FixedSizeLearnableEmbeddings)
        
        
        if positional_encoding_block == 'fixed_embeddings':

            positional_encoding_block_params.update(
                {'seq_len': self.seq_len, 'embedding_dim': channels}
            )
        elif positional_encoding_block == 'embedding_layer':
            positional_encoding_block_params.update(
                {'num_embeddings': self.seq_len, 'embedding_dim': channels}
            )

        positional_encoding_block = pos_enc_factory_dict[positional_encoding_block]
        #print(type(positional_encoding_block)==EmbeddingLayer)
        #print(positional_encoding_block)
        #print(positional_encoding_block_params)
        self.positional_encoding = positional_encoding_block(**positional_encoding_block_params)

        # можно создать несколько трансформерных слоев
        transformer_layers_dict = {
            f'transformer_enc_{i}': transformer_block(
                num_heads=num_heads,
                hidden_dim=channels,
                mlp_dim=transformer_mlp_dim,
                attention_dropout=attention_dropout,
                dropout=dropout)
            for i in range(layer_num)
        }
        self.transformer_layers = nn.ModuleDict(transformer_layers_dict)

    def forward(self, x):
        #print(x.shape)
        bs, channels, rows, cols = x.shape
        rows_win_num = rows//self.rows_in_win
        cols_win_num = cols//self.cols_in_win
        #print(rows_win_num, cols_win_num)

        # размер (bs, channels, rows, cols) преобразовываем в размер (bs, rows_win_num*cols_win_num, rows_in_win*cols_in_win, channels)
        # rows=row_patches*rows_in_patch, cols=col_patches*cols_in_patch
        
        rearrange_pattern = 'bs channels (rows_win_num rows_in_win) (cols_win_num cols_in_win) -> (rows_win_num cols_win_num) bs (rows_in_win cols_in_win) channels'
        rearrange_args = {
            'rows_in_win':self.rows_in_win,
            'rows_in_win':self.cols_in_win,
            'rows_win_num':rows_win_num,
            'cols_win_num':cols_win_num,
        }
            
        h = eo.rearrange(
            x,
            rearrange_pattern,
            **rearrange_args,
        )
        #print(f'h:{h.shape}')
        windows_outs = []
        # итерируем по окнам
        for i, window in enumerate(h):
            #print(f'encoded:{window.shape}')
            # позиционное кодирование длля окна
            window = self.positional_encoding(window)
            #print(f'encoded window:{window.shape}')
            #print()
            # итерируем по слоям трансформера
            for layer_name, layer in self.transformer_layers.items():

                window = layer(window)
            windows_outs.append(window.unsqueeze(0))
        
        windows_outs = torch.cat(windows_outs,dim=0)

        # rearrange back
        rearrange_pattern = '(rows_win_num cols_win_num) bs (rows_in_win cols_in_win) channels -> bs channels (rows_win_num rows_in_win) (cols_win_num cols_in_win)'
        windows_outs = eo.rearrange(
            windows_outs,
            rearrange_pattern,
            **rearrange_args,
        )

        return windows_outs

pos_enc_factory_dict = {
    'fixed_embeddings': FixedSizeLearnableEmbeddings,
    'embedding_layer': EmbeddingLayer,
}



tr_block = LazyWindowVisionTransformerBlocks(
    rows_in_win=6,
    cols_in_win=6,
    channels=256,
    layer_num=1,
    num_heads=8,
    transformer_mlp_dim=512,
    attention_dropout=0.2,
    dropout=0.2,
    positional_encoding_block='embedding_layer',#'fixed_embeddings' 'embedding_layer'
    positional_encoding_block_params={},
    transformer_block=VitEncoderBlock,
)
features = torch.randn(2, 256, 24, 24)
out = tr_block(features)
out.shape

torch.Size([2, 256, 24, 24])

In [None]:
def scaled_dot_product_attention(query, key, value, mask=None):
    # query, key, value are (batch_size, num_heads, seq_len, head_dim)
    d_k = query.size(-1)
    scores = torch.matmul(query, key.transpose(-2, -1)) / (d_k ** 0.5)

    if mask is not None:
        scores = scores.masked_fill(mask == 0, -1e9) # Apply mask for padding or causal attention

    attention_weights = F.softmax(scores, dim=-1)
    output = torch.matmul(attention_weights, value)
    return output, attention_weights

class ConvMultiheadAttention(nn.Module):
    '''
    Class for convolutional multihead self-attention Wq, Wk, Wv are replaced from fully connected to convolutional layers 
    '''
    def __init__(
            self,
            q_in_channels,
            k_in_channels,
            v_in_channels,
            q_out_channels,
            k_out_channels,
            v_out_channels,
            in_kernel_size,
            in_padding,
            in_stride,
            head_row_dim,
            head_col_dim,
            head_ch_dim,
            norm=nn.LayerNorm
            ):
        super().__init__()

        self.head_row_dim = head_row_dim
        self.head_col_dim = head_col_dim
        self.head_ch_dim = head_ch_dim

        self.input_norm = norm(head_ch_dim*head_row_dim*head_col_dim, eps=1e-6)
        
        self.conv_q = nn.Conv2d(
            in_channels=q_in_channels,
            out_channels=q_out_channels,
            kernel_size=in_kernel_size,
            padding=in_padding,
            stride=in_stride
            )
        
        self.conv_k = nn.Conv2d(
            in_channels=k_in_channels,
            out_channels=k_out_channels,
            kernel_size=in_kernel_size,
            padding=in_padding,
            stride=in_stride
            )
        
        self.conv_v = nn.Conv2d(
            in_channels=v_in_channels,
            out_channels=v_out_channels,
            kernel_size=in_kernel_size,
            padding=in_padding,
            stride=in_stride
            )
        
        #self.out_conv = nn.Conv2d(in_channels=q_out_channels,out_channels=out_channels,kernel_size=out_kernel_size,padding=out_padding,stride=out_stride)
        
    def forward(self, query, key, value, mask=None):
        bs, ch, rows, cols = value.shape
        head_row_num = rows // self.head_row_dim
        head_col_num = cols // self.head_col_dim
        head_ch_num = ch // self.head_ch_dim

        #print(head_row_num, head_col_num)

        q = self.conv_q(query)
        k = self.conv_k(key)
        v = self.conv_v(value)
        qkv_rearrangement_str = 'bs (head_ch_dim head_ch_num) (head_rdim head_rnum) (head_cdim head_cnum) -> bs (head_rnum head_cnum head_ch_num) (head_rdim head_cdim head_ch_dim) '
        
        q = eo.rearrange(
            q,
            qkv_rearrangement_str,
            head_rdim=self.head_row_dim, head_cdim=self.head_col_dim, head_ch_dim=self.head_ch_dim)

        k = eo.rearrange(
            k,
            qkv_rearrangement_str,
            head_rdim=self.head_row_dim, head_cdim=self.head_col_dim, head_ch_dim=self.head_ch_dim
            )
        
        v = eo.rearrange(
            v,
            qkv_rearrangement_str,
            head_rdim=self.head_row_dim, head_cdim=self.head_col_dim, head_ch_dim=self.head_ch_dim)
        #print(f'q:{q.shape};k:{k.shape};v:{v.shape}')
        
        weighted_v = F.scaled_dot_product_attention(query=q, key=k, value=v)
        #print(f'v_w:{weighted_v.shape}')

        weighted_v = eo.rearrange(
            weighted_v,
            'bs (head_rnum head_cnum head_ch_num) (head_rdim head_cdim head_ch_dim) -> bs (head_ch_dim head_ch_num) (head_rdim head_rnum) (head_cdim head_cnum)',
            head_rdim=self.head_row_dim, head_cdim=self.head_col_dim, head_rnum=head_row_num, head_cnum=head_col_num,
        )
        #print(f'v_w_ra:{weighted_v.shape}')
        return weighted_v

        out = self.out_conv(weighted_v)
        return out
    
class ConvMSABlock(nn.Module):
    '''
    Implementatuion of convolutional multihead self-attention block
    '''
    def __init__(
            self,
            msa_in_channels,
            msa_intermediate_channels,
            msa_in_kernel_size,
            msa_in_padding,
            msa_in_stride,
            msa_out_channels,
            
            msa_head_row_dim,
            msa_head_col_dim,
            msa_head_ch_dim,

            dropout,

            out_conv_hidden_channels,
            out_conv_kernel_size,
            out_conv_padding,
            out_conv_stride,

            out_conv_out_channels,

            out_conv_act,
            
            norm_layer: Callable[..., torch.nn.Module],
            ):
        super().__init__()
        
        self.norm1 = norm_layer(msa_in_channels)
        self.self_att = ConvMultiheadAttention(
            q_in_channels=msa_in_channels,
            k_in_channels=msa_in_channels,
            v_in_channels=msa_in_channels,
            q_out_channels=msa_intermediate_channels,
            k_out_channels=msa_intermediate_channels,
            v_out_channels=msa_intermediate_channels,
            in_kernel_size=msa_in_kernel_size,
            in_padding=msa_in_padding,
            in_stride=msa_in_stride,
            head_row_dim=msa_head_row_dim,
            head_col_dim=msa_head_col_dim,
            head_ch_dim=msa_head_ch_dim,
        )

        self.dropout = nn.Dropout2d(dropout)

        self.out_conv = nn.Sequential(
            torchvision.ops.Conv2dNormActivation(
                in_channels=msa_out_channels,
                out_channels=out_conv_hidden_channels,
                kernel_size=out_conv_kernel_size,
                padding=out_conv_padding,
                stride=out_conv_stride,
                activation_layer=out_conv_act
            ),
            torchvision.ops.Conv2dNormActivation(
                in_channels=out_conv_hidden_channels,
                out_channels=out_conv_out_channels,
                kernel_size=out_conv_kernel_size,
                padding=out_conv_padding,
                stride=out_conv_stride,
                activation_layer=out_conv_act
            )
        )

        self.norm2 = norm_layer(msa_out_channels)
    def forward(self, input_features):
        
        x = self.norm1(input_features)
        #print(x.shape)
        
        x = self.self_att(query=x, key=x, value=x)
        
        x = self.dropout(x)
        x = x + input_features
        x = self.norm2(x)
        #print(x.shape)
        y = self.out_conv(x)
        return x + y
    
class ConvCrossAttentionBlock(nn.Module):
    '''
    Implementation of convolutional cross-attention block
    '''
    def __init__(
            self,
            q_in_channels,
            k_in_channels,
            v_in_channels,
            q_out_channels,
            k_out_channels,
            v_out_channels,
            in_kernel_size,
            in_padding,
            in_stride,
            
            head_row_dim,
            head_col_dim,
            head_ch_dim,

            dropout,
            
            norm_layer: Callable[..., torch.nn.Module],
            ):
        super().__init__()
        
        self.kv_inp_norm = norm_layer(k_in_channels)
        self.q_inp_norm = norm_layer(q_in_channels)

        self.cross_att = ConvMultiheadAttention(
            q_in_channels=q_in_channels,
            k_in_channels=k_in_channels,
            v_in_channels=v_in_channels,
            q_out_channels=q_out_channels,
            k_out_channels=k_out_channels,
            v_out_channels=v_out_channels,
            in_kernel_size=in_kernel_size,
            in_padding=in_padding,
            in_stride=in_stride,
            
            head_row_dim=head_row_dim,
            head_col_dim=head_col_dim,
            head_ch_dim=head_ch_dim,
        )
        self.dropout = nn.Dropout2d(dropout)
        self.out_norm = norm_layer(q_out_channels)

    def forward(self, q, kv):
        
        x = self.kv_inp_norm(kv)
        q = self.q_inp_norm(q)
        #print(x.shape)
        
        x = self.cross_att(query=q, key=x, value=x)
        
        x = self.dropout(x)
        #print(x.shape, q.shape)
        x = x + q
        x = self.out_norm(x)
        
        return x

class ConcatDim1(nn.Module):
    '''
    Implementation of concatenation. It is nececcary for various aggreagation strategies in UNet decoder
    '''
    def forward(self, *tensors):
        return torch.cat(tensors, dim=1)

## Код для UnetAtt

In [None]:
class UnetAtt(SegmentationModel):
    """
    U-Net is a fully convolutional neural network architecture designed for semantic image segmentation.

    It consists of two main parts:

    1. An encoder (downsampling path) that extracts increasingly abstract features
    2. A decoder (upsampling path) that gradually recovers spatial details

    The key is the use of skip connections between corresponding encoder and decoder layers.
    These connections allow the decoder to access fine-grained details from earlier encoder layers,
    which helps produce more precise segmentation masks.

    The skip connections work by concatenating feature maps from the encoder directly into the decoder
    at corresponding resolutions. This helps preserve important spatial information that would
    otherwise be lost during the encoding process.

    Args:
        encoder_name: Name of the classification model that will be used as an encoder (a.k.a backbone)
            to extract features of different spatial resolution
        encoder_depth: A number of stages used in encoder in range [3, 5]. Each stage generate features
            two times smaller in spatial dimensions than previous one (e.g. for depth 0 we will have features
            with shapes [(N, C, H, W),], for depth 1 - [(N, C, H, W), (N, C, H // 2, W // 2)] and so on).
            Default is 5
        encoder_weights: One of **None** (random initialization), **"imagenet"** (pre-training on ImageNet) and
            other pretrained weights (see table with available weights for each encoder_name)
        decoder_channels: List of integers which specify **in_channels** parameter for convolutions used in decoder.
            Length of the list should be the same as **encoder_depth**
        decoder_use_norm:     Specifies normalization between Conv2D and activation.
            Accepts the following types:
            - **True**: Defaults to `"batchnorm"`.
            - **False**: No normalization (`nn.Identity`).
            - **str**: Specifies normalization type using default parameters. Available values:
              `"batchnorm"`, `"identity"`, `"layernorm"`, `"instancenorm"`, `"inplace"`.
            - **dict**: Fully customizable normalization settings. Structure:
              ```python
              {"type": <norm_type>, **kwargs}
              ```
              where `norm_name` corresponds to normalization type (see above), and `kwargs` are passed directly to the normalization layer as defined in PyTorch documentation.

            **Example**:
            ```python
            decoder_use_norm={"type": "layernorm", "eps": 1e-2}
            ```
        decoder_attention_type: Attention module used in decoder of the model. Available options are
            **None** and **scse** (https://arxiv.org/abs/1808.08127).
        decoder_interpolation: Interpolation mode used in decoder of the model. Available options are
            **"nearest"**, **"bilinear"**, **"bicubic"**, **"area"**, **"nearest-exact"**. Default is **"nearest"**.
        in_channels: A number of input channels for the model, default is 3 (RGB images)
        classes: A number of classes for output mask (or you can think as a number of channels of output mask)
        activation: An activation function to apply after the final convolution layer.
            Available options are **"sigmoid"**, **"softmax"**, **"logsoftmax"**, **"tanh"**, **"identity"**,
            **callable** and **None**. Default is **None**.
        aux_params: Dictionary with parameters of the auxiliary output (classification head). Auxiliary output is build
            on top of encoder if **aux_params** is not **None** (default). Supported params:
                - classes (int): A number of classes
                - pooling (str): One of "max", "avg". Default is "avg"
                - dropout (float): Dropout factor in [0, 1)
                - activation (str): An activation function to apply "sigmoid"/"softmax"
                    (could be **None** to return logits)
        kwargs: Arguments passed to the encoder class ``__init__()`` function. Applies only to ``timm`` models. Keys with ``None`` values are pruned before passing.

    Returns:
        ``torch.nn.Module``: Unet

    Example:
        .. code-block:: python

            import torch
            import segmentation_models_pytorch as smp

            model = smp.Unet("resnet18", encoder_weights="imagenet", classes=5)
            model.eval()

            # generate random images
            images = torch.rand(2, 3, 256, 256)

            with torch.inference_mode():
                mask = model(images)

            print(mask.shape)
            # torch.Size([2, 5, 256, 256])

    .. _Unet:
        https://arxiv.org/abs/1505.04597

    """

    requires_divisible_input_shape = False

    @supports_config_loading
    def __init__(
        self,
        decoder_layers_configs: Sequence,
        transformer_branch_config: Dict,
        encoder_name: str = "resnet34",
        encoder_depth: int = 5,
        encoder_weights: Optional[str] = "imagenet",
        decoder_use_norm: Union[bool, str, Dict[str, Any]] = "batchnorm",
        decoder_channels: Sequence[int] = (256, 128, 64, 32, 16),
        decoder_interpolation: str = "nearest",
        in_channels: int = 3,
        classes: int = 1,
        activation: Optional[Union[str, Callable]] = None,
        aux_params: Optional[dict] = None,
        **kwargs: dict[str, Any],
    ):
        super().__init__()

        decoder_use_batchnorm = kwargs.pop("decoder_use_batchnorm", None)
        if decoder_use_batchnorm is not None:
            warnings.warn(
                "The usage of decoder_use_batchnorm is deprecated. Please modify your code for decoder_use_norm",
                DeprecationWarning,
                stacklevel=2,
            )
            decoder_use_norm = decoder_use_batchnorm

        self.encoder = get_encoder(
            encoder_name,
            in_channels=in_channels,
            depth=encoder_depth,
            weights=encoder_weights,
            **kwargs,
        )

        add_center_block = encoder_name.startswith("vgg")

        self.decoder = UnetDecoderAtt(
            encoder_channels=self.encoder.out_channels,
            decoder_channels=decoder_channels,
            decoder_layers_configs=decoder_layers_configs,
            n_blocks=encoder_depth,
            use_norm=decoder_use_norm,
            add_center_block=add_center_block,
            interpolation_mode=decoder_interpolation,
        )

        self.segmentation_head = SegmentationHead(
            in_channels=decoder_channels[-1],
            out_channels=classes,
            activation=activation,
            kernel_size=3,
        )

        if aux_params is not None:
            self.classification_head = ClassificationHead(
                in_channels=self.encoder.out_channels[-1], **aux_params
            )
        else:
            self.classification_head = None

        self.name = "u-{}".format(encoder_name)
        self.initialize()

    def forward(self, x):
        """Sequentially pass `x` trough model`s encoder, decoder and heads"""

        features = self.encoder(x)
        decoder_output = self.decoder(features)

        masks = self.segmentation_head(decoder_output)

        if self.classification_head is not None:
            labels = self.classification_head(features[-1])
            return masks, labels

        return masks

class UnetDecoderBlockAtt(nn.Module):
    """A decoder block in the U-Net architecture that performs upsampling and feature fusion."""

    def __init__(
        self,
        config
    ):
        super().__init__()
        self.interpolation_mode = config['interpolation_mode']
        in_channels = config['in_channels']
        out_channels = config['out_channels']
        skip_channels = config['skip_channels']

        agg_type = config['aggregation_layer']['layer']
        att1_type = config['attention1']['layer']
        
        # настраиваем cross-attention
        if agg_type == 'conv_cross_att':
            if skip_channels != 0:
                config['aggregation_layer']['params']['q_in_channels'] = in_channels
                config['aggregation_layer']['params']['k_in_channels'] = skip_channels
                config['aggregation_layer']['params']['v_in_channels'] = skip_channels
                config['aggregation_layer']['params']['q_out_channels'] = in_channels
                config['aggregation_layer']['params']['k_out_channels'] = skip_channels
                config['aggregation_layer']['params']['v_out_channels'] = skip_channels

        # настраиваем multihead self-attention, в зависимости от параметров слоя объединения
        if att1_type == 'conv_msa':
            if agg_type == 'conv_cross_att':
                # надо добавить выбор q=features OR q=skip
                if skip_channels != 0:
                    att_channels = in_channels
                else:
                    att_channels = in_channels
                config['attention1']['params']['msa_in_channels'] = att_channels
                config['attention1']['params']['msa_intermediate_channels'] = att_channels
                config['attention1']['params']['msa_out_channels'] = att_channels
                config['attention1']['params']['out_conv_out_channels'] = att_channels
            else:
                config['attention1']['params']['msa_in_channels'] = in_channels + skip_channels
                config['attention1']['params']['msa_intermediate_channels'] = in_channels + skip_channels
                config['attention1']['params']['msa_out_channels'] = in_channels + skip_channels
                config['attention1']['params']['out_conv_out_channels'] = in_channels + skip_channels

        elif att1_type == 'win_msa':
            if 'cross_att' in agg_type:
                config['attention1']['params']['channels'] = in_channels
            else:
                config['attention1']['params']['channels'] = in_channels + skip_channels
            #pos_enc_type = config['attention1']['params']['positional_encoding_block']
            #config['attention1']['params']['positional_encoding_block'] = pos_enc_factory_dict[pos_enc_type]
            
        # получаем метод создания слоя агрегации признаков
        create_aggregation = unet_aggregation_factory_dict[agg_type]
        if skip_channels != 0:
            self.aggregation_layer = create_aggregation(**config['aggregation_layer']['params'])
        else:
            self.aggregation_layer = nn.Identity()
        
        # получаем метод создания слоя внимания после агрегации
        create_attention = unet_attention_factory_dict[att1_type]
        self.attention1 = create_attention(**config['attention1']['params'])
        # создаем сверточные слои после слоя внимания
        conv_layers = []
        for idx, params in enumerate(config['conv']):
            if idx == 0:
                if agg_type == 'conv_cross_att' and skip_channels != 0:
                    in_conv_ch = in_channels
                else:
                    in_conv_ch = in_channels + skip_channels
            else:
                in_conv_ch = out_channels

            conv = torchvision.ops.Conv2dNormActivation(
                in_channels=in_conv_ch,
                out_channels=out_channels,
                **params
                )
            conv_layers.append(conv)
        self.conv_layers = nn.Sequential(*conv_layers)

        att2_type = config['attention2']['layer']
        if att2_type == 'conv_msa':
            config['attention2']['params']['msa_in_channels'] = out_channels
            config['attention2']['params']['msa_intermediate_channels'] = out_channels
            config['attention2']['params']['msa_out_channels'] = out_channels
            config['attention2']['params']['out_conv_out_channels'] = out_channels
        elif att2_type == 'win_msa':
            config['attention2']['params']['channels'] = out_channels
            #pos_enc_type = config['attention2']['params']['positional_encoding_block']
            #config['attention2']['params']['positional_encoding_block'] = pos_enc_factory_dict[pos_enc_type]
            
        create_attention = unet_attention_factory_dict[att2_type]
        self.attention2 = create_attention(**config['attention2']['params'])
        

    def forward(
        self,
        feature_map: torch.Tensor,
        target_height: int,
        target_width: int,
        skip_connection: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        feature_map = F.interpolate(
            feature_map,
            size=(target_height, target_width),
            mode=self.interpolation_mode,
        )
        #print('DECODER LAYER!!!')
        if skip_connection is not None:
            #print(f'feat:{feature_map.shape},skip:{skip_connection.shape}')
            feature_map = self.aggregation_layer(feature_map, skip_connection)
            feature_map = self.attention1(feature_map)
        #print(f'att_feat:{feature_map.shape}')
        feature_map = self.conv_layers(feature_map)
        feature_map = self.attention2(feature_map)
        return feature_map

class UnetDecoderAtt(nn.Module):
    """The decoder part of the U-Net architecture.

    Takes encoded features from different stages of the encoder and progressively upsamples them while
    combining with skip connections. This helps preserve fine-grained details in the final segmentation.
    """

    def __init__(
        self,
        encoder_channels: Sequence[int],
        decoder_channels: Sequence[int],
        decoder_layers_configs: Sequence[Dict],
        n_blocks: int = 5,
        add_center_block: bool = False,
        interpolation_mode: str = "nearest",
        use_norm:str = "batchnorm",
    ):
        super().__init__()

        if n_blocks != len(decoder_channels):
            raise ValueError(
                "Model depth is {}, but you provide `decoder_channels` for {} blocks.".format(
                    n_blocks, len(decoder_channels)
                )
            )
        
        if decoder_layers_configs is not None and (n_blocks != len(decoder_layers_configs)):
            raise ValueError(
                "Model depth is {}, but you provide `attention_configs` for {} blocks.".format(
                    n_blocks, len(decoder_layers_configs)
                )
            )

        # remove first skip with same spatial resolution
        encoder_channels = encoder_channels[1:]
        # reverse channels to start from head of encoder
        encoder_channels = encoder_channels[::-1]

        # computing blocks input and output channels
        head_channels = encoder_channels[0]
        in_channels = [head_channels] + list(decoder_channels[:-1])
        skip_channels = list(encoder_channels[1:]) + [0]
        out_channels = decoder_channels

        if add_center_block:
            self.center = smp.decoders.unet.decoder.UnetCenterBlock(
                head_channels,
                head_channels,
                use_norm=use_norm,
            )
        else:
            self.center = nn.Identity()

        # combine decoder keyword arguments
        self.blocks = nn.ModuleList()
        for block_in_channels, block_skip_channels, block_out_channels, decoder_layer_config in zip(
            in_channels, skip_channels, out_channels, decoder_layers_configs
        ):
            #print(f'in:{block_in_channels}, skip:{block_skip_channels}, out:{block_out_channels}')
            #print('-------------')
            decoder_layer_config['in_channels'] = block_in_channels
            decoder_layer_config['skip_channels'] = block_skip_channels
            decoder_layer_config['out_channels'] = block_out_channels
            block = UnetDecoderBlockAtt(
                decoder_layer_config
            )
            self.blocks.append(block)

    def forward(self, features: List[torch.Tensor]) -> torch.Tensor:
        # spatial shapes of features: [hw, hw/2, hw/4, hw/8, ...]
        spatial_shapes = [feature.shape[2:] for feature in features]
        spatial_shapes = spatial_shapes[::-1]

        features = features[1:]  # remove first skip with same spatial resolution
        features = features[::-1]  # reverse channels to start from head of encoder

        head = features[0]
        skip_connections = features[1:]

        x = self.center(head)

        for i, decoder_block in enumerate(self.blocks):
            # upsample to the next spatial shape
            height, width = spatial_shapes[i + 1]
            skip_connection = skip_connections[i] if i < len(skip_connections) else None
            x = decoder_block(x, height, width, skip_connection=skip_connection)

        return x

# Адаптация FCN

In [44]:
class FCNDecoderBlock(nn.Module):
    """A decoder block in the FCN architecture that performs upsampling and feature fusion."""

    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        use_norm: Union[bool, str, Dict[str, Any]] = "batchnorm",
        attention_type: Optional[str] = None,
        interpolation_mode: str = "nearest",
    ):
        super().__init__()
        self.interpolation_mode = interpolation_mode
        self.conv1 = md.Conv2dReLU(
            in_channels,
            out_channels,
            kernel_size=3,
            padding=1,
            use_norm=use_norm,
        )
        self.attention1 = md.Attention(
            attention_type, in_channels=in_channels
        )
        self.conv2 = md.Conv2dReLU(
            out_channels,
            out_channels,
            kernel_size=3,
            padding=1,
            use_norm=use_norm,
        )
        self.attention2 = md.Attention(attention_type, in_channels=out_channels)

    def forward(
        self,
        feature_map: torch.Tensor,
        target_height: int,
        target_width: int,
        skip_connection: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        # сначала интерполяция и свертка
        feature_map = F.interpolate(
            feature_map,
            size=(target_height, target_width),
            mode=self.interpolation_mode,
        )
        feature_map = self.conv1(feature_map)
        feature_map = self.attention1(feature_map)
        
        # потом сложение и выходная свертка
        if skip_connection is not None:
            feature_map = feature_map + skip_connection
        feature_map = self.conv2(feature_map)
        feature_map = self.attention2(feature_map)
        
        return feature_map
    
class FCNDecoder(nn.Module):
    def __init__(
            self,
            encoder_channels: Sequence[int],
            decoder_last_channel: Sequence[int],
            n_blocks: int = 5,
            use_norm: Union[bool, str, Dict[str, Any]] = "batchnorm",
            attention_type: Optional[str] = None,
            add_center_block: bool = False,
            interpolation_mode: str = "nearest",
        ):
            super().__init__()
            # remove first skip with same spatial resolution
            encoder_channels = encoder_channels[1:]
            # reverse channels to start from head of encoder
            encoder_channels = encoder_channels[::-1]

            # computing blocks input and output channels
            head_channels = encoder_channels[0]
            in_channels = encoder_channels
            out_channels = encoder_channels[1:] + [decoder_last_channel]
            
            if add_center_block:
                self.center = smp.decoders.unet.decoder.UnetCenterBlock(
                    head_channels,
                    head_channels//2,
                    use_norm=use_norm,
                )
            else:
                self.center = nn.Identity()

            # combine decoder keyword arguments
            self.blocks = nn.ModuleList()
            for block_in_channels, block_out_channels in zip(
                in_channels, out_channels
            ):
                block = FCNDecoderBlock(
                    block_in_channels,
                    block_out_channels,
                    use_norm=use_norm,
                    attention_type=attention_type,
                    interpolation_mode=interpolation_mode,
                )
                self.blocks.append(block)

    def forward(self, features: List[torch.Tensor]) -> torch.Tensor:
        # spatial shapes of features: [hw, hw/2, hw/4, hw/8, ...]
        spatial_shapes = [feature.shape[2:] for feature in features]
        spatial_shapes = spatial_shapes[::-1]

        features = features[1:]  # remove first skip with same spatial resolution
        features = features[::-1]  # reverse channels to start from head of encoder
        

        head = features[0]
        skip_connections = features[1:]
        

        x = self.center(head)

        for i, decoder_block in enumerate(self.blocks):
            # upsample to the next spatial shape
            height, width = spatial_shapes[i + 1]
            
            skip_connection = skip_connections[i] if i < len(skip_connections) else None
            
            x = decoder_block(x, height, width, skip_connection=skip_connection)

        return x

class FCN(SegmentationModel):
    """
    FCN is a fully convolutional neural network architecture designed for semantic image segmentation.

    It consists of two main parts:

    1. An encoder (downsampling path) that extracts increasingly abstract features
    2. A decoder (upsampling path) that gradually recovers spatial details

    The key is the use of skip connections between corresponding encoder and decoder layers.
    These connections allow the decoder to access fine-grained details from earlier encoder layers,
    which helps produce more precise segmentation masks.

    The skip connections work by concatenating feature maps from the encoder directly into the decoder
    at corresponding resolutions. This helps preserve important spatial information that would
    otherwise be lost during the encoding process.

    Args:
        encoder_name: Name of the classification model that will be used as an encoder (a.k.a backbone)
            to extract features of different spatial resolution
        encoder_depth: A number of stages used in encoder in range [3, 5]. Each stage generate features
            two times smaller in spatial dimensions than previous one (e.g. for depth 0 we will have features
            with shapes [(N, C, H, W),], for depth 1 - [(N, C, H, W), (N, C, H // 2, W // 2)] and so on).
            Default is 5
        encoder_weights: One of **None** (random initialization), **"imagenet"** (pre-training on ImageNet) and
            other pretrained weights (see table with available weights for each encoder_name)
        decoder_channels: List of integers which specify **in_channels** parameter for convolutions used in decoder.
            Length of the list should be the same as **encoder_depth**
        decoder_use_norm:     Specifies normalization between Conv2D and activation.
            Accepts the following types:
            - **True**: Defaults to `"batchnorm"`.
            - **False**: No normalization (`nn.Identity`).
            - **str**: Specifies normalization type using default parameters. Available values:
              `"batchnorm"`, `"identity"`, `"layernorm"`, `"instancenorm"`, `"inplace"`.
            - **dict**: Fully customizable normalization settings. Structure:
              ```python
              {"type": <norm_type>, **kwargs}
              ```
              where `norm_name` corresponds to normalization type (see above), and `kwargs` are passed directly to the normalization layer as defined in PyTorch documentation.

            **Example**:
            ```python
            decoder_use_norm={"type": "layernorm", "eps": 1e-2}
            ```
        decoder_attention_type: Attention module used in decoder of the model. Available options are
            **None** and **scse** (https://arxiv.org/abs/1808.08127).
        decoder_interpolation: Interpolation mode used in decoder of the model. Available options are
            **"nearest"**, **"bilinear"**, **"bicubic"**, **"area"**, **"nearest-exact"**. Default is **"nearest"**.
        in_channels: A number of input channels for the model, default is 3 (RGB images)
        classes: A number of classes for output mask (or you can think as a number of channels of output mask)
        activation: An activation function to apply after the final convolution layer.
            Available options are **"sigmoid"**, **"softmax"**, **"logsoftmax"**, **"tanh"**, **"identity"**,
            **callable** and **None**. Default is **None**.
        aux_params: Dictionary with parameters of the auxiliary output (classification head). Auxiliary output is build
            on top of encoder if **aux_params** is not **None** (default). Supported params:
                - classes (int): A number of classes
                - pooling (str): One of "max", "avg". Default is "avg"
                - dropout (float): Dropout factor in [0, 1)
                - activation (str): An activation function to apply "sigmoid"/"softmax"
                    (could be **None** to return logits)
        kwargs: Arguments passed to the encoder class ``__init__()`` function. Applies only to ``timm`` models. Keys with ``None`` values are pruned before passing.

    Returns:
        ``torch.nn.Module``: Unet

    Example:
        .. code-block:: python

            import torch
            import segmentation_models_pytorch as smp

            model = smp.Unet("resnet18", encoder_weights="imagenet", classes=5)
            model.eval()

            # generate random images
            images = torch.rand(2, 3, 256, 256)

            with torch.inference_mode():
                mask = model(images)

            print(mask.shape)
            # torch.Size([2, 5, 256, 256])

    .. _Unet:
        https://arxiv.org/abs/1505.04597

    """

    requires_divisible_input_shape = False

    @supports_config_loading
    def __init__(
        self,
        encoder_name: str = "resnet34",
        encoder_depth: int = 5,
        encoder_weights: Optional[str] = "imagenet",
        decoder_use_norm: Union[bool, str, Dict[str, Any]] = "batchnorm",
        decoder_last_channel: int = 16,
        decoder_attention_type: Optional[str] = None,
        decoder_interpolation: str = "nearest",
        in_channels: int = 3,
        classes: int = 1,
        activation: Optional[Union[str, Callable]] = None,
        aux_params: Optional[dict] = None,
        **kwargs: dict[str, Any],
    ):
        super().__init__()

        decoder_use_batchnorm = kwargs.pop("decoder_use_batchnorm", None)
        if decoder_use_batchnorm is not None:
            warnings.warn(
                "The usage of decoder_use_batchnorm is deprecated. Please modify your code for decoder_use_norm",
                DeprecationWarning,
                stacklevel=2,
            )
            decoder_use_norm = decoder_use_batchnorm

        self.encoder = get_encoder(
            encoder_name,
            in_channels=in_channels,
            depth=encoder_depth,
            weights=encoder_weights,
            **kwargs,
        )

        add_center_block = encoder_name.startswith("vgg")

        self.decoder = FCNDecoder(
            encoder_channels=self.encoder.out_channels,
            decoder_last_channel=decoder_last_channel,
            n_blocks=encoder_depth,
            use_norm=decoder_use_norm,
            add_center_block=add_center_block,
            attention_type=decoder_attention_type,
            interpolation_mode=decoder_interpolation,
        )

        self.segmentation_head = SegmentationHead(
            in_channels=decoder_last_channel,
            out_channels=classes,
            activation=activation,
            kernel_size=3,
        )

        if aux_params is not None:
            self.classification_head = ClassificationHead(
                in_channels=decoder_last_channel, **aux_params
            )
        else:
            self.classification_head = None

        self.name = "fcn-{}".format(encoder_name)
        self.initialize()

#model = FCN(decoder_last_channel=32)
#ret = model(torch.randn(1, 3, 224, 224))
#ret.shape

# Фабрики для создания моделей по конфигурациям

In [None]:
def relace_input_layer(model:nn.Module, config:dict):
    pass

def create_weights_from_avg_ch(weight, new_in_channels):
    return torch.cat([weight.mean(dim=1).unsqueeze(1)]*new_in_channels, dim=1)

def cerate_weights_from_repeated_ch(weight, in_channels, new_in_channels):
    ch_multiple = new_in_channels//in_channels
    reminded_channels = new_in_channels%in_channels
    # сначала набираем новые каналы путем подставления друг за другом (stack) каналов изначального изображения,
    # а затем, если количество новых каналов не делится без остатка на количество изначальных, 
    # то набираем оставшиеся новые каналы из оставшихся изначальных    
    new_weight = torch.cat(
        [weight]*ch_multiple + [weight[:,:reminded_channels]], dim=1)
    return new_weight

def create_augmentation_transforms(transforms_dict:Dict[str, Dict]):
    transforms_list = []
    for name, transform_params in transforms_dict.items():
        transform_creation_fn = transforms_factory_dict[name]
        transforms_list.append(transform_creation_fn(**transform_params))
    #return v2.Compose([v2.RandomOrder(transforms_list)])
    return v2.RandomOrder(transforms_list)

def create_model(config_dict, segmentation_nns_factory_dict):
    model_name = config_dict['segmentation_nn']['nn_architecture']
    if 'fpn' in model_name:
        stride = config_dict['segmentation_nn']['input_layer_config']['params']['stride']
        if isinstance(stride, (list, tuple)):
            stride_val = stride[0]
        elif isinstance(stride, (list, tuple)):
            stride_val = stride
        #if stride_val != 1:
        config_dict['segmentation_nn']['params']['upsampling'] = stride_val
    # создаем нейронную сеть из фабрики
    model = segmentation_nns_factory_dict[model_name](**config_dict['segmentation_nn']['params'])
    multispecter_bands_indices = config_dict['multispecter_bands_indices']
    in_channels = len(multispecter_bands_indices)
    # замена входного слоя, если кол-во каналов изображения не равно трем
    input_conv = model.get_submodule(
        config_dict['segmentation_nn']['input_layer_config']['layer_path']
        )
    if 'channels' in config_dict['segmentation_nn']['input_layer_config']['replace_type']:
        new_input_conv = nn.Conv2d(
            in_channels=in_channels,
            out_channels=input_conv.out_channels,
            kernel_size=input_conv.kernel_size,
            #stride=conv1.stride,
            stride=config_dict['segmentation_nn']['input_layer_config']['params']['stride'],
            #padding=conv1.padding,
            padding=config_dict['segmentation_nn']['input_layer_config']['params']['padding'],
            dilation=input_conv.dilation,
            groups=input_conv.groups,
            bias=input_conv.bias is not None
        )
        if in_channels != 3:
            # получаем входной слой, специфический для конкретной нейронной сети
            
            
            if config_dict['segmentation_nn']['params']['encoder_weights'] is not None:
                # выбор типа обнолвления весов
                if config_dict['segmentation_nn']['input_layer_config']['weight_update_type'] == 'average_all':
                    
                    #new_weight = torch.cat([input_conv.weight.mean(dim=1).unsqueeze(1)]*in_channels, dim=1)
                    new_weight = create_weights_from_avg_ch(input_conv.weight, in_channels)
                    input_conv.weight = nn.Parameter(new_weight)

                elif config_dict['segmentation_nn']['input_layer_config']['weight_update_type'] == 'repeate':
                    '''
                    ch_multiple = in_channels//input_conv.in_channels
                    reminded_channels = in_channels%input_conv.in_channels
                    new_weight = torch.cat(
                        [input_conv.weight]*ch_multiple + [input_conv.weight[:,:reminded_channels]], dim=1)
                    '''
                    new_weight = cerate_weights_from_repeated_ch(input_conv.weight, input_conv.in_channels, in_channels)
                    new_input_conv.weight = nn.Parameter(new_weight)
        else:
            # если у нас три канала на входе, то просто перезаписываем вес
            new_input_conv.weight = nn.Parameter(input_conv.weight)

        if input_conv.bias is not None:
            new_input_conv.bias = input_conv.bias
        # перезаписываем входной слой исходя из специфики оригинальной сети
        model.set_submodule(
                config_dict['segmentation_nn']['input_layer_config']['layer_path'],
                new_input_conv
                )

    elif 'multisize_conv' in config_dict['segmentation_nn']['input_layer_config']['replace_type']:
        multisize_params = config_dict['segmentation_nn']['input_layer_config']['params']
        new_input_conv = MultisizeConv(**multisize_params)

        # Если мы модифицируем входной слой.
        if config_dict['segmentation_nn']['params']['encoder_weights'] is not None:
            # вычленяем словрь с параметрами размеров ядер сверток.
            kernel_sizes_dict = config_dict['segmentation_nn']['input_layer_config']['params']['kernel_size']
            interpolated_kernels_dict = {}
            # выполняем интерполяцию ядер свертки для каждого набора из новых ядер
            for name, kernel_size in kernel_sizes_dict.items():
                if isinstance(kernel_size, int):
                    kernel_size = (kernel_size, kernel_size)
                # получаем интерполированную версию ядер свертки
                interpolated_kernels_dict[name] = [
                    F.interpolate(input_conv.weight, size=kernel_size, mode='bicubic', antialias=True),
                    input_conv.bias]
                '''            
                out_channels_dict = config_dict['segmentation_nn']['input_layer_config']['params']['out_channels']
                for name, out_channels in out_channels_dict.items():
                    weights = interpolated_kernels_dict[name][0]
                    weights = create_weights_from_avg_ch(weights, in_channels)
                    interpolated_kernels_dict[name][0] = weights
                '''
            #out_channels_dict = config_dict['segmentation_nn']['input_layer_config']['params']['out_channels']
            for name in interpolated_kernels_dict.keys():
                weights = interpolated_kernels_dict[name][0]
                if config_dict['segmentation_nn']['input_layer_config']['weight_update_type'] == 'average_all':
                    weights = create_weights_from_avg_ch(weights, new_in_channels=in_channels)
                elif config_dict['segmentation_nn']['input_layer_config']['weight_update_type'] == 'repeat':
                    weights = cerate_weights_from_repeated_ch(weights, in_channels=input_conv.in_channels, new_in_channels=in_channels)
                
                interpolated_kernels_dict[name][0] = weights
                        
            new_input_conv.update_weights(new_weights_dict=interpolated_kernels_dict)
        if config_dict['segmentation_nn']['input_layer_config']['params']['aggregation_type'] == 'cat':
            # Если тип агрегации выхода MultisizeConv - это конкатенация, то изменяем также второй сверточный слой,
            # чтобы число его входных каналов соответствовало числу выходных первого слоя 
            raise NotImplementedError
        # заменяем сходной слой по заранее определенному пути, который может варьировать в зависимости от архитектуры энкодера
        model.set_submodule(
                config_dict['segmentation_nn']['input_layer_config']['layer_path'],
                new_input_conv
                )
    return model

class MultisizeConv(nn.Module):
    def __init__(
            self,
            in_channels:int,
            out_channels:dict,
            kernel_size:dict,
            stride:dict,
            padding:dict,
            dilation:dict,
            groups:dict,
            bias:dict,
            aggregation_type:str,
            ):
        super().__init__()
        
        self.aggregation_type = aggregation_type
        self.multisize_convs = nn.ModuleDict()
        for conv_name in kernel_size.keys():
            self.multisize_convs[conv_name] = nn.Conv2d(
                in_channels=in_channels,
                out_channels=out_channels[conv_name],
                kernel_size=kernel_size[conv_name],
                stride=stride[conv_name],
                padding=padding[conv_name],
                dilation=dilation[conv_name],
                groups=groups[conv_name],
                bias=bias[conv_name]
                )
        
    def update_weights(self, new_weights_dict):
        '''
        На вход принимается словрь со структурой {'имя_свертки': (weight, bias)}
        '''
        for conv_name, (weight, bias) in new_weights_dict.items():
            self.multisize_convs[conv_name].weight = nn.Parameter(weight)
            if self.multisize_convs[conv_name].bias is not None:
                self.multisize_convs[conv_name].bias = nn.Parameter(bias)

    def forward(self, x):
        outputs = []
        for conv_name in self.multisize_convs.keys():
            out = self.multisize_convs[conv_name](x)
            #print(out.shape)
            outputs.append(out)
        
        if self.aggregation_type == 'add':
            outputs = torch.stack(outputs, dim=0)
            outputs = outputs.sum(dim=0)
        elif self.aggregation_type == 'cat':
            outputs = torch.cat(outputs, dim=1)
        else:
            raise ValueError(f'self.aggregation_type should be either "add" or "cat". Got {self.aggregation_type}')
        return outputs

segmentation_nns_factory_dict = {
    'unet': smp.Unet,
    'att_unet': UnetAtt,
    'fpn': smp.FPN,
    'custom_fpn': FPNMod,
    'unet++': UnetPlusPlusMod,
    'fcn': FCN,
    'custom_manet': MAnetMod,
}

unet_aggregation_factory_dict = {
    'concat': ConcatDim1,
    'conv_cross_att': ConvCrossAttentionBlock,
    'cross-att': VisionTransformerBlock,
    
}

pos_enc_factory_dict = {
    'fixed_embeddings': FixedSizeLearnableEmbeddings,
    'embedding_layer': EmbeddingLayer,
    'none': nn.Identity
}

unet_attention_factory_dict = {
    'conv_msa': ConvMSABlock,
    'win_msa': LazyWindowVisionTransformerBlocks,
    'none': nn.Identity,
}

criterion_factory_dict = {
    'crossentropy': nn.CrossEntropyLoss,
    'dice_crossentropy': DiceCELoss,
    'dice': smp.losses.DiceLoss
}

optimizers_factory_dict = {
    'adam': torch.optim.Adam,
    'adamw': torch.optim.AdamW
}

lr_schedulers_factory_dict = {
    'cosine_warm_restarts': torch.optim.lr_scheduler.CosineAnnealingWarmRestarts,
    'plateau': torch.optim.lr_scheduler.ReduceLROnPlateau,
    'cosine': torch.optim.lr_scheduler.CosineAnnealingLR,
}

transforms_factory_dict = {
    'affine': v2.RandomAffine,
    'perspective': v2.RandomPerspective,
    'horizontal_flip': v2.RandomHorizontalFlip,
    'vertical_flip': v2.RandomVerticalFlip,
    'crop': v2.RandomCrop,
    'gauss_noise': v2.GaussianNoise,
    'gauss_blur': v2.GaussianBlur,
    'elastic': v2.ElasticTransform,
}

# Конфигурации

In [71]:
unet_config_dict = {
    'name_postfix': '',
    'segmentation_nn': {
        'nn_architecture': 'unet',
        'params': {
            'encoder_name': "efficientnet-b0",
            'encoder_depth': 5,
            'encoder_weights': "imagenet",
            'decoder_use_norm': "batchnorm",
            'decoder_channels': (256, 128, 128, 128, 128),
            'decoder_attention_type': None,
            'decoder_interpolation': "nearest",
            'in_channels': 3,
            'classes': 11,
            'activation': None,
            'aux_params': None,
        },
        'input_layer_config': {
            'layer_path': 'encoder._conv_stem',
            'replace_type': 'channels+stride',
            'weight_update_type': 'repeate',
            'params':{
                'stride': (1, 1),
                'padding': (1, 1),
                }
        }
    },
    'multispecter_bands_indices': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12],
    'input_image_size': 96,
    'loss': {
        'type': 'crossentropy',
        #'params': {'weight': 'classes'}
        'params': {
            'weight': None,
            'ignore_index': -100,
            'reduction': "mean",
            'label_smoothing': 0.15,
            },
    },
    'optimizer': {
        'type': 'adam',
        'args': {}
    },
    'lr_scheduler': {
        'type': 'cosine_warm_restarts',
        'args': {
            'T_0': 25,
            'T_mult': 1,
            'eta_min': 0,
            'last_epoch': -1
        },
        'params':{
            'interval': 'epoch',
            'frequency': 1,
            'monitor': 'val_loss',
            'strict': True,
            'name': None,
        },
    },
    'device': 'cuda:0',
    'batch_size': 16,
    'epoch_num':300,
    'train_augmentations': {
        'affine':{
            'degrees': [0, 45],
            'translate': [0, 0.3],
            'scale': [0.7, 1.5],
            'shear': [0, 0.2],
            'fill': [0],
        },
        'horizontal_flip': {
            'p': 0.5,
        },
        'vertical_flip':{
            'p': 0.5,
        },
    },
    'path_to_dataset_root': r'C:\Users\admin\python_programming\DATA\DATA_FOR_TRAINIG_96',
    #'path_to_dataset_root': r'C:\Users\mokhail\develop\DATA\DATA_FOR_TRAINIG_96'
}

win_att1l_cat_agg_2_conv_unet_config_dict = {
    'name_postfix': '1L_win_cat_agg',
    'segmentation_nn': {
        'nn_architecture': 'att_unet',
        'params': {
            'decoder_layers_configs': [
                {
                    'interpolation_mode': "nearest",
                    'aggregation_layer':{
                        'layer': 'concat',
                        'params': {},
                    },
                    'attention1': {
                        'layer':'none',
                        'params':{},
                    },
                    'conv': [
                        {
                            'kernel_size': 3,
                            'stride':  1,
                            'padding': 1,
                            'groups': 1,
                            'norm_layer': nn.BatchNorm2d,
                            'activation_layer': nn.ReLU,
                            'dilation': 1,
                            'inplace': True,
                            'bias': True,
                        },
                        {
                            'kernel_size': 3,
                            'stride':  1,
                            'padding': 1,
                            'groups': 1,
                            'norm_layer': nn.BatchNorm2d,
                            'activation_layer': nn.ReLU,
                            'dilation': 1,
                            'inplace': True,
                            'bias': True,
                        },
                    ],
                    'attention2': {
                        'layer':'win_msa',
                        'params':{
                            'rows_in_win':6,
                            'cols_in_win':6,
                            'layer_num':1,
                            'num_heads':8,
                            'transformer_mlp_dim':8,
                            'attention_dropout':0.2,
                            'dropout':0.2,
                            'positional_encoding_block':'embedding_layer',
                            'positional_encoding_block_params':{},
                            'transformer_block':VitEncoderBlock,
                        },
                    },
                },
                {
                    'interpolation_mode': "nearest",
                    'aggregation_layer':{
                        'layer': 'concat',
                        'params': {},
                    },
                    'attention1': {
                        'layer':'none',
                        'params':{},
                    },
                    'conv': [
                        {
                            'kernel_size': 3,
                            'stride':  1,
                            'padding': 1,
                            'groups': 1,
                            'norm_layer': nn.BatchNorm2d,
                            'activation_layer': nn.ReLU,
                            'dilation': 1,
                            'inplace': True,
                            'bias': True,
                        },
                        {
                            'kernel_size': 3,
                            'stride':  1,
                            'padding': 1,
                            'groups': 1,
                            'norm_layer': nn.BatchNorm2d,
                            'activation_layer': nn.ReLU,
                            'dilation': 1,
                            'inplace': True,
                            'bias': True,
                        },
                    ],
                    'attention2': {
                        'layer':'win_msa',
                        'params':{
                            'rows_in_win':12,
                            'cols_in_win':12,
                            'layer_num':1,
                            'num_heads':8,
                            'transformer_mlp_dim':8,
                            'attention_dropout':0.2,
                            'dropout':0.2,
                            'positional_encoding_block':'embedding_layer',
                            'positional_encoding_block_params':{},
                            'transformer_block':VitEncoderBlock,
                        },
                    },
                },
                {
                    'interpolation_mode': "nearest",
                    'aggregation_layer':{
                        'layer': 'concat',
                        'params': {},
                    },
                    'attention1': {
                        'layer':'none',
                        'params':{},
                    },
                    'conv': [
                        {
                            'kernel_size': 3,
                            'stride':  1,
                            'padding': 1,
                            'groups': 1,
                            'norm_layer': nn.BatchNorm2d,
                            'activation_layer': nn.ReLU,
                            'dilation': 1,
                            'inplace': True,
                            'bias': True,
                        },
                        {
                            'kernel_size': 3,
                            'stride':  1,
                            'padding': 1,
                            'groups': 1,
                            'norm_layer': nn.BatchNorm2d,
                            'activation_layer': nn.ReLU,
                            'dilation': 1,
                            'inplace': True,
                            'bias': True,
                        },
                    ],
                    'attention2': {
                        'layer':'win_msa',
                        'params':{
                            'rows_in_win':24,
                            'cols_in_win':24,
                            'layer_num':1,
                            'num_heads':8,
                            'transformer_mlp_dim':8,
                            'attention_dropout':0.2,
                            'dropout':0.2,
                            'positional_encoding_block':'embedding_layer',
                            'positional_encoding_block_params':{},
                            'transformer_block':VitEncoderBlock,
                        },
                    },
                },
                {
                    'interpolation_mode': "nearest",
                    'aggregation_layer':{
                        'layer': 'concat',
                        'params': {},
                    },
                    'attention1': {
                        'layer':'none',
                        'params':{},
                    },
                    'conv': [
                        {
                            'kernel_size': 3,
                            'stride':  1,
                            'padding': 1,
                            'groups': 1,
                            'norm_layer': nn.BatchNorm2d,
                            'activation_layer': nn.ReLU,
                            'dilation': 1,
                            'inplace': True,
                            'bias': True,
                        },
                        {
                            'kernel_size': 3,
                            'stride':  1,
                            'padding': 1,
                            'groups': 1,
                            'norm_layer': nn.BatchNorm2d,
                            'activation_layer': nn.ReLU,
                            'dilation': 1,
                            'inplace': True,
                            'bias': True,
                        },
                    ],
                    'attention2': {
                        'layer':'win_msa',
                        'params':{
                            'rows_in_win':24,
                            'cols_in_win':24,
                            'layer_num':1,
                            'num_heads':8,
                            'transformer_mlp_dim':8,
                            'attention_dropout':0.2,
                            'dropout':0.2,
                            'positional_encoding_block':'embedding_layer',
                            'positional_encoding_block_params':{},
                            'transformer_block':VitEncoderBlock,
                        },
                    },
                },
                {
                    'interpolation_mode': "nearest",
                    'aggregation_layer':{
                        'layer': 'concat',
                        'params': {},
                    },
                    'attention1': {
                        'layer':'none',
                        'params':{},
                    },
                    'conv': [
                        {
                            'kernel_size': 3,
                            'stride':  1,
                            'padding': 1,
                            'groups': 1,
                            'norm_layer': nn.BatchNorm2d,
                            'activation_layer': nn.ReLU,
                            'dilation': 1,
                            'inplace': True,
                            'bias': True,
                        },
                        {
                            'kernel_size': 3,
                            'stride':  1,
                            'padding': 1,
                            'groups': 1,
                            'norm_layer': nn.BatchNorm2d,
                            'activation_layer': nn.ReLU,
                            'dilation': 1,
                            'inplace': True,
                            'bias': True,
                        },
                    ],
                    'attention2': {
                        'layer':'win_msa',
                        'params':{
                            'rows_in_win':24,
                            'cols_in_win':24,
                            'layer_num':1,
                            'num_heads':8,
                            'transformer_mlp_dim':8,
                            'attention_dropout':0.2,
                            'dropout':0.2,
                            'positional_encoding_block':'embedding_layer',
                            'positional_encoding_block_params':{},
                            'transformer_block':VitEncoderBlock,
                        },
                    },
                },
                
            ],
            'encoder_name': "efficientnet-b0",
            'encoder_depth': 5,
            'encoder_weights': "imagenet",
            'decoder_use_norm': "batchnorm",
            'decoder_channels': (256, 128, 128, 64, 64),
            'decoder_interpolation': "nearest",
            'in_channels': 3,
            'classes': 11,
            'activation': None,
            'aux_params': None,
        },
        'input_layer_config': {
            'layer_path': 'encoder._conv_stem',
            'replace_type': 'channels+stride',
            'weight_update_type': 'repeate',
            'params':{
                'stride': (1, 1),
                'padding': (1, 1),
                }
        }
    },
    'multispecter_bands_indices': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12],
    'input_image_size': 96,
    'loss': {
        'type': 'crossentropy',
        #'params': {'weight': 'classes'}
        'params': {
            'weight': None,
            'ignore_index': -100,
            'reduction': "mean",
            'label_smoothing': 0.15,
            },
    },
    'optimizer': {
        'type': 'adam',
        'args': {}
    },
    'lr_scheduler': {
        'type': 'cosine_warm_restarts',
        'args': {
            'T_0': 25,
            'T_mult': 1,
            'eta_min': 0,
            'last_epoch': -1
        },
        'params':{
            'interval': 'epoch',
            'frequency': 1,
            'monitor': 'val_loss',
            'strict': True,
            'name': None,
        },
    },
    'device': 'cuda:0',
    'batch_size': 4,
    'epoch_num':4,
    'train_augmentations': {
        'affine':{
            'degrees': [0, 45],
            'translate': [0, 0.3],
            'scale': [0.7, 1.5],
            'shear': [0, 0.2],
            'fill': [0],
        },
        'horizontal_flip': {
            'p': 0.5,
        },
        'vertical_flip':{
            'p': 0.5,
        },
    },
    #'path_to_dataset_root': r'C:\Users\admin\python_programming\DATA\DATA_FOR_TRAINIG_96',
    'path_to_dataset_root': r'C:\Users\mokhail\develop\DATA\DATA_FOR_TRAINIG_96'    
}

win_att_cat_agg_2_conv_unet_config_dict = {
    'name_postfix': 'win_cat_agg',
    'segmentation_nn': {
        'nn_architecture': 'att_unet',
        'params': {
            'decoder_layers_configs': [
                {
                    'interpolation_mode': "nearest",
                    'aggregation_layer':{
                        'layer': 'concat',
                        'params': {},
                    },
                    'attention1': {
                        'layer':'win_msa',
                        'params':{
                            'rows_in_win':6,
                            'cols_in_win':6,
                            'layer_num':1,
                            'num_heads':8,
                            'transformer_mlp_dim':8,
                            'attention_dropout':0.2,
                            'dropout':0.2,
                            'positional_encoding_block':'embedding_layer',
                            'positional_encoding_block_params':{},
                            'transformer_block':VitEncoderBlock,
                        },
                    },
                    'conv': [
                        {
                            'kernel_size': 3,
                            'stride':  1,
                            'padding': 1,
                            'groups': 1,
                            'norm_layer': nn.BatchNorm2d,
                            'activation_layer': nn.ReLU,
                            'dilation': 1,
                            'inplace': True,
                            'bias': True,
                        },
                        {
                            'kernel_size': 3,
                            'stride':  1,
                            'padding': 1,
                            'groups': 1,
                            'norm_layer': nn.BatchNorm2d,
                            'activation_layer': nn.ReLU,
                            'dilation': 1,
                            'inplace': True,
                            'bias': True,
                        },
                    ],
                    'attention2': {
                        'layer':'win_msa',
                        'params':{
                            'rows_in_win':6,
                            'cols_in_win':6,
                            'layer_num':1,
                            'num_heads':8,
                            'transformer_mlp_dim':8,
                            'attention_dropout':0.2,
                            'dropout':0.2,
                            'positional_encoding_block':'embedding_layer',
                            'positional_encoding_block_params':{},
                            'transformer_block':VitEncoderBlock,
                        },
                    },
                },
                {
                    'interpolation_mode': "nearest",
                    'aggregation_layer':{
                        'layer': 'concat',
                        'params': {},
                    },
                    'attention1': {
                        'layer':'win_msa',
                        'params':{
                            'rows_in_win':12,
                            'cols_in_win':12,
                            'layer_num':1,
                            'num_heads':8,
                            'transformer_mlp_dim':8,
                            'attention_dropout':0.2,
                            'dropout':0.2,
                            'positional_encoding_block':'embedding_layer',
                            'positional_encoding_block_params':{},
                            'transformer_block':VitEncoderBlock,
                        },
                    },
                    'conv': [
                        {
                            'kernel_size': 3,
                            'stride':  1,
                            'padding': 1,
                            'groups': 1,
                            'norm_layer': nn.BatchNorm2d,
                            'activation_layer': nn.ReLU,
                            'dilation': 1,
                            'inplace': True,
                            'bias': True,
                        },
                        {
                            'kernel_size': 3,
                            'stride':  1,
                            'padding': 1,
                            'groups': 1,
                            'norm_layer': nn.BatchNorm2d,
                            'activation_layer': nn.ReLU,
                            'dilation': 1,
                            'inplace': True,
                            'bias': True,
                        },
                    ],
                    'attention2': {
                        'layer':'win_msa',
                        'params':{
                            'rows_in_win':12,
                            'cols_in_win':12,
                            'layer_num':1,
                            'num_heads':8,
                            'transformer_mlp_dim':8,
                            'attention_dropout':0.2,
                            'dropout':0.2,
                            'positional_encoding_block':'embedding_layer',
                            'positional_encoding_block_params':{},
                            'transformer_block':VitEncoderBlock,
                        },
                    },
                },
                {
                    'interpolation_mode': "nearest",
                    'aggregation_layer':{
                        'layer': 'concat',
                        'params': {},
                    },
                    'attention1': {
                        'layer':'win_msa',
                        'params':{
                            'rows_in_win':24,
                            'cols_in_win':24,
                            'layer_num':1,
                            'num_heads':8,
                            'transformer_mlp_dim':8,
                            'attention_dropout':0.2,
                            'dropout':0.2,
                            'positional_encoding_block':'embedding_layer',
                            'positional_encoding_block_params':{},
                            'transformer_block':VitEncoderBlock,
                        },
                    },
                    'conv': [
                        {
                            'kernel_size': 3,
                            'stride':  1,
                            'padding': 1,
                            'groups': 1,
                            'norm_layer': nn.BatchNorm2d,
                            'activation_layer': nn.ReLU,
                            'dilation': 1,
                            'inplace': True,
                            'bias': True,
                        },
                        {
                            'kernel_size': 3,
                            'stride':  1,
                            'padding': 1,
                            'groups': 1,
                            'norm_layer': nn.BatchNorm2d,
                            'activation_layer': nn.ReLU,
                            'dilation': 1,
                            'inplace': True,
                            'bias': True,
                        },
                    ],
                    'attention2': {
                        'layer':'win_msa',
                        'params':{
                            'rows_in_win':24,
                            'cols_in_win':24,
                            'layer_num':1,
                            'num_heads':8,
                            'transformer_mlp_dim':8,
                            'attention_dropout':0.2,
                            'dropout':0.2,
                            'positional_encoding_block':'embedding_layer',
                            'positional_encoding_block_params':{},
                            'transformer_block':VitEncoderBlock,
                        },
                    },
                },
                {
                    'interpolation_mode': "nearest",
                    'aggregation_layer':{
                        'layer': 'concat',
                        'params': {},
                    },
                    'attention1': {
                        'layer':'win_msa',
                        'params':{
                            'rows_in_win':24,
                            'cols_in_win':24,
                            'layer_num':1,
                            'num_heads':8,
                            'transformer_mlp_dim':8,
                            'attention_dropout':0.2,
                            'dropout':0.2,
                            'positional_encoding_block':'embedding_layer',
                            'positional_encoding_block_params':{},
                            'transformer_block':VitEncoderBlock,
                        },
                    },
                    'conv': [
                        {
                            'kernel_size': 3,
                            'stride':  1,
                            'padding': 1,
                            'groups': 1,
                            'norm_layer': nn.BatchNorm2d,
                            'activation_layer': nn.ReLU,
                            'dilation': 1,
                            'inplace': True,
                            'bias': True,
                        },
                        {
                            'kernel_size': 3,
                            'stride':  1,
                            'padding': 1,
                            'groups': 1,
                            'norm_layer': nn.BatchNorm2d,
                            'activation_layer': nn.ReLU,
                            'dilation': 1,
                            'inplace': True,
                            'bias': True,
                        },
                    ],
                    'attention2': {
                        'layer':'win_msa',
                        'params':{
                            'rows_in_win':24,
                            'cols_in_win':24,
                            'layer_num':1,
                            'num_heads':8,
                            'transformer_mlp_dim':8,
                            'attention_dropout':0.2,
                            'dropout':0.2,
                            'positional_encoding_block':'embedding_layer',
                            'positional_encoding_block_params':{},
                            'transformer_block':VitEncoderBlock,
                        },
                    },
                },
                {
                    'interpolation_mode': "nearest",
                    'aggregation_layer':{
                        'layer': 'concat',
                        'params': {},
                    },
                    'attention1': {
                        'layer':'win_msa',
                        'params':{
                            'rows_in_win':24,
                            'cols_in_win':24,
                            'layer_num':1,
                            'num_heads':8,
                            'transformer_mlp_dim':8,
                            'attention_dropout':0.2,
                            'dropout':0.2,
                            'positional_encoding_block':'embedding_layer',
                            'positional_encoding_block_params':{},
                            'transformer_block':VitEncoderBlock,
                        },
                    },
                    'conv': [
                        {
                            'kernel_size': 3,
                            'stride':  1,
                            'padding': 1,
                            'groups': 1,
                            'norm_layer': nn.BatchNorm2d,
                            'activation_layer': nn.ReLU,
                            'dilation': 1,
                            'inplace': True,
                            'bias': True,
                        },
                        {
                            'kernel_size': 3,
                            'stride':  1,
                            'padding': 1,
                            'groups': 1,
                            'norm_layer': nn.BatchNorm2d,
                            'activation_layer': nn.ReLU,
                            'dilation': 1,
                            'inplace': True,
                            'bias': True,
                        },
                    ],
                    'attention2': {
                        'layer':'win_msa',
                        'params':{
                            'rows_in_win':24,
                            'cols_in_win':24,
                            'layer_num':1,
                            'num_heads':8,
                            'transformer_mlp_dim':8,
                            'attention_dropout':0.2,
                            'dropout':0.2,
                            'positional_encoding_block':'embedding_layer',
                            'positional_encoding_block_params':{},
                            'transformer_block':VitEncoderBlock,
                        },
                    },
                },
                
            ],
            'encoder_name': "efficientnet-b0",
            'encoder_depth': 5,
            'encoder_weights': "imagenet",
            'decoder_use_norm': "batchnorm",
            'decoder_channels': (256, 128, 128, 64, 64),
            'decoder_interpolation': "nearest",
            'in_channels': 3,
            'classes': 11,
            'activation': None,
            'aux_params': None,
        },
        'input_layer_config': {
            'layer_path': 'encoder._conv_stem',
            'replace_type': 'channels+stride',
            'weight_update_type': 'repeate',
            'params':{
                'stride': (1, 1),
                'padding': (1, 1),
                }
        }
    },
    'multispecter_bands_indices': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12],
    'input_image_size': 96,
    'loss': {
        'type': 'crossentropy',
        #'params': {'weight': 'classes'}
        'params': {
            'weight': None,
            'ignore_index': -100,
            'reduction': "mean",
            'label_smoothing': 0.15,
            },
    },
    'optimizer': {
        'type': 'adam',
        'args': {}
    },
    'lr_scheduler': {
        'type': 'cosine_warm_restarts',
        'args': {
            'T_0': 25,
            'T_mult': 1,
            'eta_min': 0,
            'last_epoch': -1
        },
        'params':{
            'interval': 'epoch',
            'frequency': 1,
            'monitor': 'val_loss',
            'strict': True,
            'name': None,
        },
    },
    'device': 'cuda:0',
    'batch_size': 4,
    'epoch_num':4,
    'train_augmentations': {
        'affine':{
            'degrees': [0, 45],
            'translate': [0, 0.3],
            'scale': [0.7, 1.5],
            'shear': [0, 0.2],
            'fill': [0],
        },
        'horizontal_flip': {
            'p': 0.5,
        },
        'vertical_flip':{
            'p': 0.5,
        },
    },
    #'path_to_dataset_root': r'C:\Users\admin\python_programming\DATA\DATA_FOR_TRAINIG_96',
    'path_to_dataset_root': r'C:\Users\mokhail\develop\DATA\DATA_FOR_TRAINIG_96'    
}

att_cat_agg_2_conv_unet_config_dict = {
    'name_postfix': 'cat_agg_2conv',
    'segmentation_nn': {
        'nn_architecture': 'att_unet',
        'params': {
            'decoder_layers_configs': [
                {
                    'interpolation_mode': "nearest",
                    'aggregation_layer':{
                        'layer': 'concat',
                        'params': {},
                    },
                    'attention1': {
                        'layer':'conv_msa',
                        'params':{
                            'msa_in_kernel_size':3,
                            'msa_in_padding':1,
                            'msa_in_stride':1,
                            'msa_head_row_dim':6,
                            'msa_head_col_dim':6,
                            'msa_head_ch_dim':4,
                            'dropout':0.2,
                            'out_conv_hidden_channels':512,
                            'out_conv_kernel_size':1,
                            'out_conv_padding':0,
                            'out_conv_stride':1,
                            'out_conv_act':nn.SiLU,
                            'norm_layer':nn.BatchNorm2d,
                        },
                    },
                    'conv': [
                        {
                            'kernel_size': 3,
                            'stride':  1,
                            'padding': 1,
                            'groups': 1,
                            'norm_layer': nn.BatchNorm2d,
                            'activation_layer': nn.ReLU,
                            'dilation': 1,
                            'inplace': True,
                            'bias': True,
                        },
                        {
                            'kernel_size': 3,
                            'stride':  1,
                            'padding': 1,
                            'groups': 1,
                            'norm_layer': nn.BatchNorm2d,
                            'activation_layer': nn.ReLU,
                            'dilation': 1,
                            'inplace': True,
                            'bias': True,
                        },
                    ],
                    'attention2': {
                        'layer':'conv_msa',
                        'params':{
                            'msa_in_kernel_size':3,
                            'msa_in_padding':1,
                            'msa_in_stride':1,
                            'msa_head_row_dim':6,
                            'msa_head_col_dim':6,
                            'msa_head_ch_dim':4,
                            'dropout':0.2,
                            'out_conv_hidden_channels':512,
                            'out_conv_kernel_size':1,
                            'out_conv_padding':0,
                            'out_conv_stride':1,
                            'out_conv_act':nn.SiLU,
                            'norm_layer':nn.BatchNorm2d,
                        },
                    }
                },
                {
                    'interpolation_mode': "nearest",
                    'aggregation_layer':{
                        'layer': 'conv_cross_att',
                        'params': {
                            'in_kernel_size':3,
                            'in_padding':1,
                            'in_stride':1,
                            "head_row_dim":6,
                            'head_col_dim':6,
                            'head_ch_dim':4,
                            'dropout':0.2,
                            'norm_layer':nn.BatchNorm2d,
                        },
                    },
                    'attention1': {
                        'layer':'conv_msa',
                        'params':{
                            'msa_in_kernel_size':3,
                            'msa_in_padding':1,
                            'msa_in_stride':1,
                            'msa_head_row_dim':6,
                            'msa_head_col_dim':6,
                            'msa_head_ch_dim':4,
                            'dropout':0.2,
                            'out_conv_hidden_channels':512,
                            'out_conv_kernel_size':1,
                            'out_conv_padding':0,
                            'out_conv_stride':1,
                            'out_conv_act':nn.SiLU,
                            'norm_layer':nn.BatchNorm2d,
                        },
                    },
                    'conv': [
                        {
                            'kernel_size': 3,
                            'stride':  1,
                            'padding': 1,
                            'groups': 1,
                            'norm_layer': nn.BatchNorm2d,
                            'activation_layer': nn.ReLU,
                            'dilation': 1,
                            'inplace': True,
                            'bias': True,
                        },
                        {
                            'kernel_size': 3,
                            'stride':  1,
                            'padding': 1,
                            'groups': 1,
                            'norm_layer': nn.BatchNorm2d,
                            'activation_layer': nn.ReLU,
                            'dilation': 1,
                            'inplace': True,
                            'bias': True,
                        },
                    ],
                    'attention2': {
                        'layer':'conv_msa',
                        'params':{
                            'msa_in_kernel_size':3,
                            'msa_in_padding':1,
                            'msa_in_stride':1,
                            'msa_head_row_dim':6,
                            'msa_head_col_dim':6,
                            'msa_head_ch_dim':4,
                            'dropout':0.2,
                            'out_conv_hidden_channels':512,
                            'out_conv_kernel_size':1,
                            'out_conv_padding':0,
                            'out_conv_stride':1,
                            'out_conv_act':nn.SiLU,
                            'norm_layer':nn.BatchNorm2d,
                        },
                    }
                },
                {
                    'interpolation_mode': "nearest",
                    'aggregation_layer':{
                        'layer': 'conv_cross_att',
                        'params': {
                            'in_kernel_size':3,
                            'in_padding':1,
                            'in_stride':1,
                            "head_row_dim":6,
                            'head_col_dim':6,
                            'head_ch_dim':4,
                            'dropout':0.2,
                            'norm_layer':nn.BatchNorm2d,
                        },
                    },
                    'attention1': {
                        'layer':'conv_msa',
                        'params':{
                            'msa_in_kernel_size':3,
                            'msa_in_padding':1,
                            'msa_in_stride':1,
                            'msa_head_row_dim':6,
                            'msa_head_col_dim':6,
                            'msa_head_ch_dim':4,
                            'dropout':0.2,
                            'out_conv_hidden_channels':512,
                            'out_conv_kernel_size':1,
                            'out_conv_padding':0,
                            'out_conv_stride':1,
                            'out_conv_act':nn.SiLU,
                            'norm_layer':nn.BatchNorm2d,
                        },
                    },
                    'conv': [
                        {
                            'kernel_size': 3,
                            'stride':  1,
                            'padding': 1,
                            'groups': 1,
                            'norm_layer': nn.BatchNorm2d,
                            'activation_layer': nn.ReLU,
                            'dilation': 1,
                            'inplace': True,
                            'bias': True,
                        },
                        {
                            'kernel_size': 3,
                            'stride':  1,
                            'padding': 1,
                            'groups': 1,
                            'norm_layer': nn.BatchNorm2d,
                            'activation_layer': nn.ReLU,
                            'dilation': 1,
                            'inplace': True,
                            'bias': True,
                        },
                    ],
                    'attention2': {
                        'layer':'conv_msa',
                        'params':{
                            'msa_in_kernel_size':3,
                            'msa_in_padding':1,
                            'msa_in_stride':1,
                            'msa_head_row_dim':6,
                            'msa_head_col_dim':6,
                            'msa_head_ch_dim':4,
                            'dropout':0.2,
                            'out_conv_hidden_channels':512,
                            'out_conv_kernel_size':1,
                            'out_conv_padding':0,
                            'out_conv_stride':1,
                            'out_conv_act':nn.SiLU,
                            'norm_layer':nn.BatchNorm2d,
                        },
                    }
                },
                {
                    'interpolation_mode': "nearest",
                    'aggregation_layer':{
                        'layer': 'conv_cross_att',
                        'params': {
                            'in_kernel_size':3,
                            'in_padding':1,
                            'in_stride':1,
                            "head_row_dim":6,
                            'head_col_dim':6,
                            'head_ch_dim':4,
                            'dropout':0.2,
                            'norm_layer':nn.BatchNorm2d,
                        },
                    },
                    'attention1': {
                        'layer':'conv_msa',
                        'params':{
                            'msa_in_kernel_size':3,
                            'msa_in_padding':1,
                            'msa_in_stride':1,
                            'msa_head_row_dim':6,
                            'msa_head_col_dim':6,
                            'msa_head_ch_dim':4,
                            'dropout':0.2,
                            'out_conv_hidden_channels':512,
                            'out_conv_kernel_size':1,
                            'out_conv_padding':0,
                            'out_conv_stride':1,
                            'out_conv_act':nn.SiLU,
                            'norm_layer':nn.BatchNorm2d,
                        },
                    },
                    'conv': [
                        {
                            'kernel_size': 3,
                            'stride':  1,
                            'padding': 1,
                            'groups': 1,
                            'norm_layer': nn.BatchNorm2d,
                            'activation_layer': nn.ReLU,
                            'dilation': 1,
                            'inplace': True,
                            'bias': True,
                        },
                        {
                            'kernel_size': 3,
                            'stride':  1,
                            'padding': 1,
                            'groups': 1,
                            'norm_layer': nn.BatchNorm2d,
                            'activation_layer': nn.ReLU,
                            'dilation': 1,
                            'inplace': True,
                            'bias': True,
                        },
                    ],
                    'attention2': {
                        'layer':'conv_msa',
                        'params':{
                            'msa_in_kernel_size':3,
                            'msa_in_padding':1,
                            'msa_in_stride':1,
                            'msa_head_row_dim':6,
                            'msa_head_col_dim':6,
                            'msa_head_ch_dim':4,
                            'dropout':0.2,
                            'out_conv_hidden_channels':512,
                            'out_conv_kernel_size':1,
                            'out_conv_padding':0,
                            'out_conv_stride':1,
                            'out_conv_act':nn.SiLU,
                            'norm_layer':nn.BatchNorm2d,
                        },
                    }
                },
                {
                    'interpolation_mode': "nearest",
                    'aggregation_layer':{
                        'layer': 'conv_cross_att',
                        'params': {
                            'in_kernel_size':3,
                            'in_padding':1,
                            'in_stride':1,
                            "head_row_dim":6,
                            'head_col_dim':6,
                            'head_ch_dim':4,
                            'dropout':0.2,
                            'norm_layer':nn.BatchNorm2d,
                        },
                    },
                    'attention1': {
                        'layer':'conv_msa',
                        'params':{
                            'msa_in_kernel_size':3,
                            'msa_in_padding':1,
                            'msa_in_stride':1,
                            'msa_head_row_dim':6,
                            'msa_head_col_dim':6,
                            'msa_head_ch_dim':4,
                            'dropout':0.2,
                            'out_conv_hidden_channels':512,
                            'out_conv_kernel_size':1,
                            'out_conv_padding':0,
                            'out_conv_stride':1,
                            'out_conv_act':nn.SiLU,
                            'norm_layer':nn.BatchNorm2d,
                        },
                    },
                    'conv': [
                        {
                            'kernel_size': 3,
                            'stride':  1,
                            'padding': 1,
                            'groups': 1,
                            'norm_layer': nn.BatchNorm2d,
                            'activation_layer': nn.ReLU,
                            'dilation': 1,
                            'inplace': True,
                            'bias': True,
                        },
                        {
                            'kernel_size': 3,
                            'stride':  1,
                            'padding': 1,
                            'groups': 1,
                            'norm_layer': nn.BatchNorm2d,
                            'activation_layer': nn.ReLU,
                            'dilation': 1,
                            'inplace': True,
                            'bias': True,
                        },
                    ],
                    'attention2': {
                        'layer':'conv_msa',
                        'params':{
                            'msa_in_kernel_size':3,
                            'msa_in_padding':1,
                            'msa_in_stride':1,
                            'msa_head_row_dim':6,
                            'msa_head_col_dim':6,
                            'msa_head_ch_dim':4,
                            'dropout':0.2,
                            'out_conv_hidden_channels':512,
                            'out_conv_kernel_size':1,
                            'out_conv_padding':0,
                            'out_conv_stride':1,
                            'out_conv_act':nn.SiLU,
                            'norm_layer':nn.BatchNorm2d,
                        },
                    }
                },
            ],
            'encoder_name': "efficientnet-b0",
            'encoder_depth': 5,
            'encoder_weights': "imagenet",
            'decoder_use_norm': "batchnorm",
            'decoder_channels': (256, 128, 128, 64, 64),
            'decoder_interpolation': "nearest",
            'in_channels': 3,
            'classes': 11,
            'activation': None,
            'aux_params': None,
        },
        'input_layer_config': {
            'layer_path': 'encoder._conv_stem',
            'replace_type': 'channels+stride',
            'weight_update_type': 'repeate',
            'params':{
                'stride': (1, 1),
                'padding': (1, 1),
                }
        }
    },
    'multispecter_bands_indices': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12],
    'input_image_size': 96,
    'loss': {
        'type': 'crossentropy',
        #'params': {'weight': 'classes'}
        'params': {
            'weight': None,
            'ignore_index': -100,
            'reduction': "mean",
            'label_smoothing': 0.15,
            },
    },
    'optimizer': {
        'type': 'adam',
        'args': {}
    },
    'lr_scheduler': {
        'type': 'cosine_warm_restarts',
        'args': {
            'T_0': 25,
            'T_mult': 1,
            'eta_min': 0,
            'last_epoch': -1
        },
        'params':{
            'interval': 'epoch',
            'frequency': 1,
            'monitor': 'val_loss',
            'strict': True,
            'name': None,
        },
    },
    'device': 'cuda:0',
    'batch_size': 4,
    'epoch_num':4,
    'train_augmentations': {
        'affine':{
            'degrees': [0, 45],
            'translate': [0, 0.3],
            'scale': [0.7, 1.5],
            'shear': [0, 0.2],
            'fill': [0],
        },
        'horizontal_flip': {
            'p': 0.5,
        },
        'vertical_flip':{
            'p': 0.5,
        },
    },
    #'path_to_dataset_root': r'C:\Users\admin\python_programming\DATA\DATA_FOR_TRAINIG_96',
    'path_to_dataset_root': r'C:\Users\mokhail\develop\DATA\DATA_FOR_TRAINIG_96'    
}

att_cross_agg_2_conv_unet_config_dict = {
    'name_postfix': 'cross_agg_2conv',
    'segmentation_nn': {
        'nn_architecture': 'att_unet',
        'params': {
            'decoder_layers_configs': [
                {
                    'interpolation_mode': "nearest",
                    'aggregation_layer':{
                        'layer': 'conv_cross_att',
                        'params': {
                            'in_kernel_size':3,
                            'in_padding':1,
                            'in_stride':1,
                            "head_row_dim":6,
                            'head_col_dim':6,
                            'head_ch_dim':4,
                            'dropout':0.2,
                            'norm_layer':nn.BatchNorm2d,
                        },
                    },
                    'attention1': {
                        'layer':'conv_msa',
                        'params':{
                            'msa_in_kernel_size':3,
                            'msa_in_padding':1,
                            'msa_in_stride':1,
                            'msa_head_row_dim':6,
                            'msa_head_col_dim':6,
                            'msa_head_ch_dim':4,
                            'dropout':0.2,
                            'out_conv_hidden_channels':512,
                            'out_conv_kernel_size':1,
                            'out_conv_padding':0,
                            'out_conv_stride':1,
                            'out_conv_act':nn.SiLU,
                            'norm_layer':nn.BatchNorm2d,
                        },
                    },
                    'conv': [
                        {
                            'kernel_size': 3,
                            'stride':  1,
                            'padding': 1,
                            'groups': 1,
                            'norm_layer': nn.BatchNorm2d,
                            'activation_layer': nn.ReLU,
                            'dilation': 1,
                            'inplace': True,
                            'bias': True,
                        },
                        {
                            'kernel_size': 3,
                            'stride':  1,
                            'padding': 1,
                            'groups': 1,
                            'norm_layer': nn.BatchNorm2d,
                            'activation_layer': nn.ReLU,
                            'dilation': 1,
                            'inplace': True,
                            'bias': True,
                        },
                    ],
                    'attention2': {
                        'layer':'conv_msa',
                        'params':{
                            'msa_in_kernel_size':3,
                            'msa_in_padding':1,
                            'msa_in_stride':1,
                            'msa_head_row_dim':6,
                            'msa_head_col_dim':6,
                            'msa_head_ch_dim':4,
                            'dropout':0.2,
                            'out_conv_hidden_channels':512,
                            'out_conv_kernel_size':1,
                            'out_conv_padding':0,
                            'out_conv_stride':1,
                            'out_conv_act':nn.SiLU,
                            'norm_layer':nn.BatchNorm2d,
                        },
                    }
                },
                {
                    'interpolation_mode': "nearest",
                    'aggregation_layer':{
                        'layer': 'conv_cross_att',
                        'params': {
                            'in_kernel_size':3,
                            'in_padding':1,
                            'in_stride':1,
                            "head_row_dim":6,
                            'head_col_dim':6,
                            'head_ch_dim':4,
                            'dropout':0.2,
                            'norm_layer':nn.BatchNorm2d,
                        },
                    },
                    'attention1': {
                        'layer':'conv_msa',
                        'params':{
                            'msa_in_kernel_size':3,
                            'msa_in_padding':1,
                            'msa_in_stride':1,
                            'msa_head_row_dim':6,
                            'msa_head_col_dim':6,
                            'msa_head_ch_dim':4,
                            'dropout':0.2,
                            'out_conv_hidden_channels':512,
                            'out_conv_kernel_size':1,
                            'out_conv_padding':0,
                            'out_conv_stride':1,
                            'out_conv_act':nn.SiLU,
                            'norm_layer':nn.BatchNorm2d,
                        },
                    },
                    'conv': [
                        {
                            'kernel_size': 3,
                            'stride':  1,
                            'padding': 1,
                            'groups': 1,
                            'norm_layer': nn.BatchNorm2d,
                            'activation_layer': nn.ReLU,
                            'dilation': 1,
                            'inplace': True,
                            'bias': True,
                        },
                        {
                            'kernel_size': 3,
                            'stride':  1,
                            'padding': 1,
                            'groups': 1,
                            'norm_layer': nn.BatchNorm2d,
                            'activation_layer': nn.ReLU,
                            'dilation': 1,
                            'inplace': True,
                            'bias': True,
                        },
                    ],
                    'attention2': {
                        'layer':'conv_msa',
                        'params':{
                            'msa_in_kernel_size':3,
                            'msa_in_padding':1,
                            'msa_in_stride':1,
                            'msa_head_row_dim':6,
                            'msa_head_col_dim':6,
                            'msa_head_ch_dim':4,
                            'dropout':0.2,
                            'out_conv_hidden_channels':512,
                            'out_conv_kernel_size':1,
                            'out_conv_padding':0,
                            'out_conv_stride':1,
                            'out_conv_act':nn.SiLU,
                            'norm_layer':nn.BatchNorm2d,
                        },
                    }
                },
                {
                    'interpolation_mode': "nearest",
                    'aggregation_layer':{
                        'layer': 'conv_cross_att',
                        'params': {
                            'in_kernel_size':3,
                            'in_padding':1,
                            'in_stride':1,
                            "head_row_dim":6,
                            'head_col_dim':6,
                            'head_ch_dim':4,
                            'dropout':0.2,
                            'norm_layer':nn.BatchNorm2d,
                        },
                    },
                    'attention1': {
                        'layer':'conv_msa',
                        'params':{
                            'msa_in_kernel_size':3,
                            'msa_in_padding':1,
                            'msa_in_stride':1,
                            'msa_head_row_dim':6,
                            'msa_head_col_dim':6,
                            'msa_head_ch_dim':4,
                            'dropout':0.2,
                            'out_conv_hidden_channels':512,
                            'out_conv_kernel_size':1,
                            'out_conv_padding':0,
                            'out_conv_stride':1,
                            'out_conv_act':nn.SiLU,
                            'norm_layer':nn.BatchNorm2d,
                        },
                    },
                    'conv': [
                        {
                            'kernel_size': 3,
                            'stride':  1,
                            'padding': 1,
                            'groups': 1,
                            'norm_layer': nn.BatchNorm2d,
                            'activation_layer': nn.ReLU,
                            'dilation': 1,
                            'inplace': True,
                            'bias': True,
                        },
                        {
                            'kernel_size': 3,
                            'stride':  1,
                            'padding': 1,
                            'groups': 1,
                            'norm_layer': nn.BatchNorm2d,
                            'activation_layer': nn.ReLU,
                            'dilation': 1,
                            'inplace': True,
                            'bias': True,
                        },
                    ],
                    'attention2': {
                        'layer':'conv_msa',
                        'params':{
                            'msa_in_kernel_size':3,
                            'msa_in_padding':1,
                            'msa_in_stride':1,
                            'msa_head_row_dim':6,
                            'msa_head_col_dim':6,
                            'msa_head_ch_dim':4,
                            'dropout':0.2,
                            'out_conv_hidden_channels':512,
                            'out_conv_kernel_size':1,
                            'out_conv_padding':0,
                            'out_conv_stride':1,
                            'out_conv_act':nn.SiLU,
                            'norm_layer':nn.BatchNorm2d,
                        },
                    }
                },
                {
                    'interpolation_mode': "nearest",
                    'aggregation_layer':{
                        'layer': 'conv_cross_att',
                        'params': {
                            'in_kernel_size':3,
                            'in_padding':1,
                            'in_stride':1,
                            "head_row_dim":6,
                            'head_col_dim':6,
                            'head_ch_dim':4,
                            'dropout':0.2,
                            'norm_layer':nn.BatchNorm2d,
                        },
                    },
                    'attention1': {
                        'layer':'conv_msa',
                        'params':{
                            'msa_in_kernel_size':3,
                            'msa_in_padding':1,
                            'msa_in_stride':1,
                            'msa_head_row_dim':6,
                            'msa_head_col_dim':6,
                            'msa_head_ch_dim':4,
                            'dropout':0.2,
                            'out_conv_hidden_channels':512,
                            'out_conv_kernel_size':1,
                            'out_conv_padding':0,
                            'out_conv_stride':1,
                            'out_conv_act':nn.SiLU,
                            'norm_layer':nn.BatchNorm2d,
                        },
                    },
                    'conv': [
                        {
                            'kernel_size': 3,
                            'stride':  1,
                            'padding': 1,
                            'groups': 1,
                            'norm_layer': nn.BatchNorm2d,
                            'activation_layer': nn.ReLU,
                            'dilation': 1,
                            'inplace': True,
                            'bias': True,
                        },
                        {
                            'kernel_size': 3,
                            'stride':  1,
                            'padding': 1,
                            'groups': 1,
                            'norm_layer': nn.BatchNorm2d,
                            'activation_layer': nn.ReLU,
                            'dilation': 1,
                            'inplace': True,
                            'bias': True,
                        },
                    ],
                    'attention2': {
                        'layer':'conv_msa',
                        'params':{
                            'msa_in_kernel_size':3,
                            'msa_in_padding':1,
                            'msa_in_stride':1,
                            'msa_head_row_dim':6,
                            'msa_head_col_dim':6,
                            'msa_head_ch_dim':4,
                            'dropout':0.2,
                            'out_conv_hidden_channels':512,
                            'out_conv_kernel_size':1,
                            'out_conv_padding':0,
                            'out_conv_stride':1,
                            'out_conv_act':nn.SiLU,
                            'norm_layer':nn.BatchNorm2d,
                        },
                    }
                },
                {
                    'interpolation_mode': "nearest",
                    'aggregation_layer':{
                        'layer': 'conv_cross_att',
                        'params': {
                            'in_kernel_size':3,
                            'in_padding':1,
                            'in_stride':1,
                            "head_row_dim":6,
                            'head_col_dim':6,
                            'head_ch_dim':4,
                            'dropout':0.2,
                            'norm_layer':nn.BatchNorm2d,
                        },
                    },
                    'attention1': {
                        'layer':'conv_msa',
                        'params':{
                            'msa_in_kernel_size':3,
                            'msa_in_padding':1,
                            'msa_in_stride':1,
                            'msa_head_row_dim':6,
                            'msa_head_col_dim':6,
                            'msa_head_ch_dim':4,
                            'dropout':0.2,
                            'out_conv_hidden_channels':512,
                            'out_conv_kernel_size':1,
                            'out_conv_padding':0,
                            'out_conv_stride':1,
                            'out_conv_act':nn.SiLU,
                            'norm_layer':nn.BatchNorm2d,
                        },
                    },
                    'conv': [
                        {
                            'kernel_size': 3,
                            'stride':  1,
                            'padding': 1,
                            'groups': 1,
                            'norm_layer': nn.BatchNorm2d,
                            'activation_layer': nn.ReLU,
                            'dilation': 1,
                            'inplace': True,
                            'bias': True,
                        },
                        {
                            'kernel_size': 3,
                            'stride':  1,
                            'padding': 1,
                            'groups': 1,
                            'norm_layer': nn.BatchNorm2d,
                            'activation_layer': nn.ReLU,
                            'dilation': 1,
                            'inplace': True,
                            'bias': True,
                        },
                    ],
                    'attention2': {
                        'layer':'conv_msa',
                        'params':{
                            'msa_in_kernel_size':3,
                            'msa_in_padding':1,
                            'msa_in_stride':1,
                            'msa_head_row_dim':6,
                            'msa_head_col_dim':6,
                            'msa_head_ch_dim':4,
                            'dropout':0.2,
                            'out_conv_hidden_channels':512,
                            'out_conv_kernel_size':1,
                            'out_conv_padding':0,
                            'out_conv_stride':1,
                            'out_conv_act':nn.SiLU,
                            'norm_layer':nn.BatchNorm2d,
                        },
                    }
                },
            ],
            'encoder_name': "efficientnet-b0",
            'encoder_depth': 5,
            'encoder_weights': "imagenet",
            'decoder_use_norm': "batchnorm",
            'decoder_channels': (256, 128, 128, 64, 64),
            'decoder_interpolation': "nearest",
            'in_channels': 3,
            'classes': 11,
            'activation': None,
            'aux_params': None,
        },
        'input_layer_config': {
            'layer_path': 'encoder._conv_stem',
            'replace_type': 'channels+stride',
            'weight_update_type': 'repeate',
            'params':{
                'stride': (1, 1),
                'padding': (1, 1),
                }
        }
    },
    'multispecter_bands_indices': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12],
    'input_image_size': 96,
    'loss': {
        'type': 'crossentropy',
        #'params': {'weight': 'classes'}
        'params': {
            'weight': None,
            'ignore_index': -100,
            'reduction': "mean",
            'label_smoothing': 0.15,
            },
    },
    'optimizer': {
        'type': 'adam',
        'args': {}
    },
    'lr_scheduler': {
        'type': 'cosine_warm_restarts',
        'args': {
            'T_0': 25,
            'T_mult': 1,
            'eta_min': 0,
            'last_epoch': -1
        },
        'params':{
            'interval': 'epoch',
            'frequency': 1,
            'monitor': 'val_loss',
            'strict': True,
            'name': None,
        },
    },
    'device': 'cuda:0',
    'batch_size': 8,
    'epoch_num': 300,
    'train_augmentations': {
        'affine':{
            'degrees': [0, 45],
            'translate': [0, 0.3],
            'scale': [0.7, 1.5],
            'shear': [0, 0.2],
            'fill': [0],
        },
        'horizontal_flip': {
            'p': 0.5,
        },
        'vertical_flip':{
            'p': 0.5,
        },
    },
    #'path_to_dataset_root': r'C:\Users\admin\python_programming\DATA\DATA_FOR_TRAINIG_96',
    'path_to_dataset_root': r'C:\Users\mokhail\develop\DATA\DATA_FOR_TRAINIG_96'    
}

unetpp_config_dict = {
    'name_postfix': '',
    'segmentation_nn': {
        'nn_architecture': 'unet++',
        'params': {
            'encoder_name': "efficientnet-b0",
            'encoder_depth': 5,
            'encoder_weights': "imagenet",
            'decoder_use_norm': "batchnorm",
            'decoder_channels': [256, 128, 128, 128, 128],
            'decoder_attention_type': None,
            'decoder_interpolation': "nearest",
            'in_channels': 3,
            'classes': 11,
            'activation': None,
            'aux_params': None,
        },
        'input_layer_config': {
            'layer_path': 'encoder._conv_stem',
            'replace_type': 'channels+stride',
            'weight_update_type': 'repeate',
            'params':{
                'stride': (1, 1),
                'padding': (1, 1),
                }
        }
    },
    'multispecter_bands_indices': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12],
    'input_image_size': 96,
    'loss': {
        'type': 'crossentropy',
        #'params': {'weight': 'classes'}
        'params': {
            'weight': None,
            'ignore_index': -100,
            'reduction': "mean",
            'label_smoothing': 0.15,
            },
    },
    'optimizer': {
        'type': 'adam',
        'args': {}
    },
    'lr_scheduler': {
        'type': 'cosine_warm_restarts',
        'args': {
            'T_0': 25,
            'T_mult': 1,
            'eta_min': 0,
            'last_epoch': -1
        },
        'params':{
            'interval': 'epoch',
            'frequency': 1,
            'monitor': 'val_loss',
            'strict': True,
            'name': None,
        },
    },
    'device': 'cuda:0',
    'batch_size': 16,
    'epoch_num':300,
    'train_augmentations': {
       'gauss_noise':{
            'mean': 0.0,
            'sigma': 0.0008,
            'clip': False,
        },
        'affine':{
            'degrees': [0, 45],
            'translate': [0, 0.3],
            'scale': [0.7, 1.5],
            'shear': [0, 0.2],
            'fill': [0],
        },
        'perspective':{
            'distortion_scale': 0.2,
            'p': 0.3,
            'fill': [0],
        },
        'horizontal_flip': {
            'p': 0.5,
        },
        'vertical_flip':{
            'p': 0.5,
        },
    },
    'path_to_dataset_root': r'C:\Users\admin\python_programming\DATA\DATA_FOR_TRAINIG_96',
    #'path_to_dataset_root': r'C:\Users\mokhail\develop\DATA\DATA_FOR_TRAINIG_96'
}

scse_unet_config_dict = {
    'name_postfix': 'scse_att',
    'segmentation_nn': {
        'nn_architecture': 'unet',
        'params': {
            'encoder_name': "efficientnet-b0",
            'encoder_depth': 5,
            'encoder_weights': "imagenet",
            'decoder_use_norm': "batchnorm",
            'decoder_channels': (256, 128, 128, 128, 128),
            'decoder_attention_type': 'scse',
            'decoder_interpolation': "nearest",
            'in_channels': 3,
            'classes': 11,
            'activation': None,
            'aux_params': None,
        },
        'input_layer_config': {
            'layer_path': 'encoder._conv_stem',
            'replace_type': 'channels+stride',
            'weight_update_type': 'average_all',
            'params':{
                'stride': (1, 1),
                'padding': (1, 1),
                }
        }
    },
    'multispecter_bands_indices': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12],
    'input_image_size': 96,
    'loss': {
        'type': 'crossentropy',
        #'params': {'weight': 'classes'}
        'params': {
            'weight': None,
            'ignore_index': -100,
            'reduction': "mean",
            'label_smoothing': 0.15,
            },
    },
    'optimizer': {
        'type': 'adam',
        'args': {}
    },
    'lr_scheduler': {
        'type': 'cosine_warm_restarts',
        'args': {
            'T_0': 25,
            'T_mult': 1,
            'eta_min': 0,
            'last_epoch': -1
        },
        'params':{
            'interval': 'epoch',
            'frequency': 1,
            'monitor': 'val_loss',
            'strict': True,
            'name': None,
        },
    },
    'device': 'cuda:0',
    'batch_size': 16,
    'epoch_num':300,
    'train_augmentations': {
       'gauss_noise':{
            'mean': 0.0,
            'sigma': 0.0008,
            'clip': False,
        },
        'affine':{
            'degrees': [0, 45],
            'translate': [0, 0.3],
            'scale': [0.7, 1.5],
            'shear': [0, 0.2],
            'fill': [0],
        },
        'perspective':{
            'distortion_scale': 0.2,
            'p': 0.3,
            'fill': [0],
        },
        'horizontal_flip': {
            'p': 0.5,
        },
        'vertical_flip':{
            'p': 0.5,
        },
    },
    'path_to_dataset_root': r'C:\Users\admin\python_programming\DATA\DATA_FOR_TRAINIG_96',
    #'path_to_dataset_root': r'C:\Users\mokhail\develop\DATA\DATA_FOR_TRAINIG_96'
}

unet_mit_config_dict = {
    'name_postfix': '',
    'segmentation_nn': {
        'nn_architecture': 'unet',
        'params': {
            'encoder_name': "mit_b0",
            'encoder_depth': 5,
            'encoder_weights': "imagenet",
            'decoder_use_norm': "batchnorm",
            'decoder_channels': (256, 128, 128, 128, 128),
            'decoder_attention_type': None,
            'decoder_interpolation': "nearest",
            'in_channels': 3,
            'classes': 11,
            'activation': None,
            'aux_params': None,
        },
        'input_layer_config': {
            'layer_path': 'encoder.patch_embed1.proj',
            'replace_type': 'channels+stride',
            'weight_update_type': 'average_all',
            'params':{
                'stride': (1, 1),
                'padding': (3, 3),
                }
        }
    },
    'multispecter_bands_indices': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12],
    'input_image_size': 96,
    'loss': {
        'type': 'crossentropy',
        #'params': {'weight': 'classes'}
        'params': {
            'weight': None,
            'ignore_index': -100,
            'reduction': "mean",
            'label_smoothing': 0.15,
            },
    },
    'optimizer': {
        'type': 'adam',
        'args': {}
    },
    'lr_scheduler': {
        'type': 'cosine_warm_restarts',
        'args': {
            'T_0': 25,
            'T_mult': 1,
            'eta_min': 0,
            'last_epoch': -1
        },
        'params':{
            'interval': 'epoch',
            'frequency': 1,
            'monitor': 'val_loss',
            'strict': True,
            'name': None,
        },
    },
    'device': 'cuda:0',
    'batch_size': 16,
    'epoch_num':300,
    'train_augmentations': {
       'gauss_noise':{
            'mean': 0.0,
            'sigma': 0.0008,
            'clip': False,
        },
        'affine':{
            'degrees': [0, 45],
            'translate': [0, 0.3],
            'scale': [0.7, 1.5],
            'shear': [0, 0.2],
            'fill': [0],
        },
        'perspective':{
            'distortion_scale': 0.2,
            'p': 0.3,
            'fill': [0],
        },
        'horizontal_flip': {
            'p': 0.5,
        },
        'vertical_flip':{
            'p': 0.5,
        },
    },
    'path_to_dataset_root': r'C:\Users\admin\python_programming\DATA\DATA_FOR_TRAINIG_96',
    #'path_to_dataset_root': r'C:\Users\mokhail\develop\DATA\DATA_FOR_TRAINIG_96'
}

unet_maxvit_config_dict = {
    'name_postfix': '',
    'segmentation_nn': {
        'nn_architecture': 'unet',
        'params': {
            'encoder_name': 'tu-maxvit_tiny_rw_224',
            'encoder_depth': 5,
            'encoder_weights': "imagenet",
            'decoder_use_norm': "batchnorm",
            'decoder_channels': (256, 128, 128, 128, 128),
            'decoder_attention_type': None,
            'decoder_interpolation': "nearest",
            'in_channels': 3,
            'classes': 11,
            'activation': None,
            'aux_params': None,
            'img_size':96,
        },
        'input_layer_config': {
            'layer_path': 'encoder.model.stem.conv1',
            'replace_type': 'channels+stride',
            'weight_update_type': 'average_all',
            'params':{
                'stride': (1, 1),
                'padding': (1, 1),
                }
        }
    },
    'multispecter_bands_indices': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12],
    'input_image_size': 96,
    'loss': {
        'type': 'crossentropy',
        #'params': {'weight': 'classes'}
        'params': {
            'weight': None,
            'ignore_index': -100,
            'reduction': "mean",
            'label_smoothing': 0.15,
            },
    },
    'optimizer': {
        'type': 'adam',
        'args': {}
    },
    'lr_scheduler': {
        'type': 'cosine_warm_restarts',
        'args': {
            'T_0': 25,
            'T_mult': 1,
            'eta_min': 0,
            'last_epoch': -1
        },
        'params':{
            'interval': 'epoch',
            'frequency': 1,
            'monitor': 'val_loss',
            'strict': True,
            'name': None,
        },
    },
    'device': 'cuda:0',
    'batch_size': 16,
    'epoch_num':300,
    'train_augmentations': {
       'gauss_noise':{
            'mean': 0.0,
            'sigma': 0.0008,
            'clip': False,
        },
        'affine':{
            'degrees': [0, 45],
            'translate': [0, 0.3],
            'scale': [0.7, 1.5],
            'shear': [0, 0.2],
            'fill': [0],
        },
        'perspective':{
            'distortion_scale': 0.2,
            'p': 0.3,
            'fill': [0],
        },
        'horizontal_flip': {
            'p': 0.5,
        },
        'vertical_flip':{
            'p': 0.5,
        },
    },
    #'path_to_dataset_root': r'C:\Users\admin\python_programming\DATA\DATA_FOR_TRAINIG_96',
    'path_to_dataset_root': r'/home/mikhail_u/develop/DATA/DATA_FOR_TRAINIG_96'
}

net_hgnetv2_config_dict = {
    'name_postfix': '',
    'segmentation_nn': {
        'nn_architecture': 'unet',
        'params': {
            'encoder_name': 'tu-hgnetv2_b1',
            'encoder_depth': 5,
            'encoder_weights': "imagenet",
            'decoder_use_norm': "batchnorm",
            'decoder_channels': (256, 128, 128, 128, 128),
            'decoder_attention_type': None,
            'decoder_interpolation': "nearest",
            'in_channels': 3,
            'classes': 11,
            'activation': None,
            'aux_params': None,
            #'img_size':96,
        },
        'input_layer_config': {
            'layer_path': 'encoder.model.stem.stem1.conv',
            'replace_type': 'channels+stride',
            'weight_update_type': 'average_all',
            'params':{
                'stride': (1, 1),
                'padding': (1, 1),
                }
        }
    },
    'multispecter_bands_indices': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12],
    'input_image_size': 96,
    'loss': {
        'type': 'crossentropy',
        #'params': {'weight': 'classes'}
        'params': {
            'weight': None,
            'ignore_index': -100,
            'reduction': "mean",
            'label_smoothing': 0.15,
            },
    },
    'optimizer': {
        'type': 'adam',
        'args': {}
    },
    'lr_scheduler': {
        'type': 'cosine_warm_restarts',
        'args': {
            'T_0': 25,
            'T_mult': 1,
            'eta_min': 0,
            'last_epoch': -1
        },
        'params':{
            'interval': 'epoch',
            'frequency': 1,
            'monitor': 'val_loss',
            'strict': True,
            'name': None,
        },
    },
    'device': 'cuda:0',
    'batch_size': 16,
    'epoch_num':300,
    'train_augmentations': {
       'gauss_noise':{
            'mean': 0.0,
            'sigma': 0.0008,
            'clip': False,
        },
        'affine':{
            'degrees': [0, 45],
            'translate': [0, 0.3],
            'scale': [0.7, 1.5],
            'shear': [0, 0.2],
            'fill': [0],
        },
        'perspective':{
            'distortion_scale': 0.2,
            'p': 0.3,
            'fill': [0],
        },
        'horizontal_flip': {
            'p': 0.5,
        },
        'vertical_flip':{
            'p': 0.5,
        },
    },
    #'path_to_dataset_root': r'C:\Users\admin\python_programming\DATA\DATA_FOR_TRAINIG_96',
    'path_to_dataset_root': r'/home/mikhail_u/develop/DATA/DATA_FOR_TRAINIG_96'
}

net_mambaout_config_dict = {
    'name_postfix': '',
    'segmentation_nn': {
        'nn_architecture': 'unet',
        'params': {
            'encoder_name': 'tu-mambaout_small',
            'encoder_depth': 5,
            'encoder_weights': "imagenet",
            'decoder_use_norm': "batchnorm",
            'decoder_channels': (256, 128, 128, 128, 128),
            'decoder_attention_type': None,
            'decoder_interpolation': "nearest",
            'in_channels': 3,
            'classes': 11,
            'activation': None,
            'aux_params': None,
            #'img_size':96,
        },
        'input_layer_config': {
            'layer_path': 'encoder.model.stem.conv1',
            'replace_type': 'channels+stride',
            'weight_update_type': 'average_all',
            'params':{
                'stride': (1, 1),
                'padding': (1, 1),
                }
        }
    },
    'multispecter_bands_indices': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12],
    'input_image_size': 96,
    'loss': {
        'type': 'crossentropy',
        #'params': {'weight': 'classes'}
        'params': {
            'weight': None,
            'ignore_index': -100,
            'reduction': "mean",
            'label_smoothing': 0.15,
            },
    },
    'optimizer': {
        'type': 'adam',
        'args': {}
    },
    'lr_scheduler': {
        'type': 'cosine_warm_restarts',
        'args': {
            'T_0': 25,
            'T_mult': 1,
            'eta_min': 0,
            'last_epoch': -1
        },
        'params':{
            'interval': 'epoch',
            'frequency': 1,
            'monitor': 'val_loss',
            'strict': True,
            'name': None,
        },
    },
    'device': 'cuda:0',
    'batch_size': 16,
    'epoch_num':300,
    'train_augmentations': {
       'gauss_noise':{
            'mean': 0.0,
            'sigma': 0.0008,
            'clip': False,
        },
        'affine':{
            'degrees': [0, 45],
            'translate': [0, 0.3],
            'scale': [0.7, 1.5],
            'shear': [0, 0.2],
            'fill': [0],
        },
        'perspective':{
            'distortion_scale': 0.2,
            'p': 0.3,
            'fill': [0],
        },
        'horizontal_flip': {
            'p': 0.5,
        },
        'vertical_flip':{
            'p': 0.5,
        },
    },
    #'path_to_dataset_root': r'C:\Users\admin\python_programming\DATA\DATA_FOR_TRAINIG_96',
    'path_to_dataset_root': r'/home/mikhail_u/develop/DATA/DATA_FOR_TRAINIG_96'
}

unet_hrnet_config_dict = {
    'name_postfix': '',
    'segmentation_nn': {
        'nn_architecture': 'unet',
        'params': {
            'encoder_name': 'tu-hrnet_w18',
            'encoder_depth': 5,
            'encoder_weights': "imagenet",
            'decoder_use_norm': "batchnorm",
            'decoder_channels': (256, 128, 128, 128, 128),
            'decoder_attention_type': None,
            'decoder_interpolation': "nearest",
            'in_channels': 3,
            'classes': 11,
            'activation': None,
            'aux_params': None,
            #'img_size':96,
        },
        'input_layer_config': {
            'layer_path': 'encoder.model.conv1',
            'replace_type': 'channels+stride',
            'weight_update_type': 'average_all',
            'params':{
                'stride': (1, 1),
                'padding': (1, 1),
                }
        }
    },
    'multispecter_bands_indices': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12],
    'input_image_size': 96,
    'loss': {
        'type': 'crossentropy',
        #'params': {'weight': 'classes'}
        'params': {
            'weight': None,
            'ignore_index': -100,
            'reduction': "mean",
            'label_smoothing': 0.15,
            },
    },
    'optimizer': {
        'type': 'adam',
        'args': {}
    },
    'lr_scheduler': {
        'type': 'cosine_warm_restarts',
        'args': {
            'T_0': 25,
            'T_mult': 1,
            'eta_min': 0,
            'last_epoch': -1
        },
        'params':{
            'interval': 'epoch',
            'frequency': 1,
            'monitor': 'val_loss',
            'strict': True,
            'name': None,
        },
    },
    'device': 'cuda:0',
    'batch_size': 16,
    'epoch_num':300,
    'train_augmentations': {
       'gauss_noise':{
            'mean': 0.0,
            'sigma': 0.0008,
            'clip': False,
        },
        'affine':{
            'degrees': [0, 45],
            'translate': [0, 0.3],
            'scale': [0.7, 1.5],
            'shear': [0, 0.2],
            'fill': [0],
        },
        'perspective':{
            'distortion_scale': 0.2,
            'p': 0.3,
            'fill': [0],
        },
        'horizontal_flip': {
            'p': 0.5,
        },
        'vertical_flip':{
            'p': 0.5,
        },
    },
    #'path_to_dataset_root': r'C:\Users\admin\python_programming\DATA\DATA_FOR_TRAINIG_96',
    'path_to_dataset_root': r'/home/mikhail_u/develop/DATA/DATA_FOR_TRAINIG_96'
}

unet_cspresnext_config_dict = {
    'name_postfix': '',
    'segmentation_nn': {
        'nn_architecture': 'unet',
        'params': {
            'encoder_name': 'tu-cspresnext50',
            'encoder_depth': 5,
            'encoder_weights': "imagenet",
            'decoder_use_norm': "batchnorm",
            'decoder_channels': (256, 128, 128, 128, 128),
            'decoder_attention_type': None,
            'decoder_interpolation': "nearest",
            'in_channels': 3,
            'classes': 11,
            'activation': None,
            'aux_params': None,
            #'img_size':96,
        },
        'input_layer_config': {
            'layer_path': 'encoder.model.stem_conv1.conv',
            'replace_type': 'channels+stride',
            'weight_update_type': 'average_all',
            'params':{
                'stride': (1, 1),
                'padding': (1, 1),
                }
        }
    },
    'multispecter_bands_indices': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12],
    'input_image_size': 96,
    'loss': {
        'type': 'crossentropy',
        #'params': {'weight': 'classes'}
        'params': {
            'weight': None,
            'ignore_index': -100,
            'reduction': "mean",
            'label_smoothing': 0.15,
            },
    },
    'optimizer': {
        'type': 'adam',
        'args': {}
    },
    'lr_scheduler': {
        'type': 'cosine_warm_restarts',
        'args': {
            'T_0': 25,
            'T_mult': 1,
            'eta_min': 0,
            'last_epoch': -1
        },
        'params':{
            'interval': 'epoch',
            'frequency': 1,
            'monitor': 'val_loss',
            'strict': True,
            'name': None,
        },
    },
    'device': 'cuda:0',
    'batch_size': 16,
    'epoch_num':300,
    'train_augmentations': {
       'gauss_noise':{
            'mean': 0.0,
            'sigma': 0.0008,
            'clip': False,
        },
        'affine':{
            'degrees': [0, 45],
            'translate': [0, 0.3],
            'scale': [0.7, 1.5],
            'shear': [0, 0.2],
            'fill': [0],
        },
        'perspective':{
            'distortion_scale': 0.2,
            'p': 0.3,
            'fill': [0],
        },
        'horizontal_flip': {
            'p': 0.5,
        },
        'vertical_flip':{
            'p': 0.5,
        },
    },
    'path_to_dataset_root': r'C:\Users\admin\python_programming\DATA\DATA_FOR_TRAINIG_96',
    #'path_to_dataset_root': r'/home/mikhail_u/develop/DATA/DATA_FOR_TRAINIG_96'
}

fpn_config_dict = {
    'name_postfix': '',
    'segmentation_nn': {
        'nn_architecture': 'custom_fpn',
        'params': {
            'encoder_name': "efficientnet-b0",
            'encoder_depth': 5,
            'encoder_weights': "imagenet",
            'decoder_pyramid_channels': 128,
            'decoder_segmentation_channels': 128,
            'decoder_merge_policy': "add",
            'decoder_dropout': 0.2,
            'decoder_interpolation': "nearest",
            'in_channels': 3,
            'classes': 11,
            'activation': None,
            'upsampling': 0,
            'aux_params': None,
        },
        'input_layer_config': {
            'layer_path': 'encoder._conv_stem',
            'replace_type': 'channels+stride',
            'weight_update_type': 'repeate',
            'params':{
                'stride': (1, 1),
                'padding': (1, 1),
                },
        },
    },
    'multispecter_bands_indices': [1, 2, 3, 7, 'ndvi'],
    'input_image_size': 96,
    'loss': {
        'type': 'crossentropy',
        #'params': {'weight': 'classes'},
        'params': {
            'weight': None,
            'ignore_index': -100,
            'reduction': "mean",
            'label_smoothing': 0.15,
            },
    },
    'optimizer': {
        'type': 'adam',
        'args': {}
    },
    'lr_scheduler': {
        'type': 'cosine_warm_restarts',
        'args': {
            'T_0': 25,
            'T_mult': 1,
            'eta_min': 0,
            'last_epoch': -1
        },
        'params':{
            'interval': 'epoch',
            'frequency': 1,
            'monitor': 'val_loss',
            'strict': True,
            'name': None,
        },
    },
    'device': 'cuda:0',
    'batch_size': 16,
    'epoch_num':300,
    'train_augmentations': {
       'gauss_noise':{
            'mean': 0.0,
            'sigma': 0.0008,
            'clip': False,
        },
        'affine':{
            'degrees': [0, 45],
            'translate': [0, 0.3],
            'scale': [0.7, 1.5],
            'shear': [0, 0.2],
            'fill': [0],
        },
        'perspective':{
            'distortion_scale': 0.2,
            'p': 0.3,
            'fill': [0],
        },
        'horizontal_flip': {
            'p': 0.5,
        },
        'vertical_flip':{
            'p': 0.5,
        },
    },
    'path_to_dataset_root': r'C:\Users\admin\python_programming\DATA\DATA_FOR_TRAINIG_96',
    #'path_to_dataset_root': r'C:\Users\mokhail\develop\DATA\DATA_FOR_TRAINIG_96',
}

fcn_config_dict = {
    'name_postfix': '',
    'segmentation_nn': {
        'nn_architecture': 'fcn',
        'params': {
            'encoder_name': "efficientnet-b0",
            'encoder_depth': 5,
            'encoder_weights': "imagenet",
            'decoder_use_norm': 'batchnorm',
            'decoder_last_channel': 16,
            'decoder_attention_type': None,
            'decoder_interpolation': "nearest",
            'in_channels': 3,
            'classes': 11,
            'activation':  None,
            'aux_params': None,
        },
        'input_layer_config': {
            'layer_path': 'encoder._conv_stem',
            'replace_type': 'channels+stride',
            'weight_update_type': 'repeate',
            'params':{
                'stride': (1, 1),
                'padding': (1, 1),
                },
        },
    },
    'multispecter_bands_indices': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12],
    'input_image_size': 96,
    'loss': {
        'type': 'crossentropy',
        'params': {
            'weight': None,
            'ignore_index': -100,
            'reduction': "mean",
            'label_smoothing': 0.15,
            },
        #'params': {},
    },
    'optimizer': {
        'type': 'adam',
        'args': {}
    },
    'lr_scheduler': {
        'type': 'cosine_warm_restarts',
        'args': {
            'T_0': 25,
            'T_mult': 1,
            'eta_min': 0,
            'last_epoch': -1
        },
        'params':{
            'interval': 'epoch',
            'frequency': 1,
            'monitor': 'val_loss',
            'strict': True,
            'name': None,
        },
    },
    'device': 'cuda:0',
    'batch_size': 16,
    'epoch_num':300,
    'train_augmentations': {
       'gauss_noise':{
            'mean': 0.0,
            'sigma': 0.0008,
            'clip': False,
        },
        'affine':{
            'degrees': [0, 45],
            'translate': [0, 0.3],
            'scale': [0.7, 1.5],
            'shear': [0, 0.2],
            'fill': [0],
        },
        'perspective':{
            'distortion_scale': 0.2,
            'p': 0.3,
            'fill': [0],
        },
        'horizontal_flip': {
            'p': 0.5,
        },
        'vertical_flip':{
            'p': 0.5,
        },
    },
    'path_to_dataset_root': r'C:\Users\admin\python_programming\DATA\DATA_FOR_TRAINIG_96',
    #'path_to_dataset_root': r'C:\Users\mokhail\develop\DATA\DATA_FOR_TRAINIG_96',
}

fpn_config_dict_dice_ce = {
    'name_postfix': '',
    'segmentation_nn': {
        'nn_architecture': 'custom_fpn',
        'params': {
            'encoder_name': "efficientnet-b0",
            'encoder_depth': 5,
            'encoder_weights': "imagenet",
            'decoder_pyramid_channels': 128,
            'decoder_segmentation_channels': 128,
            'decoder_merge_policy': "add",
            'decoder_dropout': 0.2,
            'decoder_interpolation': "nearest",
            'in_channels': 3,
            'classes': 11,
            'activation': None,
            'upsampling': 0,
            'aux_params': None,
        },
        'input_layer_config': {
            'layer_path': 'encoder._conv_stem',
            'replace_type': 'channels+stride',
            'weight_update_type': 'repeate',
            'params':{
                'stride': (1, 1),
                'padding': (1, 1),
                },
        },
    },
    'multispecter_bands_indices': [1, 2, 3, 7, 'ndvi'],
    'input_image_size': 96,
    'loss': {
        'type': 'dice_crossentropy',
        #'params': {'weight': 'classes'},
        'params': {
            'ce_weight':None,
            'ce_ignore_index':-100,
            'ce_reducion':'mean',
            'ce_label_smoothing':0.15,
            'dice_mode':'multiclass',
            'dice_classes': None,
            'dice_log_loss':False,
            'dice_from_logits':True,
            'dice_smooth':0.15,
            'dice_ignore_index':-100,
            'dice_eps': 1e-7,
            'losses_weight': [0.5, 0.5],
            'is_trainable_weights': True,
            'weights_processing_type': 'softmax',
            },
    },
    'optimizer': {
        'type': 'adam',
        'args': {}
    },
    'lr_scheduler': {
        'type': 'cosine_warm_restarts',
        'args': {
            'T_0': 25,
            'T_mult': 1,
            'eta_min': 0,
            'last_epoch': -1
        },
        'params':{
            'interval': 'epoch',
            'frequency': 1,
            'monitor': 'val_loss',
            'strict': True,
            'name': None,
        },
    },
    'device': 'cuda:0',
    'batch_size': 16,
    'epoch_num':300,
    'train_augmentations': {
       'gauss_noise':{
            'mean': 0.0,
            'sigma': 0.0008,
            'clip': False,
        },
        'affine':{
            'degrees': [0, 45],
            'translate': [0, 0.3],
            'scale': [0.7, 1.5],
            'shear': [0, 0.2],
            'fill': [0],
        },
        'perspective':{
            'distortion_scale': 0.2,
            'p': 0.3,
            'fill': [0],
        },
        'horizontal_flip': {
            'p': 0.5,
        },
        'vertical_flip':{
            'p': 0.5,
        },
    },
    'path_to_dataset_root': r'C:\Users\admin\python_programming\DATA\DATA_FOR_TRAINIG_96',
    'path_to_dataset_root': r'C:\Users\mokhail\develop\DATA\DATA_FOR_TRAINIG_96',
}

fpn_config_dict_dice = {
    'name_postfix': '',
    'segmentation_nn': {
        'nn_architecture': 'custom_fpn',
        'params': {
            'encoder_name': "efficientnet-b0",
            'encoder_depth': 5,
            'encoder_weights': "imagenet",
            'decoder_pyramid_channels': 128,
            'decoder_segmentation_channels': 128,
            'decoder_merge_policy': "add",
            'decoder_dropout': 0.2,
            'decoder_interpolation': "nearest",
            'in_channels': 3,
            'classes': 11,
            'activation': None,
            'upsampling': 0,
            'aux_params': None,
        },
        'input_layer_config': {
            'layer_path': 'encoder._conv_stem',
            'replace_type': 'channels+stride',
            'weight_update_type': 'repeate',
            'params':{
                'stride': (1, 1),
                'padding': (1, 1),
                },
        },
    },
    'multispecter_bands_indices': [1, 2, 3, 7, 'ndvi'],
    'input_image_size': 96,
    'loss': {
        'type': 'dice',
        #'params': {'weight': 'classes'},
        'params': {
            'mode':'multiclass',
            'classes': None,
            'log_loss':False,
            'from_logits':True,
            'smooth':0.15,
            'ignore_index':-100,
            'eps': 1e-7,
            },
    },
    'optimizer': {
        'type': 'adam',
        'args': {}
    },
    'lr_scheduler': {
        'type': 'cosine_warm_restarts',
        'args': {
            'T_0': 25,
            'T_mult': 1,
            'eta_min': 0,
            'last_epoch': -1
        },
        'params':{
            'interval': 'epoch',
            'frequency': 1,
            'monitor': 'val_loss',
            'strict': True,
            'name': None,
        },
    },
    'device': 'cuda:0',
    'batch_size': 16,
    'epoch_num':300,
    'train_augmentations': {
       'gauss_noise':{
            'mean': 0.0,
            'sigma': 0.0008,
            'clip': False,
        },
        'affine':{
            'degrees': [0, 45],
            'translate': [0, 0.3],
            'scale': [0.7, 1.5],
            'shear': [0, 0.2],
            'fill': [0],
        },
        'perspective':{
            'distortion_scale': 0.2,
            'p': 0.3,
            'fill': [0],
        },
        'horizontal_flip': {
            'p': 0.5,
        },
        'vertical_flip':{
            'p': 0.5,
        },
    },
    'path_to_dataset_root': r'C:\Users\admin\python_programming\DATA\DATA_FOR_TRAINIG_96',
    #'path_to_dataset_root': r'C:\Users\mokhail\develop\DATA\DATA_FOR_TRAINIG_96',
}

fpn_config_dict_ce = {
    'name_postfix': '',
    'segmentation_nn': {
        'nn_architecture': 'custom_fpn',
        'params': {
            'encoder_name': "efficientnet-b0",
            'encoder_depth': 5,
            'encoder_weights': "imagenet",
            'decoder_pyramid_channels': 128,
            'decoder_segmentation_channels': 128,
            'decoder_merge_policy': "add",
            'decoder_dropout': 0.2,
            'decoder_interpolation': "nearest",
            'in_channels': 3,
            'classes': 11,
            'activation': None,
            'upsampling': 0,
            'aux_params': None,
        },
        'input_layer_config': {
            'layer_path': 'encoder._conv_stem',
            'replace_type': 'channels+stride',
            'weight_update_type': 'average_all',
            'params':{
                'stride': (1, 1),
                'padding': (1, 1),
                },
        },
    },
    'multispecter_bands_indices': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12,],
    'input_image_size': 96,
    'loss': {
        'type': 'crossentropy',
        #'params': {'weight': 'classes'},
        'params': {
            'weight': None,
            'ignore_index': -100,
            'reduction': "mean",
            'label_smoothing': 0.15,
            },
    },
    'optimizer': {
        'type': 'adam',
        'args': {}
    },
    'lr_scheduler': {
        'type': 'cosine_warm_restarts',
        'args': {
            'T_0': 25,
            'T_mult': 1,
            'eta_min': 0,
            'last_epoch': -1
        },
        'params':{
            'interval': 'epoch',
            'frequency': 1,
            'monitor': 'val_loss',
            'strict': True,
            'name': None,
        },
    },
    'device': 'cuda:0',
    'batch_size':16,
    'epoch_num':300,
    'train_augmentations': {
       'gauss_noise':{
            'mean': 0.0,
            'sigma': 0.0008,
            'clip': False,
        },
        'affine':{
            'degrees': [0, 45],
            'translate': [0, 0.3],
            'scale': [0.7, 1.5],
            'shear': [0, 0.2],
            'fill': [0],
        },
        'perspective':{
            'distortion_scale': 0.2,
            'p': 0.3,
            'fill': [0],
        },
        'horizontal_flip': {
            'p': 0.5,
        },
        'vertical_flip':{
            'p': 0.5,
        },
    },
    #'path_to_dataset_root': r'C:\Users\admin\python_programming\DATA\DATA_FOR_TRAINIG_96',
    'path_to_dataset_root': r'C:\Users\mokhail\develop\DATA\DATA_FOR_TRAINIG_96',
}

fpn_multisize_input_config_dict_ce = {
    'name_postfix': '',
    'segmentation_nn': {
        'nn_architecture': 'custom_fpn',
        'params': {
            'encoder_name': "efficientnet-b0",
            'encoder_depth': 5,
            'encoder_weights': "imagenet",
            'decoder_pyramid_channels': 128,
            'decoder_segmentation_channels': 128,
            'decoder_merge_policy': "add",
            'decoder_dropout': 0.2,
            'decoder_interpolation': "nearest",
            'in_channels': 3,
            'classes': 11,
            'activation': None,
            'upsampling': 0,
            'aux_params': None,
        },
        'input_layer_config': {
            'layer_path': 'encoder._conv_stem',
            'replace_type': 'multisize_conv',
            'weight_update_type': 'repeat', # avearge_all OR repeate
            'params':{
                'in_channels': 3,
                'out_channels': {
                    '1x1': 32,
                    '3x3': 32,
                    #'5x5': 32, 
                },
                'kernel_size': {
                    '1x1': 1,
                    '3x3': 3,
                    #'5x5': 5, 
                },
                'stride': {
                    '1x1': 1,
                    '3x3': 1,
                    #'5x5': 1, 
                },
                'padding': {
                    '1x1': 0,
                    '3x3': 1,
                    #'5x5': 2, 
                },
                'dilation': {
                    '1x1': 1,
                    '3x3': 1,
                    #'5x5': 1, 
                },
                'groups': {
                    '1x1': 1,
                    '3x3': 1,
                    #'5x5': 1, 
                },
                'bias': {
                    '1x1': False,
                    '3x3': False,
                    #'5x5': False, 
                },
                'aggregation_type': 'add',
            },
        },
    },
    'multispecter_bands_indices': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12,],
    'input_image_size': 96,
    'loss': {
        'type': 'crossentropy',
        #'params': {'weight': 'classes'},
        'params': {
            'weight': None,
            'ignore_index': -100,
            'reduction': "mean",
            'label_smoothing': 0.15,
            },
    },
    'optimizer': {
        'type': 'adam',
        'args': {}
    },
    'lr_scheduler': {
        'type': 'cosine_warm_restarts',
        'args': {
            'T_0': 25,
            'T_mult': 1,
            'eta_min': 0,
            'last_epoch': -1
        },
        'params':{
            'interval': 'epoch',
            'frequency': 1,
            'monitor': 'val_loss',
            'strict': True,
            'name': None,
        },
    },
    'device': 'cuda:0',
    'batch_size': 16,
    'epoch_num':300,
    'train_augmentations': {
       'gauss_noise':{
            'mean': 0.0,
            'sigma': 0.0008,
            'clip': False,
        },
        'affine':{
            'degrees': [0, 45],
            'translate': [0, 0.3],
            'scale': [0.7, 1.5],
            'shear': [0, 0.2],
            'fill': [0],
        },
        'perspective':{
            'distortion_scale': 0.2,
            'p': 0.3,
            'fill': [0],
        },
        'horizontal_flip': {
            'p': 0.5,
        },
        'vertical_flip':{
            'p': 0.5,
        },
    },
    #'path_to_dataset_root': r'C:\Users\admin\python_programming\DATA\DATA_FOR_TRAINIG_96',
    'path_to_dataset_root': r'C:\Users\mokhail\develop\DATA\DATA_FOR_TRAINIG_96',
}

fpn_mit_config_dict_ce = {
    'name_postfix': '',
    'segmentation_nn': {
        'nn_architecture': 'custom_fpn',
        'params': {
            'encoder_name': "mit_b0",
            'encoder_depth': 5,
            'encoder_weights': "imagenet",
            'decoder_pyramid_channels': 128,
            'decoder_segmentation_channels': 128,
            'decoder_merge_policy': "add",
            'decoder_dropout': 0.2,
            'decoder_interpolation': "nearest",
            'in_channels': 3,
            'classes': 11,
            'activation': None,
            'upsampling': 0,
            'aux_params': None,
        },
        'input_layer_config': {
            'layer_path': 'encoder._conv_stem',
            'replace_type': 'channels+stride',
            'weight_update_type': 'repeate',
            'params':{
                'stride': (1, 1),
                'padding': (1, 1),
                },
        },
    },
    'multispecter_bands_indices': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12,],
    'input_image_size': 96,
    'loss': {
        'type': 'crossentropy',
        #'params': {'weight': 'classes'},
        'params': {
            'weight': None,
            'ignore_index': -100,
            'reduction': "mean",
            'label_smoothing': 0.15,
            },
    },
    'optimizer': {
        'type': 'adam',
        'args': {}
    },
    'lr_scheduler': {
        'type': 'cosine_warm_restarts',
        'args': {
            'T_0': 25,
            'T_mult': 1,
            'eta_min': 0,
            'last_epoch': -1
        },
        'params':{
            'interval': 'epoch',
            'frequency': 1,
            'monitor': 'val_loss',
            'strict': True,
            'name': None,
        },
    },
    'device': 'cuda:0',
    'batch_size': 16,
    'epoch_num':300,
    'train_augmentations': {
       'gauss_noise':{
            'mean': 0.0,
            'sigma': 0.0008,
            'clip': False,
        },
        'affine':{
            'degrees': [0, 45],
            'translate': [0, 0.3],
            'scale': [0.7, 1.5],
            'shear': [0, 0.2],
            'fill': [0],
        },
        'perspective':{
            'distortion_scale': 0.2,
            'p': 0.3,
            'fill': [0],
        },
        'horizontal_flip': {
            'p': 0.5,
        },
        'vertical_flip':{
            'p': 0.5,
        },
    },
    'path_to_dataset_root': r'C:\Users\admin\python_programming\DATA\DATA_FOR_TRAINIG_96',
    #'path_to_dataset_root': r'C:\Users\mokhail\develop\DATA\DATA_FOR_TRAINIG_96',
}

custom_manet_config_dict = {
    'name_postfix': '',
    'segmentation_nn': {
        'nn_architecture': 'custom_manet',
        'params': {
            'encoder_name': "efficientnet-b0",
            'encoder_depth': 5,
            'encoder_weights': "imagenet",
            'decoder_use_norm': "batchnorm",
            'decoder_channels': [256, 128, 64, 64, 64],
            'decoder_pab_channels': 64,
            'decoder_interpolation': "nearest",
            'in_channels': 3,
            'classes': 11,
            'activation': None,
            'aux_params': None,
        },
        'input_layer_config': {
            'layer_path': 'encoder._conv_stem',
            'replace_type': 'channels+stride',
            'weight_update_type': 'repeate',
            'params':{
                'stride': (1, 1),
                'padding': (1, 1),
                },
        },
    },
    'multispecter_bands_indices': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12,],
    'input_image_size': 96,
    'loss': {
        'type': 'crossentropy',
        #'params': {'weight': 'classes'},
        'params': {
            'weight': None,
            'ignore_index': -100,
            'reduction': "mean",
            'label_smoothing': 0.15,
            },
    },
    'optimizer': {
        'type': 'adam',
        'args': {}
    },
    'lr_scheduler': {
        'type': 'cosine_warm_restarts',
        'args': {
            'T_0': 25,
            'T_mult': 1,
            'eta_min': 0,
            'last_epoch': -1
        },
        'params':{
            'interval': 'epoch',
            'frequency': 1,
            'monitor': 'val_loss',
            'strict': True,
            'name': None,
        },
    },
    'device': 'cuda:0',
    'batch_size':8,
    'epoch_num':300,
    'train_augmentations': {
       'gauss_noise':{
            'mean': 0.0,
            'sigma': 0.0008,
            'clip': False,
        },
        'affine':{
            'degrees': [0, 45],
            'translate': [0, 0.3],
            'scale': [0.7, 1.5],
            'shear': [0, 0.2],
            'fill': [0],
        },
        'perspective':{
            'distortion_scale': 0.2,
            'p': 0.3,
            'fill': [0],
        },
        'horizontal_flip': {
            'p': 0.5,
        },
        'vertical_flip':{
            'p': 0.5,
        },
    },
    'path_to_dataset_root': r'C:\Users\admin\python_programming\DATA\DATA_FOR_TRAINIG_96',
    #'path_to_dataset_root': r'C:\Users\mokhail\develop\DATA\DATA_FOR_TRAINIG_96',
}

config_dict = deepcopy(fpn_config_dict_ce)

config_dict

{'name_postfix': '',
 'segmentation_nn': {'nn_architecture': 'custom_fpn',
  'params': {'encoder_name': 'efficientnet-b0',
   'encoder_depth': 5,
   'encoder_weights': 'imagenet',
   'decoder_pyramid_channels': 128,
   'decoder_segmentation_channels': 128,
   'decoder_merge_policy': 'add',
   'decoder_dropout': 0.2,
   'decoder_interpolation': 'nearest',
   'in_channels': 3,
   'classes': 11,
   'activation': None,
   'upsampling': 0,
   'aux_params': None},
  'input_layer_config': {'layer_path': 'encoder._conv_stem',
   'replace_type': 'channels+stride',
   'weight_update_type': 'average_all',
   'params': {'stride': (1, 1), 'padding': (1, 1)}}},
 'multispecter_bands_indices': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12],
 'input_image_size': 96,
 'loss': {'type': 'crossentropy',
  'params': {'weight': None,
   'ignore_index': -100,
   'reduction': 'mean',
   'label_smoothing': 0.15}},
 'optimizer': {'type': 'adam', 'args': {}},
 'lr_scheduler': {'type': 'cosine_warm_restarts',
  'args'

In [None]:
configs_to_save = [
    unet_config_dict,
    win_att1l_cat_agg_2_conv_unet_config_dict,
    win_att_cat_agg_2_conv_unet_config_dict,
    att_cross_agg_2_conv_unet_config_dict,
    att_cat_agg_2_conv_unet_config_dict,
    unetpp_config_dict,
    unet_mit_config_dict,
    unet_maxvit_config_dict,
    net_hgnetv2_config_dict,
    net_mambaout_config_dict,
    unet_hrnet_config_dict,
    unet_cspresnext_config_dict,
    fpn_config_dict,
    fcn_config_dict,
    fpn_config_dict_dice_ce,
    fpn_config_dict_dice,
    fpn_config_dict_ce,
    fpn_multisize_input_config_dict_ce,
    custom_manet_config_dict,
]
for config_dict in configs_to_save:
    name_postfix = config_dict["name_postfix"]
    model_name = f'{config_dict["segmentation_nn"]["nn_architecture"]}_{config_dict["segmentation_nn"]["params"]["encoder_name"]}'
    if name_postfix is not None and len(name_postfix) != 0:
        model_name = f'{model_name}_{name_postfix}'
        
    path_to_save = os.path.join('training_configs', f'{model_name}.yaml')
    with open(path_to_save, 'w', encoding='utf-8') as fd:
        yaml.dump(config_dict, fd, indent=4)

In [68]:
name_postfix is not None or len(name_postfix) != 0

True

# Создание модели

In [72]:
#with open(r'training_configs\custom_fpn_ce_efficientnet-b0_RGB.yaml') as fd:
#    config_dict = yaml.load(fd, yaml.Loader)

#config_dict['path_to_dataset_root'] = r'C:\Users\mokhail\develop\DATA\DATA_FOR_TRAINIG_96'
#config_dict

path_to_dataset_root = config_dict['path_to_dataset_root']

path_to_dataset_info_csv = os.path.join(path_to_dataset_root, 'data_info_table.csv')
path_to_surface_classes_json = os.path.join(path_to_dataset_root, 'surface_classes.json')

input_image_size = config_dict['input_image_size']
multispecter_bands_indices = config_dict['multispecter_bands_indices']
device = config_dict['device']

# чтение списка имен классов поверхностей
with open(path_to_surface_classes_json) as fd:
    surface_classes_list = json.load(fd)
# чтение таблицы с информацией о каждом изображении в выборке
images_df = pd.read_csv(path_to_dataset_info_csv)

path_to_partition_json = os.path.join(path_to_dataset_root, 'dataset_partition.json')
# чтение словаря со списками квадратов, находящихся в обучающей и тестовой выборке
with open(path_to_partition_json) as fd:
    partition_dict = json.load(fd)

# формирование pandas DataFrame-ов с информацией об изображениях обучающей и тестовой выборках
train_images_df = []
for train_square in partition_dict['train_squares']:
    train_images_df.append(images_df[images_df['square_id']==train_square])
train_images_df = pd.concat(train_images_df, ignore_index=True)

test_images_df = []
for test_square in partition_dict['test_squares']:
    test_images_df.append(images_df[images_df['square_id']==test_square])
test_images_df = pd.concat(test_images_df, ignore_index=True)

#train_images_df, test_images_df = train_test_split(images_df, test_size=0.3, random_state=0)

class_num = images_df['class_num'].iloc[0]

# формирование словаря, отображающейго имя класса поверхности в индекс класса
class_name2idx_dict = {n:i for i, n in enumerate(surface_classes_list)}

# вычисление распределений пикселей в классах поверхностей 
classes_pixels_distribution_df = images_df[surface_classes_list]
classes_pixels_num = classes_pixels_distribution_df.sum()
classes_weights = classes_pixels_num / classes_pixels_num.sum()
classes_weights = classes_weights[surface_classes_list].to_numpy().astype(np.float32)

input_image_size = config_dict['input_image_size']
'''
train_transforms = v2.Compose(
    [v2.Resize((input_image_size,input_image_size), antialias=True),v2.ToDtype(torch.float32, scale=True)])
test_transforms = v2.Compose(
    [v2.Resize((input_image_size,input_image_size), antialias=True),v2.ToDtype(torch.float32, scale=True)])
'''
train_transforms = create_augmentation_transforms(config_dict['train_augmentations'])
test_transforms = nn.Identity()
# если ф-ция потерь перекрестная энтропия, то проверяем, есть ли там веса классов
if config_dict['loss']['type'] == 'crossentropy':
    # если в параметрах функции потерь стоит строка 'classes', надо передать в функцию вектор весов классов
    if 'weight' in config_dict['loss']['params']:
        if isinstance(config_dict['loss']['params']['weight'], (list, tuple)):
            config_dict['loss']['params']['weight'] = torch.tensor(config_dict['loss']['params']['weight'])
        
        elif config_dict['loss']['params']['weight'] is not None:
            config_dict['loss']['params']['weight'] = torch.tensor(classes_weights)

# создание функции потерь
criterion = criterion_factory_dict[config_dict['loss']['type']](**config_dict['loss']['params'])

# если ф-ция потерь перекрестная энтропия, то проверяем, есть ли там веса классов
if config_dict['loss']['type'] == 'crossentropy':
    # если в параметрах функции потерь стоит строка 'classes', надо передать в функцию вектор весов классов
    if 'weight' in config_dict['loss']['params']:
        if isinstance(config_dict['loss']['params']['weight'], torch.Tensor):
            config_dict['loss']['params']['weight'] = config_dict['loss']['params']['weight'].cpu().tolist()

model = create_model(config_dict, segmentation_nns_factory_dict)
model = model.to(device)

# создаем датасеты и даталоадеры
train_dataset = SegmentationDataset(path_to_dataset_root=path_to_dataset_root, samples_df=train_images_df, channel_indices=multispecter_bands_indices, transforms=train_transforms, dtype=torch.float32, device=device)
test_dataset = SegmentationDataset(path_to_dataset_root=path_to_dataset_root, samples_df=test_images_df, channel_indices=multispecter_bands_indices, transforms=test_transforms, dtype=torch.float32, device=device)
#train_dataset = SegmentationDatasetApplSurf(path_to_dataset_root=path_to_dataset_root, samples_df=test_images_df, channel_indices=channel_indices, name2class_idx_dict=class_name2idx_dict, applicable_surfaces_dict=applicable_surfaces_dict, transforms=test_transforms, device=device)
#test_dataset = SegmentationDatasetApplSurf(path_to_dataset_root=path_to_dataset_root, samples_df=test_images_df, channel_indices=channel_indices, name2class_idx_dict=class_name2idx_dict, applicable_surfaces_dict=applicable_surfaces_dict, transforms=test_transforms, device=device)

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=config_dict['batch_size'], shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=config_dict['batch_size'])

# тестовое чтение данных
for data, labels in test_loader:
    break
    pred = model(data)
    loss = criterion(pred, labels)
    loss.backward()
    
# тестовая обработка данных нейронной сетью
ret = model(data)
print(data.shape, ret.shape)

createion_time_str = datetime.now().strftime('%Y-%m-%dT%H-%M-%S')
nn_arch_str = config_dict["segmentation_nn"]["nn_architecture"]
nn_encoder_str = config_dict["segmentation_nn"]["params"]["encoder_name"]
name_postfix = config_dict["name_postfix"]
if name_postfix is not None:
    model_name = f'{nn_arch_str}_{nn_encoder_str}_{name_postfix} {createion_time_str}'
else:
    model_name = f'{nn_arch_str}_{nn_encoder_str} {createion_time_str}'
model_name

torch.Size([16, 13, 96, 96]) torch.Size([16, 11, 96, 96])


'custom_fpn_efficientnet-b0_ 2025-10-05T00-56-04'

In [54]:
model

FPNMod(
  (encoder): EfficientNetEncoder(
    (_conv_stem): Conv2d(3, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (_bn0): BatchNorm2d(32, eps=0.001, momentum=0.010000000000000009, affine=True, track_running_stats=True)
    (_blocks): ModuleList(
      (0): MBConvBlock(
        (_expand_conv): Identity()
        (_bn0): Identity()
        (_depthwise_conv): Conv2dStaticSamePadding(
          32, 32, kernel_size=(3, 3), stride=[1, 1], groups=32, bias=False
          (static_padding): ZeroPad2d((1, 1, 1, 1))
        )
        (_bn1): BatchNorm2d(32, eps=0.001, momentum=0.010000000000000009, affine=True, track_running_stats=True)
        (_se_reduce): Conv2dStaticSamePadding(
          32, 8, kernel_size=(1, 1), stride=(1, 1)
          (static_padding): Identity()
        )
        (_se_expand): Conv2dStaticSamePadding(
          8, 32, kernel_size=(1, 1), stride=(1, 1)
          (static_padding): Identity()
        )
        (_project_conv): Conv2dStaticSamePadd

# Обучение

In [37]:
epoch_num = config_dict['epoch_num']

print('#############################')
print(model_name)
print('#############################')
print()

# создаем список словарей с информацией о вычисляемых метриках с помощью multiclass confusion matrix
# см. подробнее ддокументацию к функции compute_metric_from_confusion
metrics_dict = {
    'train': {
        'iou': classification.JaccardIndex(task='multiclass', average='none', num_classes=len(class_name2idx_dict)).to(device),
        'precision': classification.Precision(task='multiclass', average='none', num_classes=len(class_name2idx_dict)).to(device),
        'recall': classification.Precision(task='multiclass', average='none', num_classes=len(class_name2idx_dict)).to(device),
        'confusion': classification.ConfusionMatrix(task='multiclass', num_classes=len(class_name2idx_dict)).to(device),
    },
    'val': {
        'iou': classification.JaccardIndex(task='multiclass', average='none', num_classes=len(class_name2idx_dict)).to(device),
        'precision': classification.Precision(task='multiclass', average='none', num_classes=len(class_name2idx_dict)).to(device),
        'recall': classification.Precision(task='multiclass', average='none', num_classes=len(class_name2idx_dict)).to(device),
        'confusion': classification.ConfusionMatrix(task='multiclass', num_classes=len(class_name2idx_dict)).to(device),
    }
}

optimizer_cfg = {
    'optmizer': optimizers_factory_dict[config_dict['optimizer']['type']],
    'optimizer_args':config_dict['optimizer']['args'],
    'lr_scheduler': lr_schedulers_factory_dict[config_dict['lr_scheduler']['type']],
    'lr_scheduler_args': config_dict['lr_scheduler']['args'],
    'lr_scheduler_params': config_dict['lr_scheduler']['params']

}

# Создаем модуль Lightning
segmentation_module = LightningSegmentationModule(model, criterion, optimizer_cfg, metrics_dict, class_name2idx_dict)

# задаем путь до папки с логгерами и создаем логгер, записывающий результаты в csv
path_to_saving_dir = 'saving_dir'

#csv_logger = CSVLogger(save_dir = path_to_saving_dir, name=model_name, flush_logs_every_n_steps=1,)
csv_logger = CSVLoggerMetricsAndConfusion(save_dir = path_to_saving_dir, name=model_name, flush_logs_every_n_steps=1,)


# создаем объект, записывающий в чекпоинт лучшую модель
path_to_save_model_dir = os.path.join(path_to_saving_dir, model_name)
os.makedirs(path_to_save_model_dir, exist_ok=True)
checkpoint_callback = ModelCheckpoint(
    mode="max",
    filename=model_name+"-{epoch:02d}-{val_iou_mean:.3}",
    dirpath=path_to_save_model_dir, 
    save_top_k=1, monitor="val_iou_mean"
    )

trainer = L.Trainer(logger=[csv_logger],
        max_epochs=epoch_num, 
        callbacks=[checkpoint_callback],
        accelerator = 'gpu'
        )

# сохраняем конфигурацию
path_to_config = os.path.join(path_to_save_model_dir, 'training_config.yaml')
with open(path_to_config, 'w', encoding='utf-8') as fd:
    #json.dump(config_dict, fd, indent=4)
    yaml.dump(config_dict, fd, indent=4)

trainer.fit(segmentation_module , train_loader, test_loader)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
c:\Users\mokhail\miniconda3\envs\deep_learning\Lib\site-packages\lightning\pytorch\callbacks\model_checkpoint.py:658: Checkpoint directory C:\Users\mokhail\develop\MultispectralSegmentation\saving_dir\att_unet_efficientnet-b0_1L_win_cat_agg 2025-10-01T03-13-25 exists and is not empty.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name      | Type             | Params | Mode 
-------------------------------------------------------
0 | model     | UnetAtt          | 7.2 M  | train
1 | criterion | CrossEntropyLoss | 0      | train
-------------------------------------------------------
7.2 M     Trainable params
0         Non-trainable params
7.2 M     Total params
28.947    Total estimated model params size (MB)
399       Modules in train mode
0         Modules in eval mode


#############################
att_unet_efficientnet-b0_1L_win_cat_agg 2025-10-01T03-13-25
#############################

Sanity Checking DataLoader 0:   0%|          | 0/2 [00:00<?, ?it/s]

c:\Users\mokhail\miniconda3\envs\deep_learning\Lib\site-packages\lightning\pytorch\trainer\connectors\data_connector.py:425: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=7` in the `DataLoader` to improve performance.


                                                                           

c:\Users\mokhail\miniconda3\envs\deep_learning\Lib\site-packages\lightning\pytorch\trainer\connectors\data_connector.py:425: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=7` in the `DataLoader` to improve performance.


Epoch 1:  25%|██▍       | 105/424 [1:18:11<3:57:31,  0.02it/s, v_num=0, val_loss=2.160, val_iou_UNLABELED=0.000, val_iou_buildings_territory=0.000, val_iou_natural_ground=0.000, val_iou_natural_grow=0.000212, val_iou_natural_wetland=0.00556, val_iou_natural_wood=0.469, val_iou_quasi_natural_ground=0.000, val_iou_quasi_natural_grow=0.181, val_iou_quasi_natural_wetland=0.000, val_iou_transport=0.000, val_iou_water=0.000, val_iou_mean=0.0596, val_precision_UNLABELED=0.000, val_precision_buildings_territory=0.000, val_precision_natural_ground=0.000, val_precision_natural_grow=0.176, val_precision_natural_wetland=0.0708, val_precision_natural_wood=0.470, val_precision_quasi_natural_ground=0.000, val_precision_quasi_natural_grow=0.499, val_precision_quasi_natural_wetland=0.000, val_precision_transport=0.000, val_precision_water=0.000, val_precision_mean=0.111, val_recall_UNLABELED=0.000, val_recall_buildings_territory=0.000, val_recall_natural_ground=0.000, val_recall_natural_grow=0.176, val


Detected KeyboardInterrupt, attempting graceful shutdown ...


NameError: name 'exit' is not defined

# Черновики

In [None]:
# Список трансформеров, входные слои которых можно безопасно изменять 
[
    'mit', # (7,7) Ньюансы с FPN
    'tu-davit', # (7,7) Dual-Attention vision transformer + 
    'tu-efficientvit_b0',#(3,3) +
    'tu-fastvit', #(3,3)
    'tu-hgnet',#(3,3)
    'tu-hgnetv2',#(3,3)
    'tu-mambaout',#(3,3)
    'tu-maxvit',#(3,3)
    'tu-mvitv2', #(7,7) +
    'tu-nextvit',#(3,3)
    'tu-poolformer',#(7,7)
    'tu-poolformerv2',#(7,7)
    'tu-pvt_v2',#(7,7)
    'tu-repvit',#(3,3)
    'tu-sam2_hiera',#(7,7)
    'tu-tiny_vit',#(3,3)
]

In [None]:
model = smp.MAnet()
model.encoder.conv1.stride=(1,1)

def custom_forward(self, x):
    """Sequentially pass `x` trough model`s encoder, decoder and heads"""
    features = self.encoder(x)
    '''
    for f in features:
        print(f.shape)
    print()
    '''

    decoder_output = self.decoder(features)
    #print(decoder_output.shape)

    masks = self.segmentation_head(decoder_output)

    if self.classification_head is not None:
        labels = self.classification_head(features[-1])
        return masks, labels

    return masks


def decoder_custom_forward(self, features: List[torch.Tensor]) -> torch.Tensor:
    features = features[1:]  # remove first skip with same spatial resolution
    features = features[::-1]  # reverse channels to start from head of encoder

    head = features[0]
    skips = features[1:]
    print('Features shapes')
    print(f'head:{head.shape}')
    print('Skips:')
    for skip in skips:
        print(skip.shape)

    print('----------------------')

    x = self.center(head)
    print(f'Center:{x.shape}')

    for i, decoder_block in enumerate(self.blocks):
        skip = skips[i] if i < len(skips) else None
        x = decoder_block(x, skip)
        if skip is not None:
            print(f'x:{x.shape}; skip:{skip.shape}')
        else:
            print(f'x:{x.shape}; skip:{skip}')

    return x

model.forward = types.MethodType(custom_forward, model)
model.decoder.forward = types.MethodType(decoder_custom_forward, model.decoder)

ret = model(torch.randn(1, 3, 96, 96))
ret.shape

In [None]:
idx_in_batch = 1
#transform = v2.RandomPerspective(distortion_scale=0.3, p=1.0, fill={tv_tensors.Image:0.0, tv_tensors.Mask:0})
transform = v2.RandomAffine(degrees=[0, 30], translate=[0.0, 0.3], scale=[0.3, 0.5], shear=[0.0, 0.4], fill=[0])
#transform = v2.RandomResize(min_size=96, max_size=256)
#transform = v2.RandomRotation(degrees=(0, 45))

transforms_factory_dict = {
    'affine': v2.RandomAffine,
    'perspective': v2.RandomPerspective,
    'horizontal_flip': v2.RandomHorizontalFlip,
    'vertical_flip': v2.RandomVerticalFlip,
    'crop': v2.RandomCrop,
    'gauss_noise': v2.GaussianNoise,
    'gauss_blur': v2.GaussianBlur,
    'elastic': v2.ElasticTransform,
}

def create_transforms(transforms_dict:Dict[str, Dict]):
    transforms_list = []
    for name, transform_params in transforms_dict.items():
        transform_creation_fn = transforms_factory_dict[name]
        transforms_list.append(transform_creation_fn(**transform_params))
    #return v2.Compose([v2.RandomOrder(transforms_list)])
    return v2.RandomOrder(transforms_list)
        
transforms_dict = {
        'gauss_noise':{
            'mean': 0.0,
            'sigma': 0.0008,
            'clip': False,
        },
        'affine':{
            'degrees': [0, 45],
            'translate': [0, 0.3],
            'scale': [0.7, 1.5],
            'shear': [0, 0.2],
            'fill': [0],
        },
        'perspective':{
            'distortion_scale': 0.2,
            'p': 0.3,
            'fill': [0],
        },
        'horizontal_flip': {
            'p': 0.5,
        },
        'vertical_flip':{
            'p': 0.5,
        },
}

transforms = create_transforms(transforms_dict)
transforms

mask = tv_tensors.Mask(labels[idx_in_batch])

to_transform = {'image':data[idx_in_batch], 'mask':mask}

out = transforms(to_transform)
img_tr, mask_tr = out['image'].detach().cpu(), out['mask'].detach().cpu()

img = data[idx_in_batch].detach().cpu()
mask = labels[idx_in_batch].detach().cpu()

fig, axs = plt.subplots(2,2)
axs[0,0].imshow(img[idx_in_batch])
axs[0,1].imshow(mask)
axs[1,0].imshow(img_tr[idx_in_batch])
axs[1,1].imshow(mask_tr)
plt.show()

In [None]:
import matplotlib.pyplot as plt
%matplotlib inline

model = smp.Unet(encoder_name='resnet34')
#input_weight = model.encoder._conv_stem.weight.detach()#.numpy()
input_weight = model.encoder.conv1.weight.detach()#.numpy()
interpolated_weight = F.interpolate(input_weight, size=(11, 11), mode='bicubic', antialias=True, align_corners=False)

print(input_weight.shape)
print(interpolated_weight.shape)

filter_index = 9
conv_filter = input_weight[filter_index]
interpolated_conv_filter = interpolated_weight[filter_index]
fig1, axs1 = plt.subplots(1,3)
for idx, img in enumerate(conv_filter):
    axs1[idx].imshow(img)

fig2, axs2 = plt.subplots(1,3)
for idx, img in enumerate(interpolated_conv_filter):
    axs2[idx].imshow(img)


filter_index += 1
conv_filter = input_weight[filter_index]
interpolated_conv_filter = interpolated_weight[filter_index]
fig3, axs3 = plt.subplots(1,3)
for idx, img in enumerate(conv_filter):
    axs3[idx].imshow(img)

fig4, axs4 = plt.subplots(1,3)
for idx, img in enumerate(interpolated_conv_filter):
    axs4[idx].imshow(img)


In [123]:
indices = [1, 2, 3, 4, 5, 6, 7, 8, 11, 12, 'ndvi', 'ndbi', 'ndwi', 'ndre']
basic_indices = [1, 2, 3, 7]
rest_indices = set(indices) - set(basic_indices)
rest_indices = list(rest_indices)

rest_indices = [x for x in rest_indices if isinstance(x, int)] + [x for x in rest_indices if isinstance(x, str)]
#for combination in combinations()
for k in range(len(rest_indices)):
    k+=1
    for combination_of_indices in combinations(rest_indices, k):
        indices_to_test = basic_indices + list(combination_of_indices)
        config_dict['multispecter_bands_indices'] = indices_to_test

In [None]:
class PABBlock(nn.Module):
    def __init__(self, in_channels: int, pab_channels: int = 64):
        super().__init__()

        # Series of 1x1 conv to generate attention feature maps
        self.pab_channels = pab_channels
        self.in_channels = in_channels
        self.top_conv = nn.Conv2d(in_channels, pab_channels, kernel_size=1)
        self.center_conv = nn.Conv2d(in_channels, pab_channels, kernel_size=1)
        self.bottom_conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1)
        self.map_softmax = nn.Softmax(dim=1)
        self.out_conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        batch_size, _, height, width = x.shape

        x_top = self.top_conv(x)
        x_center = self.center_conv(x)
        x_bottom = self.bottom_conv(x)

        x_top = x_top.flatten(2)
        x_center = x_center.flatten(2).transpose(1, 2)
        x_bottom = x_bottom.flatten(2).transpose(1, 2)

        sp_map = torch.matmul(x_center, x_top)
        sp_map = self.map_softmax(sp_map.view(batch_size, -1))
        sp_map = sp_map.view(batch_size, height * width, height * width)

        sp_map = torch.matmul(sp_map, x_bottom)
        sp_map = sp_map.reshape(batch_size, self.in_channels, height, width)

        x = x + sp_map
        x = self.out_conv(x)
        return x
    

block = PABBlock(in_channels=64, pab_channels=64)


pab_channels = 256
in_channels = 128
top_conv = nn.Conv2d(in_channels, pab_channels, kernel_size=1)
center_conv = nn.Conv2d(in_channels, pab_channels, kernel_size=1)
bottom_conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1)
map_softmax = nn.Softmax(dim=1)
out_conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1)

x = torch.randn(1, in_channels, 12, 12)
batch_size, _, height, width = x.shape

x_top = top_conv(x)
x_center = center_conv(x)
x_bottom = bottom_conv(x)
print('After conv:')
print(f'x_top={x_top.shape}, x_center={x_center.shape}, x_bottom={x_bottom.shape},')

x_top = x_top.flatten(2)
x_center = x_center.flatten(2).transpose(1, 2)
x_bottom = x_bottom.flatten(2).transpose(1, 2)
print('After reshape and transpose:')
print(f'x_top_r={x_top.shape}, x_center_rt={x_center.shape}, x_bottom_rt={x_bottom.shape},')

sp_map = torch.matmul(x_center, x_top)
print(f'sp_map={sp_map.shape} (x_center_rt × x_top_r)')
sp_map = map_softmax(sp_map.view(batch_size, -1))
print(f'sp_map after softmax={sp_map.shape}')
sp_map = sp_map.view(batch_size, height * width, height * width)
print(f'sp_map after reshape={sp_map.shape}')
sp_map = torch.matmul(sp_map, x_bottom)
print(f'sp_map × x_bottom = {sp_map.shape}')
sp_map = sp_map.reshape(batch_size, in_channels, height, width)
print(f'sp_map after reshape = {sp_map.shape}')

[4, 5, 6, 8, 11, 12, 'ndre', 'ndvi', 'ndwi', 'ndbi']