In [1]:
import sys
import os
import matplotlib.pyplot as plt
import numpy as np
import gc
import importlib

# Ajouter le chemin absolu du dossier 'config'
sys.path.append('../config')
sys.path.append('../src')
# Importer le module de configuration
import config
# Recharger le module de configuration pour prendre en compte les modifications
importlib.reload(config)
# Utiliser la configuration
cfg = config.cfg


# Importer
from config import cfg
from model_unet import build_model_UNET
from common_utils import load_normalize_images, load_rgb2mask_labels

from tensorflow.keras.utils import to_categorical
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint


import tensorflow as tf


def display_loss_acc(res_list):
    """Afficher les courbes de perte et de précision"""
    
    # Initialiser les listes pour stocker la perte et la précision
    loss = []
    accuracy = []

    # Parcourir chaque objet History dans la liste
    for res in res_list:
        loss.extend(res['loss'])
        accuracy.extend(res['accuracy'])
        
    """Afficher les courbes de perte et de précision"""
    plt.figure(figsize=(12, 6))
    plt.subplot(1, 2, 1)
    plt.plot(res['loss'])
    plt.title('Perte (Loss)')
    plt.xlabel('Époque')
    plt.ylabel('Perte')
    
    plt.subplot(1, 2, 2)
    plt.plot(res['accuracy'])  # Utilisez 'acc' pour la précision
    plt.title('Précision (Accuracy)')
    plt.xlabel('Époque')
    plt.ylabel('Précision')
    plt.show()

def train_model(model):
    """Params : 
        model = model UNET
        train_img_list_norm = liste images to train
        train_label_list = liste mask images to train
        batch_size = taille du batch
        epochs = nbre epochs
    """
    # Définir les rappels
    callbacks = [
    EarlyStopping(patience=5, verbose=1),
    ModelCheckpoint('../model/' + cfg['MODEL_NAME'] , verbose=1, save_best_only=True)
    ]
    
    # Nombre d'images à traiter à chaque itération
    slice_size = cfg['NB_IMG_BY_TRAIN_FIT'] #140 ok
    
    # Nombre d'epochs à effectuer
    num_epochs = cfg['UNET_EPOCHS']
    
    validation_split = cfg['SPLIT_VALID']/100
    
    # Nombre total d'images dans l'ensemble de données d'entraînement et de validation
    total_train_images = len([name for name in os.listdir(cfg['TRAIN_DIR'] + cfg['OUTPUT_DIR_IMAGE']) if os.path.isfile(os.path.join(cfg['TRAIN_DIR'] + cfg['OUTPUT_DIR_IMAGE'], name))])
    total_val_images = len([name for name in os.listdir(cfg['VALID_DIR'] + cfg['OUTPUT_DIR_IMAGE']) if os.path.isfile(os.path.join(cfg['VALID_DIR'] + cfg['OUTPUT_DIR_IMAGE'], name))])

    # total_train_images = 5000
    # total_val_images = int(total_train_images * validation_split)

    # Accumuler les résultats des batches
    results = []
    
    epoch_loss = []
    epoch_accuracy = []
    
    for epoch in range(num_epochs):
        # print(f"Epoch {epoch+1}/{num_epochs}")

   
        # Boucle sur les tranches d'images
        for i in range(0, total_train_images, slice_size):
            # Sélectionnez la tranche d'images à utiliser pour cet itération
            print(f"Epoch {epoch+1}/{num_epochs} - tranche {i+1}/{i+slice_size}")
            
            # Charger et traiter les données par batch
            epoch_train_images = []
            epoch_train_labels = []
            epoch_val_images = []
            epoch_val_labels = []
            
            # Charger et normaliser les images d'entraînement + labelliser
            train_img_list_norm = load_normalize_images(cfg['TRAIN_DIR'], i, min(i + slice_size, total_train_images))
            train_label_list = load_rgb2mask_labels(cfg['TRAIN_DIR'], i, min(i + slice_size, total_train_images))

            # Calculer la tranche pour les images de validation
            val_slice_size = int(slice_size * validation_split)
            val_start_index = (i // slice_size) * val_slice_size
            val_end_index = min(val_start_index + val_slice_size, total_val_images)
            
            val_img_list_norm = load_normalize_images(cfg['VALID_DIR'], val_start_index, val_end_index)
            val_label_list = load_rgb2mask_labels(cfg['VALID_DIR'], val_start_index, val_end_index)
                     
            # Convertir les labels en one-hot encoding
            train_label = to_categorical(np.array(train_label_list), num_classes=3)
            val_label = to_categorical(np.array(val_label_list), num_classes=3)
           
            # Accumulez les tranches dans les listes
            epoch_train_images.extend(train_img_list_norm)
            epoch_train_labels.extend(train_label)
            epoch_val_images.extend(val_img_list_norm)
            epoch_val_labels.extend(val_label)
    
            # Convertir les listes accumulées en arrays numpy
            epoch_train_images = np.array(epoch_train_images, dtype=np.float32)
            epoch_train_labels = np.array(epoch_train_labels, dtype=np.float32)
            epoch_val_images = np.array(epoch_val_images, dtype=np.float32)
            epoch_val_labels = np.array(epoch_val_labels, dtype=np.float32)
  
            # Entraîner le modèle sur toutes les tranches accumulées pour cette epoch
            result = model.fit(
                epoch_train_images, epoch_train_labels,
                batch_size=cfg['UNET_BATCH_SIZE'], 
                epochs=1,
                verbose=1,
                callbacks=callbacks,
                validation_data=(epoch_val_images, epoch_val_labels)
            )
        
        # Append the loss and accuracy for each slice
        epoch_loss.append(result.history['loss'])
        epoch_accuracy.append(result.history['accuracy'])
        
        # Libérer la mémoire des variables non utilisées
        del epoch_train_images, epoch_train_labels, epoch_val_images, epoch_val_labels
        tf.keras.backend.clear_session()
        gc.collect()
        
        # Afficher l'utilisation de la mémoire GPU
        # print_gpu_utilization()
        
    # Append the loss and accuracy for each epoch
    results.append({'loss': epoch_loss, 'accuracy': epoch_accuracy})
    tf.keras.backend.clear_session()
    gc.collect()
    
    return results



#Chargement du model UNET 2Ds
model = build_model_UNET()


# Libérer la mémoire des variables non utilisées
tf.keras.backend.clear_session()
gc.collect()

# Entrainement du model
results = train_model(model)

display_loss_acc(results)


Model: "model"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 input_1 (InputLayer)           [(None, 240, 240, 1  0           []                               
                                )]                                                                
                                                                                                  
 conv2d (Conv2D)                (None, 240, 240, 64  640         ['input_1[0][0]']                
                                )                                                                 
                                                                                                  
 tf.nn.relu (TFOpLambda)        (None, 240, 240, 64  0           ['conv2d[0][0]']                 
                                )                                                             

KeyboardInterrupt: 