In [35]:
import torch
from torch import nn

import glob
import os
from tqdm import tqdm
from datetime import datetime
import json

from argparse import ArgumentParser

from itertools import combinations

import torchvision
from torchvision.transforms import v2
from torchvision import tv_tensors
from torchvision import models

import segmentation_models_pytorch as smp

import lightning as L
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch.loggers import CSVLogger

from torchmetrics import classification
from torchmetrics import segmentation

from sklearn.model_selection import train_test_split
from sklearn import metrics

import numpy as np

import pandas as pd

In [18]:
confusion = classification.ConfusionMatrix(task='multiclass', num_classes=5)
target = torch.tensor([2, 1, 0, 0])
preds = torch.tensor([2, 1, 0, 1])
confusion.update(preds, target)

target = torch.tensor([3, 4, 2, 1])
preds = torch.tensor([4, 3, 0, 1])
confusion.update(preds, target)

confusion.compute()


tensor([[1, 1, 0, 0, 0],
        [0, 2, 0, 0, 0],
        [1, 0, 1, 0, 0],
        [0, 0, 0, 0, 1],
        [0, 0, 0, 1, 0]])

In [16]:

torchmetrics.segmentation

AttributeError: module 'torchmetrics' has no attribute 'segmentation'

In [31]:
iou = segmentation.DiceScore(average='none', num_classes=5, input_format='index')

target = torch.tensor([[2, 1, 0, 1]])
preds = torch.tensor([[2, 1, 0, 1]])
iou.update(preds, target)

target = torch.tensor([[3, 4, 2, 1]])
preds = torch.tensor([[4, 3, 0, 1]])
#iou.update(preds, target)
iou.compute()

tensor([1., 1., 1., 0., 0.])

In [32]:
target.reshape(-1)

tensor([3, 4, 2, 1])

In [None]:
metrics_dict = {
    'train': {
        'iou': classification.JaccardIndex(task='multiclass', average='none', num_classes=5),
        'precision': classification.Precision(task='multiclass', average='none', num_classes=5),
        'recall': classification.Precision(task='multiclass', average='none', num_classes=5),
    },
    'test': {
        'iou': classification.JaccardIndex(task='multiclass', average='none', num_classes=5),
        'precision': classification.Precision(task='multiclass', average='none', num_classes=5),
        'recall': classification.Precision(task='multiclass', average='none', num_classes=5),
    }
}


# Работа с данными

In [None]:
# datasets
class SegmentationDataset(torch.utils.data.Dataset):
    def __init__(self, path_to_dataset_root:str, samples_df:pd.DataFrame, channel_indices:list, transforms:v2._transform.Transform, device:torch.device):
        '''
        In:
            path_to_dataset_root - путь до корневой папки с датасетом
            samples_df - pandas.DataFrame с информацией о файлах
            channel_indices - список с номерами каналов мультиспектрального изображения
            transforms - аугментация изображений
        '''
        super().__init__()
        self.path_to_dataset_root = path_to_dataset_root
        self.samples_df = samples_df
        self.channel_indices = channel_indices
        self.transforms = transforms
        self.device = device

    def __len__(self):
        return len(self.samples_df)

    def __getitem__(self, idx):
        sample = self.samples_df.iloc[idx]

        file_name = sample['file_name']

        path_to_image = os.path.join(self.path_to_dataset_root, 'images', f'{file_name}.npy')
        path_to_labels = os.path.join(self.path_to_dataset_root, 'labels', f'{file_name}.npy')

        image = torch.as_tensor(np.load(path_to_image), dtype=torch.int16)[self.channel_indices]
        #image = np.load(path_to_image)
        # метки читаем как одноканальное изображение
        label = np.load(path_to_labels)
        label = np.where(label >= 0, label, 0)
        #label = torch.as_tensor(np.load(path_to_labels), dtype=torch.uint8).long()
        label = torch.as_tensor(label, dtype=torch.uint8).long()
        
        
        image = tv_tensors.Image(image, device=self.device)
        label = tv_tensors.Mask(label, device=self.device)

        transforms_dict = {'image':image, 'mask':label}
        transformed = self.transforms(transforms_dict)
        return transformed['image'], transformed['mask']#, image

