In [None]:
import tensorflow as tf # type: ignore
import os
from tensorflow.keras.models import Model, load_model # type: ignore
from tensorflow.keras.layers import Input, Conv2D, BatchNormalization, Activation, MaxPooling2D, Flatten, Dense, Reshape, Conv2DTranspose, Add, Resizing, LeakyReLU, UpSampling2D, Dropout, Concatenate, AveragePooling2D, GlobalMaxPooling2D, Lambda, ZeroPadding2D  # type: ignore
from keras.initializers import glorot_uniform # type: ignore
from tensorflow.keras.optimizers import SGD, Adam # type: ignore
from tensorflow.keras.regularizers import l2 # type: ignore
from tensorflow.keras.optimizers.schedules import ExponentialDecay # type: ignore
from tensorflow.keras.losses import MeanSquaredError, CategoricalCrossentropy, MeanAbsoluteError # type: ignore
from tensorflow.keras.metrics import CategoricalAccuracy # type: ignore
from tensorflow.keras.preprocessing.image import ImageDataGenerator # type: ignore
from tensorflow.keras.callbacks import ModelCheckpoint
from tensorflow.keras.utils import to_categorical # type: ignore
import keras_tuner as kt # type: ignore
import matplotlib.pyplot as plt # type: ignore
import numpy as np # type: ignore
import scipy

In [None]:
# Fonction qui genere des vecteurs aleatoires
def generate_unique_vectors(num_vectors, vector_length, vectors):
    vectors_list = []
    while len(vectors_list) < num_vectors:
        vector = tuple(np.random.randint(0, 2, vector_length))
        if vector not in vectors:
            vectors.add(vector)
            vectors_list.append(vector)
    return vectors_list, vectors

# Fonction qui convertie le generateur en tableau numpy
def generator_to_array(generator, class_vectors):
    samples = []
    vectors = []
    data_filenames = generator.filenames
    total_images = len(data_filenames)

    for i in range(len(generator)):
        batch = next(generator)
        batch_size = len(batch[0])

        for j in range(batch_size):
            index = i * generator.batch_size + j
            if index >= total_images:
                break  # Prevent going out of bounds

            samples.append(batch[0][j])
            class_name = data_filenames[index].split(os.path.sep)[0]
            vectors.append(class_vectors[class_name])
            
    return np.array(samples), np.array(vectors), data_filenames

# Fonction qui associe les images aux labels
def preprocess(train_generator, val_generator, num_classes, vector_length, total_vectors, use_random_vectors):
    class_indices = train_generator.class_indices
    
    if use_random_vectors:
        unique_vectors, total_vectors = generate_unique_vectors(num_classes, vector_length, total_vectors)
        class_vectors = {class_name: vector for class_name, vector in zip(class_indices, unique_vectors)}
    else:
        class_vectors = {class_name: i for i, class_name in enumerate(class_indices)}
    
    samples_train, vectors_train, _ = generator_to_array(train_generator, class_vectors)
    samples_val, vectors_val, _ = generator_to_array(val_generator, class_vectors)
    
    if not use_random_vectors:
        # Convert class indices to one-hot encoding
        vectors_train = to_categorical(vectors_train, num_classes=num_classes)
        vectors_val = to_categorical(vectors_val, num_classes=num_classes)
    
    return samples_train, vectors_train, samples_val, vectors_val, total_vectors

