In [None]:
# importation de bibliothèques:
import tensorflow as tf
from tensorflow.keras.datasets import mnist, cifar10
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Conv2D, Flatten, LSTM, BatchNormalization, MaxPooling2D
from tensorflow.keras.utils import to_categorical
from tensorflow.keras.metrics import Accuracy, Precision, Recall, AUC
import matplotlib.pyplot as plt
import time
import pickle

# Installation du drive
from google.colab import drive
drive.mount('/content/drive')
# Spécifier le chemin du dossier de destination dans Google Drive
base_path = '/content/drive/My Drive/initiation_recherche/final/'

# Fonctions personnelles de sauvegarde de resultats

def save_with_pickle(obj, filename):
    with open(filename, 'wb') as file:
        pickle.dump(obj, file)

# Fonction pour charger les données depuis un fichier pickle
def load_pkl(filename):
    with open(filename, 'rb') as file:
        return pickle.load(file)

Mounted at /content/drive


In [None]:


# Télécharger les ensembles de données MNIST et CIFAR-10
(mnist_train_images, mnist_train_labels), (mnist_test_images, mnist_test_labels) = mnist.load_data()
(cifar_train_images, cifar_train_labels), (cifar_test_images, cifar_test_labels) = cifar10.load_data()

# Normaliser les images pour MNIST et CIFAR-10
mnist_train_images = mnist_train_images.astype('float32') / 255.0
mnist_test_images = mnist_test_images.astype('float32') / 255.0
cifar_train_images = cifar_train_images.astype('float32') / 255.0
cifar_test_images = cifar_test_images.astype('float32') / 255.0

# Redimensionner les images MNIST pour le modèle CNN (ajout d'un canal)
mnist_train_images_cnn = mnist_train_images.reshape((-1, 28, 28, 1))
mnist_test_images_cnn = mnist_test_images.reshape((-1, 28, 28, 1))

# Redimensionner les images CIFAR-10 pour le modèle CNN
cifar_train_images_cnn = cifar_train_images.reshape((-1, 32, 32, 3))
cifar_test_images_cnn = cifar_test_images.reshape((-1, 32, 32, 3))

# Redimensionner les images CIFAR-10 pour le modèle RNN
cifar_train_images_rnn = cifar_train_images.reshape((-1, 32, 32 * 3))  # Redimensionner en (32, 96)
cifar_test_images_rnn = cifar_test_images.reshape((-1, 32, 32 * 3))

# One-hot encoding des étiquettes pour MNIST et CIFAR-10
mnist_train_labels = to_categorical(mnist_train_labels, 10)
mnist_test_labels = to_categorical(mnist_test_labels, 10)
cifar_train_labels = to_categorical(cifar_train_labels, 10)
cifar_test_labels = to_categorical(cifar_test_labels, 10)


Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz
Downloading data from https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz


## modèles CNN

In [None]:
# modèle CNN pour mnist :
model_cnn_mnist = Sequential([
    Conv2D(256, (3, 3), activation='relu', input_shape=(28, 28, 1)),  # Adapté pour CIFAR-10
    BatchNormalization(),
    MaxPooling2D(2, 2),
    Conv2D(256, (3, 3), activation='relu'),
    BatchNormalization(),
    Conv2D(256, (3, 3), activation='relu'),
    BatchNormalization(),
    MaxPooling2D(2, 2),
    Flatten(),
    Dense(128, activation='relu'),
    BatchNormalization(),
    Dense(64, activation='relu'),
    BatchNormalization(),
    Dense(10, activation='softmax')  # 10 classes pour CIFAR-10
])


