# Compare models

In [None]:
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import datasets, models, transforms
from torch.utils.data import DataLoader, Dataset
import numpy as np
import pandas as pd
import os
import matplotlib.pyplot as plt
import matplotlib.colors as colors
import torch
import glob
from sklearn.model_selection import train_test_split
from PIL import Image
import seaborn as sns
import math
from skimage.metrics import structural_similarity as ssim
import random
from scipy import stats

## Set up data 

In [None]:
# set paths
main_dir = '/home/jovyan/InSAR_denoising_CNN'
train_signal_dir = f'{main_dir}/train_subsets_v2/veloc/'
train_noise_dir = f'{main_dir}/train_subsets_v2/int/'
train_dem_dir = f'{main_dir}/train_subsets_v2/dem/'
train_era5_dir = f'{main_dir}/train_subsets_v2/era5/'

val_signal_dir = f'{main_dir}/val_subsets_v2/veloc/'
val_noise_dir = f'{main_dir}/val_subsets_v2/int/'
val_dem_dir = f'{main_dir}/val_subsets_v2/dem/'
val_era5_dir = f'{main_dir}/val_subsets_v2/era5/'

# list files
train_signal_fns = os.listdir(train_signal_dir)
train_noise_fns = os.listdir(train_noise_dir)
train_dem_fns = os.listdir(train_dem_dir)
train_era5_fns = os.listdir(train_era5_dir)

val_signal_fns = os.listdir(val_signal_dir)
val_noise_fns = os.listdir(val_noise_dir)
val_dem_fns = os.listdir(val_dem_dir)
val_era5_fns = os.listdir(val_era5_dir)

# exclude non tif files, e.g. metadata
def list_tifs(my_fns):
    my_list = []
    for i in my_fns:
        if i[-4:] == '.tif':
            my_list.append(i)
    return my_list

train_signal_list = list_tifs(train_signal_fns)
train_noise_list = list_tifs(train_noise_fns)
train_dem_list = list_tifs(train_dem_fns)
train_era5_list = list_tifs(train_era5_fns)

val_signal_list = list_tifs(val_signal_fns)
val_noise_list = list_tifs(val_noise_fns)
val_dem_list = list_tifs(val_dem_fns)
val_era5_list = list_tifs(val_era5_fns)

# create training list of only scenes shared in all necessary dirs
train_list = []
for fn in train_signal_list:
    if fn in train_noise_list and fn in train_dem_list and fn in train_era5_list:
        train_list.append(fn)
        
val_list = []
for fn in val_signal_list:
    if fn in val_noise_list and fn in val_dem_list and fn in val_era5_list:
        val_list.append(fn)

In [None]:
# define transforms
my_transforms = transforms.Compose([
    transforms.ToTensor() #because label is also an image that needs to match, can't do any flipping
])

