In [None]:
from sys import path
from os import getcwd, environ

path.insert(0, getcwd())
path.insert(0, getcwd() + "/modules/")
environ["CUDA_VISIBLE_DEVICES"] = "3,1"

from modules.DataMod import DataSet
from modules.CustomLosses import LSSIM, LPSNRB, L3SSIM
from modules.misc import ssim_metric, psnrb_metric
from modules.ImageMetrics.metrics import three_ssim
from tensorflow.keras.optimizers import Adam

from keras import models

import mlflow.keras

import multiprocessing

## Fetching Datasets

In [None]:
# creates the datasets
tinyDataSet, cifarDataSet, cifarAndTinyDataSet = DataSet(), DataSet(), DataSet()

tinyDataSet = tinyDataSet.load_rafael_tinyImagenet_64x64_noise_data()
cifarDataSet = cifarDataSet.load_rafael_cifar_10_noise_data()

# concatenates the datasets
cifarAndTinyDataSet = cifarAndTinyDataSet.concatenateDataSets(cifarDataSet, tinyDataSet)

## Training Models

In [None]:
# to do: 
# paralelize the training (does it's necessary?)
# batch size shoudn't be specified (keras API doc), does it affect the training?

# training with LSSIM loss function and ssim and psnrb metrics


In [None]:
# fix bath_size and epochs (how to decide the number of epochs and batch size?)
batch_size = 20
epochs = 15

file = open("logs/run1.txt", "w")

mlflow.keras.autolog()

# to do: paralelize the training
def train_model_paralel(dataset : DataSet):
        # trains the models with the datasets

        for idx in range(3):
                if idx == 0:
                        model = models.load_model("nNet_models/AutoEncoder-2.3-64x64.json", compile=False)
                elif idx == 1:
                        model = models.load_model("nNet_models/GANResidualAutoEncoder-0.1-64x64.json", compile=False)
                else:
                        model = models.load_model("nNet_models/Unet2.3-64x64.json", compile=False)

                for loss in [L3SSIM(), LSSIM(), LPSNRB()]:

                        try:
                                model.compile(optimizer = Adam(learning_rate=0.001, beta_1=0.9, beta_2=0.999, epsilon=1e-7, amsgrad=False), loss = LSSIM(), metrics = [ssim_metric, three_ssim, psnrb_metric])
                        except Exception as e:
                                file.write(f"Error {e}: Error compiling {model.name} with {dataset.name} dataset\n")
                                file.write(e.__cause__)
                                file.write(e.__context__)
                                continue
                        

                        with mlflow.start_run(run_name= model.name + dataset.name):
                                
                                try:
                                        history = model.fit(
                                                x = dataset.x_train,
                                                y = dataset.y_train,
                                                batch_size = batch_size,
                                                epochs = epochs,
                                                verbose = 1,
                                                validation_split = 0,
                                                shuffle = True,
                                                class_weight = None,
                                                sample_weight = None,
                                                steps_per_epoch = None,
                                                validation_steps = None,
                                                validation_batch_size = None,
                                                validation_freq = 1,
                                                max_queue_size = 10,
                                                workers = 1,
                                                use_multiprocessing = False
                                        )

                                        model.save_weights("models/weights/run1/" + model.name + dataset.name + ".h5")

                                except Exception as e:
                                        file.write(f"Error {e}: Error fitting and saving {model.name} with {dataset.name} dataset\n")
                                        file.write(e.__cause__)
                                        file.write(e.__context__)

In [None]:
procs = []

for dataset in [tinyDataSet, cifarDataSet, cifarAndTinyDataSet]:
        proc = multiprocessing.Process(target=train_model_paralel, args=(dataset, ))
        procs.append(proc)

# waits for the training to finish
for proc in procs:
        proc.start()
        proc.join()