In [None]:
import gc
import tensorflow as tf # type: ignore
import os
from tensorflow.keras.models import Model, Sequential, load_model # type: ignore
from tensorflow.keras.layers import Input, Conv2D, BatchNormalization, Activation, MaxPooling2D, Flatten, Dense, Reshape, Conv2DTranspose, Add, LeakyReLU, UpSampling2D, Dropout, Concatenate, AveragePooling2D, GlobalMaxPooling2D, Lambda, ZeroPadding2D  # type: ignore
from keras.initializers import glorot_uniform # type: ignore
from tensorflow.keras.optimizers import SGD, Adam # type: ignore
from tensorflow.keras.regularizers import l2 # type: ignore
from tensorflow.keras.optimizers.schedules import ExponentialDecay # type: ignore
from tensorflow.keras.losses import MeanSquaredError, CategoricalCrossentropy, MeanAbsoluteError # type: ignore
from tensorflow.keras.metrics import CategoricalAccuracy # type: ignore
from tensorflow.keras.preprocessing.image import ImageDataGenerator # type: ignore
from tensorflow.keras.callbacks import ModelCheckpoint
from tensorflow.keras.utils import to_categorical # type: ignore
import keras_tuner as kt # type: ignore
import matplotlib.pyplot as plt # type: ignore
import numpy as np # type: ignore
import scipy
from numba import cuda
from multiprocessing import Pool, Manager

In [None]:
# Fonction qui genere des vecteurs aleatoires
def generate_unique_vectors(num_vectors, vector_length, vectors):
    vectors_list = []
    while len(vectors_list) < num_vectors:
        vector = tuple(np.random.randint(0, 2, vector_length))
        if vector not in vectors:
            vectors.add(vector)
            vectors_list.append(vector)
    return vectors_list, vectors

# Fonction qui convertie le generateur en tableau numpy
def generator_to_array(generator, class_vectors):
    samples = []
    vectors = []
    data_filenames = generator.filenames
    total_images = len(data_filenames)

    for i in range(len(generator)):
        batch = next(generator)
        batch_size = len(batch[0])

        for j in range(batch_size):
            index = i * generator.batch_size + j
            if index >= total_images:
                break  # Prevent going out of bounds

            samples.append(batch[0][j])
            class_name = data_filenames[index].split(os.path.sep)[0]
            vectors.append(class_vectors[class_name])
            
    return np.array(samples), np.array(vectors), data_filenames

# Fonction qui associe les images aux labels
def preprocess(train_generator, val_generator, num_classes, vector_length, total_vectors, use_random_vectors):
    class_indices = train_generator.class_indices
    
    if use_random_vectors:
        unique_vectors, total_vectors = generate_unique_vectors(num_classes, vector_length, total_vectors)
        class_vectors = {class_name: vector for class_name, vector in zip(class_indices, unique_vectors)}
    else:
        class_vectors = {class_name: i for i, class_name in enumerate(class_indices)}
    
    samples_train, vectors_train, _ = generator_to_array(train_generator, class_vectors)
    samples_val, vectors_val, _ = generator_to_array(val_generator, class_vectors)
    
    if not use_random_vectors:
        # Convert class indices to one-hot encoding
        vectors_train = to_categorical(vectors_train, num_classes=num_classes)
        vectors_val = to_categorical(vectors_val, num_classes=num_classes)
    
    return samples_train, vectors_train, samples_val, vectors_val, total_vectors

# Define the worker function
def process_ethnie(ethnie, target_size, batch_size, class_mode, shuffle, color_mode, use_random_vectors, vector_length):
    datagen = ImageDataGenerator(rescale=1./255, validation_split=0.05, dtype='float16')
    
    trainset = datagen.flow_from_directory(f'../../Datasets/VGG/{ethnie}', target_size=target_size, batch_size=batch_size, class_mode=class_mode, shuffle=shuffle, color_mode=color_mode, subset='training')
    testset  = datagen.flow_from_directory(f'../../Datasets/VGG/{ethnie}', target_size=target_size, batch_size=batch_size, class_mode=class_mode, shuffle=shuffle, color_mode=color_mode, subset='validation')
    
    num_classes = len(trainset.class_indices)
    total_vectors = set()
    
    samples_train, vectors_train, samples_val, vectors_val, total_vectors = preprocess(trainset, testset, num_classes, vector_length, total_vectors, use_random_vectors)
    
    return ethnie, samples_train, vectors_train, samples_val, vectors_val

