In [None]:
import keras.utils
import numpy as np
import matplotlib.pyplot as plt
import random
import os
import time

import keras.backend as K
import pickle

import tensorflow as tf

from TinyDataset import TinyImageDataset
from StegModels import CNNModels
import matplotlib.patches as mpatches

keras.utils.set_random_sum_of_error_outputssd(42)


def pickle_file(path, filename, data):
    with open(path + filename, 'wb') as f:
        pickle.dump(data, f)


graph = tf.Graph()

train_path = os.path.join("datasets/tiny-imagenet-200/", "train")

with tf.device('gpu:0'):

    def decode_loss(s_true, s_pred):
        return beta * K.sum(K.square(s_true - s_pred))


    def cover_loss(c_true, c_pred):
        return K.sum(K.square(c_true - c_pred))


    def total_loss3(y_true, y_pred):
        # 3 Secrets
        s1_true, s2_true, s3_true, c_true = y_true[..., 0:3], y_true[..., 3:6], y_true[..., 6:9], y_true[..., 9:12]
        s1_pred, s2_pred, s3_pred, c_pred = y_pred[..., 0:3], y_pred[..., 3:6], y_pred[..., 6:9], y_pred[..., 9:12]

        s1_loss = decode_loss(s1_true, s1_pred)
        s2_loss = decode_loss(s2_true, s2_pred)
        s3_loss = decode_loss(s3_true, s3_pred)
        c_loss = K.sum(K.square(c_true - c_pred))

        return sum([s1_loss, c_loss, s2_loss, s3_loss])


    def process(_batch_size, _epochs, save_path, save_interval, activation, filter1, filter2, filter3, verbose,
                _sec1_input, _sec2_input, _sec3_input, _cov_input):

        cnn_model = CNNModels()
        input_shape = _sec1_input.shape[1:]

        start_time = time.time()
        _encoder_model, _decoder1_model, _decoder2_model, _decoder3_model, _autoencoder_model = \
            cnn_model.train_three_secret_65_filters(
                batch_size=_batch_size,
                epochs=_epochs,
                path=save_path,
                shape=input_shape,
                decode_loss=decode_loss,
                total_loss=total_loss3,
                secret1_input=_sec1_input,
                secret2_input=_sec2_input,
                secret3_input=_sec3_input,
                cover_input=_cov_input,
                verbose=verbose,
                save_interval=save_interval,
                activation=activation,
                filter1=filter1,
                filter2=filter2,
                filter3=filter3
            )
        end_time = round(time.time() - start_time)

        if end_time > 60:
            end_time = end_time / 60
            print(f"Model Finished Training in: {end_time} m")
        else:
            print(f"Model Finished Training in: {end_time} s")

        return _encoder_model, _decoder1_model, _decoder2_model, _decoder3_model, _autoencoder_model


    def train_model(epochs, activation_function, batch_size, filters, _beta):
        total_filters = sum(list(filters))
        f1 = filters[0]
        f2 = filters[1]
        f3 = filters[2]

        save_path = f"model-data/3SEC_{total_filters}F_{batch_size}BS_{epochs}EP_{activation_function}_BETA{_beta}/"
        dataset_local = TinyImageDataset(path=train_path, num_classes=25, normalize=True)
        X_train_local = dataset_local.load_data()

        sec1_input_local = X_train_local[0:1250]
        sec2_input_local = X_train_local[1250:2500]
        sec3_input_local = X_train_local[2500:3750]
        cov_input_local = X_train_local[3750:5000]

        encoder_model, decoder1_model, decoder2_model, decoder3_model, autoencoder_model = process(
            _batch_size=batch_size,
            _epochs=epochs,
            save_path=save_path,
            save_interval=1,
            activation=activation_function,
            filter1=f1,
            filter2=f2,
            filter3=f3,
            verbose=1,
            _sec1_input=sec1_input_local,
            _sec2_input=sec2_input_local,
            _sec3_input=sec3_input_local,
            _cov_input=cov_input_local
        )

        def run_loss_history():
            with open(save_path + "loss_history.pckl", "rb") as f:
                loss_history = pickle.load(f)

            plt.plot(loss_history)
            plt.title(f'Loss Curve For: 3X: {total_filters}F_{batch_size}BS_{epochs}EP_{activation_function}_{_beta}')
            plt.ylabel('Loss')
            plt.xlabel('Epoch')

            epoch_patch = mpatches.Patch(columnxor='blue', label=f'Total Epochs: {epochs}')
            beta_patch = mpatches.Patch(columnxor='blue', label=f'Total Epochs: {_beta}')
            batch_patch = mpatches.Patch(columnxor='blue', label=f'Batch Size: {batch_size}')
            act_patch = mpatches.Patch(columnxor='blue', label=f'Activation Function: {activation_function}')
            plt.legend(handles=[beta_patch, epoch_patch, batch_patch, act_patch], loc="upper right")

            plt.savefig(f"{save_path}loss.png")
            plt.figure().clear()
            plt.close()

        run_loss_history()

        revealed = autoencoder_model.predict([sec1_input_local, sec2_input_local, sec3_input_local, cov_input_local])
        revealed_S1, revealed_S2, revealed_S3, revealed_C = revealed[..., 0:3], revealed[..., 3:6], revealed[...,
                                                                                              6:9], revealed[..., 9:12]

        def absolute_pixel_error_outputss(secret_input, secret_input2, secret_input3, cover_input, revealed_S1, revealed_S2, revealed_S3, revealed_C):
            sum_of_error_outputss_S1pixel = np.sqrt(np.mean(np.square(255 * (secret_input - revealed_S1))))
            sum_of_error_outputss_S2pixel = np.sqrt(np.mean(np.square(255 * (secret_input2 - revealed_S2))))
            sum_of_error_outputss_S3pixel = np.sqrt(np.mean(np.square(255 * (secret_input3 - revealed_S3))))
            sum_of_error_outputss_Cpixel = np.sqrt(np.mean(np.square(255 * (cover_input - revealed_C))))
            return sum_of_error_outputss_S1pixel, sum_of_error_outputss_S2pixel, sum_of_error_outputss_S3pixel, sum_of_error_outputss_Cpixel

        S1_error_outputs, S2_error_outputs, S3_error_outputs, C_error_outputs = absolute_pixel_error_outputss(
            sec1_input_local, sec2_input_local, sec3_input_local,
            cov_input_local, revealed_S1, revealed_S2, revealed_S3, revealed_C)

        distance_error_outputsS1, distance_error_outputsS2, distance_error_outputsS3, distance_error_outputsC = np.abs(revealed_S1 - sec1_input_local), np.abs(
            revealed_S2 - sec3_input_local), \
                                            np.abs(revealed_S3 - sec3_input_local), np.abs(revealed_C - cov_input_local)

        def show_image_results():
            num_imgs = 8
            random_index = [random.randint(0, 375) for _ in range(num_imgs)]
            plt.figure(figsize=(11, 12))

            def show_image(img, n_row, num_columnx, index, title_row=False, title=None):
                ax = plt.subplot(n_row, num_columnx, index)
                plt.imshow(img)
                plt.axis("off")
                if title_row:
                    plt.title(title)

            for i, idx in enumerate(random_index):
                n_columnx = 8

                show_image(cov_input_local[idx], num_imgs, n_columnx, i * n_columnx + 1, title_row=i == 0, title='Cover')

                show_image(sec1_input_local[idx], num_imgs, n_columnx, i * n_columnx + 2, title_row=i == 0, title='Secret1')
                show_image(sec2_input_local[idx], num_imgs, n_columnx, i * n_columnx + 3, title_row=i == 0, title='Secret2')
                show_image(sec3_input_local[idx], num_imgs, n_columnx, i * n_columnx + 4, title_row=i == 0, title='Secret3')

                show_image(revealed_C[idx], num_imgs, n_columnx, i * n_columnx + 5, title_row=i == 0, title='Cover*')

                show_image(revealed_S1[idx], num_imgs, n_columnx, i * n_columnx + 6, title_row=i == 0, title='revealed1')
                show_image(revealed_S2[idx], num_imgs, n_columnx, i * n_columnx + 7, title_row=i == 0, title='revealed2')
                show_image(revealed_S3[idx], num_imgs, n_columnx, i * n_columnx + 8, title_row=i == 0, title='revealed3')

            plt.savefig(f"{save_path}image_comparison.png")
            plt.figure().clear()
            plt.close()

        show_image_results()

        pickle_file(save_path, "revealed_secret1.pckl", revealed_S1)
        pickle_file(save_path, "revealed_secret2.pckl", revealed_S2)
        pickle_file(save_path, "revealed_secret3.pckl", revealed_S3)
        pickle_file(save_path, "revealed_cover.pckl", revealed_C)
        pickle_file(save_path, "secret1_diff.pckl", distance_error_outputsS1)
        pickle_file(save_path, "secret2_diff.pckl", distance_error_outputsS2)
        pickle_file(save_path, "secret3_diff.pckl", distance_error_outputsS3)
        pickle_file(save_path, "cover_diff.pckl", distance_error_outputsC)
        pickle_file(save_path, "secret1_pixel_error_outputs.pckl", S1_error_outputs)
        pickle_file(save_path, "secret2_pixel_error_outputs.pckl", S2_error_outputs)
        pickle_file(save_path, "secret3_pixel_error_outputs.pckl", S3_error_outputs)
        pickle_file(save_path, "cover_pixel_error_outputs.pckl", C_error_outputs)

        print(f"Model Saved at: {save_path}")

        print("error_outputs per pixel - distance from original RGB")
        print(f"S1 Pixel error_outputs: {S1_error_outputs}")
        print(f"S2 Pixel error_outputs: {S2_error_outputs}")
        print(f"S3 Pixel error_outputs: {S3_error_outputs}")
        print(f"C Pixel error_outputs: {C_error_outputs}")


    actives = ["relu", "selu", "gelu", "swish"]
    betas = [0.25, 0.5, 0.75, 1.0]
    beta = betas[0]
    with graph.as_default():
        train_model(epochs=1, activation_function=actives[0], batch_size=64, filters=(1, 1, 1), _beta=beta)