# Modèle CNN pour CIFAR-10
model_cnn_cifar = Sequential([
    Conv2D(256, (3, 3), activation='relu', input_shape=(32, 32, 3)),  # Adapté pour CIFAR-10
    BatchNormalization(),
    MaxPooling2D(2, 2),
    Conv2D(256, (3, 3), activation='relu'),
    BatchNormalization(),
    Conv2D(256, (3, 3), activation='relu'),
    BatchNormalization(),
    MaxPooling2D(2, 2),
    Flatten(),
    Dense(128, activation='relu'),
    BatchNormalization(),
    Dense(64, activation='relu'),
    BatchNormalization(),
    Dense(10, activation='softmax')  # 10 classes pour CIFAR-10
])





## modèles RNN

In [None]:
# modèle RNN :
model_rnn_mnist = Sequential([
    LSTM(512, return_sequences=True, input_shape=(28, 28)),  # 32 lignes de 32x3 pixels
    BatchNormalization(),
    LSTM(256),  # 256 neurones
    BatchNormalization(),
    Dense(256, activation='relu'),  # 256 neurones
    BatchNormalization(),
    Dense(128, activation='relu'),  # 128 neurones
    BatchNormalization(),
    Dense(10, activation='softmax')  # 10 classes pour CIFAR-10
])

# Modèle RNN pour CIFAR-10
model_rnn_cifar = Sequential([
    LSTM(512, return_sequences=True, input_shape=(32, 96)),  # 32 lignes de 32x3 pixels
    BatchNormalization(),
    LSTM(256),  # 256 neurones
    BatchNormalization(),
    Dense(256, activation='relu'),  # 256 neurones
    BatchNormalization(),
    Dense(128, activation='relu'),  # 128 neurones
    BatchNormalization(),
    Dense(10, activation='softmax')  # 10 classes pour CIFAR-10
])


## modèles DNN

In [None]:
# modèle DNN
model_dnn_mnist = Sequential([
    Flatten(input_shape=(28,28)),
    Dense(512, activation='relu'),  # 512 neurones
    BatchNormalization(),
    Dense(512, activation='relu'),  # 512 neurones
    BatchNormalization(),
    Dense(512, activation='relu'),  # 512 neurones
    BatchNormalization(),
    Dense(512, activation='relu'),  # 512 neurones
    BatchNormalization(),
    Dense(512, activation='relu'),  # 512 neurones
    BatchNormalization(),
    Dense(256, activation='relu'),  # 256 neurones
    BatchNormalization(),
    Dense(256, activation='relu'),  # 256 neurones
    BatchNormalization(),
    Dense(128, activation='relu'),  # 128 neurones
    BatchNormalization(),
    Dense(128, activation='relu'),  # 128 neurones
    BatchNormalization(),
    Dense(128, activation='relu'),   # 128 neurones
    BatchNormalization(),
    Dense(64, activation='relu'),   # 64 neurones
    BatchNormalization(),
    Dense(10, activation='softmax')  # 10 classes pour CIFAR-10
])

# Modèle DNN pour CIFAR-10
model_dnn_cifar = Sequential([
    Flatten(input_shape=(32, 32, 3)),  # Adapté pour CIFAR-10
    Dense(512, activation='relu'),  # 512 neurones
    BatchNormalization(),
    Dense(512, activation='relu'),  # 512 neurones
    BatchNormalization(),
    Dense(256, activation='relu'),  # 256 neurones
    BatchNormalization(),
    Dense(256, activation='relu'),  # 256 neurones
    BatchNormalization(),
    Dense(128, activation='relu'),  # 128 neurones
    BatchNormalization(),
    Dense(128, activation='relu'),  # 128 neurones
    BatchNormalization(),
    Dense(128, activation='relu'),   # 128 neurones
    BatchNormalization(),
    Dense(64, activation='relu'),   # 64 neurones
    BatchNormalization(),
    Dense(10, activation='softmax')  # 10 classes pour CIFAR-10
])



In [None]:
# compiler modèles :

models_mnist = [model_cnn_mnist, model_rnn_mnist, model_dnn_mnist]
for model in models_mnist:
    model.compile(optimizer='adam',
                  loss='categorical_crossentropy',
                  metrics=['accuracy'])