# Fonction qui charge les donnees
def load_data(target_size=(150, 150), batch_size=112, class_mode='input', shuffle=False, color_mode='grayscale', use_random_vectors=True, vector_length=56):
    ethnies = ['caucasians', 'afro_americans', 'asians']
    
    # Use Manager to create a shared dictionary
    manager = Manager()
    return_dict = manager.dict()
    
    # Create a pool of processes
    pool = Pool(processes=4*len(ethnies))
    
    # Create a list of arguments for each ethnic group
    args_list = [(ethnie, target_size, batch_size, class_mode, shuffle, color_mode, use_random_vectors, vector_length) for ethnie in ethnies]
    
    # Map the worker function to the list of arguments
    results = pool.starmap(process_ethnie, args_list)
    
    # Close the pool and wait for the work to finish
    pool.close()
    pool.join()
    
    # Collect the results from the pool
    ethnies_data = {ethnie: [samples_train, vectors_train, samples_val, vectors_val] for ethnie, samples_train, vectors_train, samples_val, vectors_val in results}
    
    return ethnies_data

# Parameters
target_size = (300, 300);batch_size = 128;class_mode = 'input';shuffle = False;color_mode = 'grayscale';use_random_vectors = False;vector_length = 56

# Load data
ethnies = load_data(target_size=target_size, batch_size=batch_size, class_mode=class_mode, shuffle=shuffle, color_mode=color_mode, use_random_vectors=use_random_vectors, vector_length=vector_length)

In [None]:
def split(inputs, cardinality):
    inputs_channels = inputs.shape[3]
    group_size = inputs_channels // cardinality    
    groups = list()
    for number in range(1, cardinality+1):
        begin = int((number-1)*group_size)
        end = int(number*group_size)
        block = Lambda(lambda x:x[:,:,:,begin:end])(inputs)
        groups.append(block)
    return groups

def transform(groups, filters, stage, block, downsampling):
    f1, f2 = filters    
    conv_name = "conv2d-{stage}{block}-branch".format(stage=str(stage), block=str(block))
    bn_name = "batchnorm-{stage}{block}-branch".format(stage=str(stage), block=str(block))
    
    transformed_tensor = list()
    i = 1
    
    for inputs in groups:
        # first conv of the transformation phase
        x = Conv2D(filters=f1, kernel_size=(1, 1), padding="valid", 
                   kernel_initializer=glorot_uniform(seed=0))(inputs) #name=conv_name+'1a_split'+str(i), 
        if downsampling:
            x = MaxPooling2D(2)(x)
        x = BatchNormalization(axis=3)(x) #name=bn_name+'1a_split'+str(i)
        x = Activation('relu')(x)

        # second conv of the transformation phase
        x = Conv2D(filters=f2, kernel_size=(2, 2), padding="same", kernel_initializer=glorot_uniform(seed=0))(x) #name=conv_name+'1b_split'+str(i), 
        x = BatchNormalization(axis=3)(x) #name=bn_name+'1b_split'+str(i)
        x = Activation('relu')(x)
        
        # Add x to transformed tensor list
        transformed_tensor.append(x)
        i+=1
        
    # Concatenate all tensor from each group
    x = Concatenate()(transformed_tensor)
    
    return x

def transition(inputs, filters, stage, block):
    x = Conv2D(filters=filters, kernel_size=(1, 1), padding="valid", kernel_initializer=glorot_uniform(seed=0))(inputs) #name='conv2d-trans'+str(stage)+''+block, 
    x = BatchNormalization(axis=3)(x) #name='batchnorm-trans'+str(stage)+''+block
    x = Activation('relu')(x)
    return x