# Fonction qui charge les donnees
def load_data(datagen, target_size=(150, 150), batch_size=112, class_mode='input', shuffle=False, color_mode='grayscale', use_random_vectors=True, vector_length=56):
    ethnies = {'caucasians': [], 'afro_americans': [], 'asians': []}
    total_vectors = set()
    #check = True
    for ethnie in ethnies.keys():
        trainset = datagen.flow_from_directory(f'../../Datasets/VGG/{ethnie}', target_size=target_size, batch_size=batch_size, class_mode=class_mode, shuffle=shuffle, color_mode=color_mode, subset='training')
        testset  = datagen.flow_from_directory(f'../../Datasets/VGG/{ethnie}', target_size=target_size, batch_size=batch_size, class_mode=class_mode, shuffle=shuffle, color_mode=color_mode, subset='validation')
        
        # Number of classes for one-hot encoding
        num_classes = len(trainset.class_indices)
        
        samples_train, vectors_train, samples_val, vectors_val, total_vectors = preprocess(trainset, testset, num_classes, vector_length, total_vectors, use_random_vectors)
        
        """ if trainset.n != samples_train.shape[0] or testset.n != samples_val.shape[0]: check = False """
        ethnies[ethnie] = [trainset, testset, samples_train, vectors_train, samples_val, vectors_val]
    #print(check)
    return ethnies

datagen = ImageDataGenerator(rescale=1./255, validation_split=0.05, dtype='float16')

ethnies = load_data(datagen, target_size=(300, 300), color_mode='grayscale', use_random_vectors=False)

In [None]:
img_cau = ethnies['caucasians'][2][:100]

In [None]:
img_cau.shape

In [None]:
np.random.seed(42)
model=Sequential()

model.add(Conv2D(64, (3,3), activation='relu', padding='same', input_shape=(300,300,1)))
model.add(MaxPooling2D((2,2), padding='same'))
model.add(Conv2D(32, (3,3),activation='relu',padding='same'))
model.add(MaxPooling2D((2,2), padding='same'))
model.add(Conv2D(16, (3,3),activation='relu',padding='same'))
model.add(MaxPooling2D((2,2), padding='same'))

model.add(Conv2D(16, (3,3), activation='relu', padding='same'))
model.add(UpSampling2D((2,2)))

model.add(Conv2D(32, (3,3), activation='relu', padding='same'))
model.add(UpSampling2D((2,2)))

model.add(Conv2D(64, (3,3), activation='relu', padding='same'))
model.add(UpSampling2D((2,2)))

model.add(Conv2D(1, (3,3), activation='relu', padding='same'))

model.add(Resizing(height=300, width=300, name='recon_image'))
model.compile(optimizer='adam',loss='mean_squared_error',metrics=['accuracy'])

model.summary()

In [None]:
model.fit(img_cau, img_cau, epochs=2000, shuffle=True, verbose=1)

In [None]:
pred=model.predict(img_cau)

In [None]:
plt.imshow(pred[1].reshape(300,300,1))

In [None]:
n = 10  # Number of images to display
plt.figure(figsize=(20, 4))
for i in range(n):    
    # Display reconstruction
    ax = plt.subplot(1, n, i + 1)
    plt.imshow(pred[i].reshape(300, 300), cmap='gray')
    plt.title("Reconstructed")
    plt.axis('off')

plt.show()

In [None]:
def split(inputs, cardinality):
    inputs_channels = inputs.shape[3]
    group_size = inputs_channels // cardinality    
    groups = list()
    for number in range(1, cardinality+1):
        begin = int((number-1)*group_size)
        end = int(number*group_size)
        block = Lambda(lambda x:x[:,:,:,begin:end])(inputs)
        groups.append(block)
    return groups

def transform(groups, filters, strides, stage, block, downsampling):
    f1, f2 = filters    
    conv_name = "conv2d-{stage}{block}-branch".format(stage=str(stage), block=str(block))
    bn_name = "batchnorm-{stage}{block}-branch".format(stage=str(stage), block=str(block))
    
    transformed_tensor = list()
    i = 1
    
    for inputs in groups:
        # first conv of the transformation phase
        x = Conv2D(filters=f1, kernel_size=(1,1), padding="valid", 
                   kernel_initializer=glorot_uniform(seed=0))(inputs) #name=conv_name+'1a_split'+str(i), 
        if downsampling:
            x = MaxPooling2D(2)(x)
        x = BatchNormalization(axis=3)(x) #name=bn_name+'1a_split'+str(i)
        x = Activation('relu')(x)

        # second conv of the transformation phase
        x = Conv2D(filters=f2, kernel_size=(3,3), strides=(1,1), padding="same", 
                   kernel_initializer=glorot_uniform(seed=0))(x) #name=conv_name+'1b_split'+str(i), 
        x = BatchNormalization(axis=3)(x) #name=bn_name+'1b_split'+str(i)
        x = Activation('relu')(x)
        
        # Add x to transformed tensor list
        transformed_tensor.append(x)
        i+=1
        
    # Concatenate all tensor from each group
    x = Concatenate()(transformed_tensor)
    
    return x

