In [1]:
import torch
from torch import nn

import glob
import os
from tqdm import tqdm
from datetime import datetime
import json

import torchvision
from torchvision.transforms import v2
from torchvision import tv_tensors
from torchvision import models

import segmentation_models_pytorch as smp

import lightning as L
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch.loggers import CSVLogger

from sklearn.model_selection import train_test_split
from sklearn import metrics

import numpy as np

import pandas as pd

from itertools import combinations

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
models.segmentation.deeplabv3.DeepLabHead(in_channels=2048, num_classes=9)

In [None]:
path_to_surface_classes_json = os.path.join(path_to_dataset_root, 'surface_classes.json')


In [None]:
model = models.segmentation.deeplabv3_resnet50(weights=models.segmentation.DeepLabV3_ResNet50_Weights.DEFAULT)
print(model.classifier[-1].weight)
for layer in model.classifier.children():
    for m in layer.modules():
        m.reset_parameters()
print(model.classifier[-1].weight)

In [38]:
path_to_dataset_root = r'I:\LANDCOVER_DATA\MULTISPECTRAL_SATELLITE_DATA\DATA_FOR_TRAINIG'
path_to_dataset_info_csv = os.path.join(path_to_dataset_root, 'data_info_table.csv')

images_df = pd.read_csv(path_to_dataset_info_csv)
path_to_surface_classes_json = os.path.join(path_to_dataset_root, 'surface_classes.json')
with open(path_to_surface_classes_json) as fd:
    surface_classes_list = json.load(fd)
surface_classes_list

['UNLABELED',
 'buildings_territory',
 'natural_ground',
 'natural_grow',
 'natural_wetland',
 'natural_wood',
 'quasi_natural_grow',
 'transport',
 'water']

In [None]:
conv = nn.Conv2d(1, 3, 3)
dir(conv)

In [29]:
class FCNSegmentationWrapper(nn.Module):
    def __init__(self, model):
        super().__init__()
        self.model = model
    def forward(self, x):
        return self.model(x)['out']
    
class MultispectralNN(nn.Module):
    def __init__(self, main_model, preprocessing_block):
        super().__init__()
        self.preprocessing_block = preprocessing_block
        self.main_model = main_model

    def forward(self, x):
        x = self.preprocessing_block(x)
        return self.model(x)
    

in_channels = 13
cols = rows = 150
input_tensor = torch.randn(1, 13, cols, rows)
conv = nn.Conv2d(in_channels, 64, kernel_size=[1,1])
conv(input_tensor).shape

torch.Size([1, 64, 150, 150])

In [34]:
class MultispectralFuseOut(nn.Module):
    def __init__(self, main_model, multispectral_preprocessing_block, preprocessing_out_dim, class_num):
        super().__init__()
        self.multispectral_preprocessing_block = multispectral_preprocessing_block
        self.main_model = main_model
        self.multispectral_preout_block = nn.Sequential(
            nn.ReLU(),
            nn.Conv2d(in_channels=preprocessing_out_dim,out_channels=class_num, kernel_size=1),
            nn.BatchNorm2d(class_num),
            nn.ReLU()
        )
        self.fusion_block = nn.Sequential(
            #nn.Dropout2d(0.3),
            nn.ChannelShuffle(groups=2),
            nn.Conv2d(in_channels=class_num*2, out_channels=class_num, kernel_size=1)
        )

    def forward(self, x):
        multispectral_preprocessed_out = self.multispectral_preprocessing_block(x)
        multispectral_out = self.multispectral_preout_block(multispectral_preprocessed_out)
        print(multispectral_preprocessed_out.shape)
        print(multispectral_out.shape)
        main_out = self.main_model(multispectral_preprocessed_out)
        concat_out = torch.cat([multispectral_out, main_out], dim=1)

        return self.fusion_block(concat_out)
    
preprocess1_layer = nn.Sequential(
    nn.Conv2d(in_channels=13, out_channels=13, kernel_size=1),
    nn.BatchNorm2d(13)
)
model = models.segmentation.fcn_resnet50()
conv1 = model.backbone.conv1

weights = conv1.weight
new_weight = torch.cat([weights.mean(dim=1).unsqueeze(1)]*13, dim=1)
new_conv1 = nn.Conv2d(
    in_channels=13,
    out_channels=conv1.out_channels,
    kernel_size=conv1.kernel_size,
    stride=conv1.stride,
    padding=conv1.padding,
    dilation=conv1.dilation,
    groups=conv1.groups,
    bias=conv1.bias is not None
)
new_conv1.weight = nn.Parameter(new_weight)
if conv1.bias is not None:
    new_conv1.bias = model.backbone.conv1.bias
