## Retrain Domain Adaptation Network

#### Setup notes: 
    1) Data file paths must be configured locally
    1) CARALA train directory's expected structure: 
        - images in 'CameraRGB' directory 
        - labels in 'CameraSeg' directory
    2) BDD test directory's expected structure:
        - images in 'images' directory 
        - labels in 'labels' directory

In [None]:
import tensorflow as tf
import helper_functions_bdd as test_hf
import helper_functions_carla as train_hf
import numpy as np
import warnings
import os
import time
import pickle
from sklearn.metrics import precision_recall_fscore_support
from sklearn.utils import shuffle
from functools import reduce
import random

### Model Name and Paths
#### Note: parameter values below must be re-assigned based on local directories

In [None]:
MODEL_NAME = 'network_v7_da'
MODEL_RESTORE_VER = '00'
MODEL_SAVE_VER = '00_r0'

DATA_DIR = os.path.join(os.getcwd(), 'datasets', 'da')
TRAIN_DIR = os.path.join(DATA_DIR, 'train_carla')
TEST_DIR = os.path.join(DATA_DIR, 'test_bdd')
RESTORE_DIR = os.path.join(os.getcwd(), 'saved_models', MODEL_NAME, MODEL_RESTORE_VER, 'score')
SAVE_DIR = os.path.join(os.getcwd(), 'saved_models', MODEL_NAME, MODEL_SAVE_VER)

In [None]:
best_fscore = 999.9   #TODO: change based on previous fscore
best_loss = 0.0       #TODO: change based on previous loss

EPOCHS = 200
SHUFFLE_PER_EPOCH = True
BATCH_SIZE = 12
L2_REG = 1e-5
STD_DEV = 1e-2
LEARNING_RATE = 1e-5 # changed for next run
MOMENTUM = 0.9
KEEP_PROB = 0.5
EPSILON = 1e-6
ADAM_EPSILON = 1e-5
SAVE_EPSILON = 1e-4

TRAIN_TRIM = True
TRAIN_TRIM_IND = (115, 523)
TRAIN_RESHAPE = False
TRAIN_NEW_SHAPE = (800, 408)  # shape after trim (without reshape)
TRAIN_LABEL_CHANNELS = [10, 7, 2, 4, 5, 8, 9, 20, 30]

TEST_TRIM = False
TEST_TRIM_IND = (0, 720)     # original height
TEST_RESHAPE = True
TEST_NEW_SHAPE = (640, 360)
TEST_LABEL_CHANNELS = [13, 0, 4, 11, 5, 1, 8, 20, 30]

FLIP = True
PREPROCESS = True
NEW_LABELS = True
CHANNEL_NAMES = ['Back', 'Vehi', 'Road', 'Fence', 'Ped', 'Poles', 'Side', 'Veg', 'BW', 'OT']
LOSS_WEIGHTS = [0.6, 1.2, 0.6, 1.2, 1.2, 1.2, 0.7, 0.7, 0.7, 0.7]
NUM_CLASSES = len(TRAIN_LABEL_CHANNELS) + 1

In [None]:
print(f'MODEL_NAME: {MODEL_NAME}')
print(f'MODEL_RESTORE_VER: {MODEL_RESTORE_VER}')
print(f'MODEL_SAVE_VER: {MODEL_SAVE_VER}')
print(f'TRAIN_DIR: {TRAIN_DIR}')
print(f'TEST_DIR: {TEST_DIR}')
print(f'RESTORE_DIR: {RESTORE_DIR}')
print(f'SAVE_DIR: {SAVE_DIR}')
  
print(f'SHUFFLE_PER_EPOCH: {SHUFFLE_PER_EPOCH}')
print(f'BATCH_SIZE: {BATCH_SIZE}')
print(f'L2_REG: {L2_REG}')
print(f'STD_DEV: {STD_DEV}')
print(f'LEARNING_RATE: {LEARNING_RATE}')
print(f'MOMENTUM: {MOMENTUM}')
print(f'KEEP_PROB: {KEEP_PROB}')
print(f'EPSILON: {EPSILON}')
print(f'ADAM_EPSILON: {ADAM_EPSILON}')
    