def downsampling(inputs, filters, cardinality, strides, stage, block):    
    # useful variables
    conv_name = "conv2d-{stage}{block}-branch".format(stage=str(stage), block=str(block))
    bn_name = "batchnorm-{stage}{block}-branch".format(stage=str(stage), block=str(block))
    
    # Retrieve filters for each layer
    f1, f2, f3 = filters
    
    # save the input tensor value
    x_shortcut = inputs
    x = inputs
    
    # divide input channels into groups. The number of groups is define by cardinality param
    groups = split(inputs=x, cardinality=cardinality)
    
    # transform each group by doing a set of convolutions and concat the results
    f1 = int(f1 / cardinality)
    f2 = int(f2 / cardinality)
    x = transform(groups=groups, filters=(f1, f2), stage=stage, block=block, downsampling=True)
    
    # make a transition by doing 1x1 conv
    x = transition(inputs=x, filters=f3, stage=stage, block=block)

    # Projection Shortcut to match dimensions 
    x_shortcut = Conv2D(filters=f3, kernel_size=(1, 1), padding="valid", kernel_initializer=glorot_uniform(seed=0))(x_shortcut) #name='{base}2'.format(base=conv_name),
    x_shortcut = MaxPooling2D(2)(x_shortcut)
    x_shortcut = BatchNormalization(axis=3)(x_shortcut) #, name='{base}2'.format(base=bn_name)
    
    # Add x and x_shortcut
    x = Add()([x,x_shortcut])
    #print(x.shape)
    x = Activation('relu')(x)
    
    return x

def transform_decoder(groups, filters, strides, stage, block):
    conv_name = "transpo_conv2d-{stage}{block}-branch".format(stage=str(stage), block=str(block))
    bn_name = "transpo_batchnorm-{stage}{block}-branch".format(stage=str(stage), block=str(block))
    
    transformed_tensor = []
    f1, f2 = filters
    i = 1
    
    for inputs in groups:
        # first conv transpose of the transformation phase
        x = Conv2DTranspose(filters=f1, kernel_size=(1,1), strides=(1,1), padding="same", 
                            kernel_initializer=glorot_uniform(seed=0))(inputs) #name=conv_name+'2a_split'+str(i),
        x = BatchNormalization(axis=3)(x) #name=bn_name+'2a_split'+str(i)
        x = Activation('relu')(x)

        # second conv transpose of the transformation phase
        x = Conv2DTranspose(filters=f2, kernel_size=(2, 2), strides=strides, padding="same", 
                            kernel_initializer=glorot_uniform(seed=0))(x) #name=conv_name+'2b_split'+str(i), 
        x = BatchNormalization(axis=3)(x) #name=bn_name+'2b_split'+str(i)
        x = Activation('relu')(x)
        
        transformed_tensor.append(x)
        i += 1
        
    x = Concatenate()(transformed_tensor)
    
    return x

def transition_decoder(inputs, filters, stage, block):
    conv_name = "transpo_conv2d-{stage}{block}-branch".format(stage=str(stage), block=str(block))
    bn_name = "transpo_batchnorm-{stage}{block}-branch".format(stage=str(stage), block=str(block))
    
    x = Conv2DTranspose(filters=filters, kernel_size=(1, 1), padding="valid", 
                        kernel_initializer=glorot_uniform(seed=0))(inputs) #name=conv_name+'2',
    x = BatchNormalization(axis=3)(x) #, name=bn_name+'2')(x)
    x = Activation('relu')(x)
    
    return x

def upsampling(inputs, filters, cardinality, strides, stage, block):    
    # Variables utiles
    conv_name = "transpo_conv2d-{stage}{block}-branch".format(stage=str(stage), block=str(block))
    bn_name = "transpo_batchnorm-{stage}{block}-branch".format(stage=str(stage), block=str(block))
    
    # Récupérer les filtres pour chaque couche
    f1, f2, f3 = filters
    
    # Sauvegarder la valeur du tenseur d'entrée
    x_shortcut = inputs
    x = inputs
    
    # Diviser les canaux d'entrée en groupes. Le nombre de groupes est défini par le paramètre cardinalité
    groups = split(inputs=x, cardinality=cardinality)
    
    # Transformer chaque groupe en faisant un ensemble de convolutions transposées et concaténer les résultats
    f1 = int(f1 / cardinality)
    f2 = int(f2 / cardinality)
    x = transform_decoder(groups=groups, filters=(f1, f2), strides=strides, stage=stage, block=block)
    
    # Faire une transition en utilisant 1x1 conv transposée
    x = transition_decoder(inputs=x, filters=f3, stage=stage, block=block)
    # Projection du raccourci pour correspondre aux dimensions
    x_shortcut = Conv2DTranspose(filters=f3, kernel_size=(2, 2), strides=strides, padding="valid", kernel_initializer=glorot_uniform(seed=0))(x_shortcut) #name='{base}2'.format(base=conv_name), 
    x_shortcut = BatchNormalization(axis=3)(x_shortcut) #, name='{base}2'.format(base=bn_name))(x_shortcut)
    # Ajouter x et x_shortcut
    x = Add()([x, x_shortcut])
    x = Activation('relu')(x)
    
    return x


