# U-Net Training

Script to train the U-Net model.

## 1. Importing libraries and changing working directory

Reference for the libraries:

+ [numpy](https://numpy.org/)
+ [gdal](https://gdal.org/api/python.html)
+ [os](https://docs.python.org/3/library/os.html)
+ [datetime](https://docs.python.org/3/library/datetime.html)
+ [time](https://docs.python.org/3/library/time.html)
+ [deepgeo](https://github.com/rvmaretto/deepgeo)

In [None]:
import numpy as np
import gdal
import os
from datetime import datetime
import time
import deepgeo.networks.model_builder as mb

In [None]:
# folder where all data is stored
os.chdir(os.getcwd().rsplit('/',2)[0]+'/Data')

## 2. Creates folder to save the model and define the model's ID

In [None]:
if not os.path.exists('./models/UNET'):
    os.makedirs('./models/UNET')
models_folder = './models/UNET/'

In [None]:
now = datetime.now()

identifier = '{0}-{1}-{2}_{3}-{4}-{5}'.format(now.year, 
                                              str(now.month).zfill(2), 
                                              str(now.day).zfill(2), 
                                              str(now.hour).zfill(2), 
                                              str(now.minute).zfill(2), 
                                              str(now.second).zfill(2))

print('Identifier:', identifier)
model_dir = models_folder+identifier
os.mkdir(model_dir)
print('Model Folder: ', model_dir)

## 3. Defining parameters

In [None]:
# Approach for the training samples.
sit = 'appr1'
# State
state = 'BA'
# platform
platform = 'Sentinel'

train_tfrecord = f'./training_samples/UNET/{sit}.{state}.{platform}/{sit}.{state}.{platform}_train.tfrecord'
test_tfrecord  = f'./training_samples/UNET/{sit}.{state}.{platform}/{sit}.{state}.{platform}_test.tfrecord'

In [None]:
params = {
    'network': 'unet',
    'epochs': 100,
    'batch_size': 5,
    'chip_size': 284,
    'bands': 2,
    'learning_rate': 0.01,
    'learning_rate_decay': True,
    'decay_rate': 0.92,
    'l2_reg_rate': 0.0005,
    'chips_tensorboard': 2,
    'loss_func': 'avg_soft_dice',
    'data_aug_ops': ['rot90', 'rot180', 'rot270', 'flip_left_right',
                     'flip_up_down', 'flip_transpose'],
    'data_aug_per_chip': 6,
    'num_classes': 3,
    'class_names': ['no_data', 'Deforestation', 'Not Deforestation'],
    'num_compositions': 1,
    'bands_plot': [0 , 0, 1],
    'Notes': train_tfrecord
}

## 4. Train the model

In [None]:
t1 = time.time()

In [None]:
model = mb.ModelBuilder(params)
model.train(train_tfrecord, test_tfrecord, model_dir)

In [None]:
t2 = time.time()
print('Elapsed time: %.3f min.' % ((t2-t1)/60))