In [45]:
import torch
from torch import nn
import skorch
import glob
import os
from tqdm import tqdm

import json

import torchvision
from torchvision.transforms import v2
from torchvision import tv_tensors
from torchvision import models

import lightning as L
from pytorch_lightning.callbacks import ModelCheckpoint
from lightning.pytorch.loggers import CSVLogger

from sklearn.model_selection import train_test_split
from sklearn import metrics

import numpy as np

import pandas as pd

In [30]:
# datasets
class SegmentationDataset(torch.utils.data.Dataset):
    def __init__(self, path_to_dataset_root, samples_df, transforms, device):
        '''
        path_to_dataset - путь до корневой папки с датасетом
        instance_names_list - список имен экземпляров БЕЗ РАСШИРЕНИЯ!
        transforms - аугментация изображений
        '''
        super().__init__()
        self.path_to_dataset_root = path_to_dataset_root
        self.samples_df = samples_df
        self.transforms = transforms
        self.device = device

    def __len__(self):
        return len(self.samples_df)

    def __getitem__(self, idx):
        sample = self.samples_df.iloc[idx]

        file_name = sample['file_name']

        path_to_image = os.path.join(self.path_to_dataset_root, 'images', f'{file_name}.npy')
        path_to_labels = os.path.join(self.path_to_dataset_root, 'labels', f'{file_name}.npy')

        image = torch.as_tensor(np.load(path_to_image))
        image = np.load(path_to_image)
        # метки читаем как одноканальное изображение
        label = torch.as_tensor(np.load(path_to_labels)).long()
        
        
        image = tv_tensors.Image(image, device=self.device)
        label = tv_tensors.Mask(label, device=self.device)

        transforms_dict = {'image':image, 'mask':label}
        transformed = self.transforms(transforms_dict)
        return transformed['image'], transformed['mask']#, image
    
class SegmentationWrapper(nn.Module):
    def __init__(self, model):
        super().__init__()
        self.model = model
    def forward(self, x):
        return self.model(x)['out']


In [39]:
y_true = ["cat", "cat", "cat", "bird", "bird"]

y_pred = ["cat", "cat", "cat", "bird", "cat"]

mcm = metrics.multilabel_confusion_matrix(y_true, y_pred,
                            labels=["ant", "bird", "cat"])
print(mcm)
y_true = ["ant"]
y_pred = ["ant"]
mcm += metrics.multilabel_confusion_matrix(y_true, y_pred,
                            labels=["ant", "bird", "cat"])
mcm

[[[5 0]
  [0 0]]

 [[3 0]
  [1 1]]

 [[1 1]
  [0 3]]]


array([[[5, 0],
        [0, 1]],

       [[4, 0],
        [1, 1]],

       [[2, 1],
        [0, 3]]], dtype=int64)

In [42]:
mcm.sum(axis=0)

array([[11,  1],
       [ 1,  5]], dtype=int64)

In [44]:
path_to_dataset_info_csv = r'i:\LANDCOVER_DATA\MULTISPECTRAL_SATELLITE_DATA\DATA_FOR_TRAINIG\data_info_table.csv'
path_to_dataset_root = r'I:\LANDCOVER_DATA\MULTISPECTRAL_SATELLITE_DATA\DATA_FOR_TRAINIG'
images_df = pd.read_csv(path_to_dataset_info_csv)
train_images_df, test_images_df = train_test_split(images_df, test_size=0.3, random_state=0)

class_num = images_df['class_num'].iloc[0]

with open('surface_classes.json') as fd:
    surface_classes_list = json.load(fd)

class_name2idx_dict = {n:i for i, n in enumerate(surface_classes_list)}

classes_pixels_distribution_df = images_df[surface_classes_list]
classes_pixels_num = classes_pixels_distribution_df.sum()
classes_weights = classes_pixels_num / classes_pixels_num.sum()


transforms = v2.Compose([v2.ToDtype(torch.float32, scale=True)])
channels_num = 13

device = torch.device('cuda:0')

model = models.segmentation.fcn_resnet50(weights=models.segmentation.FCN_ResNet50_Weights.DEFAULT)

# заменяем входной слой
conv1 = model.backbone.conv1
stride = conv1.stride
kernel_size = conv1.kernel_size
out_channels = conv1.out_channels
groups = conv1.groups
padding = conv1.padding
dilation = conv1.dilation
is_bias = model.backbone.conv1.bias is not None

