# Training a model from auvlib data

Let's start by importing some useful functions. In this tutorial we are going to use [pytorch](https://pytorch.org/) for training a neural network. Let's import the necessary modules...

In [1]:
import torch
import torch.optim as optim
import torch.nn as nn
from data.aligned_dataset import AlignedDataset
from util import visualizer as vs
from util import net_util
import os
import functools
from util import html
from IPython.display import IFrame

## Defining the neural network

We are going to train a convolutional network with residual blocks, a so called [*resnet*](https://arxiv.org/abs/1512.03385), to map sidescan patches to corresponding sea floor depths. A resnet is a popular for of neural network that uses skip connections to facilitate training, see sketch below. To begin with, let us define the network structure. ![resnet](ResNets.svg)

In [2]:
class ResnetBlock(nn.Module):
    """Define a Resnet block"""

    def __init__(self, dim, padding_type, norm_layer, use_dropout, use_bias):
        """Initialize the Resnet block

        A resnet block is a conv block with skip connections
        We construct a conv block with build_conv_block function,
        and implement skip connections in <forward> function.
        Original Resnet paper: https://arxiv.org/pdf/1512.03385.pdf
        """
        super(ResnetBlock, self).__init__()
        self.conv_block = self.build_conv_block(dim, padding_type, norm_layer, use_dropout, use_bias)

    def build_conv_block(self, dim, padding_type, norm_layer, use_dropout, use_bias):
        """Construct a convolutional block.

        Parameters:
            dim (int)           -- the number of channels in the conv layer.
            padding_type (str)  -- the name of padding layer: reflect | replicate | zero
            norm_layer          -- normalization layer
            use_dropout (bool)  -- if use dropout layers.
            use_bias (bool)     -- if the conv layer uses bias or not

        Returns a conv block (with a conv layer, a normalization layer, and a non-linearity layer (ReLU))
        """
        conv_block = []
        p = 0
        if padding_type == 'reflect':
            conv_block += [nn.ReflectionPad2d(1)]
        elif padding_type == 'replicate':
            conv_block += [nn.ReplicationPad2d(1)]
        elif padding_type == 'zero':
            p = 1
        else:
            raise NotImplementedError('padding [%s] is not implemented' % padding_type)

        conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias), norm_layer(dim), nn.ReLU(True)]
        if use_dropout:
            conv_block += [nn.Dropout(0.5)]

        p = 0
        if padding_type == 'reflect':
            conv_block += [nn.ReflectionPad2d(1)]
        elif padding_type == 'replicate':
            conv_block += [nn.ReplicationPad2d(1)]
        elif padding_type == 'zero':
            p = 1
        else:
            raise NotImplementedError('padding [%s] is not implemented' % padding_type)
        conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias), norm_layer(dim)]

        return nn.Sequential(*conv_block)

    def forward(self, x):
        """Forward function (with skip connections)"""
        out = x + self.conv_block(x)  # add skip connections
        return out

