In [58]:
import os
import cv2
import pickle
import time
import random
from PIL import Image
import numpy as np
import tensorflow as tf
from numpy.random import rand, randn, randint
from numpy import ones, zeros
from numpy import vstack
from tensorflow import keras
from keras.preprocessing.image import ImageDataGenerator
import matplotlib.pyplot as plt

from keras.layers import Input, Conv2D, Conv2DTranspose, MaxPool2D, Flatten, Dense, Concatenate, Dropout, BatchNormalization, LeakyReLU, Activation
from keras.models import Sequential, Model
from keras.utils.vis_utils import plot_model
from tensorflow.keras.optimizers import Adam, RMSprop, Nadam
from keras.layers import Input, BatchNormalization
from keras.initializers.initializers_v1 import RandomNormal
from sklearn.metrics import classification_report, confusion_matrix, ConfusionMatrixDisplay

In [59]:
CN_Layers = [
    # 1
    {
        'type': 'conv',
        'kernal': 5,
        'filters': 64,
        'dilation': (1, 1),
        'stride': (1, 1),
        'activation': 'relu'
    },
    # 2
    {
        'type': 'conv',
        'kernal': 3,
        'filters': 128,
        'dilation': (1, 1),
        'stride': (2, 2),
        'activation': 'relu'
    },
    {
        'type': 'conv',
        'kernal': 3,
        'filters': 128,
        'dilation': (1, 1),
        'stride': (1, 1),
        'activation': 'relu'
    },
    # 3
    {
        'type': 'conv',
        'kernal': 3,
        'filters': 256,
        'dilation': (1, 1),
        'stride': (2, 2),
        'activation': 'relu'
    },
    {
        'type': 'conv',
        'kernal': 3,
        'filters': 256,
        'dilation': (1, 1),
        'stride': (1, 1),
        'activation': 'relu'
    },
    {
        'type': 'conv',
        'kernal': 3,
        'filters': 256,
        'dilation': (1, 1),
        'stride': (1, 1),
        'activation': 'relu'
    },
    {
        'type': 'conv',
        'kernal': 3,
        'filters': 256,
        'dilation': (2, 2),
        'stride': (1, 1),
        'activation': 'relu'
    },
    {
        'type': 'conv',
        'kernal': 3,
        'filters': 256,
        'dilation': (4, 4),
        'stride': (1, 1),
        'activation': 'relu'
    },
    {
        'type': 'conv',
        'kernal': 3,
        'filters': 256,
        'dilation': (8, 8),
        'stride': (1, 1),
        'activation': 'relu'
    },
    {
        'type': 'conv',
        'kernal': 3,
        'filters': 256,
        'dilation': (16, 16),
        'stride': (1, 1),
        'activation': 'relu'
    },
    {
        'type': 'conv',
        'kernal': 3,
        'filters': 256,
        'dilation': (1, 1),
        'stride': (1, 1),
        'activation': 'relu'
    },
    {
        'type': 'conv',
        'kernal': 3,
        'filters': 256,
        'dilation': (1, 1),
        'stride': (1, 1),
        'activation': 'relu'
    },
    # 4
    {
        'type': 'de_conv',
        'kernal': 4,
        'filters': 128,
        'dilation': (1, 1),
        'stride': (2, 2),
        'activation': 'relu'
    },
    {
        'type': 'conv',
        'kernal': 3,
        'filters': 128,
        'dilation': (1, 1),
        'stride': (1, 1),
        'activation': 'relu'
    },
    # 5
    {
        'type': 'de_conv',
        'kernal': 4,
        'filters': 64,
        'dilation': (1, 1),
        'stride': (2, 2),
        'activation': 'relu'
    },
    {
        'type': 'conv',
        'kernal': 3,
        'filters': 32,
        'dilation': (1, 1),
        'stride': (1, 1),
        'activation': 'relu'
    },
    {
        'type': 'conv',
        'kernal': 3,
        'filters': 3,
        'dilation': (1, 1),
        'stride': (1, 1),
        'activation': 'sigmoid'
    }
]
CN_Layers_Reduced = [
    # 1
    {
        'type': 'conv',
        'kernal': 5,
        'filters': 32,
        'dilation': (1, 1),
        'stride': (1, 1),
        'activation': 'relu'
    },
    # 2
    {
        'type': 'conv',
        'kernal': 3,
        'filters': 64,
        'dilation': (1, 1),
        'stride': (2, 2),
        'activation': 'relu'
    },
    {
        'type': 'conv',
        'kernal': 3,
        'filters': 64,
        'dilation': (1, 1),
        'stride': (1, 1),
        'activation': 'relu'
    },
    # 3
    {
        'type': 'conv',
        'kernal': 3,
        'filters': 128,
        'dilation': (1, 1),
        'stride': (2, 2),
        'activation': 'relu'
    },
    {
        'type': 'conv',
        'kernal': 3,
        'filters': 128,
        'dilation': (1, 1),
        'stride': (1, 1),
        'activation': 'relu'
    },
    {
        'type': 'conv',
        'kernal': 3,
        'filters': 128,
        'dilation': (1, 1),
        'stride': (1, 1),
        'activation': 'relu'
    },
    {
        'type': 'conv',
        'kernal': 3,
        'filters': 128,
        'dilation': (2, 2),
        'stride': (1, 1),
        'activation': 'relu'
    },
    {
        'type': 'conv',
        'kernal': 3,
        'filters': 128,
        'dilation': (4, 4),
        'stride': (1, 1),
        'activation': 'relu'
    },
    {
        'type': 'conv',
        'kernal': 3,
        'filters': 128,
        'dilation': (8, 8),
        'stride': (1, 1),
        'activation': 'relu'
    },
    {
        'type': 'conv',
        'kernal': 3,
        'filters': 128,
        'dilation': (16, 16),
        'stride': (1, 1),
        'activation': 'relu'
    },
    {
        'type': 'conv',
        'kernal': 3,
        'filters': 128,
        'dilation': (1, 1),
        'stride': (1, 1),
        'activation': 'relu'
    },
    {
        'type': 'conv',
        'kernal': 3,
        'filters': 128,
        'dilation': (1, 1),
        'stride': (1, 1),
        'activation': 'relu'
    },
    # 4
    {
        'type': 'de_conv',
        'kernal': 4,
        'filters': 64,
        'dilation': (1, 1),
        'stride': (2, 2),
        'activation': 'relu'
    },
    {
        'type': 'conv',
        'kernal': 3,
        'filters': 64,
        'dilation': (1, 1),
        'stride': (1, 1),
        'activation': 'relu'
    },
    # 5
    {
        'type': 'de_conv',
        'kernal': 4,
        'filters': 32,
        'dilation': (1, 1),
        'stride': (2, 2),
        'activation': 'relu'
    },
    {
        'type': 'conv',
        'kernal': 3,
        'filters': 16,
        'dilation': (1, 1),
        'stride': (1, 1),
        'activation': 'relu'
    },
    {
        'type': 'conv',
        'kernal': 3,
        'filters': 3,
        'dilation': (1, 1),
        'stride': (1, 1),
        'activation': 'sigmoid'
    }
]
IMG_DIM = 128


