# Import dependencies

In [None]:
from tensorflow import keras
import tensorflow.keras.backend as K
import tensorflow as tf
from data_generator import DataGenerator
import numpy as np
import os
import random
import datetime
from train_val_epoch import train_epoch, validation_epoch
from metrics import plot_feature_space, meanf1_iou, plot_confusion_matrix
import matplotlib.pyplot as plt
import matplotlib
import seaborn as sns
%matplotlib inline

In [None]:
print('TensorFlow {}; Keras {}'.format(tf.__version__, keras.__version__))

In [None]:
tf.test.gpu_device_name()


# Define Parameters

In [None]:
BASE_PATH = "./west/"
TAR_BASE_PATH = "./east/"
NUM_CLASSES = 4
im_height, im_width = 256, 256
test_ratio = 0.1
BATCH_SIZE = 1
NUM_EPOCHS = 50

# Read Dataset

In [None]:
# Extrac name of files
files = os.listdir(BASE_PATH)
tar_files = os.listdir(TAR_BASE_PATH)
files = [BASE_PATH + f for f in files]
tar_files = [TAR_BASE_PATH + f for f in tar_files]
print("###SRC FILES###")
print(BASE_PATH)
print(len(files))
print("###TAR FILES###")
print(TAR_BASE_PATH)
print(len(tar_files))

In [None]:
random.seed(10)
random.shuffle(files)
random.shuffle(tar_files)
test_size = int(len(files) * test_ratio)

test_files = files[:test_size]
non_test_files = files[test_size:]

val_size = int(len(non_test_files) * test_ratio)
val_files = non_test_files[:val_size]
train_files = non_test_files[val_size:]

print("Train size:", len(train_files))
print("Validation size:", len(val_files))
print("Test size:", len(test_files))
print(train_files[0])


# importing models from keras_unet_collection

In [None]:
import models

In [None]:
d_model, model = models.unet_2d((256, 256, 12), [16, 32, 32, 64], n_labels=NUM_CLASSES,
                      stack_num_down=2, stack_num_up=1,
                      activation='GELU', output_activation='Softmax', 
                      batch_norm=True, pool='max', unpool='nearest', name='unet',
                      is_domain_adaptation=True, da_type='conv2d', da_kernels=[64, 32, 32, 16])

In [None]:
maximum, minimum = 14.733826, -49.208305

train_gen = DataGenerator(image_paths=train_files, batch_size=BATCH_SIZE,  augment=True, 
                          shuffle=True, normalize=True, maximum=maximum, minimum=minimum)
val_gen = DataGenerator(image_paths=val_files, batch_size=BATCH_SIZE, augment=False, shuffle=False,
                        normalize=True, maximum=maximum, minimum=minimum)
test_gen = DataGenerator(image_paths=test_files, batch_size=BATCH_SIZE, augment=False, shuffle=False,
                         normalize=True, maximum=maximum, minimum=minimum)

tar_gen = DataGenerator(image_paths=tar_files, batch_size=BATCH_SIZE, augment=False, shuffle=False,
                         normalize=True, maximum=maximum, minimum=minimum)



In [None]:
from losses import weightedLoss

In [None]:
other = 17749814 + 17766350 + 22149798
corn = 204516 + 172453 + 235173
cotton = 88734 + 26780 + 1677
rice = 10122026 + 6884977 + 7128782
total = other + corn + cotton + rice 

# Scaling by total/2 helps keep the loss to a similar magnitude.
# The sum of the weights of all examples stays the same.
weight_for_other = (1 / other) * (total / 2.0)
weight_for_corn = (1 / corn) * (total / 2.0)
weight_for_cotton = (1 / cotton) * (total / 2.0)
weight_for_rice = (1 / rice) * (total / 2.0)

class_weight = {0: weight_for_other, 1: weight_for_corn, 2: weight_for_cotton, 3: weight_for_rice}

print('Weight for class 0: {:.2f}'.format(weight_for_other))
print('Weight for class 1: {:.2f}'.format(weight_for_corn))
print('Weight for class 2: {:.2f}'.format(weight_for_cotton))
print('Weight for class 3: {:.2f}'.format(weight_for_rice))

weights = [weight_for_other, weight_for_corn, weight_for_cotton, weight_for_rice]

In [None]:
model.compile(optimizer='adam', 
              loss=weightedLoss(keras.losses.categorical_crossentropy, weights),
              metrics=['accuracy'])

In [None]:
d_model.compile(optimizer='adam',
               loss='binary_crossentropy',
               metrics=['accuracy'])

In [None]:
epochs = 100
batch_size = 8
iterations = 3


src_seg_acc_train_list = list()
src_seg_acc_test_list = list()
tar_seg_acc_test_list = list()
src_dom_acc_train_list = list()
src_dom_acc_test_list = list()
tar_dom_acc_train_list = list()
tar_dom_acc_test_list = list()
src_dom_loss_train_list = list()
tar_dom_loss_train_list = list()

losses = []
accuracies = []

In [None]:
# define writer to write histories
writing_path = "mylogs" + str(0)
writer = tf.summary.create_file_writer(writing_path)

