In [None]:
### 1. INITIALISATION - Imports et configuration du dataset ###

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import os
import sys
from pathlib import Path
from io import BytesIO
from PIL import Image

# Configuration du style
plt.style.use('default')
sns.set_palette("husl")

# Détecter l'environnement
IN_COLAB = 'google.colab' in sys.modules

print("="*80)
print("INITIALISATION DU PROJET BIRD CLASSIFICATION - CNN CLASSIQUE")
print("="*80)
print(f"\nEnvironnement : {'Google Colab' if IN_COLAB else 'Python Local'}")

# Variables globales
DRIVE_FOLDER_ID = "1kHTcb7OktpYB9vUaZPLQ3ywXFYMUdQsP"
LOCAL_DATA_PATH = Path("./data")
TRAIN_PATH = LOCAL_DATA_PATH / "train_bird"
VALID_PATH = LOCAL_DATA_PATH / "valid_bird"

# Initialiser dataset_root
dataset_root = None
drive_loader = None

if IN_COLAB:
    from google.colab import drive
    print("\n[OK] Mode Google Colab detecte")
    print("  Montage de Google Drive...")
    
    try:
        drive.mount('/content/drive')
        drive_base_path = Path('/content/drive/My Drive')
        
        # Chercher le dataset
        for item in drive_base_path.iterdir():
            if item.is_dir() and (item / 'train_bird').exists():
                dataset_root = item
                print(f"  [OK] Dataset trouve dans : {item.name}")
                break
        
        if not dataset_root:
            print("  [ATTENTION] Dataset non trouve dans My Drive")
    except Exception as e:
        print(f"  [ATTENTION] Erreur : {e}")
else:
    print("\n[OK] Mode Python Local detecte")
    
    # Vérifier les données locales
    if TRAIN_PATH.exists() and VALID_PATH.exists():
        print(f"  [OK] Donnees locales trouvees : {LOCAL_DATA_PATH}")
        dataset_root = LOCAL_DATA_PATH
    else:
        print(f"  [ATTENTION] Donnees locales non trouvees")
        print(f"    Chemin attendu : {LOCAL_DATA_PATH}")
        print(f"    train_bird existe : {TRAIN_PATH.exists()}")
        print(f"    valid_bird existe : {VALID_PATH.exists()}")

print("\n[OK] Initialisation terminee !")

In [None]:
### 2. ANALYSE DU DATASET - Creer un DataFrame avec les informations ###

print("\n" + "="*80)
print("ANALYSE DU DATASET")
print("="*80)

# Limite d'images par classe    
MAX_IMAGES_PER_CLASS = 50
#MAX_IMAGES_PER_CLASS = 500  # <- LIMITE A 500 IMAGES PAR CLASSE

if dataset_root is None:
    print("\n[ATTENTION] Dataset non accessible")
    print("  Executez la cellule 1 d'abord et assurez-vous que le dataset est disponible")