def transition(inputs, filters, stage, block):
    x = Conv2D(filters=filters, kernel_size=(1,1), strides=(1,1), padding="valid", 
                   kernel_initializer=glorot_uniform(seed=0))(inputs) #name='conv2d-trans'+str(stage)+''+block, 
    x = BatchNormalization(axis=3)(x) #name='batchnorm-trans'+str(stage)+''+block
    x = Activation('relu')(x)
    return x

In [None]:
def identity_block(inputs, filters, cardinality, stage, block, strides=(1,1)):    
    identity = True

    conv_name = "conv2d-{stage}{block}-branch".format(stage=str(stage),block=str(block))
    bn_name = "batchnorm-{stage}{block}-branch".format(stage=str(stage),block=str(block))
    
    #save the input tensor value
    x_shortcut = inputs
    x = inputs
    
    f1, f2, f3 = filters
    
    # divide input channels into groups. The number of groups is define by cardinality param
    groups = split(inputs=x, cardinality=cardinality)
    
    # transform each group by doing a set of convolutions and concat the results
    f1 = int(f1 / cardinality)
    f2 = int(f2 / cardinality)
    x = transform(groups=groups, filters=(f1, f2), strides=strides, stage=stage, block=block, downsampling=False)
    # make a transition by doing 1x1 conv
    x = transition(inputs=x, filters=f3, stage=stage, block=block)
    # Last step of the identity block, shortcut concatenation
    x = Add()([x,x_shortcut])
    x = Activation('relu')(x)
    print(x.shape)
    return x

def downsampling(inputs, filters, cardinality, strides, stage, block):    
    # useful variables
    conv_name = "conv2d-{stage}{block}-branch".format(stage=str(stage), block=str(block))
    bn_name = "batchnorm-{stage}{block}-branch".format(stage=str(stage), block=str(block))
    
    # Retrieve filters for each layer
    f1, f2, f3 = filters
    
    # save the input tensor value
    x_shortcut = inputs
    x = inputs
    
    # divide input channels into groups. The number of groups is define by cardinality param
    groups = split(inputs=x, cardinality=cardinality)
    
    # transform each group by doing a set of convolutions and concat the results
    f1 = int(f1 / cardinality)
    f2 = int(f2 / cardinality)
    x = transform(groups=groups, filters=(f1, f2), strides=strides, stage=stage, block=block, downsampling=True)
    print(x.shape, 'transfo downsampling')
    # make a transition by doing 1x1 conv
    x = transition(inputs=x, filters=f3, stage=stage, block=block)
    print(x.shape, 'transi downsampling')
    # Projection Shortcut to match dimensions 
    x_shortcut = Conv2D(filters=f3, kernel_size=(1,1), padding="valid",
               kernel_initializer=glorot_uniform(seed=0))(x_shortcut) #name='{base}2'.format(base=conv_name),
    x_shortcut = MaxPooling2D(2)(x_shortcut)
    x_shortcut = BatchNormalization(axis=3)(x_shortcut) #, name='{base}2'.format(base=bn_name)
    #print(x.shape, x_shortcut.shape)
    # Add x and x_shortcut
    x = Add()([x,x_shortcut])
    #print(x.shape)
    x = Activation('relu')(x)
    
    return x