weights = conv1.weight
new_weight = torch.cat([model.backbone.conv1.weight.mean(dim=1).unsqueeze(1)]*channels_num, dim=1)
new_conv1 = nn.Conv2d(
    in_channels=conv1.in_channels,
    out_channels=conv1.out_channels,
    kernel_size=conv1.kernel_size,
    stride=conv1.stride,
    padding=conv1.padding,
    dilation=conv1.dilation,
    groups=conv1.groups,
    bias=conv1.bias is not None
)
new_conv1.weight = nn.Parameter(new_weight)
if conv1.bias is not None:
    new_conv1.bias = model.backbone.conv1.bias
model.backbone.conv1 = new_conv1

# заменяем выходные слои
classifier_conv = model.classifier[-1]
new_classifier_conv = nn.Conv2d(
    in_channels=classifier_conv.in_channels,
    out_channels=class_num,
    kernel_size=classifier_conv.kernel_size,
    stride=classifier_conv.kernel_size,
    padding=classifier_conv.padding,
    dilation=classifier_conv.dilation,
    groups=classifier_conv.groups,
    bias=classifier_conv.bias is not None,
    )
model.classifier[-1] = new_classifier_conv
aux_classifier_conv = model.aux_classifier[-1]
new_aux_classifier_conv = nn.Conv2d(
    in_channels=aux_classifier_conv.in_channels,
    out_channels=class_num,
    kernel_size=aux_classifier_conv.kernel_size,
    stride=aux_classifier_conv.kernel_size,
    padding=aux_classifier_conv.padding,
    dilation=aux_classifier_conv.dilation,
    groups=aux_classifier_conv.groups,
    bias=aux_classifier_conv.bias is not None,
    )
model.aux_classifier[-1] = new_aux_classifier_conv

model = model.to(device)

criterion = nn.CrossEntropyLoss()

model = SegmentationWrapper(model)



train_dataset = SegmentationDataset(path_to_dataset_root=path_to_dataset_root, samples_df=train_images_df, transforms=transforms, device=device)
test_dataset = SegmentationDataset(path_to_dataset_root=path_to_dataset_root, samples_df=test_images_df, transforms=transforms, device=device)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=4, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=4)
for data, labels in tqdm(test_loader):
    pass
out = model(data)
criterion(out, labels)

100%|██████████| 135/135 [00:01<00:00, 105.11it/s]


tensor(2.4646, device='cuda:0', grad_fn=<NllLoss2DBackward0>)

In [36]:
class_name2idx_dict
{v:k for k, v in class_name2idx_dict.items()}

{0: 'buildings_territory',
 1: 'natural_ground',
 2: 'natural_grow',
 3: 'natural_wetland',
 4: 'natural_wood',
 5: 'quasi_natural_ground',
 6: 'quasi_natural_grow',
 7: 'quasi_natural_wetland',
 8: 'quasi_natural_wood',
 9: 'transport',
 10: 'water',
 11: 'UNLABELED'}

In [26]:
_, indices = out.max(dim=1)
indices.shape

torch.Size([4, 175, 175])

In [75]:
def decode_confusion_matrix_2x2(confusion_matrix):
    tp = confusion_matrix[1, 1]
    tn = confusion_matrix[0, 0]
    fp = confusion_matrix[0, 1]
    fn = confusion_matrix[1, 0]
    return tp, tn, fp, fn

def compute_accuracy_from_confusion(multiclass_confusion_matrix):
    confusion_sum = multiclass_confusion_matrix.sum(axis=0)
    tp, tn, fp, fn = decode_confusion_matrix_2x2(confusion_sum)
    accuracy = 0
    if tp+tn+fp+fn != 0:
        accuracy = (tp+tn)/(tp+tn+fp+fn)
    return accuracy

def compute_iou_from_confusion(multiclass_confusion_matrix, idx2class_name_dict=None):
    #print(f'conf_shape={multiclass_confusion_matrix.shape}')
    mean_iou = 0
    # {class_name: iou_val}
    iou_dict = {}
    actual_classes_num = 0
    for idx, class_confusion in enumerate(multiclass_confusion_matrix):
        #print(f'class_conf_shape={class_confusion.shape}')
        tp, tn, fp, fn = decode_confusion_matrix_2x2(class_confusion)
        if class_confusion.sum() != tn:
            actual_classes_num += 1
        class_iou = 0
        if tp+fp+fn != 0:
            class_iou = tp/(tp+fp+fn)
        mean_iou += class_iou
        class_name = f'iou_{idx}'
        if idx2class_name_dict is not None:
            class_name = f'iou_{idx2class_name_dict[idx]}'
        iou_dict[class_name] = class_iou
    iou_dict['iou_mean'] = mean_iou/actual_classes_num
    return iou_dict

def compute_pred_mask(pred):
    pred = pred.detach()
    _, pred_mask = pred.max(dim=1)
    return pred_mask.cpu().numpy()