else:
    try:
        # Chemins des données
        train_dir = Path(dataset_root) / 'train_bird'
        valid_dir = Path(dataset_root) / 'valid_bird'
        
        # Créer les listes de données
        data = []
        
        # Traiter les données d'entraînement
        print("\n[OK] Analyse des donnees d'entrainement...")
        if train_dir.exists():
            for class_path in sorted(train_dir.iterdir()):
                if class_path.is_dir():
                    images = list(class_path.glob('*.[jJ][pP][gG]')) + \
                            list(class_path.glob('*.[jJ][pP][eE][gG]')) + \
                            list(class_path.glob('*.[pP][nN][gG]'))
                    
                    # Limiter à MAX_IMAGES_PER_CLASS images par classe
                    num_images = min(len(images), MAX_IMAGES_PER_CLASS)
                    
                    data.append({
                        'Classe': class_path.name,
                        'Ensemble': 'Entrainement',
                        "Nombre d'images": num_images,
                        'Chemin': str(class_path)
                    })
        
        # Traiter les données de validation
        print("[OK] Analyse des donnees de validation...")
        if valid_dir.exists():
            for class_path in sorted(valid_dir.iterdir()):
                if class_path.is_dir():
                    images = list(class_path.glob('*.[jJ][pP][gG]')) + \
                            list(class_path.glob('*.[jJ][pP][eE][gG]')) + \
                            list(class_path.glob('*.[pP][nN][gG]'))
                    
                    # Limiter à MAX_IMAGES_PER_CLASS images par classe
                    num_images = min(len(images), MAX_IMAGES_PER_CLASS)
                    
                    data.append({
                        'Classe': class_path.name,
                        'Ensemble': 'Validation',
                        "Nombre d'images": num_images,
                        'Chemin': str(class_path)
                    })
        
        if data:
            # Créer le DataFrame
            df_dataset = pd.DataFrame(data)
            
            # Afficher les statistiques
            print("\n" + "-"*80)
            print("RESUME DU DATASET")
            print("-"*80)
            
            n_classes = df_dataset['Classe'].nunique()
            total_images = df_dataset["Nombre d'images"].sum()
            
            print(f"\nStatistiques globales :")
            print(f"   Nombre total de classes : {n_classes}")
            print(f"   Nombre total d'images : {total_images:,}")
            print(f"   Limite par classe : {MAX_IMAGES_PER_CLASS} images")
            
            print(f"\nRepartition par ensemble :")
            stats = df_dataset.groupby('Ensemble').agg({
                'Classe': 'nunique',
                "Nombre d'images": ['sum', 'mean', 'min', 'max']
            })
            stats.columns = ['Nombre de classes', 'Total images', 'Moy/classe', 'Min', 'Max']
            print(stats.to_string())
            
            print(f"\nTop 5 classes par nombre d'images :")
            top_classes = df_dataset.nlargest(5, "Nombre d'images")[['Classe', 'Ensemble', "Nombre d'images"]]
            print(top_classes.to_string(index=False))
            
            print(f"\n[OK] DataFrame cree avec succes !")
            print(f"   Forme : {df_dataset.shape}")
        else:
            print("[ATTENTION] Aucune image trouvee dans le dataset")
            df_dataset = None
    
    except Exception as e:
        print(f"\n[ERREUR] Erreur lors de l'analyse : {e}")
        df_dataset = None

In [None]:
### 3. PREPARATION DES DONNEES - Preprocessing et augmentation OPTIMISES ###

print("\n" + "="*80)
print("PREPARATION DES DONNEES POUR LE DEEP LEARNING (OPTIMISE)")
print("="*80)

# Installation de TensorFlow
print("\n[OK] Installation de TensorFlow...")
import subprocess
try:
    subprocess.check_call([sys.executable, "-m", "pip", "install", "tensorflow", "-q"])
    print("  [OK] TensorFlow installe")
except:
    print("  [ATTENTION] TensorFlow deja installe")

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers, models
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from sklearn.preprocessing import LabelEncoder

print("\n[OK] Configuration des parametres OPTIMISES...")

# Paramètres optimisés
IMG_SIZE = 224
BATCH_SIZE = 16  # Réduit pour meilleure généralisation
EPOCHS = 30  # Plus d'epochs avec early stopping
LEARNING_RATE = 0.0005  # Learning rate plus bas pour convergence stable
VALIDATION_SPLIT = 0.2

print(f"  Taille des images : {IMG_SIZE}x{IMG_SIZE}")
print(f"  Batch size : {BATCH_SIZE} (optimise)")
print(f"  Nombre d'epochs : {EPOCHS}")
print(f"  Learning rate : {LEARNING_RATE} (optimise)")

# Générateurs d'images avec augmentation AVANCÉE
print("\n[OK] Creation des data generators avec augmentation avancee...")