models_cifar = [model_cnn_cifar, model_rnn_cifar, model_dnn_cifar]
for model in models_cifar:
    model.compile(optimizer='adam',
                  loss='categorical_crossentropy',
                  metrics=['accuracy'])


In [None]:
# MNIST

# Pour le modèle CNN mnist
print("Résumé du modèle CNN mnist:")
model_cnn_mnist.summary()

# Pour le modèle RNN mnist
print("\nRésumé du modèle RNN mnist:")
model_rnn_mnist.summary()

# Pour le modèle DNN
print("\nRésumé du modèle DNN mnist:")
model_dnn_mnist.summary()


Résumé du modèle CNN mnist:
Model: "sequential"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 conv2d (Conv2D)             (None, 26, 26, 256)       2560      
                                                                 
 batch_normalization (Batch  (None, 26, 26, 256)       1024      
 Normalization)                                                  
                                                                 
 max_pooling2d (MaxPooling2  (None, 13, 13, 256)       0         
 D)                                                              
                                                                 
 conv2d_1 (Conv2D)           (None, 11, 11, 256)       590080    
                                                                 
 batch_normalization_1 (Bat  (None, 11, 11, 256)       1024      
 chNormalization)                                                
                            

In [None]:
# CIFAR :

# Pour le modèle CNN
print("Résumé du modèle CNN cifar:")
model_cnn_cifar.summary()

# Pour le modèle RNN
print("\nRésumé du modèle RNN cifar:")
model_rnn_cifar.summary()

# Pour le modèle DNN
print("\nRésumé du modèle DNN cifar:")
model_dnn_cifar.summary()


Résumé du modèle CNN cifar:
Model: "sequential_1"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 conv2d_3 (Conv2D)           (None, 30, 30, 256)       7168      
                                                                 
 batch_normalization_5 (Bat  (None, 30, 30, 256)       1024      
 chNormalization)                                                
                                                                 
 max_pooling2d_2 (MaxPoolin  (None, 15, 15, 256)       0         
 g2D)                                                            
                                                                 
 conv2d_4 (Conv2D)           (None, 13, 13, 256)       590080    
                                                                 
 batch_normalization_6 (Bat  (None, 13, 13, 256)       1024      
 chNormalization)                                                
                          

In [None]:
# entrainement CNN MNIST :


start_time = time.time()


history_cnn_mnist = model_cnn_mnist.fit(mnist_train_images_cnn,
                                        mnist_train_labels,
                                        epochs=10,
                                        batch_size=64,
                                        validation_data=(mnist_test_images_cnn, mnist_test_labels))

training_time_cnn_mnist = time.time() - start_time

print(f'Temps d\'entraînement : {training_time_cnn_mnist:.2f} secondes')

# Pour le modèle CNN MNIST


save_with_pickle(history_cnn_mnist, base_path + 'history_cnn_mnist_accuracy.pkl')


Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10
Temps d'entraînement : 89.96 secondes


In [None]:
# Entraînement du modèle CNN pour CIFAR-10
start_time = time.time()

history_cnn_cifar = model_cnn_cifar.fit(cifar_train_images_cnn,
                                        cifar_train_labels,
                                        epochs=10,
                                        batch_size=64,
                                        validation_data=(cifar_test_images_cnn, cifar_test_labels)  # Données de validation pour CIFAR-10
)

training_time_cnn_cifar = time.time() - start_time
print(f'Temps d\'entraînement : {training_time_cnn_cifar:.2f} secondes')

# Pour le modèle CNN CIFAR-10

save_with_pickle(history_cnn_cifar, base_path + 'history_cnn_cifar_accuracy.pkl')


Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10
Temps d'entraînement : 85.20 secondes


In [None]:
# Graphes CNN MNIST

# Tracer la perte
plt.figure(figsize=(12, 4))
plt.subplot(1, 2, 1)
plt.plot(history_cnn_mnist.history['loss'], label='Train Loss')
plt.plot(history_cnn_mnist.history['val_loss'], label='Validation Loss')
plt.title('Loss Over Epochs')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()