In [None]:
# define dataset 
class dataset(torch.utils.data.Dataset):
    def __init__(self, file_list, signal_dir, noise_dir, dem_dir, era5_dir, transform=None, 
                 norm=True, center=True, invert=False, blurnoise=False):
        self.file_list = file_list
        self.transform = transform
        self.signal_dir = signal_dir
        self.noise_dir = noise_dir
        self.dem_dir = dem_dir
        self.era5_dir = era5_dir
        self.norm = norm
        self.center = center
        self.invert = invert
        self.blurnoise = blurnoise
        
    #dataset length
    def __len__(self):
        self.filelength = len(self.file_list)
        return self.filelength
    
    #load images
    def __getitem__(self,idx):
        signal_path = self.signal_dir+self.file_list[idx]
        noise_path = self.noise_dir+self.file_list[idx]
        dem_path = self.dem_dir+self.file_list[idx]
        era5_path = self.era5_dir+self.file_list[idx]
        
        signal = self.transform(Image.open(signal_path))
        noise = self.transform(Image.open(noise_path))
        dem = self.transform(Image.open(dem_path))
        era5 = self.transform(Image.open(era5_path))
        
        # Generate era5 noise estimates
        era5 = era5*(0.05546576/12.5663706) # convert phase to displacement
        era5 = (noise-(era5*-1)) 
        
        # Blur noise
        if self.blurnoise == True: # blur noise to mitigate noise from non atmospheric sources
            gblur = transforms.GaussianBlur(kernel_size=(7, 7), sigma=5)
            noise = gblur(noise)
        
        # Generate scaled training images
        scalar = np.round(np.random.lognormal(0.05, 1.), 3)
        signal = signal*scalar
        train = noise+signal
        
        # Set era reference point
        ref_index = signal.abs().argmin().item() # location of lowest signal in velocity map
        corr_diff = (train.flatten()[ref_index] - era5.flatten()[ref_index]).item()
        era5 = era5+corr_diff 
        
        # correct train
        era5_corr = train-era5 #produce era5 corrected train image
        
        # correct hp
        hp_filter = transforms.GaussianBlur(kernel_size=(25, 25), sigma=3)
        train_filtered = hp_filter(train)
        hp_corr = train - train_filtered
        
        # normalization between -1 and 1 as in Zhao et al. https://doi.org/10.1016/j.isprsjprs.2021.08.009
        if self.norm == True:
            if train.min() < signal.min():
                norm_min = train.min()
            else:
                norm_min = signal.min()
                
            if train.max() > signal.max():
                norm_max = train.max()
            else:
                norm_max = signal.max()
            
            signal = 2*(((signal-norm_min)/(norm_max-norm_min)))-1
            noise  = 2*(((noise-(noise.min()))/(noise.max()-(noise.min()))))-1
            dem = 2*(((dem-dem.min())/(dem.max()-dem.min())))-1
            train = 2*(((train-norm_min)/(norm_max-norm_min)))-1
            era5 = 2*(((era5-era5.min())/(era5.max()-era5.min())))-1
            era5_corr = 2*(((era5_corr-norm_min)/(norm_max-norm_min)))-1
            hp_corr = 2*(((hp_corr-norm_min)/(norm_max-norm_min)))-1
        
        if self.center == True: # center target images on 0 
            center_median = signal.median()
            train = train-center_median
            signal = signal-center_median
            era5_corr = era5_corr-center_median
            hp_corr = hp_corr-center_median
         
        # invert images to remove bias towards negative signal
        if self.invert == True:
            if random.random() < 0.5:
                train = train*-1
                signal = signal*-1
                noise = noise*-1
                era5 = era5*-1
                era5_corr = era5_corr*-1
                hp_corr = hp_corr*-1
        
        return train, signal, noise, era5, dem, era5_corr, hp_corr

In [None]:
# create dataloaders
val_data = dataset(val_list, signal_dir, noise_dir, dem_dir, era5_dir, transform=my_transforms, 
                        invert=False, blurnoise=True)
train_data = dataset(train_list, signal_dir, noise_dir, dem_dir, era5_dir, transform=my_transforms, 
                        invert=False, blurnoise=True)

val_loader = torch.utils.data.DataLoader(dataset = val_data, batch_size=1, shuffle=False)
train_loader = torch.utils.data.DataLoader(dataset = train_data, batch_size=1, shuffle=False)

## Define networks