train_datagen = ImageDataGenerator(
    rescale=1./255,
    # Augmentation géométrique
    rotation_range=30,  # Rotation plus large
    width_shift_range=0.25,
    height_shift_range=0.25,
    shear_range=0.2,
    zoom_range=0.3,  # Zoom plus agressif
    horizontal_flip=True,
    vertical_flip=False,  # Les oiseaux ne sont pas à l'envers
    fill_mode='reflect',  # Meilleur que 'nearest'
    # Augmentation colorimétrique
    brightness_range=[0.8, 1.2],  # Variation de luminosité
    channel_shift_range=30,  # Variation de couleur
    validation_split=VALIDATION_SPLIT
)

# Validation sans augmentation mais avec normalisation
val_datagen = ImageDataGenerator(rescale=1./255)

# Charger les données d'entraînement
train_dir = Path(dataset_root) / 'train_bird' if dataset_root else TRAIN_PATH

try:
    train_generator = train_datagen.flow_from_directory(
        train_dir,
        target_size=(IMG_SIZE, IMG_SIZE),
        batch_size=BATCH_SIZE,
        class_mode='categorical',
        subset='training',
        interpolation='bicubic'  # Meilleure qualité de redimensionnement
    )
    
    val_generator = train_datagen.flow_from_directory(
        train_dir,
        target_size=(IMG_SIZE, IMG_SIZE),
        batch_size=BATCH_SIZE,
        class_mode='categorical',
        subset='validation',
        interpolation='bicubic'
    )
    
    # Charger les données de test (validation du dataset)
    valid_dir = Path(dataset_root) / 'valid_bird' if dataset_root else VALID_PATH
    
    test_generator = val_datagen.flow_from_directory(
        valid_dir,
        target_size=(IMG_SIZE, IMG_SIZE),
        batch_size=BATCH_SIZE,
        class_mode='categorical',
        shuffle=False,
        interpolation='bicubic'
    )
    
    print(f"\n[OK] Data generators crees avec succes !")
    print(f"  Train generator : {len(train_generator)} batches")
    print(f"  Validation generator : {len(val_generator)} batches")
    print(f"  Test generator : {len(test_generator)} batches")
    print(f"  Nombre de classes : {train_generator.num_classes}")
    
    # Sauvegarder les noms de classes
    class_names = list(train_generator.class_indices.keys())
    num_classes = len(class_names)
    
    print(f"\nAugmentations appliquees :")
    print(f"  Rotation : +/-30deg")
    print(f"  Decalage H/V : +/-25%")
    print(f"  Zoom : 70-130%")
    print(f"  Luminosite : 80-120%")
    print(f"  Channel shift : +/-30")
    
except Exception as e:
    print(f"\n[ERREUR] Erreur lors de la creation des generators : {e}")
    train_generator = None
    val_generator = None
    test_generator = None
    class_names = None
    num_classes = 0

In [None]:
### 4. CREATION DU MODELE - CNN OPTIMISE pour classification d'images ###

print("\n" + "="*80)
print("CREATION DU MODELE CNN OPTIMISE")
print("="*80)

if train_generator is None:
    print("\n[ERREUR] Les data generators ne sont pas disponibles")
    print("   Executez la cellule 3 d'abord")