In [None]:
gpus = tf.config.experimental.list_physical_devices('GPU')
for gpu in gpus:
  tf.config.experimental.set_memory_growth(gpu, True)

strategy = tf.distribute.MirroredStrategy()

In [None]:
def autoencoder(input_shape, input_latent):
    
    filters = [128, 256, 512]
    autoencoder=Sequential()

    # Encoder
    for i, filter in enumerate(filters):
        autoencoder.add(Conv2D(filter, (3, 3), activation='relu', padding='same'))        
        autoencoder.add(LeakyReLU(alpha=0.1))
        autoencoder.add(BatchNormalization())
        autoencoder.add(MaxPooling2D((2, 2), padding='same'))
        print(f'After {i+1} Conv2D: {autoencoder.output_shape}')

    # Espace latent
    #x = Flatten()(x)
    autoencoder.add(Dense(input_latent, activation='relu', name='latent_space'))
    autoencoder.add(BatchNormalization(name='latent_space_norm'))

    # Décodeur
    for i, filter in enumerate(reversed(filters)):
        autoencoder.add(Conv2D(filter, (3, 3), activation='relu', padding='same'))        
        autoencoder.add(LeakyReLU(alpha=0.1))
        autoencoder.add(BatchNormalization())
        autoencoder.add(UpSampling2D((2, 2)))
        print(f'After {i+1} UpSampling2D: {autoencoder.output_shape}')
    
    autoencoder.add(Conv2D(1, (3, 3), activation='sigmoid', padding='same'))
    print(f'Decoded shape before resizing: {autoencoder.output_shape}')    
    # Resize to match input shape
    autoencoder.add(Resizing(height=input_shape[0], width=input_shape[1], name='final_output'))
    print(f'Decoded shape after resizing: {autoencoder.output_shape}')

    # Autoencoder autoencoder
    autoencoder.compile(optimizer=Adam(learning_rate=0.001), loss=MeanAbsoluteError(), metrics=['accuracy'])
    
    # Display the model summary
    autoencoder.summary()
    
    return autoencoder

In [None]:
def ResNeXt50_AutoEncoder(input_shape, input_latent, cardinality = 64):
    x_input = Input(input_shape, dtype='float16')
    x = ZeroPadding2D((3, 3))(x_input)
    x = Conv2D(filters=64, kernel_size=(5, 5), kernel_initializer='glorot_uniform')(x)
    x = BatchNormalization(axis=3, name='batchnorm_1')(x)
    x = Activation('relu')(x)
    x = MaxPooling2D((2, 2))(x)
    print(x.shape)
    
    filters = [128, 128, 256, 256, 512, 1024, 2048]
    
    for filter in filters:
        x = downsampling(inputs=x, filters=(filter, filter, filter*2), cardinality=cardinality, strides=(2, 2), stage=1, block="a")
        print(x.shape)
    
    x = Conv2D(input_latent, 2, padding='same', use_bias=False)(x)
    x = Flatten()(x)

    latent_space_layer = Dense(input_latent, activation='relu', use_bias=False)(x)
    latent_space_layer_norm = BatchNormalization(name='latent_space_layer_norm')(latent_space_layer)

    reshape_layer = Reshape(target_shape=(1, 1, input_latent))(latent_space_layer_norm)
    x_recon = Conv2DTranspose(input_latent, 3, strides=1, padding='same', use_bias=False)(reshape_layer)

    filters.insert(0, 64)
    for filter in reversed(filters):
        x_recon = upsampling(inputs=x_recon, filters=(filter*2, filter*2, filter), cardinality=cardinality, strides=(2, 2), stage=5, block="a")
        print(x_recon.shape)
    
    x_recon = Conv2DTranspose(1, (1, 1), padding='same', activation='sigmoid')(x_recon)
    x_recon = tf.keras.layers.Resizing(height=input_shape[0], width=input_shape[1], name='recon_image')(x_recon)
    
    model = Model(inputs=x_input, outputs=x_recon, name="resnext50_autoencoder")
    
    lr_schedule = ExponentialDecay(initial_learning_rate=0.001, decay_steps=10000, decay_rate=0.9, staircase=True)
    optimizer = SGD(learning_rate=lr_schedule, momentum=0.9)
    model.compile(optimizer=optimizer, loss=MeanSquaredError())
    return model