def Encoder_Module(input_x, filters, kernal, activation, stride, dilation):
    encoding = Conv2D(filters, kernal, stride, padding='same', dilation_rate=dilation)(input_x)
    if activation == 'relu':
        encoding = BatchNormalization(momentum=0.8)(encoding, training=True)
    encoding = Activation(activation)(encoding)
    return encoding


def Decoder_Module(input_x, filters, kernal, activation, stride, dilation):
    decoded = Conv2DTranspose(filters, kernal, stride, padding='same', dilation_rate=dilation)(input_x)
    if activation == 'relu':
        decoded = BatchNormalization(momentum=0.8)(decoded, training=True)
    decoded = Activation(activation)(decoded)
    return decoded


def ContentNetwork():
    ContextNetwork_input_layer = Input(shape=(IMG_DIM, IMG_DIM, 3), )

    ContextNetwork = ContextNetwork_input_layer

    for module in CN_Layers:
        if module['type'] == 'conv':
            ContextNetwork = Encoder_Module(ContextNetwork, module['filters'], module['kernal'], module['activation'],
                                            module['stride'], module['dilation'])
        else:
            ContextNetwork = Decoder_Module(ContextNetwork, module['filters'], module['kernal'], module['activation'],
                                            module['stride'], module['dilation'])

    tmp_model = Model(inputs=ContextNetwork_input_layer, outputs=ContextNetwork)

    return tmp_model


def ContentNetwork2():
    ContextNetwork_input_layer = Input(shape=(IMG_DIM, IMG_DIM, 3), )

    ContextNetwork = ContextNetwork_input_layer

    for module in CN_Layers_Reduced:
        if module['type'] == 'conv':
            ContextNetwork = Encoder_Module(ContextNetwork, module['filters'], module['kernal'], module['activation'],
                                            module['stride'], module['dilation'])
        else:
            ContextNetwork = Decoder_Module(ContextNetwork, module['filters'], module['kernal'], module['activation'],
                                            module['stride'], module['dilation'])

    tmp_model = Model(inputs=ContextNetwork_input_layer, outputs=ContextNetwork)

    return tmp_model


def SC_Decoder_Module(input_x, filters, kernal, activation, stride, dilation, sc=None):
    decoded = Conv2DTranspose(filters, kernal, stride, padding='same', dilation_rate=dilation)(input_x)
    if activation == 'relu':
        decoded = BatchNormalization(momentum=0.8)(decoded, training=True)
    decoded = Activation(activation)(decoded)

    if sc is not None:
        decoded = Concatenate()([decoded, sc])
    return decoded


