<a href="https://colab.research.google.com/github/hibadash/-Breast-Ultrasound-Classification-Using-Xception-CNN-BUSI-Dataset/blob/ModelTraining/notebooks/02_model_training.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# Notebook 02 ‚Äî Entra√Ænement du mod√®le Xception

Ce notebook contient le code complet pour entra√Æner le mod√®le Xception CNN sur le dataset BUSI.

## Objectifs:
- Charger et pr√©parer les donn√©es
- Configurer les hyperparam√®tres (learning rate, batch size, epochs)
- Entra√Æner le mod√®le avec callbacks (EarlyStopping, ModelCheckpoint, ReduceLROnPlateau)
- Visualiser les r√©sultats d'entra√Ænement
- Sauvegarder le meilleur mod√®le

In [None]:
# Cloner le repo GitHub dans un dossier local "breast_project"
import os

# V√©rifie si le dossier "breast_project" existe, sinon clone dedans
if not os.path.exists("breast_project"):
    !git clone https://github.com/hibadash/-Breast-Ultrasound-Classification-Using-Xception-CNN-BUSI-Dataset.git breast_project

# Aller dans le dossier clon√© (utiliser le chemin absolu pour √©viter les imbrications)
import os
breast_project_path = os.path.abspath("breast_project")
os.chdir(breast_project_path)

# V√©rifier qu'on est dans le bon r√©pertoire
print(f"‚úÖ R√©pertoire courant: {os.getcwd()}")
print(f"‚úÖ Contenu du projet:")
!ls