# Tracer la précision
plt.subplot(1, 2, 2)
plt.plot(history_cnn_mnist.history['accuracy'], label='Train Accuracy')
plt.plot(history_cnn_mnist.history['val_accuracy'], label='Validation Accuracy')
plt.title('Accuracy Over Epochs')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.legend()

plt.show()

In [None]:
# Graphes CNN cifar

# Tracer la perte
plt.figure(figsize=(12, 4))
plt.subplot(1, 2, 1)
plt.plot(history_cnn_cifar.history['loss'], label='Train Loss')
plt.plot(history_cnn_cifar.history['val_loss'], label='Validation Loss')
plt.title('Loss Over Epochs')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()

# Tracer la précision
plt.subplot(1, 2, 2)
plt.plot(history_cnn_cifar.history['accuracy'], label='Train Accuracy')
plt.plot(history_cnn_cifar.history['val_accuracy'], label='Validation Accuracy')
plt.title('Accuracy Over Epochs')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.legend()

plt.show()

In [None]:
# entrainement RNN mnist

start_time = time.time()


history_rnn_mnist = model_rnn_mnist.fit(mnist_train_images,
                            mnist_train_labels,
                            epochs=10,
                            batch_size=64,
                            validation_data=(mnist_test_images, mnist_test_labels))

training_time_rnn_mnist = time.time() - start_time

print(f'Temps d\'entraînement : {training_time_rnn_mnist:.2f} secondes')

# sauvegarder le résultat

save_with_pickle(history_rnn_mnist, base_path + 'history_rnn_mnist_accuracy.pkl')



Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10
Temps d'entraînement : 106.53 secondes


In [None]:
# Entraînement du modèle RNN pour CIFAR-10
start_time = time.time()

history_rnn_cifar = model_rnn_cifar.fit(cifar_train_images_rnn,
                                        cifar_train_labels,
                                        epochs=10,
                                        batch_size=64,
                                        validation_data=(cifar_test_images_rnn, cifar_test_labels))

training_time_rnn_cifar = time.time() - start_time
print(f'Temps d\'entraînement : {training_time_rnn_cifar:.2f} secondes')

# sauvegarder le résultat

save_with_pickle(history_rnn_cifar, base_path + 'history_rnn_cifar_accuracy.pkl')


Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10
Temps d'entraînement : 97.75 secondes


In [None]:
# Graphes RNN mnist
# Tracer la perte
plt.figure(figsize=(12, 4))
plt.subplot(1, 2, 1)
plt.plot(history_rnn_mnist.history['loss'], label='Train Loss')
plt.plot(history_rnn_mnist.history['val_loss'], label='Validation Loss')
plt.title('Loss Over Epochs')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()

# Tracer la précision
plt.subplot(1, 2, 2)
plt.plot(history_rnn_mnist.history['accuracy'], label='Train Accuracy')
plt.plot(history_rnn_mnist.history['val_accuracy'], label='Validation Accuracy')
plt.title('Accuracy Over Epochs')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.legend()

plt.show()

In [None]:
# Graphes RNN cifar

# Tracer la perte
plt.figure(figsize=(12, 4))
plt.subplot(1, 2, 1)
plt.plot(history_rnn_cifar.history['loss'], label='Train Loss')
plt.plot(history_rnn_cifar.history['val_loss'], label='Validation Loss')
plt.title('Loss Over Epochs')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()

# Tracer la précision
plt.subplot(1, 2, 2)
plt.plot(history_rnn_cifar.history['accuracy'], label='Train Accuracy')
plt.plot(history_rnn_cifar.history['val_accuracy'], label='Validation Accuracy')
plt.title('Accuracy Over Epochs')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.legend()

plt.show()

In [None]:
# entrainement DNN mnist

start_time = time.time()