def skipConnections(input_dim):
    ContextNetwork_input_layer = Input(shape=(input_dim, input_dim, 3), )

    EM1 = Encoder_Module(input_x=ContextNetwork_input_layer, filters=64, kernal=5, activation='relu', stride=1,
                         dilation=1)
    EM2 = Encoder_Module(input_x=EM1, filters=128, kernal=3, activation='relu', stride=2,
                         dilation=1)
    EM3 = Encoder_Module(input_x=EM2, filters=128, kernal=3, activation='relu', stride=1,
                         dilation=1)
    EM4 = Encoder_Module(input_x=EM3, filters=256, kernal=3, activation='relu', stride=2,
                         dilation=1)
    EM5 = Encoder_Module(input_x=EM4, filters=256, kernal=3, activation='relu', stride=1,
                         dilation=1)
    EM6 = Encoder_Module(input_x=EM5, filters=256, kernal=3, activation='relu', stride=1,
                         dilation=1)
    EM7 = Encoder_Module(input_x=EM6, filters=256, kernal=3, activation='relu', stride=1,
                         dilation=2)
    EM8 = Encoder_Module(input_x=EM7, filters=256, kernal=3, activation='relu', stride=1,
                         dilation=4)
    EM9 = Encoder_Module(input_x=EM8, filters=256, kernal=3, activation='relu', stride=1,
                         dilation=8)
    EM10 = Encoder_Module(input_x=EM9, filters=256, kernal=3, activation='relu', stride=1,
                          dilation=16)
    EM11 = Encoder_Module(input_x=EM10, filters=256, kernal=3, activation='relu', stride=1,
                          dilation=1)
    EM12 = Encoder_Module(input_x=EM11, filters=256, kernal=3, activation='relu', stride=1,
                          dilation=1)
    DM1 = SC_Decoder_Module(input_x=EM12, filters=128, kernal=4, activation='relu', stride=2, dilation=1, sc=EM3)
    EM13 = Encoder_Module(input_x=DM1, filters=128, kernal=3, activation='relu', stride=1, dilation=1)
    DM2 = SC_Decoder_Module(input_x=EM13, filters=64, kernal=4, activation='relu', stride=2, dilation=1, sc=EM1)
    EM14 = Encoder_Module(input_x=DM2, filters=32, kernal=3, activation='relu', stride=1, dilation=1)
    EM15 = Encoder_Module(input_x=EM14, filters=3, kernal=3, activation='tanh', stride=1, dilation=1)

    model = Model(inputs=ContextNetwork_input_layer, outputs=EM15, name='Generator')

    return model


In [60]:
OUTPUT_CHANNELS = 3
input = 128

GLOBAL_IMG_DIM = 128
GLOBAL_MODEL_FILTERS = [32, 64, 128]  # [64, 128, 256]  # , 512, 512]  # , 512]

LOCAL_IMG_DIM = 64
LOCAL_MODEL_FILTERS = [64, 128, 256, 512]  # , 512]


# Define the downsample function for the generator model
def downsample(filters, size, apply_batchnorm=True):
    # Weight initialization
    initializer = tf.random_normal_initializer(0., 0.02)

    # Define the model
    model = tf.keras.Sequential()
    model.add(
        tf.keras.layers.Conv2D(filters, size, strides=2, padding='same',
                               kernel_initializer=initializer, use_bias=False))
  
    # Apply batch normalization if required and add leaky ReLU activation function to the model
    if apply_batchnorm:
        model.add(tf.keras.layers.BatchNormalization())

    model.add(tf.keras.layers.LeakyReLU())

    # Return the model
    return model

# Define the upsample function for the generator model
def upsample(filters, size, apply_dropout=False):
    # Weight initialization
    initializer = tf.random_normal_initializer(0., 0.02)

    # Define the model
    model = tf.keras.Sequential()
    model.add(
        tf.keras.layers.Conv2DTranspose(filters, size, strides=2,
                                        padding='same',
                                        kernel_initializer=initializer,
                                        use_bias=False))
  
    # Apply batch normalization if required and add ReLU activation function to the model
    model.add(tf.keras.layers.BatchNormalization())

    # Apply dropout if required
    if apply_dropout:
        model.add(tf.keras.layers.Dropout(0.5))

    model.add(tf.keras.layers.ReLU())

    # Return the model
    return model

