# Train [U-Net](https://arxiv.org/abs/1505.04597)

The [`nn`](../nn) module defines everything related to neural network. 
The default assumption for these neural networks is they need [tf.data.Dataset](https://www.tensorflow.org/api_docs/python/tf/data/Dataset) as input. This reader can be created by using [`reader`](../reader) module

Training a network can be splitted into the following steps:

In [1]:
# this is for path management in jupyter notebook
# note necessary if you're running in terminal or other IDEs
import os
import sys
module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)

## Define the Network

In [2]:
# imports and parameter settings
import tensorflow as tf
from nn import unet
class_num = 2                 # class number in ground truth
patch_size = (572, 572)       # input patch size
lr = 1e-4                     # start learning rate
ds = 60                       # #epochs before lr decays
dr = 0.1                      # lr will decay to lr*dr
epochs = 1                    # #epochs to train
bs = 5                        # batch size
suffix = 'test'               # user customize name for the network

  from ._conv import register_converters as _register_converters


In [3]:
# define network
tf.reset_default_graph()
unet = unet.UNet(class_num, patch_size, suffix=suffix, learn_rate=lr, decay_step=ds, decay_rate=dr,
                 epochs=epochs, batch_size=bs)
overlap = unet.get_overlap()

## Make Collection

In [4]:
# imports and parameter settings
import numpy as np
from collection import collectionMaker, collectionEditor
ds_name = 'Inria'

In [5]:
cm = collectionMaker.read_collection(raw_data_path=r'/media/ei-edl01/data/uab_datasets/inria/data/Original_Tiles',
                                     field_name='austin,chicago,kitsap,tyrol-w,vienna', # use all cities
                                     field_id=','.join(str(i) for i in range(37)), # use all tiles
                                     rgb_ext='RGB',
                                     gt_ext='GT',
                                     file_ext='tif',
                                     force_run=False,
                                     clc_name=ds_name)
gt_d255 = collectionEditor.SingleChanMult(cm.clc_dir, 1/255, ['GT', 'gt_d255']).\
    run(force_run=False, file_ext='png', d_type=np.uint8,)
cm.replace_channel(gt_d255.files, True, ['GT', 'gt_d255'])
chan_mean = cm.meta_data['chan_mean'][:3]

gt_d255 might already exist, skip replacement


In [6]:
cm.print_meta_data()

raw_data_path: /media/ei-edl01/data/uab_datasets/inria/data/Original_Tiles
field_name: ['austin', 'chicago', 'kitsap', 'tyrol-w', 'vienna']
field_id: ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '10', '11', '12', '13', '14', '15', '16', '17', '18', '19', '20', '21', '22', '23', '24', '25', '26', '27', '28', '29', '30', '31', '32', '33', '34', '35', '36']
clc_name: Inria
tile_dim: (5000, 5000)
chan_mean: [103.2341876  108.95194825 100.14192002]
Source file: *RGB*.tif
GT file: *gt_d255*.png


## Get Training and Validation Files

In [7]:
# 6~36 for training
file_list_train = cm.load_files(field_id=','.join(str(i) for i in range(6, 37)), field_ext='RGB,gt_d255')
# 1~6 for validation
file_list_valid = cm.load_files(field_id=','.join(str(i) for i in range(6)), field_ext='RGB,gt_d255')

## Patch Extraction

In [8]:
# imports and parameter settings
from preprocess import patchExtractor
tile_size = (5000, 5000)

In [9]:
patch_list_train = patchExtractor.PatchExtractor(patch_size, tile_size, ds_name+'_train', overlap, overlap//2).\
    run(file_list=file_list_train, file_exts=['jpg', 'png'], force_run=False).get_filelist()
patch_list_valid = patchExtractor.PatchExtractor(patch_size, tile_size, ds_name+'_valid', overlap, overlap//2).\
    run(file_list=file_list_valid, file_exts=['jpg', 'png'], force_run=False).get_filelist()

## Create Data Reader

In [10]:
# imports and parameter settings
from reader import dataReaderSegmentation, reader_utils
valid_mult = 5 # validation can have a larger batch size

In [11]:
train_init_op, valid_init_op, reader_op = dataReaderSegmentation.DataReaderSegmentationTrainValid(
    patch_size, patch_list_train, patch_list_valid, batch_size=bs, chan_mean=chan_mean,
    aug_func=[reader_utils.image_flipping, reader_utils.image_rotating], # augmentation function for training data
    random=True, has_gt=True, gt_dim=1, include_gt=True, valid_mult=valid_mult).read_op()
feature, label = reader_op

## Setup Training

[`hook`](../nn/hook.py) is used here to monitor the training/validation process.

In [12]:
# imports and parameter settings
import ersaPath
from nn import hook, nn_utils
sfn = 32 # start fileter number for the U-Net
n_train = 1000 # #samples per epoch
n_valid = 1000//bs//valid_mult # #steps to run in validation
gpu = 1
verb_step = 200 # print out verbose messages every 200 steps
nn_utils.set_gpu(gpu)

In [13]:
unet.create_graph(feature, sfn)
unet.compile(feature, label, n_train, n_valid, patch_size, ersaPath.PATH['model'], par_dir='test', loss_type='xent')
train_hook = hook.ValueSummaryHook(verb_step, [unet.loss, unet.lr_op], value_names=['train_loss', 'learning_rate'],
                                   print_val=[0]) # print loss, write loss and lr every 200 step
# print&write validation loss every epoch
valid_loss_hook = hook.ValueSummaryHook(unet.get_epoch_step(), [unet.loss],
                                        value_names=['valid_loss'], log_time=True, run_time=unet.n_valid)
# print&write validation IoU every epoch
valid_iou_hook = hook.IoUSummaryHook(unet.get_epoch_step(), unet.loss_iou, log_time=True, run_time=unet.n_valid,
                                     cust_str='\t')
# write validation images every epoch
image_hook = hook.ImageValidSummaryHook(unet.get_epoch_step(), unet.valid_images, feature, label, unet.pred,
                                        nn_utils.image_summary, img_mean=chan_mean)

  "Converting sparse IndexedSlices to a dense Tensor of unknown shape. "


In [14]:
unet.train(train_hooks=[train_hook], valid_hooks=[valid_loss_hook, valid_iou_hook, image_hook],
           train_init=train_init_op, valid_init=valid_init_op)

Step 200	train_loss 0.464
Eval @ Epoch 0 Step 200	valid_loss 0.411, Duration: 107.648
	Step 200	IoU 0.205, Duration: 138.107


The hoooks also write data into [tensorboard](https://www.tensorflow.org/guide/summaries_and_tensorboard).