## PANSFORMERS

# Importing Libraries

In [None]:
!pip install rasterio
!pip install sewar
import tifffile 
from tqdm import tqdm
import numpy as np
import pandas as pd
import pickle as pkl
import cv2
import math
import matplotlib.pyplot as plt
import datetime
import os
import gc
import torch
import torch.nn as nn
import torch.nn.functional as F
import copy
from torch.nn.functional import relu, leaky_relu, sigmoid
from torch import Tensor
from torch.nn import Dropout, BatchNorm1d, Conv2d, ConvTranspose2d, MultiheadAttention, Softmax, Softmax2d, Container, Module, ModuleList
from typing import Optional, Any
import torch.cuda
from torch import Tensor
from torch.utils.data import Dataset, DataLoader
from torch.nn import Dropout, Linear, LayerNorm, LogSoftmax
from torch.nn.functional import softmax , log_softmax
import rasterio
import glob
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
from __future__ import absolute_import, division, print_function
import numpy as np
import torch.cuda
from sewar.full_ref import mse, rmse, psnr, uqi, ssim, scc, sam

In [None]:
if torch.cuda.is_available(): # Setting up GPU Interface 
    device = torch.device("cuda")
    print("GPU")
else:
    device = torch.device("gpu")

# Dataset Preprocessing (Array Creation)

In [None]:
def training_image_creation(img_ms, img_pan, n_factor):
    """ 
    This function generates the blurred version of the original input multispectral image, and concatenate it with the 
    downsampled panchromatic so as to create the training sample used for Pansharpening Convolutional Neural Network (PCNN) 
    model training. 
    
    Inputs:
    - img_ms: Numpy array of the original multispectral image which is to be used for PCNN model training
    - img_pan: Numpy array of the original panchromatic image which is to be used for PCNN model training
    - n_factor: The ratio of pixel resolution of multispectral image to that of the panchromatic image
    
    Outputs:
    - training_sample_array: Numpy array of concatenated blurred multispectral image and downsampled panchromatic image to be 
                             used for PCNN model training
    
    """
    
    blurred_img_ms = np.zeros((img_ms.shape))
    
    for i in range(img_ms.shape[2]):
        blurred_img_ms[:, :, i] = cv2.GaussianBlur(img_ms[:, :, i], (3, 3), 0)
    
    blurred_img_ms_small = cv2.resize(blurred_img_ms, (int(img_ms.shape[1] / n_factor), int(img_ms.shape[0] / n_factor)), 
                                        interpolation = cv2.INTER_AREA)
    blurred_img_ms_sam = cv2.resize(blurred_img_ms_small, (img_ms.shape[1], img_ms.shape[0]), interpolation = cv2.INTER_CUBIC)
        
    downsampled_img_pan = cv2.resize(img_pan, (img_ms.shape[1], img_ms.shape[0]), 
                                        interpolation = cv2.INTER_AREA)[:, :, np.newaxis]
        
    training_sample_array = np.concatenate((blurred_img_ms_sam, downsampled_img_pan), axis = 2)
        
    return training_sample_array