# Generator network
def generator_model(in_shape=(256,256,3)):
    inputs = tf.keras.layers.Input(in_shape)
    
    down_stack = [
        downsample(64, 4, apply_batchnorm=False),  # (batch_size, 128, 128, 64)
        downsample(128, 4),  # (batch_size, 64, 64, 128)
        downsample(256, 4),  # (batch_size, 32, 32, 256)
        downsample(512, 4),  # (batch_size, 16, 16, 512)
        downsample(512, 4),  # (batch_size, 8, 8, 512)
        downsample(512, 4),  # (batch_size, 4, 4, 512)
        downsample(512, 4),  # (batch_size, 2, 2, 512)
        downsample(512, 4),  # (batch_size, 1, 1, 512)
    ]
    
    up_stack = [
        upsample(512, 4, apply_dropout=True),  # (batch_size, 2, 2, 1024)
        upsample(512, 4, apply_dropout=True),  # (batch_size, 4, 4, 1024)
        upsample(512, 4, apply_dropout=True),  # (batch_size, 8, 8, 1024)
        upsample(512, 4),  # (batch_size, 16, 16, 1024)
        upsample(256, 4),  # (batch_size, 32, 32, 512)
        upsample(128, 4),  # (batch_size, 64, 64, 256)
        upsample(64, 4),  # (batch_size, 128, 128, 128)
    ]
    
    initializer = tf.random_normal_initializer(0., 0.02)
    last = tf.keras.layers.Conv2DTranspose(OUTPUT_CHANNELS, 4,
                                         strides=2,
                                         padding='same',
                                         kernel_initializer=initializer,
                                         activation='tanh')  # (batch_size, 256, 256, 3)
    
    x = inputs
    
    # Downsampling through the model
    skips = []
    
    for down in down_stack:
        x = down(x)
        skips.append(x)
        
    skips = reversed(skips[:-1])
    
    # Upsampling and establishing the skip connections
    for up, skip in zip(up_stack, skips):
        x = up(x)
        x = tf.keras.layers.Concatenate()([x, skip])
        
    x = last(x)
    
    return tf.keras.Model(inputs=inputs, outputs=x)

def Discriminator(LocalGlobal):
    if LocalGlobal == 'Local':
        IMG_DIM = LOCAL_IMG_DIM
        FILTERS = LOCAL_MODEL_FILTERS
    else:
        IMG_DIM = GLOBAL_IMG_DIM
        FILTERS = GLOBAL_MODEL_FILTERS

    def conv2d(xin, filters, count):
        layer = Conv2D(filters, 5, (2, 2), padding='same')(xin)
        layer = LeakyReLU(alpha=0.2)(layer)
        if filters != 64:
            layer = BatchNormalization(momentum=0.8)(layer)

        return layer

    input_layer = Input(shape=(IMG_DIM, IMG_DIM, 3))

    x = input_layer

    for FILTERS, count in zip(FILTERS, range(len(FILTERS))):
        x = conv2d(x, FILTERS, count)

    x = Dropout(0.4)(x)
    x = Flatten()(x)
    x = Dense(128, activation='relu')(x)
    x = Dense(1, activation='sigmoid')(x)
    DiscriminatorModel = Model(inputs=input_layer, outputs=x)

    return DiscriminatorModel


def Discriminator2(LocalGlobal):
    if LocalGlobal == 'Local':
        IMG_DIM = LOCAL_IMG_DIM
        FILTERS = LOCAL_MODEL_FILTERS
    else:
        IMG_DIM = GLOBAL_IMG_DIM
        FILTERS = GLOBAL_MODEL_FILTERS

    def conv2d(xin, filters):
        layer = Conv2D(filters, 5, (2, 2), padding='same')(xin)
        if filters != 32:
            layer = BatchNormalization(momentum=0.8)(layer)
        layer = LeakyReLU(alpha=0.2)(layer)
        return layer

    input_layer = Input(shape=(IMG_DIM, IMG_DIM, 3))

    x = input_layer

    for FILTERS in FILTERS:
        x = conv2d(x, FILTERS)

    x = Dropout(0.4)(x)
    x = Flatten()(x)
    x = Dense(1, activation='sigmoid')(x)
    DiscriminatorModel2 = Model(inputs=input_layer, outputs=x)

    return DiscriminatorModel2