In [None]:
class DnCNN_noise(nn.Module):
    """
    Neural network model for InSAR denoising adapted from Rouet-Leduc et al., 2021
    """
    def __init__(self):
        super().__init__()
        kernel_size=3
        padding=1
        features=64
        channels=2
        
        # Convolution 1
        self.cnn1 = nn.Conv2d(in_channels=channels, out_channels=features, kernel_size=kernel_size, padding=padding)
        
        # Convolution 2
        self.cnn2 = nn.Conv2d(in_channels=features, out_channels=features, kernel_size=kernel_size, padding=padding)
        
        # Convolution 3
        self.cnn3 = nn.Conv2d(in_channels=features, out_channels=features, kernel_size=kernel_size, padding=padding)
        
        # Convolution 4
        self.cnn4 = nn.Conv2d(in_channels=features, out_channels=features, kernel_size=kernel_size, padding=padding)
        
        # Convolution 5
        self.cnn5 = nn.Conv2d(in_channels=features, out_channels=features, kernel_size=kernel_size, padding=5, dilation=5)
        
         # Convolution 6
        self.cnn6 = nn.Conv2d(in_channels=features, out_channels=features, kernel_size=kernel_size, padding=5, dilation=5)
        
        # Convolution 7
        self.cnn7 = nn.Conv2d(in_channels=features, out_channels=features, kernel_size=kernel_size, padding=padding)
        
        # Convolution 8
        self.cnn8 = nn.Conv2d(in_channels=features, out_channels=1, kernel_size=kernel_size, padding=padding)

    def forward(self, x, dem):                
        # Set 1
        out = F.elu(self.cnn1(torch.cat((x, dem), dim=1)), inplace=True) 
        
        # Set 2
        out = F.elu(self.cnn2(out), inplace=True) 
        
        # Set 3
        out = F.elu(self.cnn3(out), inplace=True)  
        
        # Set 4
        out = F.elu(self.cnn4(out), inplace=True) 
        
        # Set 5
        out = F.elu(self.cnn5(out), inplace=True)
        
        # Set 6
        out = F.elu(self.cnn6(out), inplace=True)
        
        # Set 7
        out = F.elu(self.cnn7(out), inplace=True)

        # Set 8
        out = self.cnn8(out)
        
        return out

In [None]:
class DnCNN_noise_era5(nn.Module):
    """
    Neural network model for InSAR denoising adapted from Rouet-Leduc et al., 2021
    """
    def __init__(self):
        super().__init__()
        kernel_size=3
        padding=1
        features=64
        channels=3
        
        # Convolution 1
        self.cnn1 = nn.Conv2d(in_channels=channels, out_channels=features, kernel_size=kernel_size, padding=padding)
        
        # Convolution 2
        self.cnn2 = nn.Conv2d(in_channels=features, out_channels=features, kernel_size=kernel_size, padding=padding)
        
        # Convolution 3
        self.cnn3 = nn.Conv2d(in_channels=features, out_channels=features, kernel_size=kernel_size, padding=padding)
        
        # Convolution 4
        self.cnn4 = nn.Conv2d(in_channels=features, out_channels=features, kernel_size=kernel_size, padding=padding)
        
        # Convolution 5
        self.cnn5 = nn.Conv2d(in_channels=features, out_channels=features, kernel_size=kernel_size, padding=5, dilation=5)
        
         # Convolution 6
        self.cnn6 = nn.Conv2d(in_channels=features, out_channels=features, kernel_size=kernel_size, padding=5, dilation=5)
        
        # Convolution 7
        self.cnn7 = nn.Conv2d(in_channels=features, out_channels=features, kernel_size=kernel_size, padding=padding)
        
        # Convolution 8
        self.cnn8 = nn.Conv2d(in_channels=features, out_channels=1, kernel_size=kernel_size, padding=padding)

    def forward(self, x, era5, dem):                
        # Set 1
        out = F.elu(self.cnn1(torch.cat((x, era5, dem), dim=1)), inplace=True) 
        
        # Set 2
        out = F.elu(self.cnn2(out), inplace=True) 
        
        # Set 3
        out = F.elu(self.cnn3(out), inplace=True)  
        
        # Set 4
        out = F.elu(self.cnn4(out), inplace=True) 
        
        # Set 5
        out = F.elu(self.cnn5(out), inplace=True)
        
        # Set 6
        out = F.elu(self.cnn6(out), inplace=True)
        
        # Set 7
        out = F.elu(self.cnn7(out), inplace=True)

        # Set 8
        out = self.cnn8(out)
        
        return out

