In [None]:
{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "# Challenge d'am√©lioration de mod√®le CNN\n",
        "\n",
        "## BTS SIO  - S√©ance 2: Types de r√©seaux et applications\n",
        "\n",
        "Ce notebook vous guidera √† travers un challenge d'am√©lioration d'un mod√®le CNN pour la classification d'images de v√™tements (Fashion MNIST). Vous partirez d'un mod√®le de base volontairement sous-optimal et explorerez diff√©rentes strat√©gies pour am√©liorer ses performances.\n",
        "\n",
        "### Objectifs d'apprentissage:\n",
        "- Diagnostiquer les faiblesses d'un mod√®le de Deep Learning\n",
        "- Exp√©rimenter avec diff√©rentes architectures et hyperparam√®tres\n",
        "- Appliquer des techniques d'optimisation (dropout, batch normalization, etc.)\n",
        "- Mesurer et comparer quantitativement les am√©liorations\n",
        "- Documenter m√©thodiquement les modifications et leurs impacts\n",
        "\n",
        "### Pr√©requis:\n",
        "- Connaissances de base en TensorFlow/Keras\n",
        "- Compr√©hension des principes des r√©seaux CNN\n",
        "- Avoir suivi la premi√®re partie du TP sur les CNN"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## 1. Configuration de l'environnement\n",
        "\n",
        "Commen√ßons par importer les biblioth√®ques n√©cessaires."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {},
      "outputs": [],
      "source": [
        "import tensorflow as tf\n",
        "from tensorflow.keras.models import Sequential, load_model\n",
        "from tensorflow.keras.layers import Conv2D, MaxPooling2D, Flatten, Dense, Dropout, BatchNormalization, Activation\n",
        "from tensorflow.keras.optimizers import Adam, RMSprop, SGD\n",
        "from tensorflow.keras.preprocessing.image import ImageDataGenerator\n",
        "from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint, ReduceLROnPlateau\n",
        "from tensorflow.keras.datasets import fashion_mnist\n",
        "import numpy as np\n",
        "import matplotlib.pyplot as plt\n",
        "import pandas as pd\n",
        "import time\n",
        "import os\n",
        "import seaborn as sns\n",
        "from sklearn.metrics import confusion_matrix\n",
        "\n",
        "# Configuration pour reproductibilit√©\n",
        "np.random.seed(42)\n",
        "tf.random.set_seed(42)\n",
        "\n",
        "# V√©rifier la version de TensorFlow\n",
        "print(f\"TensorFlow version: {tf.__version__}\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## 2. Chargement du dataset Fashion MNIST\n",
        "\n",
        "Fashion MNIST est un dataset similaire au MNIST original, mais avec des images de v√™tements au lieu de chiffres. C'est un excellent dataset pour tester des mod√®les de vision par ordinateur."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {},
      "outputs": [],
      "source": [
        "print(\"Chargement du dataset Fashion MNIST...\")\n",
        "(x_train, y_train), (x_test, y_test) = fashion_mnist.load_data()\n",
        "\n",
        "# Normalisation et reshape pour correspondre au format attendu par le CNN\n",
        "x_train = x_train.reshape(-1, 28, 28, 1) / 255.0\n",
        "x_test = x_test.reshape(-1, 28, 28, 1) / 255.0\n",
        "\n",
        "# Noms des classes pour l'affichage\n",
        "class_names = ['T-shirt/top', 'Pantalon', 'Pull', 'Robe', 'Manteau',\n",
        "               'Sandale', 'Chemise', 'Basket', 'Sac', 'Bottine']\n",
        "\n",
        "print(f\"Forme des donn√©es d'entra√Ænement: {x_train.shape}\")\n",
        "print(f\"Forme des donn√©es de test: {x_test.shape}\")\n",
        "print(f\"Nombre de classes: {len(class_names)}\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "### Visualisation de quelques exemples\n",
        "\n",
        "Examinons √† quoi ressemblent les images de notre dataset."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {},
      "outputs": [],
      "source": [
        "plt.figure(figsize=(10, 10))\n",
        "for i in range(25):\n",
        "    plt.subplot(5, 5, i+1)\n",
        "    plt.xticks([])\n",
        "    plt.yticks([])\n",
        "    plt.grid(False)\n",
        "    plt.imshow(x_train[i].reshape(28, 28), cmap='gray')\n",
        "    plt.xlabel(class_names[y_train[i]])\n",
        "plt.tight_layout()\n",
        "plt.show()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## 3. Tableau de bord des r√©sultats\n",
        "\n",
        "Cr√©ons une classe pour suivre et comparer les performances des diff√©rents mod√®les que nous allons tester."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {},
      "outputs": [],
      "source": [
        "class ModelImprovementDashboard:\n",
        "    \"\"\"Classe pour suivre et afficher les r√©sultats des diff√©rentes am√©liorations\"\"\"\n",
        "    \n",
        "    def __init__(self):\n",
        "        self.results = []\n",
        "    \n",
        "    def add_result(self, model_name, metrics, notes=\"\"):\n",
        "        \"\"\"Ajoute un r√©sultat au tableau de bord\"\"\"\n",
        "        result = {\n",
        "            'model_name': model_name,\n",
        "            'accuracy': metrics['test_accuracy'],\n",
        "            'loss': metrics['test_loss'],\n",
        "            'training_time': metrics['training_time'],\n",
        "            'epochs': metrics['epochs_completed'],\n",
        "            'notes': notes\n",
        "        }\n",
        "        self.results.append(result)\n",
        "    \n",
        "    def show_results(self):\n",
        "        \"\"\"Affiche un tableau comparatif des r√©sultats\"\"\"\n",
        "        if not self.results:\n",
        "            print(\"Aucun r√©sultat √† afficher.\")\n",
        "            return\n",
        "        \n",
        "        # Cr√©er un DataFrame\n",
        "        df = pd.DataFrame(self.results)\n",
        "        \n",
        "        # Trier par pr√©cision (descendant)\n",
        "        df = df.sort_values(by='accuracy', ascending=False)\n",
        "        \n",
        "        # Formater les colonnes\n",
        "        df['accuracy'] = df['accuracy'].apply(lambda x: f\"{x:.2f}%\")\n",
        "        df['loss'] = df['loss'].apply(lambda x: f\"{x:.4f}\")\n",
        "        df['training_time'] = df['training_time'].apply(lambda x: f\"{x:.2f}s\")\n",
        "        \n",
        "        print(\"\\n=== TABLEAU COMPARATIF DES MOD√àLES ===\")\n",
        "        print(df)\n",
        "        \n",
        "        return df\n",
        "    \n",
        "    def plot_comparison(self):\n",
        "        \"\"\"Visualise la comparaison des mod√®les\"\"\"\n",
        "        if not self.results:\n",
        "            print(\"Aucun r√©sultat √† afficher.\")\n",
        "            return\n",
        "        \n",
        "        # Pr√©parer les donn√©es\n",
        "        models = [r['model_name'] for r in self.results]\n",
        "        accuracies = [float(r['accuracy'].strip('%')) for r in self.results]\n",
        "        times = [float(r['training_time'].strip('s')) for r in self.results]\n",
        "        \n",
        "        # Cr√©er le graphique\n",
        "        plt.figure(figsize=(12, 6))\n",
        "        \n",
        "        # Graphique de pr√©cision\n",
        "        plt.subplot(1, 2, 1)\n",
        "        bars = plt.bar(models, accuracies, color='skyblue')\n",
        "        plt.title('Comparaison des pr√©cisions')\n",
        "        plt.xlabel('Mod√®le')\n",
        "        plt.ylabel('Pr√©cision (%)')\n",
        "        plt.xticks(rotation=45, ha='right')\n",
        "        \n",
        "        # Ajouter les valeurs sur les barres\n",
        "        for bar in bars:\n",
        "            height = bar.get_height()\n",
        "            plt.text(bar.get_x() + bar.get_width()/2., height,\n",
        "                     f'{height:.2f}%',\n",
        "                     ha='center', va='bottom')\n",
        "        \n",
        "        # Graphique de temps d'entra√Ænement\n",
        "        plt.subplot(1, 2, 2)\n",
        "        bars = plt.bar(models, times, color='salmon')\n",
        "        plt.title('Comparaison des temps d\\'entra√Ænement')\n",
        "        plt.xlabel('Mod√®le')\n",
        "        plt.ylabel('Temps (secondes)')\n",
        "        plt.xticks(rotation=45, ha='right')\n",
        "        \n",
        "        # Ajouter les valeurs sur les barres\n",
        "        for bar in bars:\n",
        "            height = bar.get_height()\n",
        "            plt.text(bar.get_x() + bar.get_width()/2., height,\n",
        "                     f'{height:.2f}s',\n",
        "                     ha='center', va='bottom')\n",
        "        \n",
        "        plt.tight_layout()\n",
        "        plt.show()\n",
        "\n",
        "# Initialiser le tableau de bord\n",
        "dashboard = ModelImprovementDashboard()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## 4. Fonctions d'√©valuation de mod√®le\n",
        "\n",
        "D√©finissons des fonctions pour entra√Æner, √©valuer et visualiser les mod√®les de mani√®re coh√©rente."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {},
      "outputs": [],
      "source": [
        "def evaluate_model(model, x_train, y_train, x_test, y_test, epochs=5, batch_size=128, data_augmentation=False):\n",
        "    \"\"\"Entra√Æne et √©value un mod√®le, retourne les m√©triques de performance\"\"\"\n",
        "    \n",
        "    # Configuration pour l'augmentation de donn√©es (si activ√©e)\n",
        "    if data_augmentation:\n",
        "        train_datagen = ImageDataGenerator(\n",
        "            rotation_range=10,\n",
        "            width_shift_range=0.1,\n",
        "            height_shift_range=0.1,\n",
        "            zoom_range=0.1,\n",
        "        )\n",
        "        train_generator = train_datagen.flow(x_train, y_train, batch_size=batch_size)\n",
        "    \n",
        "    # Callbacks pour am√©liorer l'entra√Ænement\n",
        "    callbacks = []\n",
        "    if epochs > 5:\n",
        "        callbacks = [\n",
        "            EarlyStopping(patience=5, restore_best_weights=True),\n",
        "            ReduceLROnPlateau(factor=0.2, patience=3, min_lr=0.0001)\n",
        "        ]\n",
        "    \n",
        "    # Mesure du temps d'entra√Ænement\n",
        "    start_time = time.time()\n",
        "    \n",
        "    # Entra√Ænement du mod√®le\n",
        "    if data_augmentation:\n",
        "        history = model.fit(\n",
        "            train_generator,\n",
        "            epochs=epochs,\n",
        "            steps_per_epoch=len(x_train) // batch_size,\n",
        "            validation_data=(x_test, y_test),\n",
        "            callbacks=callbacks,\n",
        "            verbose=1\n",
        "        )\n",
        "    else:\n",
        "        history = model.fit(\n",
        "            x_train, y_train,\n",
        "            batch_size=batch_size,\n",
        "            epochs=epochs,\n",
        "            validation_data=(x_test, y_test),\n",
        "            callbacks=callbacks,\n",
        "            verbose=1\n",
        "        )\n",
        "    \n",
        "    training_time = time.time() - start_time\n",
        "    \n",
        "    # √âvaluation du mod√®le\n",
        "    test_loss, test_acc = model.evaluate(x_test, y_test, verbose=0)\n",
        "    \n",
        "    # Pr√©parer les m√©triques\n",
        "    metrics = {\n",
        "        'test_accuracy': test_acc * 100,\n",
        "        'test_loss': test_loss,\n",
        "        'training_time': training_time,\n",
        "        'epochs_completed': len(history.history['loss']),\n",
        "        'history': history\n",
        "    }\n",
        "    \n",
        "    return metrics\n",
        "\n",
        "def plot_training_history(history):\n",
        "    \"\"\"Visualise l'historique d'entra√Ænement\"\"\"\n",
        "    plt.figure(figsize=(12, 5))\n",
        "    \n",
        "    # Graphique de pr√©cision\n",
        "    plt.subplot(1, 2, 1)\n",
        "    plt.plot(history.history['accuracy'], label='Entra√Ænement')\n",
        "    plt.plot(history.history['val_accuracy'], label='Validation')\n",
        "    plt.title('√âvolution de la pr√©cision')\n",
        "    plt.xlabel('√âpoque')\n",
        "    plt.ylabel('Pr√©cision')\n",
        "    plt.legend()\n",
        "    plt.grid(True, linestyle='--', alpha=0.6)\n",
        "    \n",
        "    # Graphique de perte\n",
        "    plt.subplot(1, 2, 2)\n",
        "    plt.plot(history.history['loss'], label='Entra√Ænement')\n",
        "    plt.plot(history.history['val_loss'], label='Validation')\n",
        "    plt.title('√âvolution de la perte')\n",
        "    plt.xlabel('√âpoque')\n",
        "    plt.ylabel('Perte')\n",
        "    plt.legend()\n",
        "    plt.grid(True, linestyle='--', alpha=0.6)\n",
        "    \n",
        "    plt.tight_layout()\n",
        "    plt.show()\n",
        "\n",
        "def plot_confusion_matrix(model, x_test, y_test):\n",
        "    \"\"\"Visualise la matrice de confusion du mod√®le\"\"\"\n",
        "    # Pr√©dictions\n",
        "    y_pred = model.predict(x_test)\n",
        "    y_pred_classes = np.argmax(y_pred, axis=1)\n",
        "    \n",
        "    # Calculer la matrice de confusion\n",
        "    conf_mat = confusion_matrix(y_test, y_pred_classes)\n",
        "    \n",
        "    # Visualisation\n",
        "    plt.figure(figsize=(10, 8))\n",
        "    sns.heatmap(conf_mat, annot=True, fmt='d', cmap='Blues',\n",
        "                xticklabels=class_names,\n",
        "                yticklabels=class_names)\n",
        "    plt.xlabel('Pr√©dit')\n",
        "    plt.ylabel('R√©el')\n",
        "    plt.title('Matrice de confusion')\n",
        "    plt.xticks(rotation=45, ha='right')\n",
        "    plt.tight_layout()\n",
        "    plt.show()\n",
        "\n",
        "def show_misclassified_examples(model, x_test, y_test, n=10):\n",
        "    \"\"\"Affiche des exemples d'images mal classifi√©es\"\"\"\n",
        "    predictions = model.predict(x_test)\n",
        "    predicted_classes = np.argmax(predictions, axis=1)\n",
        "    \n",
        "    # Trouver les erreurs\n",
        "    errors = (predicted_classes != y_test)\n",
        "    error_indices = np.where(errors)[0]\n",
        "    \n",
        "    if len(error_indices) == 0:\n",
        "        print(\"Aucune erreur trouv√©e!\")\n",
        "        return\n",
        "    \n",
        "    # S√©lectionner un √©chantillon d'erreurs\n",
        "    sample_size = min(n, len(error_indices))\n",
        "    sample_indices = np.random.choice(error_indices, size=sample_size, replace=False)\n",
        "    \n",
        "    # Afficher les exemples\n",
        "    plt.figure(figsize=(15, 3*sample_size//5 + 3))\n",
        "    for i, idx in enumerate(sample_indices):\n",
        "        plt.subplot(sample_size//5 + 1, 5, i+1)\n",
        "        plt.imshow(x_test[idx].reshape(28, 28), cmap='gray')\n",
        "        plt.title(f\"R√©el: {class_names[y_test[idx]]}\\nPr√©dit: {class_names[predicted_classes[idx]]}\")\n",
        "        plt.axis('off')\n",
        "    plt.tight_layout()\n",
        "    plt.show()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## 5. Mod√®le de base (sous-optimal)\n",
        "\n",
        "Commen√ßons par cr√©er et √©valuer un mod√®le CNN de base, volontairement sous-optimal, qui servira de point de r√©f√©rence pour nos am√©liorations."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {},
      "outputs": [],
      "source": [
        "def create_baseline_model():\n",
        "    \"\"\"Cr√©e un mod√®le CNN de base volontairement sous-performant\"\"\"\n",
        "    model = Sequential([\n",
        "        Conv2D(8, (3, 3), activation='relu', input_shape=(28, 28, 1)),\n",
        "        MaxPooling2D((2, 2)),\n",
        "        Flatten(),\n",
        "        Dense(16, activation='relu'),\n",
        "        Dense(10, activation='softmax')\n",
        "    ])\n",
        "    \n",
        "    model.compile(\n",
        "        optimizer=Adam(learning_rate=0.01),  # Learning rate trop √©lev√©\n",
        "        loss='sparse_categorical_crossentropy',\n",
        "        metrics=['accuracy']\n",
        "    )\n",
        "    \n",
        "    return model\n",
        "\n",
        "# Cr√©er et afficher le mod√®le de base\n",
        "baseline_model = create_baseline_model()\n",
        "baseline_model.summary()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "### Entra√Ænement et √©valuation du mod√®le de base"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {},
      "outputs": [],
      "source": [
        "print(\"\\n--- Mod√®le de base ---\")\n",
        "baseline_metrics = evaluate_model(baseline_model, x_train, y_train, x_test, y_test, epochs=5)\n",
        "print(f\"Pr√©cision du mod√®le de base: {baseline_metrics['test_accuracy']:.2f}%\")\n",
        "print(f\"Temps d'entra√Ænement: {baseline_metrics['training_time']:.2f} secondes\")\n",
        "\n",
        "# Visualiser l'historique d'entra√Ænement\n",
        "plot_training_history(baseline_metrics['history'])\n",
        "\n",
        "# Ajouter au tableau de bord\n",
        "dashboard.add_result(\"Mod√®le de base\", baseline_metrics, \n",
        "                     \"CNN simple, peu de filtres, learning rate √©lev√©\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "### Analyse des erreurs du mod√®le de base"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {},
      "outputs": [],
      "source": [
        "# Afficher la matrice de confusion\n",
        "plot_confusion_matrix(baseline_model, x_test, y_test)\n",
        "\n",
        "# Afficher des exemples d'erreurs\n",
        "print(\"\\nExemples d'erreurs de classification du mod√®le de base:\")\n",
        "show_misclassified_examples(baseline_model, x_test, y_test)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "### üîç Diagnostic du mod√®le de base\n",
        "\n",
        "Avant de passer aux am√©liorations, analysons les probl√®mes du mod√®le de base :\n",
        "\n",
        "1. **Architecture trop simple** : \n",
        "   - Seulement 8 filtres dans la couche de convolution\n",
        "   - Une seule couche de convolution\n",
        "   - Seulement 16 neurones dans la couche dense\n",
        "   \n",
        "2. **Optimisation probl√©matique** :\n",
        "   - Taux d'apprentissage trop √©lev√© (0.01)\n",
        "   - Pas de r√©gularisation (dropout, etc.)\n",
        "   - Nombre d'√©poques potentiellement insuffisant\n",
        "   \n",
        "3. **Pr√©traitement minimal** :\n",
        "   - Pas d'augmentation de donn√©es\n",
        "   - Pas de normalisation batch\n",
        "\n",
        "Ces observations nous guideront dans nos tentatives d'am√©lioration."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## 6. Premi√®re am√©lioration : Architecture plus profonde\n",
        "\n",
        "Pour notre premi√®re am√©lioration, nous allons :\n",
        "- Augmenter le nombre de filtres\n",
        "- Ajouter une couche de convolution suppl√©mentaire\n",
        "- Augmenter le nombre de neurones dans la couche dense\n",
        "- R√©duire le taux d'apprentissage"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {},
      "outputs": [],
      "source": [
        "def create_improved_model_1():\n",
        "    \"\"\"Premier exemple d'am√©lioration: architecture plus profonde\"\"\"\n",
        "    model = Sequential([\n",
        "        Conv2D(32, (3, 3), activation='relu', input_shape=(28, 28, 1)),\n",
        "        MaxPooling2D((2, 2)),\n",
        "        Conv2D(64, (3, 3), activation='relu'),\n",
        "        MaxPooling2D((2, 2)),\n",
        "        Flatten(),\n",
        "        Dense(128, activation='relu'),\n",
        "        Dense(10, activation='softmax')\n",
        "    ])\n",
        "    \n",
        "    model.compile(\n",
        "        optimizer=Adam(learning_rate=0.001),  # Taux d'apprentissage r√©duit\n",
        "        loss='sparse_categorical_crossentropy',\n",
        "        metrics=['accuracy']\n",
        "    )\n",
        "    \n",
        "    return model\n",
        "\n",
        "# Cr√©er et afficher le mod√®le am√©lior√© 1\n",
        "improved_model_1 = create_improved_model_1()\n",
        "improved_model_1.summary()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "### Entra√Ænement et √©valuation du mod√®le am√©lior√© 1"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {},
      "outputs": [],
      "source": [
        "print(\"\\n--- Mod√®le am√©lior√© 1 ---\")\n",
        "improved_metrics_1 = evaluate_model(improved_model_1, x_train, y_train, x_test, y_test, epochs=10)\n",
        "print(f\"Pr√©cision du mod√®le am√©lior√© 1: {improved_metrics_1['test_accuracy']:.2f}%\")\n",
        "print(f\"Temps d'entra√Ænement: {improved_metrics_1['training_time']:.2f} secondes\")\n",
        "\n",
        "# Visualiser l'historique d'entra√Ænement\n",
        "plot_training_history(improved_metrics_1['history'])\n",
        "\n",
        "# Ajouter au tableau de bord\n",
        "dashboard.add_result(\"Mod√®le am√©lior√© 1\", improved_metrics_1, \n",
        "                    \"Plus de filtres, couche suppl√©mentaire, learning rate plus bas\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "### Analyse des r√©sultats du mod√®le am√©lior√© 1"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {},
      "outputs": [],
      "source": [
        "# Visualiser les r√©sultats de l'am√©lioration\n",
        "print(\"Comparaison des mod√®les jusqu'√† pr√©sent:\")\n",
        "dashboard.show_results()\n",
        "\n",
        "# Voir les nouvelles erreurs\n",
        "print(\"\\nExemples d'erreurs apr√®s la premi√®re am√©lioration:\")\n",
        "show_misclassified_examples(improved_model_1, x_test, y_test)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## 7. Deuxi√®me am√©lioration : R√©gularisation et augmentation de donn√©es\n",
        "\n",
        "Pour notre deuxi√®me am√©lioration, nous allons :\n",
        "- Ajouter du dropout pour √©viter le surapprentissage\n",
        "- Int√©grer la normalisation par batch (batch normalization)\n",
        "- Utiliser l'augmentation de donn√©es pour am√©liorer la g√©n√©ralisation\n",
        "\n",
        "### Architecture du mod√®le am√©lior√© 2"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {},
      "outputs": [],
      "source": [
        "def create_improved_model_2():\n",
        "    \"\"\"Deuxi√®me exemple d'am√©lioration: ajout de dropout et batch normalization\"\"\"\n",
        "    model = Sequential([\n",
        "        # Premi√®re couche de convolution avec batch normalization\n",
        "        Conv2D(32, (3, 3), padding='same', input_shape=(28, 28, 1)),\n",
        "        BatchNormalization(),\n",
        "        Activation('relu'),\n",
        "        MaxPooling2D((2, 2)),\n",
        "        \n",
        "        # Deuxi√®me couche de convolution avec batch normalization\n",
        "        Conv2D(64, (3, 3), padding='same'),\n",
        "        BatchNormalization(),\n",
        "        Activation('relu'),\n",
        "        MaxPooling2D((2, 2)),\n",
        "        \n",
        "        # Aplatissement\n",
        "        Flatten(),\n",
        "        \n",
        "        # Couche dense avec batch normalization et dropout\n",
        "        Dense(128),\n",
        "        BatchNormalization(),\n",
        "        Activation('relu'),\n",
        "        Dropout(0.5),  # 50% de dropout pour la r√©gularisation\n",
        "        \n",
        "        # Couche de sortie\n",
        "        Dense(10, activation='softmax')\n",
        "    ])\n",
        "    \n",
        "    model.compile(\n",
        "        optimizer=Adam(learning_rate=0.001),\n",
        "        loss='sparse_categorical_crossentropy',\n",
        "        metrics=['accuracy']\n",
        "    )\n",
        "    \n",
        "    return model\n",
        "\n",
        "# Cr√©er et afficher le mod√®le am√©lior√© 2\n",
        "improved_model_2 = create_improved_model_2()\n",
        "improved_model_2.summary()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "### Entra√Ænement avec augmentation de donn√©es\n",
        "\n",
        "Pour cette am√©lioration, nous allons √©galement utiliser l'augmentation de donn√©es qui permet de g√©n√©rer artificiellement plus d'exemples d'entra√Ænement en appliquant des transformations aux images existantes. Cela am√©liore la robustesse du mod√®le face aux variations qu'il pourrait rencontrer en conditions r√©elles."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {},
      "outputs": [],
      "source": [
        "print(\"\\n--- Mod√®le am√©lior√© 2 (avec augmentation de donn√©es) ---\")\n",
        "improved_metrics_2 = evaluate_model(improved_model_2, x_train, y_train, x_test, y_test, \n",
        "                                   epochs=15, data_augmentation=True)\n",
        "print(f\"Pr√©cision du mod√®le am√©lior√© 2: {improved_metrics_2['test_accuracy']:.2f}%\")\n",
        "print(f\"Temps d'entra√Ænement: {improved_metrics_2['training_time']:.2f} secondes\")\n",
        "\n",
        "# Visualiser l'historique d'entra√Ænement\n",
        "plot_training_history(improved_metrics_2['history'])\n",
        "\n",
        "# Ajouter au tableau de bord\n",
        "dashboard.add_result(\"Mod√®le am√©lior√© 2\", improved_metrics_2, \n",
        "                    \"Dropout, BatchNorm, augmentation de donn√©es\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "### Analyse des r√©sultats du mod√®le am√©lior√© 2"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {},
      "outputs": [],
      "source": [
        "# Visualiser la matrice de confusion\n",
        "plot_confusion_matrix(improved_model_2, x_test, y_test)\n",
        "\n",
        "# Afficher des exemples d'erreurs\n",
        "print(\"\\nExemples d'erreurs apr√®s la deuxi√®me am√©lioration:\")\n",
        "show_misclassified_examples(improved_model_2, x_test, y_test)\n",
        "\n",
        "# Comparer tous les mod√®les\n",
        "dashboard.show_results()\n",
        "dashboard.plot_comparison()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## 8. Cr√©ation de votre propre mod√®le am√©lior√©\n",
        "\n",
        "C'est maintenant √† vous de concevoir votre propre am√©lioration! Vous pouvez explorer diff√©rentes architectures, techniques d'optimisation, ou combinaisons d'approches.\n",
        "\n",
        "Voici quelques pistes d'am√©lioration possibles:\n",
        "- Essayer diff√©rentes architectures (plus/moins de couches, filtres, etc.)\n",
        "- Exp√©rimenter avec d'autres optimiseurs (RMSprop, SGD avec momentum, etc.)\n",
        "- Tester diff√©rentes techniques de r√©gularisation\n",
        "- Modifier les param√®tres d'augmentation de donn√©es\n",
        "- Utiliser des connexions r√©siduelles (comme dans les architectures ResNet)\n",
        "- Combiner les meilleures pratiques des mod√®les pr√©c√©dents"
      ]
    },
.\n",
        "        \n",
        "        # Couche de sortie\n",
        "        Dense(10, activation='softmax')\n",
        "    ])\n",
        "    \n",
        "    # Compilation\n",
        "    model.compile(\n",
        "        optimizer='adam',  # Modifiez selon vos pr√©f√©rences\n",
        "        loss='sparse_categorical_crossentropy',\n",
        "        metrics=['accuracy']\n",
        "    )\n",
        "    \n",
        "    return model\n",
        "\n",
        "# Si vous √™tes pr√™t √† tester votre mod√®le, d√©commentez les lignes suivantes\n",
        "#your_model = create_your_improved_model()\n",
        "#your_model.summary()"
      ]
    },"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {},
      "outputs": [],
      "source": [
   