history_dnn_mnist = model_dnn_mnist.fit(mnist_train_images,
                            mnist_train_labels,
                            epochs=10,
                            batch_size=64,
                            validation_data=(mnist_test_images, mnist_test_labels))

training_time_dnn_mnist = time.time() - start_time

print(f'Temps d\'entraînement : {training_time_dnn_mnist:.2f} secondes')

# sauvegarder le résultat

save_with_pickle(history_dnn_mnist, base_path + 'history_dnn_mnist_accuracy.pkl')


Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10
Temps d'entraînement : 126.46 secondes


In [None]:
# Entraînement du modèle DNN pour CIFAR-10

start_time = time.time()

history_dnn_cifar = model_dnn_cifar.fit(cifar_train_images,
                                        cifar_train_labels,
                                        epochs=10,
                                        batch_size=64,
                                        validation_data=(cifar_test_images, cifar_test_labels)
)

training_time_dnn_cifar = time.time() - start_time
print(f'Temps d\'entraînement : {training_time_dnn_cifar:.2f} secondes')

# sauvegarder le résultat

save_with_pickle(history_dnn_cifar, base_path + 'history_dnn_cifar_accuracy.pkl')



Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10
Temps d'entraînement : 85.94 secondes


In [None]:
# Graphes DNN mnist

# Tracer la perte
plt.figure(figsize=(12, 4))
plt.subplot(1, 2, 1)
plt.plot(history_dnn_mnist.history['loss'], label='Train Loss')
plt.plot(history_dnn_mnist.history['val_loss'], label='Validation Loss')
plt.title('Loss Over Epochs')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()

# Tracer la précision
plt.subplot(1, 2, 2)
plt.plot(history_dnn_mnist.history['accuracy'], label='Train Accuracy')
plt.plot(history_dnn_mnist.history['val_accuracy'], label='Validation Accuracy')
plt.title('Accuracy Over Epochs')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.legend()

plt.show()

In [None]:
# Graphes DNN cifar

# Tracer la perte
plt.figure(figsize=(12, 4))
plt.subplot(1, 2, 1)
plt.plot(history_dnn_cifar.history['loss'], label='Train Loss')
plt.plot(history_dnn_cifar.history['val_loss'], label='Validation Loss')
plt.title('Loss Over Epochs')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()

# Tracer la précision
plt.subplot(1, 2, 2)
plt.plot(history_dnn_cifar.history['accuracy'], label='Train Accuracy')
plt.plot(history_dnn_cifar.history['val_accuracy'], label='Validation Accuracy')
plt.title('Accuracy Over Epochs')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.legend()

plt.show()

In [None]:
# Rapports de classification

# fonction d'impression d e rapport de classification :
from sklearn.metrics import classification_report
import numpy as np

def print_classification_report(model, X_test, y_test, model_name):
    predictions = model.predict(X_test)
    predictions = np.argmax(predictions, axis=1)
    y_test_argmax = np.argmax(y_test, axis=1) # Si les étiquettes sont en one-hot encoding
    report = classification_report(y_test_argmax, predictions)
    print(f"Rapport de classification pour {model_name}:\n")
    print(report)
    print("\n" + "-"*80 + "\n")


# Exemple d'appel de la fonction pour chaque modèle
print_classification_report(model_cnn_mnist, mnist_test_images_cnn, mnist_test_labels, "CNN sur MNIST")
print_classification_report(model_cnn_cifar, cifar_test_images_cnn, cifar_test_labels, "CNN sur CIFAR-10")
print_classification_report(model_rnn_mnist, mnist_test_images_rnn, mnist_test_labels, "RNN sur MNIST")
print_classification_report(model_rnn_cifar, cifar_test_images_rnn, cifar_test_labels, "RNN sur CIFAR-10")
print_classification_report(model_dnn_mnist, mnist_test_images_dnn, mnist_test_labels, "DNN sur MNIST")
print_classification_report(model_dnn_cifar, cifar_test_images_dnn, cifar_test_labels, "DNN sur CIFAR-10")