## Load models

In [None]:
model1 = DnCNN_noise()
model1.load_state_dict(torch.load('noisemodelv2.2.1_450epochs'))
model1.to('cuda')

In [None]:
model2 = DnCNN_noise_era5()
model2.load_state_dict(torch.load('noisemodelv2.2.1_450epochs'))
model2.to('cuda')

## Generate SSIM dataframes

In [None]:
def ssim_lists(model, data_loader):
    # Calculate val SSIM 
    ssim_list_uncorrected = []
    ssim_list_model_corrected = []
    ssim_list_era5_corrected = []
    ssim_list_hp_corrected = []
    
    for i, (sample, signal_target, noise_target, era5_noise, dem, era5_corr, hp_corr) in enumerate(data_loader):
        # uncorrected SSIM
        ssim_value_uncorrected = ssim(sample.squeeze().detach().numpy(), signal_target.squeeze().detach().numpy(),
                         gaussian_weights=True)
        ssim_list_uncorrected.append(ssim_value_uncorrected)
    
        # model corrected SSIM
        noise = model(sample.to('cuda'), dem.to('cuda')) #Generate predictions using the model
        signal = torch.clamp(sample.to('cpu') - noise.to('cpu'), -1, 1)
        ssim_value_model_corrected = ssim(signal.squeeze().detach().numpy(), signal_target.squeeze().detach().numpy(),
                         gaussian_weights=True)
        ssim_list_model_corrected.append(ssim_value_model_corrected)
    
        # era5 corrected SSIM
        ssim_value_era5_corrected = ssim(era5_corr.squeeze().numpy(), signal_target.squeeze().numpy(),
                         gaussian_weights=True)
        ssim_list_era5_corrected.append(val_ssim_value_era5_corrected)
    
        # hp filter corrected SSIM
        ssim_value_hp_corrected = ssim(hp_corr.squeeze().numpy(), signal_target.squeeze().numpy(),
                         gaussian_weights=True)
        ssim_list_hp_corrected.append(ssim_value_hp_corrected)
    
    
    print('mean ssim before correction:', np.mean(ssim_list_uncorrected),
          '\nmean ssim model correction:', np.mean(ssim_list_model_corrected), 
          '\nmean ssim era5 correction:', np.mean(ssim_list_era5_corrected),
          '\nmean ssim high pass filter correction:', np.mean(ssim_list_hp_corrected))
    
    return ssim_list_uncorrected, ssim_list_model_corrected, ssim_list_era5_corrected, ssim_list_hp_corrected

In [None]:
m1val_ssim_list_uncorrected, m1val_ssim_list_model, m1val_ssim_list_era5, m1val_ssim_list_hp = ssim_lists(model1, val_loader)
m1train_ssim_list_uncorrected, m1train_ssim_list_model, m1train_ssim_list_era5, m1train_ssim_list_hp = ssim_lists(model1, train_loader)

In [None]:
m2val_ssim_list_uncorrected, m2val_ssim_list_model, m2val_ssim_list_era5, m2val_ssim_list_hp = ssim_lists(model2, val_loader)
m2train_ssim_list_uncorrected, m2train_ssim_list_model, m2train_ssim_list_era5, m2train_ssim_list_hp = ssim_lists(model2, train_loader)

In [None]:
# Calculate SNR
def rms(tensor):
    rms = np.sqrt(np.mean(tensor.squeeze().numpy()**2))
    return rms

def snr(model, data_loader):
    snr_list = []

    for i, (sample, signal_target, noise_target, era5_noise, dem, era5_corr, hp_corr) in enumerate(data_loader):
        snr_list.append(rms(signal_target)/rms(sample-signal_target))

    print('mean snr of images:', np.mean(snr_list))
    
    return snr_list

In [None]:
m1val_snr_list = snr(model1, val_loader)
m1train_snr_list = snr(model1, train_loader)

