In [None]:
# Imports
import keras.utils
import numpy as np
import matplotlib.pyplot as plt
import random
import os
import time

import tensorflow as tf
import keras.backend as K
import pickle

from TinyDataset import TinyImageDataset
from StegModels import CNNModels
import matplotlib.patches as mpatches

keras.utils.set_random_seed(42)

graph = tf.Graph()


def pickle_file(path, filename, data):
    with open(path + filename, 'wb') as f:
        pickle.dump(data, f)


train_path = os.path.join("datasets/tiny-imagenet-200/", "train")

with tf.device('gpu:0'):
    def rev_loss(s_true, s_pred):
        return beta * K.sum(K.square(s_true - s_pred))


    def cover_loss(c_true, c_pred):
        return K.sum(K.square(c_true - c_pred))


    def full_loss(y_true, y_pred):
        # 1 Secret
        s_true, c_true = y_true[..., 0:3], y_true[..., 3:6]
        s_pred, c_pred = y_pred[..., 0:3], y_pred[..., 3:6]

        s_loss = rev_loss(s_true, s_pred)
        c_loss = K.sum(K.square(c_true - c_pred))

        return sum([s_loss, c_loss])


    def process(_batch_size, _epochs, save_path, save_interval, activation, filter1, filter2, filter3, verbose,
                _sec1_input, _cov_input):

        cnn_model = CNNModels()
        input_shape = _sec1_input.shape[1:]

        start_time = time.time()
        _encoder_model, _decoder_model, _autoencoder_model = cnn_model.train_one_secret_65_filters(
            batch_size=_batch_size,
            epochs=_epochs,
            path=save_path,
            shape=input_shape,
            rev_loss=rev_loss,
            full_loss=full_loss,
            secret_input=_sec1_input,
            cover_input=_cov_input,
            verbose=verbose,
            save_interval=save_interval,
            activation=activation,
            filter1=filter1,
            filter2=filter2,
            filter3=filter3
        )
        end_time = round(time.time() - start_time)

        if end_time > 60:
            end_time = end_time / 60
            print(f"Model Finished Training in: {end_time} m")
        else:
            print(f"Model Finished Training in: {end_time} s")

        return _encoder_model, _decoder_model, _autoencoder_model


    def train_model(epochs, activation_function, batch_size, filters, _beta):
        total_filters = sum(list(filters))
        f1 = filters[0]
        f2 = filters[1]
        f3 = filters[2]

        save_path = f"model-data/{total_filters}F_{batch_size}BS_{epochs}EP_{activation_function}_BETA{_beta}/"
        dataset_local = TinyImageDataset(path=train_path, num_classes=25, normalize=True)
        X_train_local = dataset_local.load_data()
        sec1_input_local = X_train_local[0:2500]
        cov_input_local = X_train_local[2500:5000]

        encoder_model, decoder_model, autoencoder_model = process(
            _batch_size=batch_size,
            _epochs=epochs,
            save_path=save_path,
            save_interval=1,
            activation=activation_function,
            filter1=f1,
            filter2=f2,
            filter3=f3,
            verbose=1,
            _sec1_input=sec1_input_local,
            _cov_input=cov_input_local
        )

        print(f"Model Saved at: {save_path}")

        def run_loss_history():
            with open(save_path + "loss_history.pckl", "rb") as f:
                loss_history = pickle.load(f)

            plt.plot(loss_history)
            plt.title(f'Loss Curve For: 1X: {total_filters}F_{batch_size}BS_{epochs}EP_{activation_function}_{_beta}')
            plt.ylabel('Loss')
            plt.xlabel('Epoch')

            epoch_patch = mpatches.Patch(color='blue', label=f'Total Epochs: {epochs}')
            beta_patch = mpatches.Patch(color='blue', label=f'Total Epochs: {_beta}')
            batch_patch = mpatches.Patch(color='blue', label=f'Batch Size: {batch_size}')
            act_patch = mpatches.Patch(color='blue', label=f'Activation Function: {activation_function}')
            plt.legend(handles=[beta_patch, epoch_patch, batch_patch, act_patch], loc="upper right")

            plt.savefig(f"{save_path}loss.png")
            plt.figure().clear()
            plt.close()

        run_loss_history()

        decoded = autoencoder_model.predict([sec1_input_local, cov_input_local])
        decoded_S1, decoded_C = decoded[..., 0:3], decoded[..., 3:6]

        def pixel_errors(input_S1, input_C, decoded_S1, decoded_C):
            see_S1pixel = np.sqrt(np.mean(np.square(255 * (input_S1 - decoded_S1))))
            see_Cpixel = np.sqrt(np.mean(np.square(255 * (input_C - decoded_C))))
            return see_S1pixel, see_Cpixel

        S1_error, C_error = pixel_errors(sec1_input_local, cov_input_local, decoded_S1, decoded_C)
        diff_S1, diff_C = np.abs(decoded_S1 - sec1_input_local), np.abs(decoded_C - cov_input_local)

        def show_image_results():
            num_imgs = 4
            random_index = [random.randint(0, 375) for _ in range(num_imgs)]
            plt.figure(figsize=(11, 12))

            def show_image(img, n_rows, num_col, index, first_row=False, title=None):
                ax = plt.subplot(n_rows, num_col, index)
                plt.imshow(img)
                ax.get_xaxis().set_visible(False)
                ax.get_yaxis().set_visible(False)
                if first_row:
                    plt.title(title)

            for i, idx in enumerate(random_index):
                n_col = 8

                show_image(cov_input_local[idx], num_imgs, n_col, i * n_col + 1, first_row=i == 0, title='Cover')

                show_image(sec1_input_local[idx], num_imgs, n_col, i * n_col + 2, first_row=i == 0, title='Secret1')

                show_image(decoded_C[idx], num_imgs, n_col, i * n_col + 5, first_row=i == 0, title='Cover*')

                show_image(decoded_S1[idx], num_imgs, n_col, i * n_col + 6, first_row=i == 0, title='Decoded1')

            plt.savefig(f"{save_path}image_comparison.png")
            plt.figure().clear()
            plt.close()

        show_image_results()

        pickle_file(save_path, "decoded_secret1.pckl", decoded_S1)
        pickle_file(save_path, "decoded_cover.pckl", decoded_C)
        pickle_file(save_path, "secret1_diff.pckl", diff_S1)
        pickle_file(save_path, "cover_diff.pckl", diff_C)
        pickle_file(save_path, "secret1_pixel_error.pckl", S1_error)
        pickle_file(save_path, "cover_pixel_error.pckl", C_error)

        print(f"Model Saved at: {save_path}")

        print("Error per pixel - distance from original RGB")
        print(f"S1 Pixel Error: {S1_error}")
        print(f"C Pixel Error: {C_error}")

    actives = ["relu", "selu", "gelu", "swish"]
    betas = [0.25, 0.5, 0.75, 1.0]
    beta = betas[0]
    with graph.as_default():
        train_model(epochs=1000, activation_function=actives[0], batch_size=32, filters=(50, 20, 10), _beta=beta)
