# InSAR Denoiser training, validation, testing

In [9]:
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import datasets, models, transforms
from torch.utils.data import DataLoader, Dataset
import numpy as np
import pandas as pd
import os
import matplotlib.pyplot as plt
import torch
import glob
from sklearn.model_selection import train_test_split
from PIL import Image

## Dataset Construction
intended file structure to ingest includes four dirs: train_image, train_label, test_image. Folders contain only images. 
- Filenames in image folders: sub1_S1AA_20190505T135154_20190622T135157_VVP048_INT80_G_weF_011A_los_disp.tif
- Filenames in label folders: sub1_S1AA_20190505T135154_20190622T135157_VVP048_INT80_G_weF_011A_los_disp.tif

In [5]:
train_fns = os.listdir('/Users/qbren/Desktop/uw_courses/2022_spring/inferring-structure/project/data_processing/data_crop')
test_fns = os.listdir('/Users/qbren/Desktop/uw_courses/2022_spring/inferring-structure/project/data_processing/data_crop') #change paths before use

def keep_tifs(my_fns):
    my_list = []
    for i in my_fns:
        if i[-4:] == '.tif':
            my_list.append(i)
    return my_list
        
    
train_list = keep_tifs(train_fns)
test_list = keep_tifs(test_fns)

In [45]:
train_list, val_list = train_test_split(train_list, test_size=0.2)

In [46]:
my_transforms = transforms.Compose([
    transforms.ToTensor() #because label is also an image that needs to match, can't do any flipping
])

In [50]:
# normalization between -1 and 1 as in Zhao et al. https://doi.org/10.1016/j.isprsjprs.2021.08.009

train_img_dir = '/Users/qbren/Desktop/uw_courses/2022_spring/inferring-structure/project/data_processing/data_crop/'
train_label_dir = ''
test_img_dir = ''
test_label_dir = ''

class dataset(torch.utils.data.Dataset):
    def __init__(self,file_list, img_dir, label_dir, transform=None, norm=True):
        self.file_list = file_list
        self.transform = transform
        self.img_dir = img_dir
        self.label_dir = label_dir
        self.norm = norm
        
    #dataset length
    def __len__(self):
        self.filelength = len(self.file_list)
        return self.filelength
    
    #load images
    def __getitem__(self,idx):
        img_path = self.img_dir+self.file_list[idx]
        label_path = self.label_dir+self.file_list[idx]
        img = Image.open(img_path)
        label = Image.open(label_path)
        img_transformed = self.transform(img)
        label_transformed = self.transform(label)
        
        # Perform normalization
        if self.norm == True:
            img_transformed = 2*(((img_transformed-img_transformed.min())/(img_transformed.max()-img_transformed.min())))-1
            label_transformed = 2*(((label_transformed-label_transformed.min())/(label_transformed.max()-label_transformed.min())))-1 
        
        return img_transformed, label_transformed

In [51]:
# create dataloaders
train_data = dataset(train_list, train_img_dir, train_label_dir, transform=my_transforms)
val_data = dataset(val_list, train_img_dir, train_label_dir, transform=my_transforms)
test_data = dataset(test_list, test_img_dir, test_label_dir, transform=my_transforms)

train_loader = torch.utils.data.DataLoader(dataset = train_data, batch_size=64, shuffle=True )
test_loader = torch.utils.data.DataLoader(dataset = test_data, batch_size=64, shuffle=True)
val_loader = torch.utils.data.DataLoader(dataset = val_data, batch_size=64, shuffle=True)

## Model definition

In [53]:
# model from Zhang et al. https://doi.org/10.1109/TIP.2017.2662206, as implemented here: https://github.com/SaoYan/DnCNN-PyTorch 
class DnCNN(nn.Module):
    def __init__(self, channels, num_of_layers=17):
        super(DnCNN, self).__init__()
        kernel_size = 3
        padding = 1
        features = 64
        layers = []
        layers.append(nn.Conv2d(in_channels=channels, out_channels=features, kernel_size=kernel_size, padding=padding, bias=False))
        layers.append(nn.ReLU(inplace=True))
        for _ in range(num_of_layers-2):
            layers.append(nn.Conv2d(in_channels=features, out_channels=features, kernel_size=kernel_size, padding=padding, bias=False))
            layers.append(nn.ReLU(inplace=True))
        layers.append(nn.Conv2d(in_channels=features, out_channels=channels, kernel_size=kernel_size, padding=padding, bias=False))
        self.dncnn = nn.Sequential(*layers)
    def forward(self, x):
        out = self.dncnn(x)
        return out

## Train Model

In [None]:
%%time

#Define optimizer
model = DnCNN(channels=1)
model.to('cuda') # run on gpu
optimizer = torch.optim.Adam(model.parameters(), lr=0.001) #original implementation reduces learning rate at milestone 30 epochs
loss_fn   = nn.MSELoss() 
epochs = 30

train_loss = []
valid_loss = []

for epoch in range(epochs):
    print(f'\nstarting epoch {epoch}')
    epoch_loss=[]
    val_temp_loss = []
    
    #loop through training data 
    for (sample, target) in train_loader:
        model.train()
        optimizer.zero_grad()
        out = torch.clamp(sample.to('cuda') - model(sample.to('cuda')), -1, 1) #Generate noise predictions using the model, subtract from interferogram
        loss = loss_fn(out.to('cuda'), target.to('cuda')) #Loss/error
        epoch_loss.append(loss.item())
        loss.backward() #Propagate the gradients in backward pass
        optimizer.step()
        
    train_loss.append(np.mean(epoch_loss))
    print(f'training loss: {np.mean(epoch_loss)}')
    
    # run model on validation data 
    for (sample, target) in val_loader:
        with torch.no_grad():
            model.eval()
            out = torch.clamp(sample.to('cuda') - model(sample.to('cuda')), -1, 1) #Generate predictions using the model
            loss = loss_fn(out.to('cuda'), target.to('cuda')) #Loss/error
            val_temp_loss.append(loss.item())
    
    valid_loss.append(np.mean(val_temp_loss))
    print(f'validation loss: {np.mean(val_temp_loss)}')

### Plot loss

In [None]:
# Plot loss
f, ax = plt.subplots(figsize=(10,10))
ax.plot(v1_train_loss, label='training')
ax.plot(v1_valid_loss, label='validaton')
ax.set_xlabel('epoch')
ax.set_ylabel('MSE loss')
ax.set_title('Loss')
ax.legend()

### Visualize outputs 

In [None]:
val_loader = torch.utils.data.DataLoader(dataset = val_data, batch_size=1, shuffle=True) #change batch size

for i, (sample, target) in enumerate(val_loader):
    if i < num_images:
        with torch.no_grad():
            noise = model(sample.to('cuda')) #Generate predictions using the model
            nn_corrected = torch.clamp(sample.to('cuda') - noise, -1, 1)
            
            f, ax = plt.subplots(columns=4, figsize=(5,15))
            ax[0].imshow(sample.permute(1, 2, 0)) #will probably need to fix this
            ax[0].set_title('original interferogram')
            ax[1].imshow(noise.permute(1, 2, 0))
            ax[1].set_title('predicted noise')
            ax[2].imshow(nn_corrected.permute(1, 2, 0))
            ax[2].set_title('NN corrected interferogram')
            ax[3].imshow(target.permute(1, 2, 0))
            ax[3].set_title('ERA5 corrected interferogram')

# Test Model