# Define the global discriminator model
def global_discriminator(in_shape=(256, 256, 3)):
    init = RandomNormal(mean=0.0, stddev=0.02)
    global_disc_input = Input(in_shape)
    
    x = Conv2D(32, 5, padding='same', input_shape=in_shape, strides=(2, 2), kernel_initializer=init)\
        (global_disc_input)
    x = BatchNormalization()(x)
    x = LeakyReLU(alpha=0.2)(x)
    
    x = Conv2D(64, 5, padding='same', strides=(2, 2), kernel_initializer=init)(x)
    x = BatchNormalization()(x)
    x = LeakyReLU(alpha=0.2)(x)

    x = Conv2D(128, 5, padding='same', strides=(2, 2), kernel_initializer=init)(x)
    x = BatchNormalization()(x)
    x = LeakyReLU(alpha=0.2)(x)
    
    x = Conv2D(256, 5, padding='same', strides=(2, 2), kernel_initializer=init)(x)
    x = BatchNormalization()(x)
    x = LeakyReLU(alpha=0.2)(x)
    
    x = Conv2D(256, 5, padding='same', strides=(2, 2), kernel_initializer=init)(x)
    x = BatchNormalization()(x)
    x = LeakyReLU(alpha=0.2)(x)
    
    x = Flatten()(x)
    x = Dense(512,activation='relu')(x)
    global_disc_output = Dense(1, activation='sigmoid', kernel_initializer=init)(x)
    
    model = Model(global_disc_input, outputs=global_disc_output)
    opt = Adam(learning_rate=0.0002, beta_1=0.5)
    model.compile(loss='binary_crossentropy', optimizer=opt)
    return model


def multi_patch_discriminator(in_shape=(256, 256, 3)):
    # weight initialization
    init = RandomNormal(mean=0.0, stddev=0.02)
    # source image input
    in_src_image = Input(shape=in_shape)
    # C64
    d = Conv2D(32, (4, 4), strides=(2, 2), padding='same', kernel_initializer=init)(in_src_image)
    d = LeakyReLU(alpha=0.2)(d)
    # C128
    d = Conv2D(64, (4, 4), strides=(2, 2), padding='same', kernel_initializer=init)(d)
    d = BatchNormalization()(d)
    d = LeakyReLU(alpha=0.2)(d)
    # C256
    d = Conv2D(128, (4, 4), strides=(2, 2), padding='same', kernel_initializer=init)(d)
    d = BatchNormalization()(d)
    d = LeakyReLU(alpha=0.2)(d)
    # C512
    d = Conv2D(256, (4, 4), strides=(2, 2), padding='same', kernel_initializer=init)(d)
    d = BatchNormalization()(d)
    d = LeakyReLU(alpha=0.2)(d)
    # second last output layer
    d = Conv2D(256, (4, 4), padding='same', kernel_initializer=init)(d)
    d = BatchNormalization()(d)
    d = LeakyReLU(alpha=0.2)(d)
    # patch output
    d = Conv2D(1, (4, 4), padding='same', kernel_initializer=init)(d)
    patch_out = Activation('sigmoid')(d)
    # define model
    model = Model(in_src_image, patch_out)
    return model

In [61]:
def define_gan(g_model, dpatch_model, dglobal_model, in_shape=(256,256,3)):
    # make weights in the discriminator not trainable
    for layer in dpatch_model.layers:
        if not isinstance(layer, BatchNormalization):
            layer.trainable = False

    for layer in dglobal_model.layers:
        if not isinstance(layer, BatchNormalization):
            layer.trainable = False

    # define the source image
    in_src = Input(shape=in_shape)
    
    # connect the source image to the generator input
    gen_out = g_model(in_src)
    
    # connect the source input and generator output to the discriminator input
    dispatch_out = dpatch_model(gen_out)
    disglobal_out = dglobal_model(gen_out)
    
    # src image as input, generated image and classification output
    model = Model(in_src, [dispatch_out,disglobal_out, gen_out])
    
    # compile model
    opt = Adam(learning_rate=0.0002, beta_1=0.5)
    model.compile(loss=['binary_crossentropy','binary_crossentropy','mse'], optimizer=opt, loss_weights=[1,1,1000])
    return model

In [63]:
# Create generator network
g_model = skipConnections(256)

# Create discriminator networks
dpatch_model = multi_patch_discriminator(in_shape=(256,256,3))
dglobal_model = global_discriminator(in_shape=(256,256,3))

opt = Adam(learning_rate=1e-4, beta_1=0.5)
dpatch_model.compile(loss='binary_crossentropy', optimizer=opt)
dglobal_model.compile(loss='binary_crossentropy', optimizer=opt)


GAN1 = define_gan(g_model, dpatch_model, dglobal_model, in_shape=(256,256,3))
g_model.summary()
GAN1.summary()

Model: "Generator"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 input_11 (InputLayer)          [(None, 256, 256, 3  0           []                               
                                )]                                                                
                                                                                                  
 conv2d_65 (Conv2D)             (None, 256, 256, 64  4864        ['input_11[0][0]']               
                                )                                                                 
                                                                                                  
 batch_normalization_74 (BatchN  (None, 256, 256, 64  256        ['conv2d_65[0][0]']              
 ormalization)                  )                                                         

                                                                                                  
 batch_normalization_84 (BatchN  (None, 64, 64, 256)  1024       ['conv2d_75[0][0]']              
 ormalization)                                                                                    
                                                                                                  
 activation_41 (Activation)     (None, 64, 64, 256)  0           ['batch_normalization_84[0][0]'] 
                                                                                                  
 conv2d_76 (Conv2D)             (None, 64, 64, 256)  590080      ['activation_41[0][0]']          
                                                                                                  
 batch_normalization_85 (BatchN  (None, 64, 64, 256)  1024       ['conv2d_76[0][0]']              
 ormalization)                                                                                    
          

