# Setup Environment

### Sys varibles

In [None]:
import sys
if 'google.colab' in sys.modules:
    from google.colab import drive
    drive.mount('/content/drive/')
    !cp -r "/content/drive/MyDrive/Training/" "/content/Training"
    !cd "/content/Training"
    sys.path.append("/content/Training")
#For disable GPU
#import os
#os.environ["CUDA_VISIBLE_DEVICES"] = "-1"

### Load Libs

In [None]:
import torch
import os
from skimage import io
from sklearn.model_selection import train_test_split
from TRI2BRI import tIR2bri
from util import imageToInput, inputToImage
from training_GAN import train, train_d, test, Discriminator, adap_train, adap_train_d
import matplotlib.pyplot as plt
import random
from statistics import mean
import numpy as np
import math

### Init random seed

In [None]:
SEED = 14285

random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.backends.cudnn.deterministic = True

### Load Model

In [None]:
tIR_model = tIR2bri()
d_model = Discriminator()
tIR_model.load_state_dict(torch.load("tIR2BRI_e134_0.0001_1513_last.ckpt"))
if torch.cuda.is_available():
    tIR_model = tIR_model.cuda()
    d_model = d_model.cuda()


# Data Loader Functions

In [None]:
def get_file_list(input_dir, result_dir):
    """Get file list for model training

    Args:
        input_dir (str): Input file directory (X)
        result_dir (str): Result file directory (Y)

    Returns:
        List: List of file pair contains (X_path, Y_path)
    """
    result = []
    for filename in os.listdir(input_dir):
        fIR = os.path.join(input_dir, filename)
        fGRAY = os.path.join(result_dir, filename)
        if os.path.isfile(fIR) and os.path.isfile(fGRAY):
            result.append((fIR,fGRAY))
    return result

def data_generator(file_list,shuffle=True):
    """Generate Data from file list.

    Args:
        file_list ([(str,str)]): contains list of file paths for training data in (X,Y) format.

    Yields:
        (ndarray, ndarray): Two image tuple of (X, Y)
    """
    if shuffle:
        file_list = file_list[:]
        random.shuffle(file_list)
    for fIR, fGRAY in file_list:
        x = imageToInput(io.imread(fIR))
        y = imageToInput(io.imread(fGRAY))
        yield x, y

def data_generator_inf(file_list,shuffle=True):
    """Random sample data from file list.

    Args:
        file_list ([(str,str)]): contains list of file paths for training data in (X,Y) format.

    Yields:
        (ndarray, ndarray): Two image tuple of (X, Y)
    """
    while True:
        fIR, fGRAY = random.choice(file_list)
        x = imageToInput(io.imread(fIR))
        y = imageToInput(io.imread(fGRAY))
        yield x, y

# Training Prcess

### Set Training Varibles

In [None]:
x_dir = "./IR-256"
y_dir = "./GRAY-256"
if 'google.colab' in sys.modules:
    x_dir = "/content/Training/IR-256"
    y_dir = "/content/Training/GRAY-256"


file_list = get_file_list(x_dir, y_dir)
#train_file, test_file = train_test_split(
#    file_list, test_size=0.2, random_state=10)
cut = math.floor(len(file_list)*0.8)
train_file, test_file = (file_list[:cut],file_list[cut:])
n_epochs = 120
all_loss_train = []
all_loss_test = []
all_epoch = []
all_d_loss = []
epoch = 0

### Training Functions

In [None]:
d_threshold = 0.05
g_threshold = 0.005
max_d_iter = 100
moving_avg_size = 10
tIR_model.train()
d_model.train()
# Write as while loop, so we can resume interrupts
while epoch < n_epochs:
    train_d_generator = adap_train_d(
        tIR_model, d_model, data_generator_inf(train_file), d_learning_rate=0.0005)
    train_generator = adap_train(
        tIR_model, d_model, data_generator(train_file), learning_rate=0.0001)
    loss_d = 0

    g_empty = False
    while not g_empty:
        sloss_d_list = [d_threshold]
        g_sloss_d_list = [g_threshold]
        # Train Discriminator
        count_d_iter = 0
        for sloss_d, loss_d in train_d_generator:
            d_empty = False
            sloss_d_list.append(sloss_d)
            if len(sloss_d_list) > moving_avg_size:
                sloss_d_list.pop(0)
            count_d_iter += 1
            if mean(sloss_d_list) < d_threshold or count_d_iter > max_d_iter:
                break
        # Train generator
        g_empty = True
        for g_sloss_d, current_loss_train in train_generator:
            g_empty = False
            g_sloss_d_list.append(g_sloss_d)
            if len(g_sloss_d_list) > moving_avg_size:
                g_sloss_d_list.pop(0)
            if mean(g_sloss_d_list) < g_threshold:
                break
            g_empty = True

    current_loss_test = test(tIR_model, d_model, data_generator(test_file))
    all_loss_train.append(current_loss_train)
    all_loss_test.append(current_loss_test)
    all_epoch.append(epoch)
    all_d_loss.append(loss_d)
    epoch += 1
    print("\rEpoch=", epoch, " Loss=", format(current_loss_train, '.5g'),
          ",", format(current_loss_test, '.5g'), " D_LOSS=", format(loss_d, '.5g'), sep="")


### Regular Traing without adaptive GAN learning process

In [None]:
# # Write as while loop, so we can resume interrupts
# while epoch < n_epochs:
#     loss_d = train_d(tIR_model, d_model, data_generator(train_file), d_learning_rate=0.005)
#     current_loss_train = train(
#         tIR_model, d_model, data_generator(train_file), learning_rate=0.001)
#     current_loss_test = test(tIR_model, d_model, data_generator(test_file))
#     all_loss_train.append(current_loss_train)
#     all_loss_test.append(current_loss_test)
#     all_epoch.append(epoch)
#     all_d_loss.append(loss_d)
#     epoch += 1
#     print("\rEpoch=", epoch, " Loss=", format(current_loss_train, '.5g'),
#           ",", format(current_loss_test, '.5g'), " D_LOSS=", format(loss_d, '.5g'), sep="")

### Save Model

In [None]:
torch.save(tIR_model.state_dict(), "tIR2BRI_e134_0.0001_1513_last.ckpt")

# Evaluate Model

### Show training stat

In [None]:
plt.plot(all_epoch, all_loss_test, label = "Test Loss")
plt.plot(all_epoch, all_loss_train, label = "Train Loss")
plt.legend()
plt.show()

plt.plot(all_epoch, all_d_loss, label = "D Loss")
plt.legend()
plt.show()

### Load Data for testing

In [None]:
data = data_generator(test_file)

In [None]:
tIR_model.eval()
x, y = next(data)
torch.cuda.empty_cache() 

### Plot data

In [None]:
if torch.cuda.is_available():
    x = x.cuda()

output = x
output = inputToImage(x.cpu().detach())
io.imshow(output)
plt.title("Thermal IR Image")
plt.show()
io.imsave("./output/Thermal_IR.png", output)

output = tIR_model(x)
output = inputToImage(output.cpu().detach())
io.imshow(output)
plt.title("Generated Image")
plt.show()
io.imsave("./output/Generated.png", output)

output = y
output = inputToImage(y.cpu().detach())
io.imshow(output)
plt.title("Target Image")
plt.show()
io.imsave("./output/Target.png", output)