# Utilisation du modèle
input_shape = (300, 300, 1)
input_latent = 512
cardinality = 32
with strategy.scope():    
    model = ResNeXt50_AutoEncoder(input_shape, input_latent, cardinality)

In [None]:
model.load_weights('Results/weights/Resnext_best_weights_64.h5')    
total_params = model.count_params()
total_bytes = (total_params * 16) // 8  # En float16 (16 bits par paramètre)
print(f"Taille totale du modèle en octets : {total_bytes} octets")

# Data Preparation
ethnie = 'caucasians'
x_train = ethnies[ethnie][0]
x_val = ethnies[ethnie][2]
x_print = x_val.copy()
np.random.shuffle(x_print)
x_print = x_print[:10]

def create_dataset(data, batch_size):
    # Convertir les données en float16
    data = tf.cast(data, tf.float16)
    dataset = tf.data.Dataset.from_tensor_slices((data, data))  # Utilisation des données comme input et target
    dataset = dataset.shuffle(buffer_size=1024).batch(batch_size).prefetch(buffer_size=tf.data.experimental.AUTOTUNE)
    return dataset

train_dataset = create_dataset(x_train, 128)
val_dataset = create_dataset(x_val, 128)

In [None]:
class CustomCallback(tf.keras.callbacks.Callback):
    def __init__(self, test_data, display_loss_interval=10, display_recon_interval=20, save_weight_interval=20):
        super().__init__()
        self.test_data = test_data
        self.display_loss_interval = display_loss_interval
        self.display_recon_interval = display_recon_interval
        self.save_weight_interval = save_weight_interval

        # Initialisation du callback ModelCheckpoint
        self.checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
            filepath='Results/weights/Resnext_best_weights_32.h5',
            monitor='val_loss',
            save_best_only=True,
            mode='min',
            save_weights_only=True,
            save_freq='epoch'
        )

    def on_epoch_end(self, epoch, logs=None):
        if epoch % self.display_recon_interval == 0:
            reconstructions = self.model.predict(self.test_data, verbose=0)
            self.display_reconstruction(epoch+1, self.test_data, reconstructions)
        if epoch % self.display_loss_interval == 0:
            print(f"Epoch {epoch+1}, Loss: {logs['loss']:.4g}")
        if epoch % self.save_weight_interval == 0:
            # Appel du callback ModelCheckpoint
            self.checkpoint_callback.set_model(self.model)
            self.checkpoint_callback.on_epoch_end(epoch, logs=logs)

    def display_reconstruction(self, epoch, originals, reconstructions):
        n = 10  # Nombre d'images à afficher
        plt.figure(figsize=(20, 4))
        for i in range(n):
            # Affichage original
            ax = plt.subplot(2, n, i + 1)
            plt.imshow(originals[i].reshape(200, 200), cmap='gray')
            plt.title("Original")
            plt.axis('off')
            
            # Affichage de la reconstruction
            ax = plt.subplot(2, n, i + 1 + n)
            plt.imshow(reconstructions[i].reshape(200, 200), cmap='gray')
            plt.title("Reconstructed")
            plt.axis('off')
        plt.suptitle(f'Epoch {epoch}')
        plt.savefig(f'Results/MSE/Resnext2_{epoch}.png')
        plt.close()

# Création de l'instance du callback
callback = CustomCallback(x_print)
model.fit(train_dataset, epochs=801, validation_data=val_dataset, callbacks=callback, verbose=1)