Cloning into 'breast_project'...
remote: Enumerating objects: 1636, done.[K
remote: Counting objects: 100% (52/52), done.[K
remote: Compressing objects: 100% (49/49), done.[K
remote: Total 1636 (delta 24), reused 14 (delta 1), pack-reused 1584 (from 2)[K
Receiving objects: 100% (1636/1636), 194.53 MiB | 19.70 MiB/s, done.
Resolving deltas: 100% (136/136), done.
Updating files: 100% (1589/1589), done.
/content/breast_project
app  data  notebooks  README.md  requirements.txt  results  src


In [None]:
# Installer les d√©pendances si n√©cessaire
!pip install -q tensorflow>=2.15 numpy pandas matplotlib seaborn scikit-learn

In [None]:
import sys
import os
import json
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow.keras.callbacks import (
    ModelCheckpoint,
    EarlyStopping,
    ReduceLROnPlateau,
    CSVLogger,
    TensorBoard
)

# Afficher les versions
print(f"TensorFlow version: {tf.__version__}")
print(f"Python version: {sys.version}")

# V√©rifier la disponibilit√© du GPU
print(f"\nGPU disponible: {tf.config.list_physical_devices('GPU')}")
if tf.config.list_physical_devices('GPU'):
    print("GPU d√©tect√©")
else:
    print("Pas de GPU")


In [None]:
# =======================
# HYPERPARAM√àTRES D'ENTRA√éNEMENT
# =======================

# Hyperparam√®tres principaux
LEARNING_RATE = 1e-4  # Taux d'apprentissage initial
BATCH_SIZE = 32  # Taille du batch
EPOCHS = 50  # Nombre maximum d'epochs
IMAGE_SIZE = (224, 224)  # Taille des images

# Callbacks
PATIENCE_EARLY_STOPPING = 10  # Arr√™t anticip√© si pas d'am√©lioration
PATIENCE_LR_REDUCTION = 5  # R√©duction du LR si pas d'am√©lioration

# Fine-tuning (optionnel)
FINE_TUNE_AFTER_EPOCHS = 20  # Commencer le fine-tuning apr√®s N epochs
FINE_TUNE_LEARNING_RATE = 1e-5  # LR plus faible pour fine-tuning

# Chemins
DATASET_DIR = 'data/Dataset_BUSI'
RESULTS_DIR = 'results'
MODEL_SAVE_PATH = os.path.join(RESULTS_DIR, 'model_xception_best.h5')
MODEL_FINAL_PATH = os.path.join(RESULTS_DIR, 'model_xception_final.h5')

# Cr√©er le dossier results
os.makedirs(RESULTS_DIR, exist_ok=True)

print("Configuration charg√©e!")
print(f"   - Learning rate: {LEARNING_RATE}")
print(f"   - Batch size: {BATCH_SIZE}")
print(f"   - Epochs max: {EPOCHS}")
print(f"   - Early stopping patience: {PATIENCE_EARLY_STOPPING}")


## Pr√©paration des donn√©es

In [None]:
# V√©rifier et corriger preprocess.py AVANT l'import
import os

# V√©rifier le contenu actuel de preprocess.py
preprocess_file = 'data/preprocess.py'
if os.path.exists(preprocess_file):
    with open(preprocess_file, 'r') as f:
        content = f.read()
    
    # V√©rifier si le fichier a le bon code
    if "DATASET_DIR = os.path.join(_project_root, 'data', 'Dataset_BUSI')" not in content:
        print("‚ö†Ô∏è  Correction de preprocess.py...")
        # Remplacer l'ancien code par le nouveau
        import re
        
        # Pattern pour trouver l'ancien DATASET_DIR
        old_patterns = [
            r"DATASET_DIR\s*=\s*['\"]Dataset_BUSI['\"]",
            r"DATASET_DIR\s*=\s*['\"]data/Dataset_BUSI['\"]",
        ]
        
        new_code = """# Chemin du dataset - bas√© sur l'emplacement de ce fichier
_current_file_dir = os.path.dirname(os.path.abspath(__file__))  
_project_root = os.path.dirname(_current_file_dir) 
DATASET_DIR = os.path.join(_project_root, 'data', 'Dataset_BUSI')

# Debug: Afficher le chemin calcul√©
print(f"[DEBUG preprocess.py] __file__ = {__file__}")
print(f"[DEBUG preprocess.py] DATASET_DIR = {DATASET_DIR}")
print(f"[DEBUG preprocess.py] DATASET_DIR existe? {os.path.exists(DATASET_DIR)}")"""
        
        # Remplacer
        for pattern in old_patterns:
            content = re.sub(pattern + r'.*?\n', '', content, flags=re.MULTILINE)
        
        # Ins√©rer le nouveau code apr√®s BATCH_SIZE
        content = re.sub(
            r'(BATCH_SIZE\s*=\s*\d+)',
            r'\1\n\n' + new_code,
            content
        )
        
        # Sauvegarder
        with open(preprocess_file, 'w') as f:
            f.write(content)
        print("‚úÖ preprocess.py corrig√©!")
    else:
        print("‚úÖ preprocess.py est √† jour")
else:
    print(f"‚ùå Fichier {preprocess_file} non trouv√©!")

# V√©rifier le chemin du dataset
print(f"\nüìÅ V√©rification du dataset...")
print(f"   R√©pertoire courant: {os.getcwd()}")
print(f"   Chemin dataset (notebook): {DATASET_DIR}")
print(f"   Dataset existe? {os.path.exists(DATASET_DIR)}")

if not os.path.exists(DATASET_DIR):
    print(f"\n‚ùå ERREUR: Dataset non trouv√© dans {DATASET_DIR}")
    raise FileNotFoundError(f"Dataset non trouv√©: {DATASET_DIR}")
else:
    print(f"‚úÖ Dataset trouv√©!")
    train_path = os.path.join(DATASET_DIR, 'train')
    val_path = os.path.join(DATASET_DIR, 'validation')
    test_path = os.path.join(DATASET_DIR, 'test')
    print(f"   - Train existe: {os.path.exists(train_path)}")
    print(f"   - Validation existe: {os.path.exists(val_path)}")
    print(f"   - Test existe: {os.path.exists(test_path)}")
    print()

# Importer les g√©n√©rateurs de donn√©es depuis preprocess.py
from data.preprocess import (
    train_generator,
    val_generator,
    test_generator
)

# Afficher les informations sur le dataset
print("üìä INFORMATIONS SUR LE DATASET")
print("="*60)
print(f"Images d'entra√Ænement: {train_generator.samples}")
print(f"Images de validation: {val_generator.samples}")
print(f"Images de test: {test_generator.samples}")
print(f"\nClasses: {sorted(list(train_generator.class_indices.keys()))}")
print(f"Mapping des classes: {train_generator.class_indices}")
print(f"Taille des images: {IMAGE_SIZE}")
print(f"Batch size: {BATCH_SIZE}")
print("="*60)


## Construction du mod√®le Xception

In [None]:
# Importer la fonction de cr√©ation du mod√®le
from src.model import load_xception_model

# Cr√©er le mod√®le
num_classes = len(train_generator.class_indices)
print(f"Cr√©ation du mod√®le Xception pour {num_classes} classes...")

model = load_xception_model(
    input_shape=(*IMAGE_SIZE, 3),
    num_classes=num_classes,
    trainable=False  # Commencer avec les couches de base gel√©es
)

print("Mod√®le cr√©√© avec succ√®s!")
print(f"\n Statistiques du mod√®le:")
print(f"   - Param√®tres totaux: {model.count_params():,}")
print(f"   - Param√®tres entra√Ænables: {sum([tf.keras.backend.count_params(w) for w in model.trainable_weights]):,}")

# Afficher un r√©sum√© du mod√®le
model.summary()

## Configuration des callbacks

In [None]:
# Cr√©er tous les callbacks n√©cessaires
callbacks = [
    # 1. Sauvegarde du meilleur mod√®le
    ModelCheckpoint(
        filepath=MODEL_SAVE_PATH,
        monitor='val_accuracy',
        mode='max',
        save_best_only=True,
        save_weights_only=False,
        verbose=1,
        save_fmt='h5'
    ),

    # 2. Early Stopping
    EarlyStopping(
        monitor='val_loss',
        mode='min',
        patience=PATIENCE_EARLY_STOPPING,
        restore_best_weights=True,
        verbose=1,
        min_delta=0.0001
    ),

    # 3. R√©duction automatique du learning rate
    ReduceLROnPlateau(
        monitor='val_loss',
        mode='min',
        factor=0.5,  # R√©duire le LR de moiti√©
        patience=PATIENCE_LR_REDUCTION,
        min_lr=1e-7,
        verbose=1
    ),

    # 4. Logger CSV
    CSVLogger(
        filename=os.path.join(RESULTS_DIR, 'training_log.csv'),
        separator=',',
        append=False
    ),

    # 5. TensorBoard (pour Colab, on peut utiliser TensorBoard.dev)
    TensorBoard(
        log_dir=os.path.join(RESULTS_DIR, 'tensorboard_logs'),
        histogram_freq=1,
        write_graph=True,
        update_freq='epoch'
    )
]

print(f" {len(callbacks)} callbacks configur√©s:")
print("   1. ModelCheckpoint - Sauvegarde du meilleur mod√®le")
print("   2. EarlyStopping - Arr√™t anticip√©")
print("   3. ReduceLROnPlateau - R√©duction du learning rate")
print("   4. CSVLogger - Log des m√©triques")
print("   5. TensorBoard - Visualisation avanc√©e")


In [None]:
# Calculer les steps par epoch
steps_per_epoch = train_generator.samples // BATCH_SIZE
validation_steps = val_generator.samples // BATCH_SIZE

print("D√âBUT DE L'ENTRA√éNEMENT")
print("="*60)
print(f"Steps par epoch: {steps_per_epoch}")
print(f"Validation steps: {validation_steps}")
print("="*60)
print()

# Lancer l'entra√Ænement
history = model.fit(
    train_generator,
    steps_per_epoch=steps_per_epoch,
    validation_data=val_generator,
    validation_steps=validation_steps,
    epochs=EPOCHS,
    callbacks=callbacks,
    verbose=1
)

print("\n ENTRA√éNEMENT TERMIN√â!")


## Sauvegarde du mod√®le final

In [None]:
# Sauvegarder le mod√®le final (apr√®s tous les epochs)
model.save(MODEL_FINAL_PATH)
print(f" Mod√®le final sauvegard√©: {MODEL_FINAL_PATH}")

# Sauvegarder l'historique en JSON
history_dict = {}
for key, values in history.history.items():
    history_dict[key] = [float(v) for v in values]

history_path = os.path.join(RESULTS_DIR, 'training_history.json')
with open(history_path, 'w') as f:
    json.dump(history_dict, f, indent=2)

print(f" Historique sauvegard√©: {history_path}")


## Visualisation des r√©sultats

In [None]:
# Cr√©er les graphiques d'entra√Ænement
fig, axes = plt.subplots(1, 2, figsize=(15, 5))

# Plot 1: Loss
axes[0].plot(history.history['loss'], label='Train Loss', marker='o', linewidth=2)
axes[0].plot(history.history['val_loss'], label='Validation Loss', marker='s', linewidth=2)
axes[0].set_title('Model Loss', fontsize=14, fontweight='bold')
axes[0].set_xlabel('Epoch')
axes[0].set_ylabel('Loss')
axes[0].legend()
axes[0].grid(True, alpha=0.3)

# Plot 2: Accuracy
axes[1].plot(history.history['accuracy'], label='Train Accuracy', marker='o', linewidth=2)
axes[1].plot(history.history['val_accuracy'], label='Validation Accuracy', marker='s', linewidth=2)
axes[1].set_title('Model Accuracy', fontsize=14, fontweight='bold')
axes[1].set_xlabel('Epoch')
axes[1].set_ylabel('Accuracy')
axes[1].legend()
axes[1].grid(True, alpha=0.3)

plt.tight_layout()
plot_path = os.path.join(RESULTS_DIR, 'training_history_plot.png')
plt.savefig(plot_path, dpi=300, bbox_inches='tight')
print(f" Graphique sauvegard√©: {plot_path}")
plt.show()


## R√©sum√© des performances

In [None]:
# Afficher un r√©sum√© des performances
print("\n" + "="*70)
print("R√âSUM√â DE L'ENTRA√éNEMENT")
print("="*70)

final_train_acc = history.history['accuracy'][-1]
final_val_acc = history.history['val_accuracy'][-1]
best_val_acc = max(history.history['val_accuracy'])
best_epoch = history.history['val_accuracy'].index(best_val_acc) + 1

print(f"   - Epochs effectu√©s: {len(history.history['loss'])}")
print(f"   - Meilleure validation accuracy: {best_val_acc:.4f} ({best_val_acc*100:.2f}%) - Epoch {best_epoch}")
print(f"   - Accuracy finale (train): {final_train_acc:.4f} ({final_train_acc*100:.2f}%)")
print(f"   - Accuracy finale (validation): {final_val_acc:.4f} ({final_val_acc*100:.2f}%)")
print(f"   - Loss finale (train): {history.history['loss'][-1]:.4f}")
print(f"   - Loss finale (validation): {history.history['val_loss'][-1]:.4f}")

# Afficher le learning rate final si disponible
if 'lr' in history.history:
    final_lr = history.history['lr'][-1]
    print(f"   - Learning rate final: {final_lr:.2e}")

print("="*70)
print(f"\n Mod√®le sauvegard√© dans: {MODEL_SAVE_PATH}")