# Comparatifs :

In [None]:
# comparatif :


import matplotlib.pyplot as plt

# Chargement des historiques pour chaque modèle

# history_cnn = load_pkl('history_cnn.pkl')
# history_rnn = load_pkl('history_rnn.pkl')
# history_dnn = load_pkl('history_dnn.pkl')

# Fonction pour tracer les courbes de perte et de précision
def plot_history(histories, title):
    plt.figure(figsize=(14, 5))

    # Tracé des pertes
    plt.subplot(1, 2, 1)
    for name, history in histories:
        plt.plot(history['loss'], label=f'{name} Train')
        plt.plot(history['val_loss'], label=f'{name} Validation')
    plt.title('Loss - ' + title)
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.legend()

    # Tracé des précisions
    plt.subplot(1, 2, 2)
    for name, history in histories:
        plt.plot(history['accuracy'], label=f'{name} Train')
        plt.plot(history['val_accuracy'], label=f'{name} Validation')
    plt.title('Accuracy - ' + title)
    plt.xlabel('Epochs')
    plt.ylabel('Accuracy')
    plt.legend()

    plt.show()

# Comparaison des modèles
plot_history([('CNN', history_cnn), ('RNN', history_rnn), ('DNN', history_dnn)], 'Model Comparison')


In [None]:
# matrice de confusion

from sklearn.metrics import confusion_matrix
import seaborn as sns

def plot_confusion_matrix(model, X_test, y_test, title):
    predictions = model.predict(X_test)
    predictions = np.argmax(predictions, axis=1)
    y_test_argmax = np.argmax(y_test, axis=1) # Si les étiquettes sont en one-hot encoding
    cm = confusion_matrix(y_test_argmax, predictions)
    plt.figure(figsize=(10, 8))
    sns.heatmap(cm, annot=True, fmt="d")
    plt.title('Confusion Matrix - ' + title)
    plt.ylabel('Actual Label')
    plt.xlabel('Predicted Label')
    plt.show()

# Exemple d'utilisation
# plot_confusion_matrix(model_cnn, mnist_test_images_cnn, mnist_test_labels, 'CNN')


In [None]:
# rapport de classification
from sklearn.metrics import classification_report

def print_classification_report(model, X_test, y_test):
    predictions = model.predict(X_test)
    predictions = np.argmax(predictions, axis=1)
    y_test_argmax = np.argmax(y_test, axis=1) # Si les étiquettes sont en one-hot encoding
    report = classification_report(y_test_argmax, predictions)
    print(report)

# Exemple d'utilisation
# print_classification_report(model_cnn, mnist_test_images_cnn, mnist_test_labels)


In [None]:
# courbe roc et auc :
from sklearn.metrics import roc_curve, auc
from tensorflow.keras.utils import to_categorical

def plot_roc_curve(model, X_test, y_test, num_classes):
    # Prédictions et One-hot encoding si nécessaire
    y_pred = model.predict(X_test)
    y_test_cat = to_categorical(y_test, num_classes)

    # Calcul de la courbe ROC pour chaque classe
    fpr = dict()
    tpr = dict()
    roc_auc = dict()
    for i in range(num_classes):
        fpr[i], tpr[i], _ = roc_curve(y_test_cat[:, i], y_pred[:, i])
        roc_auc[i] = auc(fpr[i], tpr[i])

    # Tracer toutes les courbes ROC
    plt.figure(figsize=(12, 8))
    for i in range(num_classes):
        plt.plot(fpr[i], tpr[i], label=f'Class {i} (area = {roc_auc[i]:0.2f})')
    plt.plot([0, 1], [0, 1], 'k--')
    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive Rate')
    plt.title('ROC Curve')
    plt.legend(loc="lower right")
    plt.show()

# Exemple d'utilisation pour une classification multi-classes
# plot_roc_curve(model_cnn, mnist_test_images_cnn, mnist_test_labels, num_classes=10)