# Описание нейронных сетей

In [33]:
def compute_pred_mask(pred):
    '''
    Определение маски классов на основе сгенерированной softmax маски
    '''
    #pred = pred.detach()
    _, pred_mask = pred.max(dim=1)
    return pred_mask#.cpu().numpy()

class SegmentationModule(L.LightningModule):
    def __init__(self, model:nn.Module, criterion:nn.Module, optimizer_cfg:dict, metrics_dict:dict, name2class_idx_dict:dict) -> None:
        '''
        Модуль Lightning для обучения сегментационной сети
        In:
            model - нейронная сеть
            criterion - функция потерь
            
            name2class_idx_dict - словарь с отображением {class_name(str): class_idx(int)}
        '''
        super().__init__()
        self.model = model
        self.criterion = criterion
        self.optimizer_cfg = optimizer_cfg
        self.metrics_dict = metrics_dict
        
        self.name2class_idx_dict = name2class_idx_dict
        # словарь, выполняющий обратное отображение class_idx в class_name
        self.class_idx2name_dict = {v:k for k, v in name2class_idx_dict.items()}
        
        
    def configure_optimizers(self):
        optimizer = self.optimizer_cfg['optmizer'](self.parameters, **self.optimizer_cfg['optimizer_params'])
        ret_dict = {'optimizer': optimizer}
        if self.optimizer_cfg['scheduler'] is not None:
            scheduler = self.optimizer_cfg['scheduler'](optimizer, **self.optimizer_cfg['scheduler_params'])
            ret_dict['lr_scheduler'] = scheduler
        
        return ret_dict

    
    def compute_metrics(self, pred_labels, true_labels, mode):
        metrics_names_list = self.metrics_dict[mode].keys()
        for metric_name in metrics_names_list:
            if 'dice' in metric_name.lower():
                self.metrics_dict[mode][metric_name].update(pred_labels, true_labels)
            else:
                self.metrics_dict[mode][metric_name].update(pred_labels.reshape(-1), true_labels.reshape(-1))
        
    
    def training_step(self, batch, batch_idx):
        data, true_labels = batch
        pred = self.model(data)
        loss = self.criterion(pred, true_labels)
        # вычисление сгенерированной маски
        pred_labels = compute_pred_mask(pred)
        #true_labels = true_labels.detach().cpu().numpy()
        
        self.compute_metrics(pred_labels=pred_labels, true_labels=true_labels, mode='train')

        # т.к. мы вычисляем общую ошибку на всей эпохе, то записываем в лог только значение функции потерь
        self.log('train_loss', loss, on_step=False, on_epoch=True, prog_bar=True)
        return loss
    
    def validation_step(self, batch, batch_idx):
        data, true_labels = batch
        pred = self.model(data)
        loss = self.criterion(pred, true_labels)
        pred_labels = compute_pred_mask(pred)
        self.compute_metrics(pred_labels=pred_labels, true_labels=true_labels, mode='val')
        self.log('val_loss', loss, on_step=False, on_epoch=True, prog_bar=True)
        return loss
    
    def log_metrics(self, mode):
        for metric_name, metric in self.metrics_dict[mode]:
            metric_val = metric.compute()
            if 'confusion' not in metric_name.lower():
                for i, value in enumerate(metric_val):
                    class_name = self.class_idx2name_dict[i]
                    disp_name = f'{mode}_{metric_name}_{class_name}'
                    self.log(disp_name, value, on_step=False, on_epoch=True, prog_bar=True)
                disp_name = f'{mode}_{metric_name}_mean'
                self.log(disp_name, value.mean(), on_step=False, on_epoch=True, prog_bar=True)
            else:
                disp_name = f'{mode}_{metric_name}'
                self.log(disp_name, value, on_step=False, on_epoch=True, prog_bar=True)
            self.metrics_dict[mode][metric_name].reset()


    def on_train_epoch_end(self):
        '''
        Декодирование результатов тренировочной эпохи и запись их в лог
        '''
        self.log_metrics(mode='train')

        
    def on_validation_epoch_end(self):
        '''
        Декодирование результатов тестовой эпохи и запись их в лог
        (работает точно также, как и )
        '''
        self.log_metrics(mode='val')