In [None]:
def transform_decoder(groups, filters, strides, stage, block):
    conv_name = "transpo_conv2d-{stage}{block}-branch".format(stage=str(stage), block=str(block))
    bn_name = "transpo_batchnorm-{stage}{block}-branch".format(stage=str(stage), block=str(block))
    
    transformed_tensor = []
    f1, f2 = filters
    i = 1
    
    for inputs in groups:
        # first conv transpose of the transformation phase
        x = Conv2DTranspose(filters=f1, kernel_size=(1,1), strides=(1,1), padding="same", 
                            kernel_initializer=glorot_uniform(seed=0))(inputs) #name=conv_name+'2a_split'+str(i),
        x = BatchNormalization(axis=3)(x) #name=bn_name+'2a_split'+str(i)
        x = Activation('relu')(x)

        # second conv transpose of the transformation phase
        x = Conv2DTranspose(filters=f2, kernel_size=(3,3), strides=strides, padding="same", 
                            kernel_initializer=glorot_uniform(seed=0))(x) #name=conv_name+'2b_split'+str(i), 
        x = BatchNormalization(axis=3)(x) #name=bn_name+'2b_split'+str(i)
        x = Activation('relu')(x)
        
        transformed_tensor.append(x)
        i += 1
        
    x = Concatenate()(transformed_tensor)
    
    return x

def transition_decoder(inputs, filters, stage, block):
    conv_name = "transpo_conv2d-{stage}{block}-branch".format(stage=str(stage), block=str(block))
    bn_name = "transpo_batchnorm-{stage}{block}-branch".format(stage=str(stage), block=str(block))
    
    x = Conv2DTranspose(filters=filters, kernel_size=(1,1), strides=(1,1), padding="valid", 
                        kernel_initializer=glorot_uniform(seed=0))(inputs) #name=conv_name+'2',
    x = BatchNormalization(axis=3)(x) #, name=bn_name+'2')(x)
    x = Activation('relu')(x)
    
    return x

def identity_block_decoder(inputs, filters, cardinality, stage, block, strides=(1,1)):
    conv_name = "conv2d-{stage}{block}-branch".format(stage=str(stage), block=str(block))
    bn_name = "batchnorm-{stage}{block}-branch".format(stage=str(stage), block=str(block))
    
    # Sauvegarder la valeur du tenseur d'entrée
    x_shortcut = inputs
    x = inputs
    
    f1, f2, f3 = filters
    
    # Diviser les canaux d'entrée en groupes. Le nombre de groupes est défini par le paramètre cardinalité
    groups = split(inputs=x, cardinality=cardinality)
    
    # Transformer chaque groupe en faisant un ensemble de convolutions transposées et concaténer les résultats
    f1 = int(f1 / cardinality)
    f2 = int(f2 / cardinality)
    x = transform_decoder(groups=groups, filters=(f1, f2), strides=strides, stage=stage, block=block)
    
    # Faire une transition en utilisant 1x1 conv transposée
    x = transition_decoder(inputs=x, filters=f3, stage=stage, block=block)
    
    # Dernière étape du bloc d'identité, concaténation du raccourci
    x = Add()([x, x_shortcut])
    x = Activation('relu')(x)
    
    return x