def image_clip_to_segment(image_ms_array, train_image_array, image_height_size, image_width_size, percentage_overlap, 
                          buffer):
    """ 
    This function is used to cut up original input images of any size into segments of a fixed size, with empty clipped areas 
    padded with zeros to ensure that segments are of equal fixed sizes and contain valid data values. The function then 
    returns a 4 - dimensional array containing the entire original input multispectral image and its corresponding 
    training image in the form of fixed size segments as training data inputs for the PCNN model.
    
    Inputs:
    - image_ms_array: Numpy array of original input multispectral image to be used for PCNN model training
    - train_image_array: Numpy array of training sample images to be used for PCNN model training
    - image_height_size: Height of image to be fed into the PCNN model for training
    - image_width_size: Width of image to be fed into the PCNN model for training
    - percentage_overlap: Percentage of overlap between image patches extracted by sliding window to be used for model 
                          training
    - buffer: Percentage allowance for image patch to be populated by reflected values for positions with no valid data values
    
    Output:
    - train_segment_array: 4 - Dimensional numpy array of training sample images to serve as training data for PCNN model
    - image_ms_segment_array: 4 - Dimensional numpy array of original input multispectral image to serve as target data for 
                           training PCNN model
    
    """
    
    y_size = ((image_ms_array.shape[0] // image_height_size) + 1) * image_height_size
    y_pad = int(y_size - image_ms_array.shape[0])
    x_size = ((image_ms_array.shape[1] // image_width_size) + 1) * image_width_size
    x_pad = int(x_size - image_ms_array.shape[1])
    
    img_complete = np.pad(image_ms_array, ((0, y_pad), (0, x_pad), (0, 0)), mode = 'symmetric').astype(image_ms_array.dtype)
    train_complete = np.pad(train_image_array, ((0, y_pad), (0, x_pad), (0, 0)), 
                            mode = 'symmetric').astype(train_image_array.dtype)
        
    img_list = []
    train_list = []
    
    for i in range(0, int(img_complete.shape[0] - (2 - buffer) * image_height_size), 
                   int((1 - percentage_overlap) * image_height_size)):
        for j in range(0, int(img_complete.shape[1] - (2 - buffer) * image_width_size), 
                       int((1 - percentage_overlap) * image_width_size)):
            img_original = img_complete[i : i + image_height_size, j : j + image_width_size, 0 : image_ms_array.shape[2]]
            img_list.append(img_original)
            train_original = train_complete[i : i + image_height_size, j : j + image_width_size, :]
            train_list.append(train_original)
    
    image_segment_array = np.zeros((len(img_list), image_height_size, image_width_size, image_ms_array.shape[2]))
    train_segment_array = np.zeros((len(train_list), image_height_size, image_width_size, train_image_array.shape[2]))
    
    for index in range(len(img_list)):
        image_segment_array[index] = img_list[index]
        train_segment_array[index] = train_list[index]
        
    return train_segment_array, image_segment_array



def training_data_generation(DATA_DIR, img_height_size, img_width_size, perc, buff, img_num):
    """ 
    This function is used to read in files from a folder which contains the images which are to be used for training the 
    PCNN model, then returns 2 numpy arrays containing the training and target data for all the images in the folder so that
    they can be used for PCNN model training.
    
    Inputs:
    - DATA_DIR: File path of the folder containing the images to be used as training data for PCNN model.
    - img_height_size: Height of image segment to be used for PCNN model training
    - img_width_size: Width of image segment to be used for PCNN model training
    - perc: Percentage of overlap between image patches extracted by sliding window to be used for model training
    - buff: Percentage allowance for image patch to be populated by reflected values for positions with no valid data values
    
    Outputs:
    - train_full_array: 4 - Dimensional numpy array of concatenated multispectral and downsampled panchromatic images to serve as 
                            training data for PCNN model
    - img_full_array: 4 - Dimensional numpy array of original input multispectral image to serve as target data for training PCNN model
    
    """
    
    if perc < 0 or perc > 1:
        raise ValueError('Please input a number between 0 and 1 (inclusive) for perc.')
        
    if buff < 0 or buff > 1:
        raise ValueError('Please input a number between 0 and 1 (inclusive) for buff.')

    img_MS_files = glob.glob(DATA_DIR + 'MS/MS_' + str(img_num) +'.tif')
    img_PAN_files = glob.glob(DATA_DIR + 'PAN/PAN_' + str(img_num) + '.tif')
    
    img_array_list = []
    train_array_list = []
    
    for file in range(len(img_MS_files)):
        
        with rasterio.open(img_MS_files[file]) as f:
            metadata = f.profile
            ms_img = np.transpose(f.read(tuple(np.arange(metadata['count']) + 1)), [1, 2, 0])
        with rasterio.open(img_PAN_files[file]) as g:
            metadata_pan = g.profile
            pan_img = g.read(1)
            
        ms_to_pan_ratio = metadata['transform'][0] / metadata_pan['transform'][0]
            
        train_img = training_image_creation(ms_img, pan_img, n_factor = ms_to_pan_ratio)
    
        train_array, img_array = image_clip_to_segment(ms_img, train_img, img_height_size, img_width_size, 
                                                       percentage_overlap = perc, buffer = buff)

        img_array_list.append(img_array)
        train_array_list.append(train_array)
        del train_img
        del train_array
        del img_array

    img_full_array = np.concatenate(img_array_list, axis = 0)
    train_full_array = np.concatenate(train_array_list, axis = 0)
    
    del img_MS_files, img_PAN_files
    gc.collect()
    
    return train_full_array, img_full_array

# Dataloader

Dataloader reads the image array generated previously to create batches of the image tiled. 

A common dataloader is written for loading training, validation and testing arrays. 

We have used memaps to improve the training speed. 

Memaps has to be created from the numpy array to use the dataloader

## IKONOS

In [None]:
class IKONOS(Dataset):
  def __init__(self, x_path, y_path, train=0, blur=0):
    if train==0:
        num_examples = 29854
    elif train==1:
        num_examples = 7434
    elif train==2:
        num_examples = 4066
    self.x = np.memmap(x_path, dtype='float32', mode='r', shape=(num_examples, 5, 64, 64))
    self.y = np.memmap(y_path, dtype='float32', mode='r', shape=(num_examples, 4, 64, 64))
    
  def __len__(self):
    return (self.x.shape[0])

  def __getitem__(self, idx):
    input = self.x[idx]
    gt = self.y[idx]

    return (gt, input)

In [None]:
ikonostrain = IKONOS(x_path = train_ip_path, y_path = train_gt_path, train=0, blur=0)
ikonosval = IKONOS(x_path = val_ip_path, y_path = val_gt_path, train=1)
ikonostest = IKONOS(x_path = test_ip_path, y_path = test_gt_path, train=2)

In [None]:
train_batch_size = 32  #main control of batch size is here
val_batch_size = 32
test_batch_size = 32

In [None]:
ikonostrainparams = {"batch_size":train_batch_size, 
          "shuffle":True, 
          "num_workers":0}

ikonosvalparams = {"batch_size":val_batch_size, 
          "shuffle":True, 
          "num_workers":0}
          
ikonostestparams = {"batch_size":test_batch_size, 
          "shuffle":True, 
          "num_workers":0}

In [None]:
IkonosTrainDataloader = DataLoader(ikonostrain, **ikonostrainparams)
IkonosValDataloader = DataLoader(ikonosval, **ikonosvalparams)
IkonosTestDataloader = DataLoader(ikonostest, **ikonostestparams)

## LANDSAT-8

In [None]:
class Landsat(Dataset):
  def __init__(self, x_path, y_path, train=0):
    if train==0:
        num_examples = 54268
    elif train==1:
        num_examples = 17955
    elif train==2:
        num_examples = 17955
    self.x = np.memmap(x_path, dtype='float32', mode='r', shape=(num_examples, 5, 64, 64))
    self.y = np.memmap(y_path, dtype='float32', mode='r', shape=(num_examples, 4, 64, 64))
    
  def __len__(self):
    return (self.x.shape[0])

  def __getitem__(self, idx):
    input = self.x[idx]
    gt = self.y[idx]

    return (gt, input)

In [None]:
landsattrain = Landsat(x_path = train_ip_path, y_path = train_gt_path, train=0)
landsatval = Landsat(x_path = val_ip_path, y_path = val_gt_path, train=1)
landsattest = Landsat(x_path = test_ip_path, y_path = test_gt_path, train=2)

In [None]:
train_batch_size = 32  #main control of batch size is here
val_batch_size = 64
test_batch_size = 8

In [None]:
landsattrainparams = {"batch_size":train_batch_size, 
          "shuffle":True, 
          "num_workers":0}

landsatvalparams = {"batch_size":val_batch_size, 
           "shuffle":True, 
           "num_workers":0}
          
landsattestparams = {"batch_size":test_batch_size, 
          "shuffle":True, 
          "num_workers":0}

In [None]:
LandsatTrainDataloader = DataLoader(landsattrain, **landsattrainparams)
LandsatValDataloader = DataLoader(landsatval, **landsatvalparams)
LandsatTestDataloader = DataLoader(landsattest, **landsattestparams)

# Model

## Convolution PCNN

In [None]:
class ConvModelSeparate(nn.Module):
    def __init__(self, final=False):

        super(ConvModelSeparate, self).__init__()
        #inputshape should be ideally (batchsize,5,256,256)
        self.convlayer1 = nn.Conv2d(5, 64, (9,9), padding='same')
        self.convlayer2 = nn.Conv2d(64, 32, (5,5), padding='same')
        self.relu = nn.ReLU() #(input = 5x64x64, output = 16x128x128)
            #self.batchnorm1 = nn.BatchNorm2d(5)

        if final==True:
            self.convlayer3 = nn.Conv2d(in_channels = 32, out_channels = 4, kernel_size = (5,5), padding='same') #(input = 32x64x64, output = 64x32x32) flatten = 16384
            # self.batchnorm3 = nn.BatchNorm2d(5)

        else:
            self.convlayer3 = nn.Conv2d(32, 5, (5,5), padding='same')
            #self.maxpool1 = nn.MaxPool2d(2, stride=None, padding = 0) #(input = 32x64x64, output = 64x32x32) flatten = 16384
            #self.batchnorm1 = nn.BatchNorm2d(4)

        
        # self.convlayer4 = nn.Conv2d(in_channels = 64, out_channels = 100, kernel_size = 2, stride = 2) #(input = 64x32x32, output = 100x16x16)
        # self.batchnorm4 = nn.BatchNorm2d(100)

    def forward(self, input_image, final = False):
        output = self.convlayer1(input_image)
        output = self.relu(output)
        output = self.convlayer2(output)
        output = self.relu(output)
        output = self.convlayer3(output)
        if final==False:
            output = self.relu(output)
        #output = F.relu(self.batchnorm3(output))
        #output = torch.flatten(output, start_dim = 2)
        # print(output.shape)
        # output = F.relu(self.batchnorm4(output))
        # output = torch.reshape(output, (-1, 5, 5120))
        #outputshape = (batchsize,5,8192) 
        
        return output

## Attention Model

In [None]:
# Self Attention Module using Convolution Layers
class ConvAttentionModule(nn.Module):
    def __init__(self, in_channels):
        super(ConvAttentionModule, self).__init__()
        self.in_channels = in_channels

        # self.norm = Normalize(in_channels)
        self.q = torch.nn.Conv2d(in_channels,in_channels,kernel_size=1,stride=1,padding='same')
        self.k = torch.nn.Conv2d(in_channels,in_channels,kernel_size=1,stride=1,padding='same')
        self.v = torch.nn.Conv2d(in_channels,in_channels,kernel_size=1,stride=1,padding='same')
   

    def forward(self, x): 

        b,c,h,w = x.shape
        q = self.q(x).reshape(b,c,h*w)
        k = self.k(x).reshape(b,c,h*w)
        v = self.v(x).reshape(b,c,h*w)

        # compute attention
        
        q = q.permute(0,2,1)   # b,hw,c # b,c,hw
        w_ = torch.bmm(q,k)     # b,hw,hw    w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
        w_ = w_ * (int(c)**(-0.5))
        w_ = torch.nn.functional.softmax(w_, dim=2)

        # attend to values
        w_ = w_.permute(0,2,1)   # b,hw,hw (first hw of k, second of q)
        h_ = torch.bmm(v,w_)     # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
        h_ = h_.reshape(b,c,h,w)

        # h_ = self.proj_out(h_)

        return h_

In [None]:
# Only one of the classes given below must be run for a model 

class ConvMultiPatchAttention(nn.Module): #Multipatch attention class
    def __init__(self, inchannels):
        super(ConvMultiPatchAttention, self).__init__()
        # self.simulatt = SimultaneousAttention(inchannels = inchannels)
        self.attentionmodules = nn.ModuleList()
        self.num_patches = 16
        for _ in range(self.num_patches):
            self.attentionmodules.append(ConvAttentionModule(in_channels=inchannels))
    


    def forward(self,image):
        # a = torch.FloatTensor(image)
        batch_size = image.shape[0]
        patches = image.unfold(2, 16, 16).unfold(3,16,16)
        unfold_shape = patches.size()
        # print(unfold_shape)
        patches = patches.contiguous().view(batch_size,-1,16,16,16)
        patches = patches.permute(2,0,1,3,4)

        attention_patches = torch.zeros(patches.size()).to(device)
        for i in range(self.num_patches):
            attention_patches[i] = self.attentionmodules[i](patches[i]) #patch wise attention function

        patches = attention_patches.permute(2,0,1,3,4)
        output_h = unfold_shape[2]*unfold_shape[4]
        output_w = unfold_shape[3]*unfold_shape[5]

        patches = patches.contiguous().view(batch_size, -1, 4, 4, 16, 16)
        patches = patches.permute(0,1,2,4,3,5).contiguous()
        orig_image = patches.view(batch_size, patches.shape[1], output_h, output_w)

        return orig_image  #batchsize, 5, 64, 64

 class GlobalAttention(nn.Module): # Global Attention Class
    def __init__(self, inchannels):
         super(GlobalAttention, self).__init__()
         self.globalatt = ConvAttentionModule(in_channels=inchannels)
    
     def forward(self, image):
         output = self.global2,0,1att(image)

         return output


 class CombinedAttentionLayer(nn.Module): # Global + Multipatch Attention Class
     def __init__(self):
         super(CombinedAttentionLayer, self).__init__()
         self.mpa = ConvMultiPatchAttention()
         self.ga = GlobalAttention()
         self.proj_out = torch.nn.Conv2d(in_channels,
                                         in_channels,
                                         kernel_size=1,
                                         stride=1,
                                         padding='same')

    def forward(self, image)2,0,1:
         mpa_output = self.mpa(image)
         ga_output = self.ga(image)
         combined_output = mpa_output + ga_output
         final_prof = self.proj_out(combined_output)
        
         return final_proj 

## Pansformers

In [None]:
# Number of Layers of the models can be adjusted here
# Corresponding Attention Layers must be used while running the model (ie; the current configuration is for MultiPatch Attention)

class Pansformers(nn.Module):
    def __init__(self, channels = 5):
        super(Pansformers, self).__init__()
        self.convlayer1 = ConvModelSeparate()
        # self.convlayer2 = ConvModelSeparate()
        # self.globalatt = ConvAttentionModule(in_channels = channels).to(device)
        self.multipatchattention = ConvMultiPatchAttention(inchannels = channels).to(device)
        # self.multipatchattention2 = ConvMultiPatchAttention(inchannels = channels).to(device)
        self.projection1 = nn.Conv2d(channels, 20, kernel_size=1,stride=1,padding='same')
        self.projection2 = nn.Conv2d(20, channels, kernel_size=1, stride=1, padding='same')
        # self.convlayer3 = ConvModelSeparate()
        self.finalconv = ConvModelSeparate(final=True)
    
    def forward(self, src):
        # src input shape (batch_size, 5, 256, 256)
        output1 = self.convlayer1(src)
        # output2 = self.convlayer2(output1)
    

        # ga_output = self.globalatt(output1)
        mpa_output = self.multipatchattention(output1)
        

        # att_out_1 = att_proj_1 + output1

        # mpa_output2 = self.multipatchattention2(att_out_1)

        # combined_out = mpa_output + ga_output

        att_proj_1 = self.projection1(mpa_output)
        
        att_proj_2 = self.projection2(att_proj_1)

        output_att = att_proj_2 + output1  #skip connection


        # output3 = self.convlayer3(output_att)
        finaloutput = self.finalconv(output_att)

        return finaloutput 

In [None]:
# Loss function and Optimizer

model = Pansformers(channels = 5).to(device)
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001,  betas=(0.9, 0.999), eps=1e-08, weight_decay=0.0001, amsgrad=False)

## Training

In [None]:
# Training Loop

L2 = []
num_epochs = 100
for epoch in tqdm(range(61,num_epochs+1)):
    model.train()
    cumulative_loss = 0.0    
    btch = 0
    for i, generator_values in enumerate(LandsatTrainDataloader):  #model input (32,5,64,64)
        groundtruth = generator_values[0].float().to(device)
        inputs = generator_values[1].float().to(device)
        model_output = model(inputs)
        main_loss = criterion(model_output, groundtruth)
        main_loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        cumulative_loss += main_loss.item()*inputs.size(0)
        btch += 1

    epoch_loss = cumulative_loss/float(54268.0)
    L2.append(epoch_loss)
    print('End of Epoch Training Loss: ' , epoch_loss)
    print('-----------------------------------------------------')
    print('-----------------------------------------------------')

    with open('/content/drive/MyDrive/LANDSAT8/Outputs/Global_WithProjection_Blur3.txt', 'at') as file :       
        now = datetime.datetime.now()
        current_time = now.strftime("%H:%M:%S")
        file.write("Epoch: {}, Batch_Loss: {}, Loss: {}, Time: {}\n".format(epoch, main_loss.item(), cumulative_loss, current_time))     

    # Validation Set

    if epoch%5==0:
        model.eval()
        with torch.no_grad():
            val_loss = 0.0
            for i, generator_values in enumerate(LandsatValDataloader):  #model input (32,5,64,64)
                val_gt = generator_values[0].float().to(device)
                val_ip = generator_values[1].float().to(device)
                predictions=model(val_ip)
                loss = criterion(predictions, val_gt)
                val_loss += loss.item()*val_ip.shape[0]
            val_loss_final = val_loss/float(17955)
            print("Validation loss: ", val_loss_final)
            print("---------------------------------------------------")
    model.train()

    del groundtruth
    del inputs
    del model_output

   
    if epoch%10==0:
        torch.save({
          'epoch': epoch,
          'model_state_dict': model.state_dict(),
          'optimizer_state_dict': optimizer.state_dict(),
          'train_loss': epoch_loss,
          'val_loss': val_loss
          }, '/content/drive/MyDrive/LANDSAT8/Models/Global_WithProjection_Blur3.pt') 

#Metrics

In [None]:
# Function to determine metrics of the predicted images 

def FullRef_Metrics(predicted, groundtruth):
    if groundtruth.shape == predicted.shape:
        total = predicted.shape[0] 
    else:
        print("Error: Array Shape Mismatch")
    names = ['MSE', 'RMSE', 'PSNR', 'UQI', 'SCC', 'SAM', 'SSIM']
    metrics = [0] * len(names)
    results = []
    t = 0
    for i in tqdm(range(0,total)):
        predictedimage = predicted[i]
        gtimage = groundtruth[i]
        s = sam(gtimage, predictedimage)
        if math.isnan(s) == True:
            continue
        else:
            metrics[0] += mse(gtimage, predictedimage)
            metrics[1] += rmse(gtimage, predictedimage)
            metrics[2] += psnr(gtimage, predictedimage, MAX = 2047) # The value of MAX changes according to the bit level of the image
            metrics[3] += uqi(gtimage, predictedimage, ws = 2)
            metrics[4] += scc(gtimage, predictedimage)
            metrics[5] += s
            ssm, cs = ssim(gtimage, predictedimage, ws = 2, MAX = 2047)
            metrics[6] += ssm
            t += 1
    for i in metrics:
      results.append(i/t)
    mets = {}
    for i in range(0, len(names)):
        mets[names[i]] = results[i]
    l = list(mets.items())
    dt = pd.DataFrame(l,columns = ['Metrics','Values'])
    return dt