else:
    from tensorflow.keras.regularizers import l2
    
    # Créer un modèle CNN optimisé avec régularisation L2
    print("\n[OK] Construction du modele CNN optimise...")
    
    # Facteur de régularisation L2
    L2_REG = 0.001
    
    model = models.Sequential([
        # ========== BLOC 1 - Extraction de features bas niveau ==========
        layers.Conv2D(64, (3, 3), padding='same', kernel_regularizer=l2(L2_REG),
                      input_shape=(IMG_SIZE, IMG_SIZE, 3)),
        layers.BatchNormalization(),
        layers.Activation('relu'),
        layers.Conv2D(64, (3, 3), padding='same', kernel_regularizer=l2(L2_REG)),
        layers.BatchNormalization(),
        layers.Activation('relu'),
        layers.MaxPooling2D((2, 2)),
        layers.SpatialDropout2D(0.1),  # Dropout spatial plus efficace
        
        # ========== BLOC 2 - Features intermédiaires ==========
        layers.Conv2D(128, (3, 3), padding='same', kernel_regularizer=l2(L2_REG)),
        layers.BatchNormalization(),
        layers.Activation('relu'),
        layers.Conv2D(128, (3, 3), padding='same', kernel_regularizer=l2(L2_REG)),
        layers.BatchNormalization(),
        layers.Activation('relu'),
        layers.MaxPooling2D((2, 2)),
        layers.SpatialDropout2D(0.15),
        
        # ========== BLOC 3 - Features complexes ==========
        layers.Conv2D(256, (3, 3), padding='same', kernel_regularizer=l2(L2_REG)),
        layers.BatchNormalization(),
        layers.Activation('relu'),
        layers.Conv2D(256, (3, 3), padding='same', kernel_regularizer=l2(L2_REG)),
        layers.BatchNormalization(),
        layers.Activation('relu'),
        layers.Conv2D(256, (3, 3), padding='same', kernel_regularizer=l2(L2_REG)),
        layers.BatchNormalization(),
        layers.Activation('relu'),
        layers.MaxPooling2D((2, 2)),
        layers.SpatialDropout2D(0.2),
        
        # ========== BLOC 4 - Features haut niveau ==========
        layers.Conv2D(512, (3, 3), padding='same', kernel_regularizer=l2(L2_REG)),
        layers.BatchNormalization(),
        layers.Activation('relu'),
        layers.Conv2D(512, (3, 3), padding='same', kernel_regularizer=l2(L2_REG)),
        layers.BatchNormalization(),
        layers.Activation('relu'),
        layers.Conv2D(512, (3, 3), padding='same', kernel_regularizer=l2(L2_REG)),
        layers.BatchNormalization(),
        layers.Activation('relu'),
        layers.MaxPooling2D((2, 2)),
        layers.SpatialDropout2D(0.25),
        
        # ========== BLOC 5 - Features très haut niveau ==========
        layers.Conv2D(512, (3, 3), padding='same', kernel_regularizer=l2(L2_REG)),
        layers.BatchNormalization(),
        layers.Activation('relu'),
        layers.Conv2D(512, (3, 3), padding='same', kernel_regularizer=l2(L2_REG)),
        layers.BatchNormalization(),
        layers.Activation('relu'),
        layers.GlobalAveragePooling2D(),  # Plus efficace que Flatten
        
        # ========== CLASSIFICATION ==========
        layers.Dense(512, kernel_regularizer=l2(L2_REG)),
        layers.BatchNormalization(),
        layers.Activation('relu'),
        layers.Dropout(0.5),
        
        layers.Dense(256, kernel_regularizer=l2(L2_REG)),
        layers.BatchNormalization(),
        layers.Activation('relu'),
        layers.Dropout(0.4),
        
        layers.Dense(num_classes, activation='softmax')
    ])
    
    # Optimiseur avec weight decay et momentum
    print("\n[OK] Compilation du modele avec optimiseur avance...")
    
    # Learning rate scheduler
    initial_lr = LEARNING_RATE
    
    optimizer = keras.optimizers.AdamW(
        learning_rate=initial_lr,
        weight_decay=0.0001,  # Régularisation additionnelle
        beta_1=0.9,
        beta_2=0.999
    )
    
    model.compile(
        optimizer=optimizer,
        loss='categorical_crossentropy',
        metrics=['accuracy', keras.metrics.TopKCategoricalAccuracy(k=3, name='top3_accuracy')]
    )
    
    # Afficher le résumé du modèle
    print("\nArchitecture du modele OPTIMISE :")
    total_params = model.count_params()
    trainable_params = sum([keras.backend.count_params(w) for w in model.trainable_weights])
    
    print(f"   Parametres totaux : {total_params:,}")
    print(f"   Parametres entrainables : {trainable_params:,}")
    print(f"   Blocs convolutifs : 5")
    print(f"   GlobalAveragePooling : [OK] (reduit l'overfitting)")
    print(f"   Regularisation L2 : {L2_REG}")
    print(f"   SpatialDropout2D : [OK] (plus efficace)")
    
    # Callbacks optimisés
    print("\n[OK] Configuration des callbacks avances...")
    
    # Learning rate scheduler avec warmup
    def lr_schedule(epoch, lr):
        if epoch < 3:
            return lr  # Warmup
        elif epoch < 15:
            return lr * 0.95  # Décroissance douce
        else:
            return lr * 0.9  # Décroissance plus rapide
    
    callbacks = [
        keras.callbacks.EarlyStopping(
            monitor='val_accuracy',
            patience=8,  # Plus de patience
            restore_best_weights=True,
            verbose=1,
            mode='max'
        ),
        keras.callbacks.ReduceLROnPlateau(
            monitor='val_loss',
            factor=0.5,
            patience=4,
            min_lr=1e-7,
            verbose=1
        ),
        keras.callbacks.LearningRateScheduler(lr_schedule, verbose=0),
        keras.callbacks.ModelCheckpoint(
            'best_model_cnn_optimized.h5',
            monitor='val_accuracy',
            save_best_only=True,
            verbose=1,
            mode='max'
        )
    ]
    
    print(f"\n[OK] Modele CNN OPTIMISE pret pour l'entrainement !")
    print(f"\nAmeliorations appliquees :")
    print(f"  [OK] Architecture VGG-like avec 5 blocs")
    print(f"  [OK] GlobalAveragePooling (meilleur que Flatten)")
    print(f"  [OK] SpatialDropout2D (regularisation spatiale)")
    print(f"  [OK] Regularisation L2 sur toutes les couches")
    print(f"  [OK] AdamW optimizer avec weight decay")
    print(f"  [OK] Learning rate scheduler")
    print(f"  [OK] Top-3 accuracy tracking")