In [45]:
# Create instances of the generator, discriminator, and multi-patch discriminator
g_model = generator_model(in_shape=(256,256,3))
dpatch_model = multi_patch_discriminator(in_shape=(256,256,3))
dglobal_model = global_discriminator(in_shape=(256,256,3))

opt = Adam(learning_rate=1e-4, beta_1=0.5)
dpatch_model.compile(loss='binary_crossentropy', optimizer=opt)
dglobal_model.compile(loss='binary_crossentropy', optimizer=opt)

GAN = define_gan(g_model, dpatch_model, dglobal_model, in_shape=(256,256,3))

print("\nGAN Summary:")
GAN.summary()


GAN Summary:
Model: "model_7"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 input_8 (InputLayer)           [(None, 256, 256, 3  0           []                               
                                )]                                                                
                                                                                                  
 model_4 (Functional)           (None, 256, 256, 3)  54425859    ['input_8[0][0]']                
                                                                                                  
 model_5 (Functional)           (None, 16, 16, 1)    1745889     ['model_4[0][0]']                
                                                                                                  
 model_6 (Functional)           (None, 1)            11109313    ['model_4[0][

### Generating real and fake samples

Real:
Chooses a random selection of images from the unmasked images of a specified batch size, returns the images as well as the index of the images so that this can be used to select the corresponding masked images

Fake:
Using the same indexes as the real images, the corresponding fake images are passed through the GAN to generate outputs

In [64]:
def generate_real_samples(orig_dir, masked_dir, dataset_list, batch_size):
    X = dataset_list
    # Declare arrays
    original_images=[]
    damaged_images=[]
    
    # Pick random samples
    ix= random.sample(X, batch_size)
    for i in ix:
        # Read in original images
        image_real = cv2.imread(os.path.join(orig_dir, i))
        original_images.append(image_real)
        # Read in damaged versions
        image_damaged=cv2.imread(os.path.join(masked_dir, i))
        damaged_images.append(image_damaged)
    
    original_images=np.asarray(original_images)
    original_images=(original_images - 127.5) / 127.5
    
    damaged_images=np.asarray(damaged_images)
    damaged_images=(damaged_images - 127.5) / 127.5

    return original_images, damaged_images


In [65]:
def saveGeneratedSamples(original_images, recon_images, groundtruth_images, epoch):
    if not os.path.exists('Images'):
        os.mkdir('Images')

    n = 5
    plt.figure(figsize=(12, 6))
    for i in range(0, n):
        # Display original
        ax = plt.subplot(3, n, i + 1)
        plt.imshow(((original_images + 1) / 2)[i + 5][:, :, ::-1])
#         plt.gray()
        ax.get_xaxis().set_visible(False)
        ax.get_yaxis().set_visible(False)
        # Title
        ax.title.set_text('Original Images')

        # Display reconstruction
        ax = plt.subplot(3, n, i + 1 + 5)
        plt.imshow(((recon_images + 1) / 2)[i + 5][:, :, ::-1])
#         plt.gray()
        ax.get_xaxis().set_visible(False)
        ax.get_yaxis().set_visible(False)
        # Title
        ax.title.set_text('Generated Images')

        # Display Ground Truth
        ax = plt.subplot(3, n, i + 1 + 10)
        plt.imshow(((groundtruth_images + 1) / 2)[i + 5][:, :, ::-1])
#         plt.gray()
        ax.get_xaxis().set_visible(False)
        ax.get_yaxis().set_visible(False)
        # Title
        ax.title.set_text('Ground Truth Images')
    ax.set_facecolor("white")
    image_format = 'png' # e.g .png, .svg, etc.
    filename = f"Images/epoch_{epoch}_images.png"
    plt.savefig(filename, format=image_format)
    plt.close()

In [66]:
def plot_confusion_matrix(Xin, Yin, DN):

    CM_test_x = Xin

    y_pred = DN.predict(CM_test_x)
    target_names = ["iceberg", "iceberg_masked"]

    cm = confusion_matrix(Yin, y_pred.round())
    print('Confusion Matrix')
    disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=target_names)

    disp.plot(cmap=plt.cm.Blues)
    plt.grid(False)
    plt.show()

    print('Classification Report')
    print(classification_report(Yin, y_pred.round(), target_names=target_names))