print(f'TRAIN_TRIM: {TRAIN_TRIM}')
print(f'TRAIN_TRIM_IND: {TRAIN_TRIM_IND}')
print(f'TRAIN_RESHAPE: {TRAIN_RESHAPE}')
print(f'TRAIN_NEW_SHAPE: {TRAIN_NEW_SHAPE}')
print(f'TRAIN_LABEL_CHANNELS: {TRAIN_LABEL_CHANNELS}')

print(f'TEST_TRIM: {TEST_TRIM}')
print(f'TEST_TRIM_IND: {TEST_TRIM_IND}')
print(f'TEST_RESHAPE: {TEST_RESHAPE}')
print(f'TEST_NEW_SHAPE: {TEST_NEW_SHAPE}')
print(f'TEST_LABEL_CHANNELS: {TEST_LABEL_CHANNELS}')

print(f'FLIP: {FLIP}')
print(f'PREPROCESS: {PREPROCESS}')
print(f'NEW_LABELS: {NEW_LABELS}')
print(f'CHANNEL_NAMES: {CHANNEL_NAMES}')
print(f'LOSS_WEIGHTS: {LOSS_WEIGHTS}')

get_train_batch = train_hf.train_batch_gen(TRAIN_DIR, TRAIN_LABEL_CHANNELS,
                                           preprocess=PREPROCESS, 
                                           new_labels=NEW_LABELS, 
                                           reshape=TRAIN_RESHAPE, 
                                           new_shape=TRAIN_NEW_SHAPE, 
                                           trim=TRAIN_TRIM, 
                                           trim_ind=TRAIN_TRIM_IND)

get_test_batch, revert_trim_reshape = test_hf.test_batch_gen(TEST_DIR, TEST_LABEL_CHANNELS, 
                                                             preprocess=PREPROCESS, 
                                                             new_labels=NEW_LABELS,
                                                             reshape=TEST_RESHAPE, 
                                                             new_shape=TEST_NEW_SHAPE, 
                                                             trim=TEST_TRIM,
                                                             trim_ind=TEST_TRIM_IND)

data_load_start = time.time()
test_images = []
test_labels = []
for images, labels, _ in get_test_batch(100):
    test_images.append(images)
    test_labels.append(labels)
    
test_images = np.array(test_images, dtype=np.uint8)
test_images = test_images.reshape(-1, *test_images.shape[2:])
test_labels = np.array(test_labels, dtype=np.uint8)
test_labels = test_labels.reshape(-1, *test_labels.shape[2:])
print(f'test_images.shape: {test_images.shape}')
print(f'test_labels.shape: {test_labels.shape}')

print(f'Data load time: {time.time() - data_load_start:#0.1f}s')

flat_labels_size = reduce(lambda x, y: x*y, test_labels.shape[:-1])
image_org_shape = (test_labels.shape[1], test_labels.shape[2])
flat_offset = BATCH_SIZE*image_org_shape[0]*image_org_shape[1]

saver = tf.train.import_meta_graph(os.path.join(RESTORE_DIR, MODEL_NAME + '.ckpt.meta'))