model.backbone.conv1 = new_conv1
model.classifier = models.segmentation.fcn.FCNHead(in_channels=2048, channels=9)
model = FCNSegmentationWrapper(model)

model = MultispectralFuseOut(model, preprocess1_layer, 13, 9)
ret = model(torch.randn(1, 13, 150, 150))
ret.shape

torch.Size([1, 13, 150, 150])
torch.Size([1, 9, 150, 150])


torch.Size([1, 9, 150, 150])

In [28]:
t1 = torch.ones((1, 3, 2, 2))
t2 = torch.arange(0, 12,dtype=torch.float32).view(1, 3, 2, 2)
t = torch.cat([t1, t2], dim=1)
conv = nn.Conv2d(in_channels=6, out_channels=3, kernel_size=1, groups=6//2, bias=False)
torch.nn.init.ones_(conv.weight)
shuffle = nn.ChannelShuffle(groups=2)
h = shuffle(t)

conv(h), t2

(tensor([[[[ 1.,  2.],
           [ 3.,  4.]],
 
          [[ 5.,  6.],
           [ 7.,  8.]],
 
          [[ 9., 10.],
           [11., 12.]]]], grad_fn=<ConvolutionBackward0>),
 tensor([[[[ 0.,  1.],
           [ 2.,  3.]],
 
          [[ 4.,  5.],
           [ 6.,  7.]],
 
          [[ 8.,  9.],
           [10., 11.]]]]))

tensor([[[[ 0.,  1.],
          [ 2.,  3.]],

         [[ 4.,  5.],
          [ 6.,  7.]],

         [[ 8.,  9.],
          [10., 11.]]]])

# Создание парных мультиспектральных индексов

In [27]:
class MakeChannelsCombinations(nn.Module):
    def __init__(self, combinations_list):
        super().__init__()
        self.combinations_list = combinations_list
    def forward(self, x):
        return x[:,self.combinations_list]

class SpectralDiffIndexModule(nn.Module):
    def __init__(self, channel_indices, channels_in_index, out_channels):
        super().__init__()
        self.channel_indices = channel_indices
        combinations_list = list(combinations(channel_indices, channels_in_index))
        
        self.combinations_list = np.array(combinations_list).reshape(-1).tolist()
        in_channels = len(self.combinations_list)
        self.make_channels_combinations = MakeChannelsCombinations(self.combinations_list)
        self.numerator = nn.Conv2d(in_channels=in_channels, out_channels=in_channels//channels_in_index, kernel_size=1, groups=in_channels//channels_in_index, bias=False)
        self.denominator = nn.Conv2d(in_channels=in_channels, out_channels=in_channels//channels_in_index, kernel_size=1, groups=in_channels//channels_in_index, bias=False)
        self.out_block = nn.Sequential(
            nn.Conv2d(in_channels=in_channels//channels_in_index, out_channels=out_channels, kernel_size=1),
            nn.BatchNorm2d(out_channels),
            #nn.ReLU()
        )

    def forward(self, x):
        channels_combinations = x[:,self.combinations_list]
        numerator_results = self.numerator(channels_combinations)
        denominator_results = self.denominator(channels_combinations)
        indices = numerator_results / (denominator_results+1e-7)
        output = self.out_block(indices)
        return output

channel_indices = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]
channel_indices
combinations_list = list(combinations(channel_indices, 2))
data = torch.randn(1, len(channel_indices), 150, 150)

combinations_list = np.array(combinations_list).reshape(-1).tolist()
in_channels = len(combinations_list)
channels_combinations = data[:,combinations_list]
'''
channels_combinations = []
for combination in combinations_list:
    channels_combination = data[:,combination]
    channels_combinations.append()

channels_combinations = torch.cat(channels_combinations, dim=1)
channels_combinations.shape
'''
channels_combinations.shape

numerator = nn.Conv2d(in_channels=in_channels, out_channels=in_channels//2, kernel_size=1, groups=in_channels//2, bias=False)
denominator = nn.Conv2d(in_channels=in_channels, out_channels=in_channels//2, kernel_size=1, groups=in_channels//2, bias=False)
#out_conv = nn.Conv2d(in_channels//2, )
numerator_results = numerator(channels_combinations)
denominator_results = denominator(channels_combinations)
res = numerator_results / (denominator_results+1e-7)
res.shape

preprocess = SpectralDiffIndexModule(channel_indices=channel_indices, channels_in_index=2, out_channels=8)
res = preprocess(data)
res.shape

torch.Size([1, 8, 150, 150])

In [21]:
numerator.weight.shape

torch.Size([78, 2, 1, 1])