def upsampling(inputs, filters, cardinality, strides, stage, block):
    # Variables utiles
    conv_name = "transpo_conv2d-{stage}{block}-branch".format(stage=str(stage), block=str(block))
    bn_name = "transpo_batchnorm-{stage}{block}-branch".format(stage=str(stage), block=str(block))
    
    # Récupérer les filtres pour chaque couche
    f1, f2, f3 = filters
    
    # Sauvegarder la valeur du tenseur d'entrée
    x_shortcut = inputs
    x = inputs
    
    # Diviser les canaux d'entrée en groupes. Le nombre de groupes est défini par le paramètre cardinalité
    groups = split(inputs=x, cardinality=cardinality)
    
    # Transformer chaque groupe en faisant un ensemble de convolutions transposées et concaténer les résultats
    f1 = int(f1 / cardinality)
    f2 = int(f2 / cardinality)
    x = transform_decoder(groups=groups, filters=(f1, f2), strides=strides, stage=stage, block=block)
    print(x.shape, 'transfo')
    # Faire une transition en utilisant 1x1 conv transposée
    x = transition_decoder(inputs=x, filters=f3, stage=stage, block=block)
    print(x.shape, 'transi')
    # Projection du raccourci pour correspondre aux dimensions
    x_shortcut = Conv2DTranspose(filters=f3, kernel_size=(1,1), strides=strides, padding="valid", 
                                 kernel_initializer=glorot_uniform(seed=0))(x_shortcut) #name='{base}2'.format(base=conv_name), 
    x_shortcut = BatchNormalization(axis=3)(x_shortcut) #, name='{base}2'.format(base=bn_name))(x_shortcut)
    
    # Ajouter x et x_shortcut
    x = Add()([x, x_shortcut])
    x = Activation('relu')(x)
    
    return x

In [None]:
def ResNeXt50_AutoEncoder(input_shape, input_latent):
    # Transform input to a tensor of shape input_shape
    x_input = Input(input_shape)
    
    # Add zero padding
    x = ZeroPadding2D((3, 3))(x_input)
    
    # Initial Stage (Encoder)
    x = Conv2D(filters=128, kernel_size=(5, 5), kernel_initializer=glorot_uniform(seed=0))(x)
    x = BatchNormalization(axis=3, name='batchnorm_1')(x)
    x = Activation('relu')(x)
    x = MaxPooling2D((2, 2))(x)
    print(x.shape)

    filters = [256, 256, 512, 512, 1024]

    for filter in filters:
        x = downsampling(inputs=x, filters=(filter, filter, filter*2), cardinality=128, strides=(2, 2), stage=1, block="a")
        print(x.shape, " en dehors de la f down")
        """ 
        x = identity_block(inputs=x, filters=(filter, filter, filter*2), cardinality=32, stage=1, block="b")
        x = identity_block(inputs=x, filters=(filter, filter, filter*2), cardinality=32, stage=1, block="c") """

    x = Conv2D(filters=1024, kernel_size=(1, 1), kernel_initializer=glorot_uniform(seed=0))(x)
    x = BatchNormalization(axis=3)(x)
    x = Activation('relu')(x)
    x = MaxPooling2D((2, 2))(x)
    x = Conv2D(filters=2048, kernel_size=(1, 1), kernel_initializer=glorot_uniform(seed=0))(x)
    x = BatchNormalization(axis=3)(x)
    x = Activation('relu')(x)
    x = MaxPooling2D(2)(x)
    print(x.shape, 'dernière reduc')
    # Stage 2 (Encoder)
    
    # Bottleneck
    x = Conv2D(input_latent, 2, padding='same', use_bias=False)(x)
    x = Flatten()(x)
    print(x.shape)

    # Latent space
    latent_space_layer = Dense(input_latent, activation='relu', use_bias=False)(x)
    latent_space_layer_norm = BatchNormalization(name='latent_space_layer_norm')(latent_space_layer)

    # Reconstruction
    reshape_layer = Reshape(target_shape=(1, 1, input_latent))(latent_space_layer_norm)
    x_recon = Conv2DTranspose(input_latent, 3, strides=1, padding='same', use_bias=False)(reshape_layer)
    
    print(x_recon.shape)
    # Stage 5 (Decoder)

    x_recon = Conv2DTranspose(filters=2048, kernel_size=(3,3), strides=1)(x_recon)
    x_recon = BatchNormalization(axis=3)(x_recon)
    x_recon = Activation('relu')(x_recon)
    x_recon = Conv2DTranspose(filters=1024, kernel_size=(3, 3), strides=1)(x_recon)
    x_recon = BatchNormalization(axis=3)(x_recon)
    x_recon = Activation('relu')(x_recon)

    for filter in reversed(filters):
        x_recon = upsampling(inputs=x_recon, filters=(filter*2, filter*2, filter), cardinality=128, strides=(2, 2), stage=5, block="a")
        print(x_recon.shape)
        """ 
        x_recon = identity_block_decoder(inputs=x_recon, filters=(filter*2, filter*2, filter), cardinality=32, stage=5, block="b")
        x_recon = identity_block_decoder(inputs=x_recon, filters=(filter*2, filter*2, filter), cardinality=32, stage=5, block="c") """
        
    x_recon = upsampling(inputs=x_recon, filters=(filter*2, filter*2, filter), cardinality=128, strides=(2, 2), stage=5, block="a")
    print(x_recon.shape)
    # Final Convolution to reconstruct the image    
    x_recon = Conv2DTranspose(1, (1, 1), padding='same', activation='sigmoid')(x_recon)
    x_recon = tf.keras.layers.Resizing(height=input_shape[0], width=input_shape[1], name='recon_image')(x_recon)
    
    # Create the model
    model = Model(inputs=x_input, outputs=x_recon, name="resnext50_autoencoder")
    
    lr_schedule = ExponentialDecay(initial_learning_rate=0.001, decay_steps=10000, decay_rate=0.9, staircase=True)
    optimizer = SGD(learning_rate=lr_schedule, momentum=0.9)
    model.compile(optimizer=optimizer, loss=MeanSquaredError())
    return model

