# InSAR denoiser testing

In [None]:
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import datasets, models, transforms
from torch.utils.data import DataLoader, Dataset
import numpy as np
import pandas as pd
import os
import matplotlib.pyplot as plt
import torch
import glob
from sklearn.model_selection import train_test_split
from PIL import Image
import seaborn as sns
import math
from skimage.metrics import structural_similarity as ssim
import rioxarray

## define network and load model

In [None]:
class DnCNN(nn.Module):
    """
    Neural network model for InSAR denoising adapted from Rouet-Leduc et al., 2021
    """
    def __init__(self):
        super().__init__()
        kernel_size=3
        padding=1
        features=64
        channels=2
        
        # Convolution 1
        self.cnn1 = nn.Conv2d(in_channels=channels, out_channels=features, kernel_size=kernel_size, padding=padding)
        
        # Convolution 2
        self.cnn2 = nn.Conv2d(in_channels=features, out_channels=features, kernel_size=kernel_size, padding=padding)
        
        # Convolution 3
        self.cnn3 = nn.Conv2d(in_channels=features, out_channels=features, kernel_size=kernel_size, padding=padding)
        
        # Convolution 4
        self.cnn4 = nn.Conv2d(in_channels=features, out_channels=features, kernel_size=kernel_size, padding=padding)
        
        # Convolution 5
        self.cnn5 = nn.Conv2d(in_channels=features, out_channels=features, kernel_size=kernel_size, padding=5, dilation=5)
        
        # Convolution 6
        self.cnn6 = nn.Conv2d(in_channels=features, out_channels=features, kernel_size=kernel_size, padding=4, dilation=4)
        
        # Convolution 7
        self.cnn7 = nn.Conv2d(in_channels=features, out_channels=features, kernel_size=kernel_size, padding=2, dilation=2)
        
        # Convolution 8
        self.cnn8 = nn.Conv2d(in_channels=features, out_channels=features, kernel_size=kernel_size, padding=padding)
        
        # Convolution 9
        self.cnn9 = nn.Conv2d(in_channels=features, out_channels=1, kernel_size=kernel_size, padding=padding)

    def forward(self, x, dem):                
        # Set 1
        out = F.elu(self.cnn1(torch.cat((x, dem), dim=1)), inplace=True) 
        
        # Set 2
        out = F.elu(self.cnn2(out), inplace=True) 
        
        # Set 3
        out = F.elu(self.cnn3(out), inplace=True)  
        
        # Set 4
        out = F.elu(self.cnn4(out), inplace=True) 
        
        # Set 5
        out = F.elu(self.cnn5(out), inplace=True) 
        
        # Set 6
        out = F.elu(self.cnn6(out), inplace=True)
        
        # Set 7
        out = F.elu(self.cnn7(out), inplace=True) 
        
        # Set 8
        out = F.elu(self.cnn8(out), inplace=True)
        
        # Set 9
        out = self.cnn9(out)
        
        return out

In [None]:
# load model
test_model = DnCNN()
test_model.load_state_dict(torch.load('model_100_epochs'))

# Load tifs

In [None]:
# functions to load interferogram tifs to xarray

def xr_read_geotif(geotif_file_path, chunks='auto', masked=True):
    """
    Reads in single or multi-band GeoTIFF as dask array.
    Inputs
    ----------
    GeoTIFF_file_path : GeoTIFF file path
    Returns
    -------
    ds : xarray.Dataset
        Includes rioxarray extension to xarray.Dataset
    """

    da = rioxarray.open_rasterio(geotif_file_path, chunks=chunks, masked=True)

    # Extract bands and assign as variables in xr.Dataset()
    ds = xr.Dataset()
    for i, v in enumerate(da.band):
        da_tmp = da.sel(band=v)
        da_tmp.name = "band" + str(i + 1)

        ds[da_tmp.name] = da_tmp

    # Delete empty band coordinates.
    # Need to preserve spatial_ref coordinate, even though it appears empty.
    # See spatial_ref attributes under ds.coords.variables used by rioxarray extension.
    del ds.coords["band"]

    # Preserve top-level attributes and extract single value from value iterables e.g. (1,) --> 1
    ds.attrs = da.attrs
    for key, value in ds.attrs.items():
        try:
            if len(value) == 1:
                ds.attrs[key] = value[0]
        except TypeError:
            pass

    return ds


def hyp3_to_xarray(hyp3_dir, file_type='unw_phase'):
    
    dirs = os.listdir(hyp3_dir) #list generated interferograms
    datasets = []
    
    for idir in dirs:
        cwd = f'{hyp3_dir}/{idir}'
        os.chdir(cwd) #change to interferogram dir

        ext = f'{file_type}.tif' #end of filename for desired hyp3 product
        
        for fn in os.listdir(cwd): #select appropriate hyp3 product
            if fn[-len(ext):] == ext: 
                tif_fn = fn
        tif_path = f'{hyp3_dir}/{idir}/{tif_fn}'
        dates = f'{tif_fn[5:13]}_{tif_fn[21:29]}' #parse filename for interferogram dates
        
        src = xr_read_geotif(tif_path, masked=False) #read product to xarray ds 
        src = src.assign_coords({"dates": dates})
        src = src.expand_dims("dates")
        
        datasets.append(src)
       
    ds = xr.concat(datasets, dim="dates", combine_attrs="no_conflicts") #create dataset
    return ds 

In [None]:
# open ints and dems
hyp3_dir = '/Users/qbren/Desktop/taco/projects/atmospheric_correction/data_processing/test_data/asc_crop'

int_xarray = hyp3_to_xarray(hyp3_dir)

dem_fn = ''
dem_src = rio.open(f'{hyp3_dir}/{dem_fn}')
dem_np = dem_src.read(1) 

In [None]:
# function to prepare arrays for model run
def arrays_to_tensor(int_array, dem_array, norm=True):
    test_tensor = torch.Tensor(array.to_numpy())
    dem_tensor = torch.Tensor(dem_array)
    
    if norm=True:
        test_tensor = 2*(((test_tensor-(test_tensor.min()))/(test_tensor.max()-(test_tensor.min()))))-1
        dem_tensor = 2*(((dem_tensor-(dem_tensor.min()))/(dem_tensor.max()-(dem_tensor.min()))))-1
    
    return test_tensor, dem_tensor

In [None]:
# run test interferograms through model