In [67]:
 def plot_discriminator(dlossreal, dlossfake, dlossreal2, dlossfake2, steps):
    plt.plot(steps, dlossreal)
    plt.plot(steps, dlossfake)
    plt.plot(steps, dlossreal2)
    plt.plot(steps, dlossfake2)
    plt.title('Discriminator Losses')
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.legend(['Discriminator train real loss', 'Discriminator train fake loss','Discriminator train real loss2', 'Discriminator train fake loss2'], loc='bottom left')
    filename = 'Graphs/discriminator_loss_graph_testing.png'
    plt.savefig(filename)
    plt.close()
    
def plot_GAN(glossbce, glossbce2, steps):
    plt.plot(steps, glossbce)
    plt.plot(steps, glossbce2)
    plt.title('GAN Losses')
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.legend(['GAN train BCE loss', 'GAN train BCE loss2'], loc='bottom left')
    filename = 'Graphs/GAN_loss_graph_testing.png'
    plt.savefig(filename)
    plt.close()

### Model Training

In [68]:
data_dir = 'C:/Users/chris/2023-mcm-master/src/data/masked_images_split/train'
#image_files = data_dir + '/**/*_masked.png'

orig_data_dir = 'C:/Users/chris/2023-mcm-master/src/data/dataset_split/train'
#orig_image_files = orig_data_dir + '/**/*.png'

# Define the ImageDataGenerator for training data
train_datagen = ImageDataGenerator(rescale=1./255)

# Load and preprocess training data
train_masked = train_datagen.flow_from_directory(
    data_dir,  # Pass the directory path, not the file paths
    target_size=(256, 256),  # Specify the target size of the images
    batch_size=32,
    class_mode='input',  # Use 'input' for input modeling
    shuffle=True
)

# Load and preprocess training data from orig_data_dir
train_orig = train_datagen.flow_from_directory(
    orig_data_dir,  # Pass the directory path, not the file paths
    target_size=(256, 256),
    batch_size=32,
    class_mode='input',
    shuffle=True
)

Found 7413 images belonging to 22 classes.
Found 7686 images belonging to 22 classes.


In [69]:
# Custom generator that yields both masked and original images
def combined_generator(masked_generator, orig_generator):
    while True:
        masked_images, _ = masked_generator.next()
        orig_images, _ = orig_generator.next()
        yield masked_images, orig_images

# Create the combined generator
combined_train_cgan = combined_generator(train_masked, train_orig)

In [70]:
from tqdm import tqdm

def train(dpatch_model, dglobal_model, g_model, gan_model,
          dataset_list, dataset_test_list,
          orig_dir, masked_dir,
          n_epochs=200, n_batch=32,n_batch_test=32): 
    
    bat_per_epoch = int(len(dataset_list) / n_batch)
    half_batch = int(n_batch / 2)

    # Ground Truths
    y_real1 = ones((n_batch,16,16, 1))
    y_real2 = ones((n_batch, 1))
    y_fake1 = zeros((n_batch,16,16, 1))
    y_fake2 = zeros((n_batch, 1))
    
    # prepare lists for storing stats each iteration
    dlossreal_epoch=[]
    dlossfake_epoch=[]
    dlossreal2_epoch=[]
    dlossfake2_epoch=[]
    glossbce_epoch=[]
    glossbce2_epoch=[]
    glossmae_epoch=[]
    steps=[]
    # manually enumerate epochs
    for i in range(n_epochs):
        start_time = time.time()
        dlossreal=[]
        dlossfake=[]
        dlossreal2=[]
        dlossfake2=[]
        glossbce=[]
        glossbce2=[]
        glossmae=[]
        print('>Epoch: %d' % (i+1))
            
        for batch in tqdm(range(bat_per_epoch)):
            # Generate real images and select damaged images
            X_real, damaged_images = generate_real_samples(orig_dir, masked_dir, dataset_list, n_batch)
            
            resized_images = []
            
            for image in damaged_images:
                # Convert the data type to uint8
                image = (image * 255).astype(np.uint8)
                resized_image = Image.fromarray(image).resize((256, 256))
                resized_images.append(resized_image)
                
            # Convert the resized images back to numpy array
            damaged_images_resized = np.array([np.array(image) for image in resized_images])

            # Generate fake images
            X_fake = g_model.predict(damaged_images)
            
            # update discriminator for real samples
            d_loss_real = dpatch_model.train_on_batch(X_real, y_real1)
            d_loss_real2 = dglobal_model.train_on_batch(X_real,y_real2)            
            dlossreal.append(d_loss_real)
            dlossreal2.append(d_loss_real2)
            
            # update discriminator for generated samples
            d_loss_fake = dpatch_model.train_on_batch(X_fake, y_fake1)
            d_loss_fake2 = dglobal_model.train_on_batch(X_fake,y_fake2)
            dlossfake.append(d_loss_fake)
            dlossfake2.append(d_loss_fake2)
   
            # Update Generator weights
            gloss_all, g_loss_BCE, g_loss_BCE2, g_loss_mae = gan_model.train_on_batch(damaged_images, [y_real1, y_real2, X_real])
            glossbce.append(g_loss_BCE)
            glossbce2.append(g_loss_BCE2)
            glossmae.append(g_loss_mae)
            
        # record history
        dlossreal_epoch.append(np.mean(dlossreal))
        dlossfake_epoch.append(np.mean(dlossfake))
        dlossreal2_epoch.append(np.mean(dlossreal2))
        dlossfake2_epoch.append(np.mean(dlossfake2))
        
        glossbce_epoch.append(sum(glossbce)/len(glossbce))
        glossbce2_epoch.append(sum(glossbce2)/len(glossbce2))
        glossmae_epoch.append(sum(glossmae)/len(glossmae))
        steps.append(i)
        finish_time = time.time()    
        print('d_real[%.5f] d_fake[%.5f] g_BCE[%.5f] d_real2[%.5f] d_fake2[%.5f] g_BCE2[%.5f] g_mse[%.5f]' % (dlossreal_epoch[-1], dlossfake_epoch[-1], glossbce_epoch[-1],dlossreal2_epoch[-1], dlossfake2_epoch[-1], glossbce2_epoch[-1],glossmae_epoch[-1]))        
        print('Time for epoch: %.0f sec' % (finish_time-start_time))
        if (i+1)%10 == 0:
            saveGeneratedSamples(damaged_images, X_fake, X_real, i+1)
            plot_discriminator(dlossreal_epoch, dlossfake_epoch, dlossreal2_epoch, dlossfake2_epoch, steps)
            plot_GAN(glossbce_epoch, glossbce2_epoch, steps)
    return dlossreal_epoch,dlossfake_epoch,glossbce_epoch,glossmae_epoch,dlossreal2_epoch,dlossfake2_epoch,glossbce2_epoch