# Фабрики для создания моделей по конфигурациям

In [None]:
segmentation_nns_factory_dict = {
    'unet': smp.Unet
}
config_dict = {
    'segmentation_nn': {
        'nn_architecture': 'unet',
        'params': {
            'encoder_name': "efficientnet-b0",
            'encoder_depth': 5,
            'encoder_weights': "imagenet",
            'decoder_use_norm': "batchnorm",
            'decoder_channels': (256, 128, 64, 32, 16),
            'decoder_attention_type': None,
            'decoder_interpolation': "nearest",
            'in_channels': 3,
            'classes': 11,
            'activation': None,
            'aux_params': None,
        },
        'input_layer_params': {
            'layer_path': 'encoder._conv_stem',
            'stride': (1, 1),
            'padding': (1, 1),
            'layers_num': 1,
        }
    },
    'multispecter_bands_indices': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12],
    'input_image_size': 112,
}

model_name = config_dict['segmentation_nn']['nn_architecture']
model = segmentation_nns_factory_dict[model_name](**config_dict['segmentation_nn']['params'])
multispecter_bands_indices = config_dict['multispecter_bands_indices']
in_channels = len(multispecter_bands_indices)
# замена входного слоя, если кол-во каналов изображения не равно трем
if in_channels != 3:
    input_conv = model.get_submodule(
        config_dict['segmentation_nn']['input_layer_params']['layer_path']
        )
    new_input_conv = nn.Conv2d(
        in_channels=in_channels,
        out_channels=input_conv.out_channels,
        kernel_size=input_conv.kernel_size,
        #stride=conv1.stride,
        stride=config_dict['segmentation_nn']['input_layer_params']['stride'],
        #padding=conv1.padding,
        padding=config_dict['segmentation_nn']['input_layer_params']['padding'],
        dilation=input_conv.dilation,
        groups=input_conv.groups,
        bias=input_conv.bias is not None
    )
    if config_dict['segmentation_nn']['params']['encoder_weights'] is not None:
        new_weight = torch.cat([input_conv.weight.mean(dim=1).unsqueeze(1)]*in_channels, dim=1)
        input_conv.weight = nn.Parameter(new_weight)
        if input_conv.bias is not None:
            new_input_conv.bias = input_conv.bias
    # перезаписываем входной слой
    model.set_submodule(
        config_dict['segmentation_nn']['input_layer_params']['layer_path'],
        new_input_conv
        )

input_image_size = config_dict['input_image_size']
train_transforms = v2.Compose(
    [v2.Resize((input_image_size,input_image_size), antialias=True),v2.ToDtype(torch.float32, scale=True)])
test_transforms = v2.Compose(
    [v2.Resize((input_image_size,input_image_size), antialias=True),v2.ToDtype(torch.float32, scale=True)])
    


model