In [None]:
# Testing the model creation
model = ResNeXt50_AutoEncoder(input_shape=(300, 300, 1), input_latent=512)

In [None]:
# Data Preparation
ethnie = 'caucasians'
x_train = ethnies[ethnie][2]
x_val = ethnies[ethnie][4]
x_print = x_val.copy()
np.random.shuffle(x_print)
x_print = x_print[:10]

In [None]:
# Custom Callback Definition
class CustomCallback(tf.keras.callbacks.Callback):
    def __init__(self, test_data, display_loss_interval=10, display_recon_interval=20):
        super().__init__()
        self.test_data = test_data
        self.display_loss_interval = display_loss_interval
        self.display_recon_interval = display_recon_interval

    def on_epoch_end(self, epoch, logs=None):
        if epoch % self.display_recon_interval == 0:
            reconstructions = self.model.predict(self.test_data, verbose=0)
            self.display_reconstruction(epoch+1, self.test_data, reconstructions)
        if epoch % self.display_loss_interval == 0:
            print(f"Epoch {epoch+1}, Loss: {logs['loss']:.4g}")

    def display_reconstruction(self, epoch, originals, reconstructions):
        n = 10  # Number of images to display
        plt.figure(figsize=(20, 4))
        for i in range(n):
            # Display original
            ax = plt.subplot(2, n, i + 1)
            plt.imshow(originals[i].reshape(300, 300), cmap='gray')
            plt.title("Original")
            plt.axis('off')
            
            # Display reconstruction
            ax = plt.subplot(2, n, i + 1 + n)
            plt.imshow(reconstructions[i].reshape(300, 300), cmap='gray')
            plt.title("Reconstructed")
            plt.axis('off')
        plt.suptitle(f'Epoch {epoch}')
        plt.savefig(f'Results/MSE/Resnext2_{epoch}.png')
        plt.close()

# Define the model checkpoint callback
checkpoint_callback = ModelCheckpoint(
    filepath='Results/weights/Resnext_best_weights.h5',
    monitor='val_loss',
    save_best_only=True,
    mode='min',
    save_weights_only=True,
    save_freq=20,
)

# Model Training
model.fit(x_train, x_train, epochs=1, shuffle=True, batch_size=1, validation_data=(x_val, x_val), callbacks=[CustomCallback(x_print), checkpoint_callback], verbose=1)