In [71]:
masked_dir = 'C:/Users/chris/2023-mcm-master/src/data/masked_images_split/train'
orig_dir = 'C:/Users/chris/2023-mcm-master/src/data/dataset_split/train'

datasetList = []  # Define the datasetList variable

for root, _, files in os.walk(orig_dir):
    for file in files:
        if file.endswith('.jpg') or file.endswith('.png'):  # Add more valid image extensions if needed
            file_path = os.path.join(root, file)
            datasetList.append(file_path)

In [72]:
results=train(dpatch_model, dglobal_model, g_model, GAN, datasetList, [], 
              orig_dir, masked_dir, n_epochs=10, n_batch=16)

>Epoch: 1


  0%|          | 0/480 [00:01<?, ?it/s]


ValueError: in user code:

    File "C:\Users\chris\anaconda3\envs\snowflakes\lib\site-packages\keras\engine\training.py", line 1801, in predict_function  *
        return step_function(self, iterator)
    File "C:\Users\chris\anaconda3\envs\snowflakes\lib\site-packages\keras\engine\training.py", line 1790, in step_function  **
        outputs = model.distribute_strategy.run(run_step, args=(data,))
    File "C:\Users\chris\anaconda3\envs\snowflakes\lib\site-packages\keras\engine\training.py", line 1783, in run_step  **
        outputs = model.predict_step(data)
    File "C:\Users\chris\anaconda3\envs\snowflakes\lib\site-packages\keras\engine\training.py", line 1751, in predict_step
        return self(x, training=False)
    File "C:\Users\chris\anaconda3\envs\snowflakes\lib\site-packages\keras\utils\traceback_utils.py", line 67, in error_handler
        raise e.with_traceback(filtered_tb) from None
    File "C:\Users\chris\anaconda3\envs\snowflakes\lib\site-packages\keras\engine\input_spec.py", line 264, in assert_input_compatibility
        raise ValueError(f'Input {input_index} of layer "{layer_name}" is '

    ValueError: Input 0 of layer "Generator" is incompatible with the layer: expected shape=(None, 256, 256, 3), found shape=(None, 1024, 1024, 3)


In [41]:
"""
# Calculate the number of steps per epoch
steps_per_epoch = min(len(train_masked), len(train_orig))

# Create an instance of the GAN model
#gan = gan_model(generator_model(), discriminator_model(), multi_patch_discriminator())

# Train the GAN model
gan.fit(combined_train_cgan, 
        epochs=50,
        verbose=1)
"""

'\n# Calculate the number of steps per epoch\nsteps_per_epoch = min(len(train_masked), len(train_orig))\n\n# Create an instance of the GAN model\n#gan = gan_model(generator_model(), discriminator_model(), multi_patch_discriminator())\n\n# Train the GAN model\ngan.fit(combined_train_cgan, \n        epochs=50,\n        verbose=1)\n'