Unet(
  (encoder): EfficientNetEncoder(
    (_conv_stem): Conv2d(13, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (_bn0): BatchNorm2d(32, eps=0.001, momentum=0.010000000000000009, affine=True, track_running_stats=True)
    (_blocks): ModuleList(
      (0): MBConvBlock(
        (_expand_conv): Identity()
        (_bn0): Identity()
        (_depthwise_conv): Conv2dStaticSamePadding(
          32, 32, kernel_size=(3, 3), stride=[1, 1], groups=32, bias=False
          (static_padding): ZeroPad2d((1, 1, 1, 1))
        )
        (_bn1): BatchNorm2d(32, eps=0.001, momentum=0.010000000000000009, affine=True, track_running_stats=True)
        (_se_reduce): Conv2dStaticSamePadding(
          32, 8, kernel_size=(1, 1), stride=(1, 1)
          (static_padding): Identity()
        )
        (_se_expand): Conv2dStaticSamePadding(
          8, 32, kernel_size=(1, 1), stride=(1, 1)
          (static_padding): Identity()
        )
        (_project_conv): Conv2dStaticSamePaddi

In [None]:
model.

TypeError: 'Unet' object is not subscriptable

In [47]:
model.get_submodule('encoder._conv_stem')
model.set_submodule('encoder._conv_stem', nn.Conv2d(1, 3, 3))
model

Unet(
  (encoder): EfficientNetEncoder(
    (_conv_stem): Conv2d(1, 3, kernel_size=(3, 3), stride=(1, 1))
    (_bn0): BatchNorm2d(32, eps=0.001, momentum=0.010000000000000009, affine=True, track_running_stats=True)
    (_blocks): ModuleList(
      (0): MBConvBlock(
        (_expand_conv): Identity()
        (_bn0): Identity()
        (_depthwise_conv): Conv2dStaticSamePadding(
          32, 32, kernel_size=(3, 3), stride=[1, 1], groups=32, bias=False
          (static_padding): ZeroPad2d((1, 1, 1, 1))
        )
        (_bn1): BatchNorm2d(32, eps=0.001, momentum=0.010000000000000009, affine=True, track_running_stats=True)
        (_se_reduce): Conv2dStaticSamePadding(
          32, 8, kernel_size=(1, 1), stride=(1, 1)
          (static_padding): Identity()
        )
        (_se_expand): Conv2dStaticSamePadding(
          8, 32, kernel_size=(1, 1), stride=(1, 1)
          (static_padding): Identity()
        )
        (_project_conv): Conv2dStaticSamePadding(
          32, 16, kernel_s

In [None]:
for m in model.encoder.get_submodule():
    print(m)

('', EfficientNetEncoder(
  (_conv_stem): Conv2dStaticSamePadding(
    3, 32, kernel_size=(3, 3), stride=(2, 2), bias=False
    (static_padding): ZeroPad2d((0, 1, 0, 1))
  )
  (_bn0): BatchNorm2d(32, eps=0.001, momentum=0.010000000000000009, affine=True, track_running_stats=True)
  (_blocks): ModuleList(
    (0): MBConvBlock(
      (_expand_conv): Identity()
      (_bn0): Identity()
      (_depthwise_conv): Conv2dStaticSamePadding(
        32, 32, kernel_size=(3, 3), stride=[1, 1], groups=32, bias=False
        (static_padding): ZeroPad2d((1, 1, 1, 1))
      )
      (_bn1): BatchNorm2d(32, eps=0.001, momentum=0.010000000000000009, affine=True, track_running_stats=True)
      (_se_reduce): Conv2dStaticSamePadding(
        32, 8, kernel_size=(1, 1), stride=(1, 1)
        (static_padding): Identity()
      )
      (_se_expand): Conv2dStaticSamePadding(
        8, 32, kernel_size=(1, 1), stride=(1, 1)
        (static_padding): Identity()
      )
      (_project_conv): Conv2dStaticSamePaddi

In [None]:
model.encoder.set_submodule

In [None]:
def prepare_smp_model(
        config_dict,
        nn_in_cannels_num,
        class_num,
        channel_indices_list,
        is_multitask,
        preprocess_params,
        fuze_params):
    '''
    Функция генерирует конфигурацию сегментационной нейронной сети из библиотеки segmentation_models_pytorch
    In:
        model_dict - словарь с параметрами модели. Структура словаря {'model_name': строковое_имя_модели, 'creation_function': функция_создания_модели (класс модели), 'weights': веса модели}
        image_channels:int - количество каналов изображения
        nn_in_cannels_num:int - количество входных каналов нейронной сети
        class_num:int - количество классов
        channel_indices_list:list - список индексов каналов обрабатываемого мультиспектрального изображения
        is_multitask:bool - флаг, обозначающий мнгогозадачное обучение
        preprocess_params:dict - параметры предварительной обработки каналов мультиспектра. Словарь структуры {'type':'1L'}, на месте '1L' мб
            'no' - отсутствие доп. блока предобработки,
            1L - однослойная модель предобработки
            2L - двухслойная модель предобработки
            SpInd - SpatialIndex??? (надо подробнее рассмотреть это потом)
        fuze_params:dict - параметры связывания выхода нейронной сети и блока преобработки. Словарь структуры {'type':'no'}, на месте 'no' мб
            shuffle - модуль channel_shuffle (см. https://arxiv.org/abs/1707.01083)
            concat - конкатенация
            add - сложение
    Out:
        dict с моделью следующей структуры {'name': model_name, 'model': model, 'train_transforms':train_transforms, 'test_transforms':test_transforms}
    '''
    model_creation_unction = model_dict['creation_function']
    encoder_name = model_dict['encoder_name']
    model_name = model_dict['model_name']
    #model = model_creation_unction(encoder_name=encoder_name, encoder_weights='imagenet', encoder_depth=5, decoder_channels=(256, 128, 64, 64, 64), in_channels=nn_in_cannels_num, classes=class_num)
    model = model_creation_unction(encoder_name=encoder_name, encoder_weights=None, in_channels=nn_in_cannels_num, classes=class_num)
    #model = model_creation_unction(encoder_name=encoder_name, in_channels=nn_in_cannels_num, classes=class_num, img_size=112)

    #print(model)
    
    #model = smp.Unet(encoder_name='resnet50', classes=class_num)
    #print(model.encoder)
    #conv1 = model.encoder.model.conv1
    #conv1 = model.encoder.model.conv_stem
    conv1 = model.encoder._conv_stem
    #conv1 = model.encoder.model.patch_embed.proj
    
    '''
    weights = conv1.weight
    new_weight = torch.cat([weights.mean(dim=1).unsqueeze(1)]*nn_in_cannels_num, dim=1)
    new_conv1 = nn.Conv2d(
        in_channels=nn_in_cannels_num,
        out_channels=conv1.out_channels,
        kernel_size=conv1.kernel_size,
        #stride=conv1.stride,
        stride=(1,1),
        padding=conv1.padding,
        dilation=conv1.dilation,
        groups=conv1.groups,
        bias=conv1.bias is not None
    )
    new_conv1.weight = nn.Parameter(new_weight)
    if conv1.bias is not None:
        #new_conv1.bias = model.encoder.conv1.bias
        #new_conv1.bias = model.encoder.model.patch_embed.proj.bias
        new_conv1.bias = model.encoder.model.conv_stem.bias
    # заменяем веса, если количнество входных каналов не равно трем 
    if nn_in_cannels_num != 3:
        #model.encoder.conv1 = new_conv1
        #model.encoder.model.conv_stem = new_conv1
        model.encoder._conv_stem = new_conv1
        #model.encoder.model.patch_embed.proj = new_conv1
    '''
    # делаем два входных слоя для более глубокой обработки низкоуровневых призников
    conv1 = model.encoder._conv_stem
    weights = conv1.weight
    new_weight1 = torch.cat([weights.mean(dim=1).unsqueeze(1)]*nn_in_cannels_num, dim=1)
    new_weight2 = torch.cat([weights.mean(dim=1).unsqueeze(1)]*conv1.out_channels, dim=1)
    new_conv1 = nn.Conv2d(
        in_channels=nn_in_cannels_num,
        out_channels=conv1.out_channels,
        kernel_size=conv1.kernel_size,
        #stride=conv1.stride,
        stride=(1,1),
        #padding=conv1.padding,
        padding=(1, 1),
        dilation=conv1.dilation,
        groups=conv1.groups,
        bias=conv1.bias is not None
    )

    new_conv2 = nn.Conv2d(
        in_channels=conv1.out_channels,
        out_channels=conv1.out_channels,
        kernel_size=conv1.kernel_size,
        #stride=conv1.stride,
        stride=(1,1),
        padding=(1,1),
        dilation=conv1.dilation,
        groups=conv1.groups,
        bias=conv1.bias is not None
    )
        
    new_conv1.weight = nn.Parameter(new_weight1)
    new_conv2.weight = nn.Parameter(new_weight2)
    if conv1.bias is not None:
        #new_conv1.bias = model.encoder.conv1.bias
        #new_conv1.bias = model.encoder.model.patch_embed.proj.bias
        new_conv1.bias = model.encoder.model.conv_stem.bias
        new_conv2.bias = model.encoder.model.conv_stem.bias

    new_conv = nn.Sequential(
        new_conv1,
        nn.BatchNorm2d(conv1.out_channels),
        nn.SiLU(),
        new_conv2,
    )
    # заменяем веса, если количнество входных каналов не равно трем 
    if nn_in_cannels_num != 3:
        #model.encoder.conv1 = new_conv1
        #model.encoder.model.conv_stem = new_conv1
        model.encoder._conv_stem = new_conv
        #model.encoder.model.patch_embed.proj = new_conv1

    if is_multitask:
        model = MultitaskModel(model, nn_output_size=512, appl_class_num=2, surf_class_num=class_num)

    if preprocess_params['type'] == 'no':
        preprocess_layer = nn.Identity()
    elif preprocess_params['type'] == '1L':
        preprocess_layer = nn.Sequential(
            nn.Conv2d(in_channels=image_channels, out_channels=nn_in_cannels_num, kernel_size=1),
            nn.BatchNorm2d(nn_in_cannels_num)
        )
    elif preprocess_params['type'] == '2L':
        preprocess_layer = nn.Sequential(
            nn.Conv2d(in_channels=image_channels, out_channels=32, kernel_size=1),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.Conv2d(in_channels=32, out_channels=nn_in_cannels_num, kernel_size=1),
            nn.BatchNorm2d(nn_in_cannels_num),
            nn.ReLU())
    elif preprocess_params['type'] == 'SpInd':
        preprocess_layer = SpectralDiffIndexModule(channel_indices_list=channel_indices_list, channels_in_index=2, out_channels=8)
    
    if fuze_params['type'] == 'no':
        model = MultispectralNN(model, preprocess_layer)
    else:
        model = MultispectralFuseOut(model, preprocess_layer, preprocessing_out_dim=nn_in_cannels_num, fusion_type=fuze_params['type'], class_num=class_num)
    
    #train_transforms = v2.Compose([v2.Resize((160,160), antialias=True),v2.ToDtype(torch.float32, scale=True)])
    #test_transforms = v2.Compose([v2.Resize((160,160), antialias=True),v2.ToDtype(torch.float32, scale=True)])
    train_transforms = v2.Compose([v2.Resize((112,112), antialias=True),v2.ToDtype(torch.float32, scale=True)])
    test_transforms = v2.Compose([v2.Resize((112,112), antialias=True),v2.ToDtype(torch.float32, scale=True)])
    
    multitask_str = '_MT' if is_multitask else ''
    
    model_name = f'{model_name}pr'
    #model_name = f'{model_name}'
    if preprocess_params['type'] != 'no':
        model_name += f'-P{preprocess_params["type"]}'
    if fuze_params['type'] != 'no':
        model_name += f'-Fuz{fuze_params["type"].capitalize()}'
    #model_name = f'{model_name}pr-P2L-FuzOutAdd({nn_in_cannels_num})' + multitask_str
    return {'name': model_name, 'model': model, 'train_transforms':train_transforms, 'test_transforms':test_transforms}
    