In [None]:
m2val_snr_list = snr(model2, val_loader)
m2train_snr_list = snr(model2, train_loader)

In [None]:
def df_for_plotting(snr_list, ssim_list_uncorrected, ssim_list_model, ssim_list_era5, ssim_list_hp):

    roll_count = 200
    q_low = 25
    q_high = 75

    ssim_dict = {'snr': snr_list,
                     'ssim_uncorrected':ssim_list_uncorrected,
                     'ssim_model':ssim_list_model,
                     'ssim_era5':ssim_list_era5, 
                     'ssim_hp':ssim_list_hp}
    ssim_df = pd.DataFrame(ssim_dict)

    # uncorrected ssim
    ssim_df['ssim_uncorrected_median'] = ssim_df.sort_values(by=['snr']).ssim_uncorrected.rolling(roll_count, center=True).median()
    ssim_df[f'ssim_uncorrected_q{q_low}'] = ssim_df.sort_values(by=['snr']).ssim_uncorrected.rolling(roll_count, center=True).quantile(quantile=q_low/100)
    ssim_df[f'ssim_uncorrected_q{q_high}'] = ssim_df.sort_values(by=['snr']).ssim_uncorrected.rolling(roll_count, center=True).quantile(quantile=q_high/100)

    # model corrected ssim
    ssim_df['ssim_model_median'] = ssim_df.sort_values(by=['snr']).ssim_model.rolling(roll_count, center=True).median()
    ssim_df[f'ssim_model_q{q_low}'] = ssim_df.sort_values(by=['snr']).ssim_model.rolling(roll_count, center=True).quantile(quantile=q_low/100)
    ssim_df[f'ssim_model_q{q_high}'] = ssim_df.sort_values(by=['snr']).ssim_model.rolling(roll_count, center=True).quantile(quantile=q_high/100)

    # era5 corrected ssim
    ssim_df['ssim_era5_median'] = ssim_df.sort_values(by=['snr']).ssim_era5.rolling(roll_count, center=True).median()
    ssim_df[f'ssim_era5_q{q_low}'] = ssim_df.sort_values(by=['snr']).ssim_era5.rolling(roll_count, center=True).quantile(quantile=q_low/100)
    ssim_df[f'ssim_era5_q{q_high}'] = ssim_df.sort_values(by=['snr']).ssim_era5.rolling(roll_count, center=True).quantile(quantile=q_high/100)

    # era5 corrected ssim
    ssim_df['ssim_hp_median'] = ssim_df.sort_values(by=['snr']).ssim_hp.rolling(roll_count, center=True).median()
    ssim_df[f'ssim_hp_q{q_low}'] = ssim_df.sort_values(by=['snr']).ssim_hp.rolling(roll_count, center=True).quantile(quantile=q_low/100)
    ssim_df[f'ssim_hp_q{q_high}'] = ssim_df.sort_values(by=['snr']).ssim_hp.rolling(roll_count, center=True).quantile(quantile=q_high/100)
    
    return ssim_df

In [None]:
m1val_ssim_df=df_for_plotting(m1val_snr_list, m1val_ssim_list_uncorrected, m1val_ssim_list_model, m1val_ssim_list_era5, m1val_ssim_list_hp)
m1train_ssim_df=df_for_plotting(m1train_snr_list, m1train_ssim_list_uncorrected, m1train_ssim_list_model, m1train_ssim_list_era5, m1train_ssim_list_hp)
m2val_ssim_df=df_for_plotting(m2val_snr_list, m2val_ssim_list_uncorrected, m2val_ssim_list_model, m2val_ssim_list_era5, m2val_ssim_list_hp)
m2train_ssim_df=df_for_plotting(m2train_snr_list, m2train_ssim_list_uncorrected, m2train_ssim_list_model, m2train_ssim_list_era5, m2train_ssim_list_hp)