class SegmentationModule(L.LightningModule):
    def __init__(self, model, criterion, name2class_idx_dict) -> None:
        super().__init__()
        self.model = model
        self.criterion = criterion
        self.name2class_idx_dict = name2class_idx_dict
        self.class_idx2name_dict = {v:k for k, v in name2class_idx_dict.items()}
        self.train_confusion_matrix = np.zeros((len(self.name2class_idx_dict), 2, 2), dtype=np.int64)
        self.val_confusion_matrix = np.zeros((len(self.name2class_idx_dict), 2, 2), dtype=np.int64)
        
    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters())
    def compute_pred_lbels(self, pred):
        return pred.max(dim=1)
    def training_step(self, batch, batch_idx):
        data, true_labels = batch
        pred = self.model(data)
        loss = self.criterion(pred, true_labels)
        
        pred_labels = compute_pred_mask(pred)
        true_labels = true_labels.detach().cpu().numpy()
        batch_confusion_matrix = metrics.multilabel_confusion_matrix(true_labels.reshape(-1), pred_labels.reshape(-1),labels=list(self.class_idx2name_dict.keys()))
        #print(f'train_batch_conf_type={batch_confusion_matrix.dtype}')
        #print(batch_confusion_matrix)
        self.train_confusion_matrix += batch_confusion_matrix.astype(np.int64)
        self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True)
        return loss
    
    def validation_step(self, batch, batch_idx):
        data, true_labels = batch
        pred = self.model(data)
        loss = self.criterion(pred, labels)
        pred_labels = compute_pred_mask(pred)
        true_labels = true_labels.detach().cpu().numpy()
        batch_confusion_matrix = metrics.multilabel_confusion_matrix(true_labels.reshape(-1), pred_labels.reshape(-1),labels=list(self.class_idx2name_dict.keys()))
        #print(f'pred_labels type={pred_labels.dtype}')
        #print(f'true_labels type={true_labels.dtype}')
        #print(f'val_batch_conf_type={batch_confusion_matrix.dtype}')
        #print(batch_confusion_matrix)
        self.val_confusion_matrix += batch_confusion_matrix.astype(np.int64)
        self.log('val_loss', loss, on_step=True, on_epoch=True, prog_bar=True)
        return loss
    
    def on_train_epoch_end(self) -> None:
        iou_dict = compute_iou_from_confusion(self.train_confusion_matrix,self.class_idx2name_dict)
        for name, value in iou_dict.items():
            name = f'train_{name}'
            self.log(name, value, on_step=False, on_epoch=True, prog_bar=True)
        self.train_confusion_matrix = np.zeros((len(self.name2class_idx_dict), 2, 2), dtype=np.int64)
        return super().on_train_epoch_end()
    
    def on_validation_epoch_end(self) -> None:
        iou_dict = compute_iou_from_confusion(self.val_confusion_matrix,self.class_idx2name_dict)
        for name, value in iou_dict.items():
            name = f'val_{name}'
            self.log(name, value, on_step=False, on_epoch=True, prog_bar=True)
        self.val_confusion_matrix = np.zeros((len(self.name2class_idx_dict), 2, 2), dtype=np.int64)
        return super().on_validation_epoch_end()



In [77]:
epoch_num = 3

segmentation_module = SegmentationModule(model, nn.CrossEntropyLoss(), class_name2idx_dict)

logger = CSVLogger(
    save_dir = "outputs",
    name="my_exp_name", 
    flush_logs_every_n_steps=1, 
    )

checkpoint_callback = ModelCheckpoint(
    mode="min",
    filename="MLP-{epoch:02d}",
    dirpath="outputs", 
    save_top_k=1, monitor="val_loss"
    
    )

trainer = L.Trainer(logger=logger,
        max_epochs=epoch_num, 
        #callbacks=[checkpoint_callback],
        accelerator = 'cuda'
        )

trainer.fit(segmentation_module , train_loader, test_loader)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name      | Type                | Params | Mode 
----------------------------------------------------------
0 | model     | SegmentationWrapper | 35.3 M | eval 
1 | criterion | CrossEntropyLoss    | 0      | train
----------------------------------------------------------
35.3 M    Trainable params
0         Non-trainable params
35.3 M    Total params
141.387   Total estimated model params size (MB)
1         Modules in train mode
163       Modules in eval mode


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

c:\Users\mokhail\miniconda3\envs\aggr_rec\lib\site-packages\lightning\pytorch\trainer\connectors\data_connector.py:424: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=7` in the `DataLoader` to improve performance.
c:\Users\mokhail\miniconda3\envs\aggr_rec\lib\site-packages\lightning\pytorch\trainer\connectors\data_connector.py:424: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=7` in the `DataLoader` to improve performance.


Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=3` reached.