class ResnetGenerator(nn.Module):
    """Resnet-based generator that consists of Resnet blocks between a few downsampling/upsampling operations.

    We adapt Torch code and idea from Justin Johnson's neural style transfer project(https://github.com/jcjohnson/fast-neural-style)
    """

    def __init__(self, input_nc, output_nc, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False, n_blocks=6, padding_type='reflect'):
        """Construct a Resnet-based generator

        Parameters:
            input_nc (int)      -- the number of channels in input images
            output_nc (int)     -- the number of channels in output images
            ngf (int)           -- the number of filters in the last conv layer
            norm_layer          -- normalization layer
            use_dropout (bool)  -- if use dropout layers
            n_blocks (int)      -- the number of ResNet blocks
            padding_type (str)  -- the name of padding layer in conv layers: reflect | replicate | zero
        """
        assert(n_blocks >= 0)
        super(ResnetGenerator, self).__init__()
        if type(norm_layer) == functools.partial:
            use_bias = norm_layer.func == nn.InstanceNorm2d
        else:
            use_bias = norm_layer == nn.InstanceNorm2d

        model = [nn.ReflectionPad2d(3),
                 nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0, bias=use_bias),
                 norm_layer(ngf),
                 nn.ReLU(True)]

        n_downsampling = 2
        for i in range(n_downsampling):  # add downsampling layers
            mult = 2 ** i
            model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=1, bias=use_bias),
                      norm_layer(ngf * mult * 2),
                      nn.ReLU(True)]

        mult = 2 ** n_downsampling
        for i in range(n_blocks):       # add ResNet blocks

            model += [ResnetBlock(ngf * mult, padding_type=padding_type, norm_layer=norm_layer, use_dropout=use_dropout, use_bias=use_bias)]

        for i in range(n_downsampling):  # add upsampling layers
            mult = 2 ** (n_downsampling - i)
            model += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2),
                                         kernel_size=3, stride=2,
                                         padding=1, output_padding=1,
                                         bias=use_bias),
                      norm_layer(int(ngf * mult / 2)),
                      nn.ReLU(True)]
        model += [nn.ReflectionPad2d(3)]
        model += [nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)]
        model += [nn.Tanh()]

        self.model = nn.Sequential(*model)

    def forward(self, input):
        """Standard forward"""
        return self.model(input)
    
def save_network(save_dir, epoch):
    """Save the network to the disk.
    Parameters:
        epoch (int) -- current epoch; used in the file name '%s_net_%s.pth' % (epoch, name)
    """
    save_filename = '%s_resnet.pth' % (epoch,)
    save_path = os.path.join(save_dir, save_filename)

    torch.save(net.state_dict(), save_path)
    
def load_network(save_dir, epoch, device):
    """Load the network from disk.
    Parameters:
        epoch (int) -- current epoch; used in the file name '%s_net_%s.pth' % (epoch, name)
        device (str) -- the name of the torch device (e.g. "cpu")
    """
    save_filename = '%s_resnet.pth' % (epoch,)
    save_path = os.path.join(save_dir, save_filename)
    
    norm_layer = net_util.get_norm_layer(norm_type='batch')
    model = ResnetGenerator(1, 1, 64, norm_layer=norm_layer, use_dropout=True, n_blocks=6).to(device)
    model.load_state_dict(torch.load(save_path, map_location=device))
    return model

## Setting up the training

The network will get loaded on the GPU if there is one available. Note that training requires a GPU, while testing could be done on a CPU.

In [3]:
norm_layer = net_util.get_norm_layer(norm_type='batch')
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
net = ResnetGenerator(1, 1, 64, norm_layer=norm_layer, use_dropout=True, n_blocks=6).to(device)
net_util.init_weights(net, init_type='normal', init_gain=0.02)

initialize network with normal


We are going to use a L1 loss to force the predicted depths to be close to the ground truth depths of the training examples. We will use the Adam optmizer because we found that it works good while being fast. However, one could also use the standard stochastic gradient descent method.

In [4]:
criterion = torch.nn.L1Loss()
#optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
optimizer = optim.Adam(net.parameters(), lr=0.0001, betas=(0.5, 0.999))

Now we are going to feed in the images that we stored in the previous part of the tutorial. For the training part, we will use the images from the `train` subdirectory.

In [5]:
dataroot = 'sss2depth' # the name of the dataset folder
batch_size = 1 # number of examples to feed in at a time
epochs = 100 # number of epochs to train the network

opt = net_util.get_default_train_options() # e.g. display config
opt.batch_size = batch_size
opt.dataroot = dataroot
opt.phase = 'train'
# this iterator will feed data to our network
dataset = torch.utils.data.DataLoader(
                AlignedDataset(opt),
                batch_size=opt.batch_size,
                shuffle=not opt.serial_batches,
                num_workers=int(opt.num_threads))

## Training the network

