# Train a despeckling neural network

### Preparation

In [None]:
import os.path

import torch
from torch.utils.data import DataLoader, random_split
from torch.nn import MSELoss, L1Loss
from torch.optim import Adam
import numpy as np
from skimage.measure import compare_ssim as ssim
import tqdm

from datasets import NoisyScansDataset
from despeckling import models

Defining the following variables will help us make our code device agnostic

In [None]:
cuda = True if torch.cuda.is_available() else False
device = torch.device("cuda") if cuda else torch.device("cpu")

Let's define some helper functions

In [None]:
# this function returns a dictionary we can use to create our noisy dataset.
def get_noise_args(noise_type):
    if noise_type == 'gaussian':
        noise_args = {'random_variable': np.random.normal,
                      'loc': 1, 'scale': 0.1}
    elif noise_type == 'gamma':
        noise_args = {'random_variable': np.random.gamma,
                      'shape': 1, 'scale': 1}

In [None]:
# this function computes the sum of SSIM over a batch of images.
def compute_ssim(noisy_batch, clean_batch, median_filter=False):
    # iterate over batch to compute SSIM
    ssim_sum = 0
    for noisy, clean in zip(noisy_batch[:, 0], clean_batch[:, 0]):
        noisy = noisy.data.cpu().numpy()

        if median_filter:
            noisy = (noisy + 1) / 2 * 255
            noisy = noisy.astype(np.uint8)
            noisy = np.median(noisy)
            noisy = (noisy / 255.0 - 0.5) * 2

        ssim_sum += ssim(noisy, clean.data.cpu().numpy(), data_range=2)
    return ssim_sum

In [None]:
# this function returns a torch model based on an easy name
def get_model(model_str, num_layers):
    if model_str == 'log_add':
        return models.LogAddDespeckle(num_layers)
    elif model_str == 'log_subtract':
        return models.LogSubtractDespeckle(num_layers)
    elif model_str == 'multiply':
        return models.MultiplyDespeckle(num_layers)
    elif model_str == 'divide':
        return models.DivideDespeckle(num_layers)
    else:
        raise NotImplementedError(model_str + 'model does not exist.')

In [None]:
# this function returns a pytorch loss object based on a easy name
def get_criterion(criterion_str):
    if criterion_str == 'mse':
        criterion = MSELoss()
    elif criterion_str == 'l1':
        criterion = L1Loss()

### Define the dataset

The dataset returns a pair of images: a multiplicative-noise contaminated image and its corresponding clean image.

We do a 90/10 train/validation split

In [None]:
# dataset returns (noisy, clean) tuple
dataset = NoisyScansDataset(args.data_root, 'F', noise_args, apply_random_crop=(not args.no_crop))
train_size = int(0.9 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

train_dataloader = DataLoader(train_dataset, batch_size=args.batch_size, num_workers=4)
val_dataloader = DataLoader(val_dataset, batch_size=args.batch_size, num_workers=4)

### Define loss function

Our loss functions will consist on a distance between the output of the model and the clean image.  
L1 (manhattan) or MSE (euclidian) are basic distance measures.

In [None]:
criterion = get_criterion('mse')

### Define despeckling model

Our model consists on a series of convolutional layers followed by a skip connection connected to the input.

* We can transform our input image to the log space and use a additive skip connection.
* Or use a multiplicative or division connection and work with the original space.

In [None]:
model = get_model('divide', 3)

In [None]:
if cuda:
    model = model.cuda()
    criterion = criterion.cuda()

### Define Adam optimizer.

In [None]:
optimizer = Adam(params=model.parameters(), lr=1e-3)

### Traing loop

In [None]:
for epoch in range(20):
    # TRAINING.
    model.train()

    print('Epoch {} of {}'.format(epoch, args.epochs - 1))
    input_and_target = tqdm.tqdm(enumerate(train_dataloader), total=len(train_dataloader))

    med_loss = 0
    for i, (x_batch, target_batch) in input_and_target:
        x_batch, target_batch = x_batch.float().to(device), target_batch.float().to(device)

        optimizer.zero_grad()
        output_batch = model(x_batch)

        loss = criterion(output_batch, target_batch)
        loss.backward()
        optimizer.step()

        med_loss += loss.data.cpu().numpy()

        input_and_target.set_description('Train loss = {0:.3f}'.format(loss))

    # VALIDATION.
    print('Validation:')
    model.eval()

    input_and_target = tqdm.tqdm(enumerate(val_dataloader), total=len(val_dataloader))

    med_loss_eval = 0
    prev_loss_eval = 0
    for i, (x_batch, target_batch) in input_and_target:
        x_batch, target_batch = x_batch.float().to(device), target_batch.float().to(device)
        output_batch = model(x_batch)
        loss = criterion(output_batch, target_batch)
        med_loss_eval += loss.data.cpu().numpy()
        prev_loss_eval = criterion(x_batch, target_batch).data.cpu().numpy()

        ssim_input = compute_ssim(x_batch, target_batch)
        ssim_output = compute_ssim(output_batch, target_batch)
            
        input_and_target.set_description(
            'Output loss = {0:.3f}'.format(loss)
            + ' Input loss = {0:.3f}'.format(prev_loss_eval)
            + ' Input SSIM = {0:.3f}'.format(ssim_noisy / args.batch_size)
            + ' Output SSIM = {0:.3f}'.format(ssim_clean / args.batch_size))