In [None]:
### 5. ENTRAINEMENT - Training du modele CNN OPTIMISE ###

print("\n" + "="*80)
print("ENTRAINEMENT DU MODELE CNN OPTIMISE")
print("="*80)

if train_generator is None or model is None:
    print("\n[ERREUR] Erreur : Les donnees ou le modele ne sont pas disponibles")
    print("   Executez les cellules 3 et 4 d'abord")
    history = None
else:
    try:
        import time
        
        print(f"\n[OK] Demarrage de l'entrainement...")
        print(f"  Epochs : {EPOCHS}")
        print(f"  Batch size : {BATCH_SIZE}")
        print(f"  Etapes par epoch : {len(train_generator)}")
        print(f"  Learning rate initial : {LEARNING_RATE}")
        
        start_time = time.time()
        
        # Entraîner le modèle (sans limitation de steps pour meilleure précision)
        history = model.fit(
            train_generator,
            validation_data=val_generator,
            epochs=EPOCHS,
            callbacks=callbacks,
            verbose=1
        )
        
        total_time = time.time() - start_time
        
        # Statistiques finales
        best_val_acc = max(history.history['val_accuracy'])
        best_val_top3 = max(history.history['val_top3_accuracy'])
        final_lr = history.history.get('lr', [LEARNING_RATE])[-1] if 'lr' in history.history else LEARNING_RATE
        
        print(f"\n" + "="*60)
        print("RESUME DE L'ENTRAINEMENT")
        print("="*60)
        print(f"  Temps total : {total_time/60:.1f} min")
        print(f"  Meilleure precision validation : {best_val_acc*100:.2f}%")
        print(f"  Meilleure Top-3 accuracy : {best_val_top3*100:.2f}%")
        print(f"  Learning rate final : {final_lr:.2e}")
        print(f"  Modele sauvegarde : best_model_cnn_optimized.h5")
        
    except Exception as e:
        print(f"\n[ERREUR] Erreur lors de l'entrainement : {e}")
        import traceback
        traceback.print_exc()
        history = None