with tf.Session() as sess:

    saver.restore(sess, tf.train.latest_checkpoint(RESTORE_DIR))
    graph = tf.get_default_graph()
    
    image_input = graph.get_tensor_by_name('image_input:0')
    label_input = graph.get_tensor_by_name('label_input:0')
    loss_weights = graph.get_tensor_by_name('loss_weights:0')
    keep_prob = graph.get_tensor_by_name('keep_prob:0')
    l_rate = graph.get_tensor_by_name('l_rate:0')
    adam_eps = graph.get_tensor_by_name('adam_eps:0')
    prediction = graph.get_tensor_by_name('output/prediction:0')
    total_loss = graph.get_tensor_by_name('optimize/total_loss:0')
    opt = graph.get_operation_by_name('optimize/Adam')
    
    fscore_avg = 0.0
    
    for epoch in range(EPOCHS):
        start_time = time.time()
        print(f'\nTraining epoch: {epoch+1}/{EPOCHS}')
        
        
        for train_image_batch, train_label_batch, _ in get_train_batch(BATCH_SIZE):
            
            if FLIP:
                if random.randint(0, 1) == 0:
                    # horizontal flip
                    train_image_batch = np.flip(train_image_batch, axis=2)
                    train_label_batch = np.flip(train_label_batch, axis=2)
                
            _, loss = sess.run([opt, total_loss],
                               feed_dict = {image_input: train_image_batch,
                                            label_input: train_label_batch,
                                            loss_weights: LOSS_WEIGHTS,
                                            keep_prob: KEEP_PROB,
                                            l_rate: LEARNING_RATE,
                                            adam_eps: ADAM_EPSILON})
        print(f'Training time: {(time.time() - start_time):#0.1f}s, loss: {loss:#0.5f}') 
        
        sess_time = 0
        total_preds = np.empty((flat_labels_size,), dtype=np.uint8)
        total_labels = np.empty((flat_labels_size,), dtype=np.uint8)
        for offset in range(0, len(test_images), BATCH_SIZE):
            pred_time = time.time()
            test_image_batch = test_images[offset:offset+BATCH_SIZE]
            test_label_batch = test_labels[offset:offset+BATCH_SIZE]            
            preds = sess.run(prediction, feed_dict = {image_input: test_image_batch,
                                                     keep_prob: 1.0})
            
            preds = revert_trim_reshape(preds)
            sess_time += time.time() - pred_time
            
            preds_result = np.array(preds, dtype=np.uint8).reshape(-1)
            labels_result = test_label_batch.argmax(axis=3).reshape(-1)
            
            batch_offset = len(test_label_batch)*image_org_shape[0]*image_org_shape[1]
            i = int(offset/BATCH_SIZE)
            total_preds[i*flat_offset:i*flat_offset+batch_offset] = preds_result
            total_labels[i*flat_offset:i*flat_offset+batch_offset] = labels_result
            
        print(f'Prediction session time: {sess_time:#0.1f}s')
        eval_start_time = time.time()
        metrics = precision_recall_fscore_support(total_labels, total_preds)
        print(f'Evaluation time: {(time.time() - eval_start_time):#0.3f}s')
        del total_preds
        del total_labels 
        
        str_title     = f'             '
        str_recall    = f'Recall:    '
        str_precision = f'Precision: '
        str_f1        = f'F1 score:  '
        str_support   = f'Support:   '
        for i, val in enumerate(metrics[0]):
            str_title += f'{CHANNEL_NAMES[i]:10}'
            str_recall += f'{val:#10.4f}'
            str_precision += f'{metrics[1][i]:#10.4f}'
            str_f1 += f'{metrics[2][i]:#10.4f}'
            str_support += f'{metrics[3][i]:10}'
        print(str_title)
        print(str_recall)
        print(str_precision)
        print(str_f1)
        print(str_support)
        
        fscore_avg = np.mean(np.array(metrics[2]))
        print(f'fscore_avg: {fscore_avg:#0.5f}')
        print(f'Total time: {time.time()-start_time:#0.1f}s')
        
        if fscore_avg - best_fscore > SAVE_EPSILON:
            best_fscore = fscore_avg
            saver.save(sess, os.path.join(SAVE_DIR, 'score', MODEL_NAME + '.ckpt'))  
            print('*************** MODEL SAVED ON SCORE ***************')
        elif best_loss - loss > SAVE_EPSILON:
            best_loss = loss
            saver.save(sess, os.path.join(SAVE_DIR, 'loss', MODEL_NAME + '.ckpt'))  
            print('******* model saved on loss *******')