Now we are ready to start training our network. Every 5 epochs we will save the current model in `checkpoints/XX_resnet.pth`. You can inspect the process of the training in [visdom](http://localhost:8097/). The website should look something like this: ![visdom](visdom.png).

In [None]:
# create a visualizer that display/save images and plots
visualizer = vs.Visualizer(opt)

# show a colorbar between the min and max depths as defined before
visualizer.display_colorbar_jet(-19., -11.)

dataset_size = len(dataset) # the number of examples in dataset
display_freq = 100 # how often to update web page images
update_html_freq = 1000 # how often to save html on disk

total_iters = 0
for epoch in range(epochs):  # loop over the dataset multiple times

    running_loss = 0.0
    for i, data in enumerate(dataset):
        visualizer.reset()

        # get the inputs; data is a list of [inputs, labels]
        inputs = data['A'].to(device)
        target = data['B'].to(device)

        # zero the parameter gradients
        optimizer.zero_grad()

        mask = (target > -.99).float()
        # forward + backward + optimize
        outputs = net(inputs)
        
        loss = criterion(mask*outputs, mask*target)
        loss.backward()
        optimizer.step()
        
        if i % display_freq == 0: # display images on visdom and save images to a HTML file
            save_result = i % update_html_freq == 0
            visual_ret = {'Input': inputs, 'Output_Jet': mask*(outputs+1.)-1., 'Real_Jet': target}
            visualizer.display_current_results(visual_ret, epoch, save_result)

        # print statistics
        running_loss += loss.item()
        total_iters += 1
        if total_iters % 20 == 0: # print every 20 mini-batches
            print('[%d, %5d] loss: %.3f' %
                  (epoch + 1, i + 1, running_loss / 20.))
            print('Epoch %d / %d, example %d / %d' % (epoch + 1, epochs, i+1, dataset_size))
            losses = {'L1_loss': running_loss/20.}
            print float(i)/dataset_size
            visualizer.plot_current_losses(epoch, float(i) / dataset_size, losses)
            running_loss = 0.0
            
    if epoch % 5 == 0: # cache our model every 5 epochs
        print('saving the model at the end of epoch %d, iters %d' % (epoch, i))
        save_network('checkpoints', 'latest')
        save_network('checkpoints', epoch)

print('Finished Training')

## Evaluating the network

We can now load any of the saved epochs and test the trained model on our test dataset. The results will be stored in `results/regress/test_XX`.

In [None]:
opt.phase = "test"
opt.isTrain = False
opt.dataroot = "sss2depth"
epoch = "latest"

dataset = torch.utils.data.DataLoader(
                AlignedDataset(opt),
                batch_size=opt.batch_size,
                shuffle=not opt.serial_batches,
                num_workers=int(opt.num_threads))

net = load_network('checkpoints', epoch, device)      # create a model given opt.model and other options

# create a website
web_dir = os.path.join("results", opt.name, '%s_%s' % ("test", epoch))  # define the website directory
webpage = html.HTML(web_dir, 'Experiment = %s, Phase = %s, Epoch = %s' % (opt.name, opt.phase, epoch))

num_test = 25
for i, data in enumerate(dataset):
    if i >= num_test:  # only apply our model to opt.num_test images.
        break
    inputs = data['A'].to(device)
    target = data['B'].to(device)
    
    mask = (target > -.99).float()
    outputs = net(inputs)
    
    visuals = {'Input': inputs, 'Output_Jet': mask*(outputs+1.)-1., 'Real_Jet': target}
    img_path = data['A_paths'] # get image paths
    if i % 5 == 0:  # save images to an HTML file
        print('processing (%04d)-th image... %s' % (i, img_path))
    vs.save_images(webpage, visuals, img_path, aspect_ratio=opt.aspect_ratio, width=opt.display_winsize)
webpage.save()  # save the HTML

# show the saved web page with results
IFrame(src=os.path.join(web_dir, "index.html"), width=850, height=600)