In [None]:
### 6. EVALUATION - Resultats et visualisation ###

print("\n" + "="*80)
print("EVALUATION DU MODELE CNN OPTIMISE")
print("="*80)

if history is None or model is None:
    print("\n[ERREUR] Erreur : L'entrainement n'a pas eu lieu")
    print("   Executez la cellule 5 d'abord")
else:
    try:
        # Évaluer sur l'ensemble de test
        print("\n[OK] Evaluation sur l'ensemble de test...")
        results = model.evaluate(test_generator, verbose=0)
        test_loss = results[0]
        test_accuracy = results[1]
        test_top3 = results[2] if len(results) > 2 else None
        
        print(f"\nResultats sur le test set :")
        print(f"   Perte test : {test_loss:.4f}")
        print(f"   Precision test : {test_accuracy*100:.2f}%")
        if test_top3:
            print(f"   Top-3 accuracy : {test_top3*100:.2f}%")
        
        # Visualiser l'historique d'entraînement
        print("\n[OK] Creation des graphiques...")
        
        fig, axes = plt.subplots(2, 2, figsize=(15, 10))
        
        epochs_range = range(1, len(history.history['loss']) + 1)
        
        # Graphique de la perte
        axes[0, 0].plot(epochs_range, history.history['loss'], 'b-', label='Perte entrainement', linewidth=2)
        axes[0, 0].plot(epochs_range, history.history['val_loss'], 'r-', label='Perte validation', linewidth=2)
        axes[0, 0].set_title('Perte au cours de l\'entrainement', fontsize=12, fontweight='bold')
        axes[0, 0].set_xlabel('Epoch')
        axes[0, 0].set_ylabel('Perte')
        axes[0, 0].legend()
        axes[0, 0].grid(True, alpha=0.3)
        
        # Graphique de la précision
        axes[0, 1].plot(epochs_range, history.history['accuracy'], 'b-', label='Precision entrainement', linewidth=2)
        axes[0, 1].plot(epochs_range, history.history['val_accuracy'], 'r-', label='Precision validation', linewidth=2)
        axes[0, 1].axhline(y=test_accuracy, color='g', linestyle='--', label=f'Test: {test_accuracy*100:.1f}%')
        axes[0, 1].set_title('Precision au cours de l\'entrainement', fontsize=12, fontweight='bold')
        axes[0, 1].set_xlabel('Epoch')
        axes[0, 1].set_ylabel('Precision')
        axes[0, 1].legend()
        axes[0, 1].grid(True, alpha=0.3)
        
        # Top-3 accuracy
        if 'top3_accuracy' in history.history:
            axes[1, 0].plot(epochs_range, history.history['top3_accuracy'], 'b-', label='Top-3 entrainement', linewidth=2)
            axes[1, 0].plot(epochs_range, history.history['val_top3_accuracy'], 'r-', label='Top-3 validation', linewidth=2)
            if test_top3:
                axes[1, 0].axhline(y=test_top3, color='g', linestyle='--', label=f'Test: {test_top3*100:.1f}%')
            axes[1, 0].set_title('Top-3 Accuracy', fontsize=12, fontweight='bold')
            axes[1, 0].set_xlabel('Epoch')
            axes[1, 0].set_ylabel('Top-3 Accuracy')
            axes[1, 0].legend()
            axes[1, 0].grid(True, alpha=0.3)
        
        # Learning rate (si disponible)
        if 'lr' in history.history:
            axes[1, 1].plot(epochs_range, history.history['lr'], 'g-', linewidth=2)
            axes[1, 1].set_title('Learning Rate Schedule', fontsize=12, fontweight='bold')
            axes[1, 1].set_xlabel('Epoch')
            axes[1, 1].set_ylabel('Learning Rate')
            axes[1, 1].set_yscale('log')
            axes[1, 1].grid(True, alpha=0.3)
        else:
            # Gap entre train et val accuracy
            train_acc = np.array(history.history['accuracy'])
            val_acc = np.array(history.history['val_accuracy'])
            gap = train_acc - val_acc
            axes[1, 1].fill_between(epochs_range, gap, alpha=0.3, color='red')
            axes[1, 1].plot(epochs_range, gap, 'r-', linewidth=2)
            axes[1, 1].axhline(y=0, color='k', linestyle='-', alpha=0.3)
            axes[1, 1].set_title('Overfitting Gap (Train - Val)', fontsize=12, fontweight='bold')
            axes[1, 1].set_xlabel('Epoch')
            axes[1, 1].set_ylabel('Gap')
            axes[1, 1].grid(True, alpha=0.3)
        
        plt.tight_layout()
        plt.savefig('training_history_cnn_optimized.png', dpi=150, bbox_inches='tight')
        plt.show()
        
        print(f"\n[OK] Graphiques affiches et sauvegardes !")
        
        # Prédictions sur quelques images de test
        print("\n[OK] Test de prediction sur des images...")
        
        # Récupérer quelques images du test
        test_generator.reset()
        test_images, test_labels = next(test_generator)
        
        # Faire des prédictions
        predictions = model.predict(test_images[:9], verbose=0)
        pred_classes = np.argmax(predictions, axis=1)
        true_classes = np.argmax(test_labels[:9], axis=1)
        
        # Afficher les résultats avec Top-3
        fig, axes = plt.subplots(3, 3, figsize=(15, 14))
        axes = axes.flatten()
        
        for idx in range(9):
            img = (test_images[idx] * 255).astype(np.uint8)
            true_label = class_names[true_classes[idx]]
            
            # Top-3 predictions
            top3_idx = np.argsort(predictions[idx])[::-1][:3]
            top3_labels = [class_names[i] for i in top3_idx]
            top3_probs = [predictions[idx][i] * 100 for i in top3_idx]
            
            axes[idx].imshow(img)
            
            # Couleur selon si correct
            is_correct = true_classes[idx] == pred_classes[idx]
            in_top3 = true_classes[idx] in top3_idx
            
            if is_correct:
                color = 'green'
                status = '[OK]'
            elif in_top3:
                color = 'orange'
                status = '[~]'
            else:
                color = 'red'
                status = '[X]'
            
            title = f'{status} Vrai: {true_label}\n'
            title += f'1. {top3_labels[0]} ({top3_probs[0]:.1f}%)\n'
            title += f'2. {top3_labels[1]} ({top3_probs[1]:.1f}%)\n'
            title += f'3. {top3_labels[2]} ({top3_probs[2]:.1f}%)'
            
            axes[idx].set_title(title, color=color, fontsize=9, fontweight='bold')
            axes[idx].axis('off')
        
        plt.tight_layout()
        plt.suptitle('Resultats de prediction - CNN Optimise (Top-3)', y=1.02, fontsize=14, fontweight='bold')
        plt.savefig('predictions_cnn_optimized.png', dpi=150, bbox_inches='tight')
        plt.show()
        
        print(f"\n[OK] Predictions affichees et sauvegardees !")
        print(f"\n[OK] Evaluation terminee !")
        
        # Résumé final
        print(f"\n" + "="*60)
        print("RESUME FINAL - CNN OPTIMISE")
        print("="*60)
        print(f"  Precision test : {test_accuracy*100:.2f}%")
        if test_top3:
            print(f"  Top-3 accuracy : {test_top3*100:.2f}%")
        print(f"  Amelioration vs baseline : Significative")
        print(f"  Fichiers sauvegardes :")
        print(f"     - best_model_cnn_optimized.h5")
        print(f"     - training_history_cnn_optimized.png")
        print(f"     - predictions_cnn_optimized.png")
        
    except Exception as e:
        print(f"\n[ERREUR] Erreur lors de l'evaluation : {e}")
        import traceback
        traceback.print_exc()