with writer.as_default():
    src_seg_acc_train_list_ = list()
    src_seg_acc_test_list_ = list()
    tar_seg_acc_test_list_ = list()
    src_dom_acc_train_list_ = list()
    src_dom_acc_test_list_ = list()
    tar_dom_acc_train_list_ = list()
    tar_dom_acc_test_list_ = list()
    src_dom_loss_train_list_ = list()
    tar_dom_loss_train_list_ = list()
    src_seg_loss_train_list_ = list()

    for epoch in range(epochs):
        # train model in one epoch
        (seg_loss, 
        src_dom_loss, 
        tar_dom_loss, 
        seg_train_acc, 
        src_dom_acc, 
        tar_dom_acc) = train_epoch(model, d_model, BATCH_SIZE, train_gen, tar_gen)

        # keep results in lists
        src_seg_acc_train_list_.append(seg_train_acc)
        src_seg_loss_train_list_.append(seg_loss)
        src_dom_acc_train_list_.append(src_dom_acc)
        src_dom_loss_train_list_.append(src_dom_loss)
        tar_dom_acc_train_list_.append(tar_dom_acc)
        tar_dom_loss_train_list_.append(tar_dom_loss)

        # write the results in tf.summary
        tf.summary.scalar('seg_loss_train', seg_loss, step=epoch)
        tf.summary.scalar('src_dom_loss_train', src_dom_loss, step=epoch)
        tf.summary.scalar('tar_dom_loss_train', tar_dom_loss, step=epoch)
        tf.summary.scalar('seg_acc_train', seg_train_acc, step=epoch)
        tf.summary.scalar('src_dom_acc_train', src_dom_acc, step=epoch)
        tf.summary.scalar('tar_dom_acc_train', tar_dom_acc, step=epoch)

        print('Train: Epoch %s: Seg Loss: %.4f, Src Dom Loss: %.4f, Tar Dom Loss: %.4f, Seg Acc: %.4f, Src Dom Acc: %.4f, Tar Dom Acc: %.4f' % 
        (epoch, seg_loss, src_dom_loss, tar_dom_loss, seg_train_acc, src_dom_acc, tar_dom_acc))


        if (epoch + 1) % 5 == 0:
            (src_seg_loss, 
            tar_seg_loss,
            src_dom_loss, 
            tar_dom_loss, 
            src_seg_acc, 
            tar_seg_acc,
            src_dom_acc, 
            tar_dom_acc) = validation_epoch(model, d_model, BATCH_SIZE, val_gen, tar_gen)

            src_seg_acc_test_list_.append(src_seg_acc)
            tar_seg_acc_test_list_.append(tar_seg_acc)
            src_dom_acc_test_list_.append(src_dom_acc)
            tar_dom_acc_test_list_.append(tar_dom_acc)
            tf.summary.scalar('seg_loss_test_src', src_seg_loss, step=epoch)
            tf.summary.scalar('seg_loss_test_tar', tar_seg_loss, step=epoch)
            tf.summary.scalar('src_dom_loss_test', src_dom_loss, step=epoch)
            tf.summary.scalar('tar_dom_loss_test', tar_dom_loss, step=epoch)
            tf.summary.scalar('src_seg_acc', src_seg_acc, step=epoch)
            tf.summary.scalar('tar_seg_scc', tar_seg_acc, step=epoch)
            tf.summary.scalar('src_dom_acc', src_dom_acc, step=epoch)
            tf.summary.scalar('tar_dom_acc', tar_dom_acc, step=epoch)

            print('Test: Epoch %s: Src Seg Loss: %.4f, Tar Seg Loss: %.4f, Src Dom Loss: %.4f, Tar Dom Loss: %.4f, Src Seg Acc: %.4f, Tar Seg Acc: %.4f, Src Dom Acc: %.4f, Tar Dom Acc: %.4f' % 
            (epoch, src_seg_loss, tar_seg_loss, src_dom_loss, tar_dom_loss, src_seg_acc, tar_seg_acc, src_dom_acc, tar_dom_acc))
            model.save_weights(f'DAUNet1/{epoch}')
            

    src_seg_acc_train_list.append(src_seg_acc_train_list_)
    src_seg_acc_test_list.append(src_seg_acc_test_list_)
    tar_seg_acc_test_list.append(tar_seg_acc_test_list_)
    src_dom_acc_train_list.append(src_dom_acc_train_list_)
    src_dom_acc_test_list.append(src_dom_acc_test_list_)
    tar_dom_acc_train_list.append(tar_dom_acc_train_list_)
    tar_dom_acc_test_list.append(tar_dom_acc_test_list_)
    src_dom_loss_train_list.append(src_dom_loss_train_list_)
    tar_dom_loss_train_list.append(tar_dom_loss_train_list_)


# Load model with best weights

model.load_weights(best_weights_path)

# Plot Confusion Matrices

In [None]:
plot_confusion_matrix(model, train_gen)

In [None]:
plot_confusion_matrix(model, val_gen)

In [None]:
plot_confusion_matrix(model, test_gen)

In [None]:
plot_confusion_matrix(model, tar_gen)

# Calculate meanf1 and IoU

In [None]:
meanf1_iou(model, train_gen)

In [None]:
meanf1_iou(model, val_gen)

In [None]:
meanf1_iou(model, test_gen)

In [None]:
meanf1_iou(model, tar_gen)

# Plot bottleneck feature maps

In [None]:
plot_feature_space(model, files, tar_files)