In [1]:
import numpy as np
import pandas as pd
import cv2

import pytorch_lightning as pl
import torch
import torchmetrics
from torch import nn, optim
import albumentations as A
from albumentations.pytorch import ToTensorV2
from torch.utils.data import Dataset, DataLoader

In [2]:
class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride, norm=False, padding='valid'):
        super().__init__()
        
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding)
        self.BrchNorm = nn.BatchNorm2d(out_channels)
        self.active = nn.LeakyReLU(inplace=True)
        self.norm = norm

    def forward(self,X):
        X = self.conv(X)
        if self.norm is True:
            X = self.BrchNorm(X)
        
        return self.active(X)

In [3]:
class deConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride, padding=0, norm=False):
        super().__init__()
        
        self.deconv = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding)
        self.BrchNorm = nn.BatchNorm2d(out_channels)
        self.active = nn.ReLU(inplace=True)
        self.norm = norm
        
    def forward(self,X):
        X = self.deconv(X)
        if self.norm is True:
            X = self.BrchNorm(X)
        
        return self.active(X)

In [4]:
class EncoderBlock(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=4, stride=2, norm=False, padding='valid', use_dropout=False):
        super().__init__()
        
        self.Conv = ConvBlock(in_channels, out_channels, kernel_size, stride, norm=norm, padding=padding)
        self.dropout = nn.Dropout(0.5)
        self.use_dropout = use_dropout
        
    def forward(self,X):
        X = self.Conv(X)
        if self.use_dropout is True:
            X = self.dropout(X)
        return X

In [5]:
class DecoderBlock(nn.Module):
    def __init__(self, in_channels, out_channels,kernel_size=4, stride=2, padding=33, norm=True, use_dropout=True):
        super().__init__()
        
        self.deconv = deConvBlock(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding, norm=norm)
        self.dropout = nn.Dropout(0.5)
        self.use_dropout = use_dropout
        
    def forward(self, X, skip):
        if skip is not None:
            X = torch.cat([X, skip], axis=1)
        X = self.deconv(X)
        if self.use_dropout is True:
            X = self.dropout(X)
        return X

In [6]:
class SyntheticS2Image(nn.Module):
    def __init__(self):
        super().__init__()
        
        
        """ Encoder """
        
        # Sentinel-2_Before
        self.s2b_conv1  = EncoderBlock(4, 64,padding=33) # Layer_1: 64*64*64
        self.s2b_conv2  = EncoderBlock(64, 128,padding=33) # Layer_2: 128*64*64
        self.s2b_conv3  = EncoderBlock(128, 256, padding=33) # Layer_3: 256*64*64
        self.s2b_conv4  = EncoderBlock(256, 512, padding=33) # Layer_4: 512*64*64
        
        self.s2b_conv5  = EncoderBlock(512, 512, padding=33, use_dropout=True) # Layer_4: 512*64*64
        self.s2b_conv6  = EncoderBlock(512, 512, padding=33, use_dropout=True) # Layer_4: 512*64*64
        
        # Sentinel-2_After
        self.s2a_conv1  = EncoderBlock(4, 64,padding=33) # Layer_1: 64*64*64
        self.s2a_conv2  = EncoderBlock(64, 128,padding=33) # Layer_2: 128*64*64
        self.s2a_conv3  = EncoderBlock(128, 256, padding=33) # Layer_3: 256*64*64
        self.s2a_conv4  = EncoderBlock(256, 512, padding=33) # Layer_4: 512*64*64
        
        self.s2a_conv5  = EncoderBlock(512, 512, padding=33, use_dropout=True) # Layer_4: 512*64*64
        self.s2a_conv6  = EncoderBlock(512, 512, padding=33, use_dropout=True) # Layer_4: 512*64*64
        
        # Sentinel-1
        self.s1_conv1  = EncoderBlock(2, 64,padding=33) # Layer_1: 64*64*64
        self.s1_conv2  = EncoderBlock(64, 128,padding=33) # Layer_2: 128*64*64
        self.s1_conv3  = EncoderBlock(128, 256, padding=33) # Layer_3: 256*64*64
        self.s1_conv4  = EncoderBlock(256, 512, padding=33) # Layer_4: 512*64*64
        
        self.s1_conv5  = EncoderBlock(512, 512, padding=33, use_dropout=True) # Layer_4: 512*64*64
        self.s1_conv6  = EncoderBlock(512, 512, padding=33, use_dropout=True) # Layer_4: 512*64*64

        
        
        """ Decoder """
        self.u6 = DecoderBlock(1536, 512,padding=33)
        self.u7 = DecoderBlock(2048, 512,padding=33)
        self.u8 = DecoderBlock(2048, 256,padding=33)
        self.u9 = DecoderBlock(1024, 128,padding=33)
        self.u10 = DecoderBlock(512, 64,padding=33)
        self.u11 = DecoderBlock(256, 4,padding=33,use_dropout=False)

        
    def forward(self,S2b, S2a, S1):
        S2b1 = self.s2b_conv1(S2b)
        S2b2 = self.s2b_conv2(S2b1)
        S2b3 = self.s2b_conv3(S2b2)
        S2b4 = self.s2b_conv4(S2b3)
        S2b5 = self.s2b_conv5(S2b4)
        S2b6 = self.s2b_conv6(S2b5)
        
        S2a1 = self.s2a_conv1(S2a)
        S2a2 = self.s2a_conv2(S2a1)
        S2a3 = self.s2a_conv3(S2a2)
        S2a4 = self.s2a_conv4(S2a3)
        S2a5 = self.s2a_conv5(S2a4)
        S2a6 = self.s2a_conv6(S2a5)
        
        S11 = self.s1_conv1(S1)
        S12 = self.s1_conv2(S11)
        S13 = self.s1_conv3(S12)
        S14 = self.s1_conv4(S13)
        S15 = self.s1_conv5(S14)
        S16 = self.s1_conv6(S15)

        concat6 = torch.concat([S2b6,S2a6,S16], axis=1)
        U6 = self.u6(concat6,None)
        
        concat5 = torch.concat([S2b5,S2a5,S15], axis=1)
        U7 = self.u7(U6,concat5)
        
        concat4 = torch.concat([S2b4,S2a4,S14], axis=1)
        U8 = self.u8(U7,concat4)
        
        concat3 = torch.concat([S2b3,S2a3,S13], axis=1)
        U9 = self.u9(U8,concat3)
        
        concat2 = torch.concat([S2b2,S2a2,S12], axis=1)
        U10 = self.u10(U9,concat2)
        
        concat1 = torch.concat([S2b1,S2a1,S11], axis=1)
        U11 = self.u11(U10,concat1)
        
        return U11


In [7]:
s2b = torch.rand(1,4,64,64)
s2a = torch.rand(1,4,64,64)
s1 = torch.rand(1,2,64,64)

In [8]:
m = SyntheticS2Image()

In [9]:
m(s2b,s2a,s1).shape

torch.Size([1, 4, 64, 64])