<a href="https://colab.research.google.com/gist/maclandrol/ab9a6ec3c96162e39c65c34e75596095/07_TorchXRayVision_Classification.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

**Enseignant:** Emmanuel Noutahi, PhD

# Tutorial 2: Classification Avanc√©e de Radiographies avec TorchXRayVision
## Multi-Pathologies et Comparaison de Mod√®les

---

## üéØ Objectifs de ce Tutorial

Dans ce deuxi√®me tutorial, vous allez ma√Ætriser :

1. **Classification multi-pathologies** avec diff√©rents mod√®les
2. **Comparaison de performances** entre architectures
3. **Analyse de cas cliniques** complexes
4. **M√©triques d'√©valuation** avanc√©es
5. **Interpr√©tation clinique** des r√©sultats
6. **Workflow hospitalier** int√©gr√©

---

## üè• Pr√©requis

Ce tutorial fait suite au **Tutorial 1**. Assurez-vous d'avoir :
- ‚úÖ Termin√© le Tutorial 1 (Introduction)
- ‚úÖ TorchXRayVision install√© et configur√©
- ‚úÖ Compr√©hension du preprocessing
- ‚úÖ Bases de l'interpr√©tation IA

---

## üî¨ Introduction M√©dicale Avanc√©e

### D√©fis de la Classification Multi-Pathologies

En pratique clinique r√©elle, les radiographies pr√©sentent souvent :
- **Pathologies multiples simultan√©es**
- **Degr√©s de s√©v√©rit√© variables**
- **Pr√©sentations atypiques**
- **Art√©facts et limitations techniques**

### üìä Enjeux Cliniques

#### **Sensibilit√© vs Sp√©cificit√© :**
- **Sensibilit√© √©lev√©e** : Ne pas manquer de pathologies (faux n√©gatifs)
- **Sp√©cificit√© √©lev√©e** : √âviter les fausses alertes (faux positifs)
- **Trade-off clinique** selon le contexte (urgences vs d√©pistage)

#### **Impact des Mod√®les :**
- **Mod√®les g√©n√©ralistes** : Performance globale √©quilibr√©e
- **Mod√®les sp√©cialis√©s** : Excellence sur pathologies cibl√©es
- **Ensembles de mod√®les** : Robustesse et fiabilit√© accrues

---

## üîß Configuration Avanc√©e

In [None]:
# Configuration avanc√©e pour classification multi-mod√®les
import sys
import torch
import warnings
warnings.filterwarnings('ignore')

print("üè• TUTORIAL 2: CLASSIFICATION AVANC√âE TORCHXRAYVISION")
print("=" * 65)

# V√©rifications syst√®me
print(f"\nüîß Configuration syst√®me :")
if 'google.colab' in sys.modules:
    print("   ‚Ä¢ Environnement : Google Colab ‚úÖ")
    IN_COLAB = True
else:
    print("   ‚Ä¢ Environnement : Local")
    IN_COLAB = False

# Configuration GPU optimis√©e
if torch.cuda.is_available():
    device = "cuda"
    print(f"   ‚Ä¢ GPU : {torch.cuda.get_device_name(0)} ‚úÖ")
    print(f"   ‚Ä¢ M√©moire : {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
    
    # Configuration CUDA optimis√©e
    torch.backends.cudnn.benchmark = True
    torch.backends.cudnn.deterministic = False
    print(f"   ‚Ä¢ CUDNN optimis√© : ‚úÖ")
else:
    device = "cpu"
    print(f"   ‚Ä¢ GPU : Non disponible")
    print(f"   ‚Ä¢ Mode : CPU (plus lent pour multi-mod√®les)")

# Configuration pour reproducibilit√©
torch.manual_seed(42)
if torch.cuda.is_available():
    torch.cuda.manual_seed(42)
    torch.cuda.manual_seed_all(42)

print(f"\nüéØ Dispositif s√©lectionn√© : {device}")
print(f"üîÑ Reproductibilit√© : Activ√©e (seed=42)")

print("\n‚úÖ Configuration avanc√©e termin√©e !")

## üì¶ Importations et V√©rifications

In [None]:
# Importations compl√®tes pour classification avanc√©e
print("üì¶ IMPORTATION DES BIBLIOTH√àQUES AVANC√âES")
print("=" * 45)

try:
    # Core libraries
    import torchxrayvision as xrv
    import torch
    import torchvision
    import numpy as np
    import pandas as pd
    import matplotlib.pyplot as plt
    import seaborn as sns
    from PIL import Image
    import skimage
    
    # Advanced analysis
    from sklearn.metrics import (
        roc_curve, auc, confusion_matrix, 
        classification_report, precision_recall_curve
    )
    from scipy import stats, ndimage
    from datetime import datetime
    import os
    import json
    import time
    
    print("‚úÖ Biblioth√®ques principales import√©es")
    print("‚úÖ M√©triques d'√©valuation charg√©es")
    print("‚úÖ Outils d'analyse statistique pr√™ts")
    
except ImportError as e:
    print(f"‚ùå Erreur d'importation : {e}")
    print("üí° Installez les d√©pendances manquantes")

# Configuration matplotlib avanc√©e
plt.style.use('default')
plt.rcParams['figure.facecolor'] = 'white'
plt.rcParams['figure.figsize'] = (12, 8)
plt.rcParams['font.size'] = 10

# Palette m√©dicale professionnelle
medical_colors = {
    'positive': '#FF4444',    # Rouge pour pathologies
    'negative': '#44FF44',    # Vert pour normal
    'uncertainty': '#FFAA44', # Orange pour incertain
    'neutral': '#4444FF',     # Bleu pour neutre
    'background': '#F0F0F0'   # Gris clair pour fond
}

print("\nüé® Configuration d'affichage m√©dical appliqu√©e")
print(f"üìä Version TorchXRayVision : {xrv.__version__}")

# Configuration de session
if IN_COLAB:
    try:
        from google.colab import drive
        drive.mount('/content/drive')
        base_dir = '/content/drive/MyDrive/TorchXRayVision_Tutorials/'
        tutorial_dir = f"{base_dir}Tutorial_2_Classification/"
        session_dir = f"{tutorial_dir}Session_{datetime.now().strftime('%Y%m%d_%H%M%S')}/"
        os.makedirs(session_dir, exist_ok=True)
        print(f"üìÅ Drive mont√© : {session_dir}")
    except:
        session_dir = './tutorial_2_results/'
        os.makedirs(session_dir, exist_ok=True)
else:
    session_dir = './tutorial_2_results/'
    os.makedirs(session_dir, exist_ok=True)
    print(f"üìÅ Dossier local : {session_dir}")

print("\nüöÄ Pr√™t pour classification avanc√©e !")

## üß† Chargement de Mod√®les Multiples

Chargeons diff√©rents mod√®les pour comparaison de performances :

In [None]:
print("üß† CHARGEMENT DE MOD√àLES MULTIPLES")
print("=" * 38)

def load_multiple_models(device, detailed_info=True):
    """
    Charge plusieurs mod√®les TorchXRayVision pour comparaison
    
    Cette approche multi-mod√®les permet :
    - Comparaison de performances
    - Validation crois√©e des r√©sultats
    - Analyse de consensus
    - Robustesse diagnostique
    """
    
    models = {}
    model_info = {}
    
    if detailed_info:
        print("\nüìö Mod√®les √† charger :")
        print("   1Ô∏è‚É£ Mod√®le Universel (tous datasets)")
        print("   2Ô∏è‚É£ Mod√®le CheXpert (Stanford)")
        print("   3Ô∏è‚É£ Mod√®le NIH (National Institutes of Health)")
        print("   4Ô∏è‚É£ Mod√®le MIMIC-CXR (MIT)")
    
    # 1. Mod√®le universel (entra√Æn√© sur tous les datasets)
    print("\nüîÑ Chargement du mod√®le universel...")
    try:
        models['universel'] = xrv.models.DenseNet(weights="densenet121-res224-all")
        models['universel'].eval().to(device)
        
        model_info['universel'] = {
            'name': 'Mod√®le Universel',
            'architecture': 'DenseNet121',
            'training_data': 'Tous datasets combin√©s',
            'specialty': 'G√©n√©raliste',
            'params': sum(p.numel() for p in models['universel'].parameters())
        }
        print("   ‚úÖ Mod√®le universel charg√©")
    except Exception as e:
        print(f"   ‚ùå Erreur mod√®le universel : {e}")
    
    # 2. Mod√®le CheXpert (Stanford)
    print("\nüéì Chargement du mod√®le CheXpert...")
    try:
        models['chexpert'] = xrv.models.DenseNet(weights="densenet121-res224-chex")
        models['chexpert'].eval().to(device)
        
        model_info['chexpert'] = {
            'name': 'CheXpert (Stanford)',
            'architecture': 'DenseNet121',
            'training_data': 'CheXpert dataset (224k images)',
            'specialty': 'Pathologies thoraciques vari√©es',
            'params': sum(p.numel() for p in models['chexpert'].parameters())
        }
        print("   ‚úÖ Mod√®le CheXpert charg√©")
    except Exception as e:
        print(f"   ‚ùå Erreur mod√®le CheXpert : {e}")
    
    # 3. Mod√®le NIH
    print("\nüèõÔ∏è Chargement du mod√®le NIH...")
    try:
        models['nih'] = xrv.models.DenseNet(weights="densenet121-res224-nih")
        models['nih'].eval().to(device)
        
        model_info['nih'] = {
            'name': 'NIH ChestX-ray',
            'architecture': 'DenseNet121',
            'training_data': 'NIH ChestX-ray8 (112k images)',
            'specialty': '14 pathologies thoraciques',
            'params': sum(p.numel() for p in models['nih'].parameters())
        }
        print("   ‚úÖ Mod√®le NIH charg√©")
    except Exception as e:
        print(f"   ‚ùå Erreur mod√®le NIH : {e}")
    
    # 4. Mod√®le MIMIC (si disponible)
    print("\nüè• Tentative chargement mod√®le MIMIC...")
    try:
        models['mimic'] = xrv.models.DenseNet(weights="densenet121-res224-mimic_ch")
        models['mimic'].eval().to(device)
        
        model_info['mimic'] = {
            'name': 'MIMIC-CXR (MIT)',
            'architecture': 'DenseNet121',
            'training_data': 'MIMIC-CXR (377k images)',
            'specialty': 'Donn√©es hospitali√®res r√©elles',
            'params': sum(p.numel() for p in models['mimic'].parameters())
        }
        print("   ‚úÖ Mod√®le MIMIC charg√©")
    except Exception as e:
        print(f"   ‚ö†Ô∏è Mod√®le MIMIC non disponible : {e}")
    
    if detailed_info:
        print(f"\nüìä R√©sum√© des mod√®les charg√©s :")
        print(f"   ‚Ä¢ Nombre total : {len(models)}")
        
        for model_key, info in model_info.items():
            if model_key in models:
                params_millions = info['params'] / 1e6
                print(f"   ‚Ä¢ {info['name']}: {params_millions:.1f}M param√®tres")
    
    return models, model_info

# Chargement des mod√®les
print("üöÄ D√©marrage du chargement multi-mod√®les...")
start_time = time.time()

models_dict, models_info = load_multiple_models(device, detailed_info=True)

loading_time = time.time() - start_time
print(f"\n‚è±Ô∏è Temps de chargement : {loading_time:.2f} secondes")
print(f"‚úÖ {len(models_dict)} mod√®les pr√™ts pour analyse comparative !")

# V√©rification des pathologies communes
if models_dict:
    first_model = list(models_dict.values())[0]
    pathologies = first_model.pathologies
    print(f"\nü©∫ Pathologies analysables : {len(pathologies)}")
    
    # Groupement par cat√©gories cliniques
    infectious = [p for p in pathologies if any(term in p.lower() for term in ['pneumonia', 'tuberculosis'])]
    structural = [p for p in pathologies if any(term in p.lower() for term in ['pneumothorax', 'cardiomegaly', 'effusion'])]
    
    print(f"   ‚Ä¢ Maladies infectieuses : {len(infectious)}")
    print(f"   ‚Ä¢ Anomalies structurelles : {len(structural)}")
    print(f"   ‚Ä¢ Autres pathologies : {len(pathologies) - len(infectious) - len(structural)}")

print("\nüéØ Pr√™t pour classification comparative !")

## üñºÔ∏è Cr√©ation de Cas Cliniques Vari√©s

Cr√©ons plusieurs radiographies repr√©sentant diff√©rents cas cliniques :

In [None]:
print("üñºÔ∏è CR√âATION DE CAS CLINIQUES VARI√âS")
print("=" * 38)

def create_clinical_case_xray(case_type, size=224, severity='moderate'):
    """
    Cr√©e des radiographies repr√©sentant diff√©rents cas cliniques
    
    Cases types :
    - 'normal' : Radiographie normale
    - 'pneumonia' : Pneumonie (consolidation)
    - 'cardiomegaly' : Cardiom√©galie
    - 'pneumothorax' : Pneumothorax
    - 'pleural_effusion' : √âpanchement pleural
    - 'complex' : Cas complexe multi-pathologies
    """
    
    print(f"   üèóÔ∏è Cr√©ation cas {case_type} (s√©v√©rit√©: {severity})...")
    
    # Image de base
    img = np.zeros((size, size), dtype=np.float32)
    center_x, center_y = size // 2, size // 2
    
    # Structures anatomiques de base
    
    # 1. Poumons normaux
    lung_intensity = 0.3
    
    # Poumon droit
    right_lung_x = int(0.3 * size)
    right_lung_y = int(0.4 * size)
    right_lung_h = int(0.25 * size)
    right_lung_w = int(0.12 * size)
    
    rr1, cc1 = skimage.draw.ellipse(right_lung_y, right_lung_x, 
                                   right_lung_h, right_lung_w, shape=img.shape)
    img[rr1, cc1] = lung_intensity
    
    # Poumon gauche
    left_lung_x = int(0.7 * size)
    left_lung_y = int(0.4 * size)
    left_lung_h = int(0.23 * size)
    left_lung_w = int(0.12 * size)
    
    rr2, cc2 = skimage.draw.ellipse(left_lung_y, left_lung_x,
                                   left_lung_h, left_lung_w, shape=img.shape)
    img[rr2, cc2] = lung_intensity
    
    # 2. C≈ìur (taille selon pathologie)
    if case_type == 'cardiomegaly':
        if severity == 'mild':
            heart_size_factor = 1.2
        elif severity == 'moderate':
            heart_size_factor = 1.5
        else:  # severe
            heart_size_factor = 1.8
    else:
        heart_size_factor = 1.0
    
    heart_w = int(0.1 * size * heart_size_factor)
    heart_h = int(0.08 * size * heart_size_factor)
    heart_x = int(0.48 * size)
    heart_y = int(0.6 * size)
    
    rr3, cc3 = skimage.draw.ellipse(heart_y, heart_x, heart_h, heart_w, shape=img.shape)
    img[rr3, cc3] = 0.6
    
    # 3. Colonne vert√©brale
    spine_width = int(0.025 * size)
    spine_start = int(0.1 * size)
    spine_end = int(0.9 * size)
    
    for y in range(spine_start, spine_end):
        img[y, center_x-spine_width:center_x+spine_width] = 0.8
    
    # 4. C√¥tes
    for rib_level in range(8):
        rib_y = int(0.15 * size) + rib_level * int(0.07 * size)
        rib_curve = int(0.03 * size)
        
        # C√¥tes droites et gauches
        for x in range(int(0.05 * size), int(0.95 * size)):
            if x < center_x:
                curve_offset = int(rib_curve * np.sin(np.pi * (x - 0.05 * size) / (0.45 * size)))
            else:
                curve_offset = int(rib_curve * np.sin(np.pi * (0.95 * size - x) / (0.45 * size)))
            
            y_rib = rib_y + curve_offset
            if 0 <= y_rib < size:
                img[y_rib:y_rib+2, x:x+1] = 0.7
    
    # 5. Pathologies sp√©cifiques
    
    if case_type == 'pneumonia':
        # Consolidation pneumonique
        pneumo_intensity = 0.75 if severity == 'mild' else 0.85
        pneumo_size = 0.08 if severity == 'mild' else 0.12
        
        pneumo_y = int(0.35 * size)
        pneumo_x = int(0.25 * size)
        pneumo_h = int(pneumo_size * size)
        pneumo_w = int(pneumo_size * 0.8 * size)
        
        rr_pneumo, cc_pneumo = skimage.draw.ellipse(pneumo_y, pneumo_x,
                                                   pneumo_h, pneumo_w, shape=img.shape)
        img[rr_pneumo, cc_pneumo] = pneumo_intensity
        
        # Bronchogramme a√©rique (signe caract√©ristique)
        for i in range(3):
            broncho_y = pneumo_y + (i-1) * 5
            broncho_x_start = pneumo_x - 10
            broncho_x_end = pneumo_x + 10
            if 0 <= broncho_y < size:
                img[broncho_y, max(0, broncho_x_start):min(size, broncho_x_end)] = 0.4
    
    elif case_type == 'pneumothorax':
        # Ligne pleurale visible
        pneumo_line_x = int(0.15 * size) if severity != 'mild' else int(0.18 * size)
        pneumo_start_y = int(0.2 * size)
        pneumo_end_y = int(0.7 * size)
        
        # Ligne de d√©collement pleural
        img[pneumo_start_y:pneumo_end_y, pneumo_line_x:pneumo_line_x+2] = 0.9
        
        # Zone de d√©collement (hyperclart√©)
        img[pneumo_start_y:pneumo_end_y, int(0.05 * size):pneumo_line_x] = 0.1
        
        # Affaissement pulmonaire
        collapse_factor = 0.7 if severity == 'mild' else 0.5
        affected_lung = img[pneumo_start_y:pneumo_end_y, pneumo_line_x:int(0.5 * size)]
        img[pneumo_start_y:pneumo_end_y, pneumo_line_x:int(0.5 * size)] = affected_lung * collapse_factor
    
    elif case_type == 'pleural_effusion':
        # √âpanchement pleural (opacit√© basale)
        effusion_height = int(0.15 * size) if severity == 'mild' else int(0.25 * size)
        effusion_y_start = int(0.8 * size) - effusion_height
        effusion_y_end = int(0.8 * size)
        
        # C√¥t√© droit (plus fr√©quent)
        effusion_x_start = int(0.1 * size)
        effusion_x_end = int(0.5 * size)
        
        # Opacit√© hydrique avec ligne de Damoiseau
        for y in range(effusion_y_start, effusion_y_end):
            # Ligne concave caract√©ristique
            curve_offset = int(0.05 * size * np.sin(np.pi * (y - effusion_y_start) / effusion_height))
            x_limit = effusion_x_end - curve_offset
            img[y, effusion_x_start:min(size, x_limit)] = 0.8
    
    elif case_type == 'complex':
        # Cas complexe : Cardiom√©galie + √©panchement + consolidation
        # D√©j√† cardiom√©galie par le facteur 1.3
        
        # Petit √©panchement
        small_effusion_height = int(0.1 * size)
        effusion_start = int(0.75 * size)
        img[effusion_start:effusion_start+small_effusion_height, int(0.1*size):int(0.4*size)] = 0.75
        
        # Petite consolidation
        small_consolidation_y = int(0.3 * size)
        small_consolidation_x = int(0.7 * size)
        rr_small, cc_small = skimage.draw.ellipse(small_consolidation_y, small_consolidation_x,
                                                 int(0.05 * size), int(0.04 * size), shape=img.shape)
        img[rr_small, cc_small] = 0.7
    
    # 6. Post-processing r√©aliste
    
    # Bruit radiographique
    noise = np.random.normal(0, 0.04, img.shape)
    img = img + noise
    
    # Lissage gaussien (flou radiographique)
    img = ndimage.gaussian_filter(img, sigma=0.8)
    
    # Normalisation finale
    img = np.clip(img, 0, 1)
    
    print(f"   ‚úÖ Cas {case_type} cr√©√© avec succ√®s")
    
    return img

# Cr√©ation de la collection de cas cliniques
print("\nüèóÔ∏è Cr√©ation de la biblioth√®que de cas cliniques...")

clinical_cases = {
    'normal': {
        'image': create_clinical_case_xray('normal'),
        'description': 'Radiographie thoracique normale',
        'expected_findings': ['Normal'],
        'clinical_context': 'Bilan de routine, patient asymptomatique'
    },
    'pneumonie_moderee': {
        'image': create_clinical_case_xray('pneumonia', severity='moderate'),
        'description': 'Pneumonie lobaire droite mod√©r√©e',
        'expected_findings': ['Pneumonia', 'Infiltration'],
        'clinical_context': 'Fi√®vre, toux productive, dyspn√©e'
    },
    'cardiomegalie': {
        'image': create_clinical_case_xray('cardiomegaly', severity='moderate'),
        'description': 'Cardiom√©galie mod√©r√©e',
        'expected_findings': ['Cardiomegaly', 'Enlarged Cardiomediastinum'],
        'clinical_context': 'Dyspn√©e d\'effort, ≈ìd√®mes membres inf√©rieurs'
    },
    'pneumothorax': {
        'image': create_clinical_case_xray('pneumothorax', severity='moderate'),
        'description': 'Pneumothorax spontan√© mod√©r√©',
        'expected_findings': ['Pneumothorax'],
        'clinical_context': 'Douleur thoracique brutale, dyspn√©e'
    },
    'epanchement_pleural': {
        'image': create_clinical_case_xray('pleural_effusion', severity='moderate'),
        'description': '√âpanchement pleural droit mod√©r√©',
        'expected_findings': ['Pleural Effusion'],
        'clinical_context': 'Dyspn√©e progressive, douleur pleurale'
    },
    'cas_complexe': {
        'image': create_clinical_case_xray('complex'),
        'description': 'Cas complexe multi-pathologies',
        'expected_findings': ['Cardiomegaly', 'Pleural Effusion', 'Infiltration'],
        'clinical_context': 'Insuffisance cardiaque d√©compens√©e'
    }
}

print(f"\n‚úÖ {len(clinical_cases)} cas cliniques cr√©√©s !")

# Affichage de la galerie de cas
fig, axes = plt.subplots(2, 3, figsize=(18, 12))
fig.suptitle("Galerie de Cas Cliniques - Tutorial 2", fontsize=18, fontweight='bold')

axes = axes.flatten()

for i, (case_name, case_data) in enumerate(clinical_cases.items()):
    if i < len(axes):
        axes[i].imshow(case_data['image'], cmap='gray')
        axes[i].set_title(f"{case_data['description']}\n({case_name})", 
                         fontweight='bold', fontsize=11)
        axes[i].axis('off')
        
        # Contexte clinique en annotation
        axes[i].text(0.02, 0.02, case_data['clinical_context'], 
                    transform=axes[i].transAxes, fontsize=9,
                    bbox=dict(boxstyle="round,pad=0.3", facecolor="white", alpha=0.8),
                    verticalalignment='bottom')

plt.tight_layout()
plt.show()

print("\nüéØ Cas cliniques pr√™ts pour classification comparative !")
print("üìö Prochaine √©tape : Analyse multi-mod√®les")

## üî¨ Classification Multi-Mod√®les

Analysons nos cas cliniques avec tous les mod√®les disponibles :

In [None]:
print("üî¨ CLASSIFICATION MULTI-MOD√àLES")
print("=" * 34)

def preprocess_for_models(image_array):
    """
    Preprocessing standardis√© pour tous les mod√®les
    """
    # Normalisation
    if image_array.max() > 1:
        img_norm = xrv.datasets.normalize(image_array, 255)
    else:
        img_norm = image_array
    
    # Format canal
    img_channel = img_norm[None, ...]
    
    # Transformations TorchXRayVision
    transform = torchvision.transforms.Compose([
        xrv.datasets.XRayCenterCrop(),
        xrv.datasets.XRayResizer(224)
    ])
    
    img_transformed = transform(img_channel)
    
    # Tenseur PyTorch
    img_tensor = torch.from_numpy(img_transformed).float()
    img_batch = img_tensor.unsqueeze(0)
    
    return img_batch, img_transformed

def analyze_with_all_models(image_tensor, models_dict, pathologies):
    """
    Analyse d'une image avec tous les mod√®les disponibles
    """
    results = {}
    
    for model_name, model in models_dict.items():
        model.eval()
        image_tensor = image_tensor.to(device)
        
        with torch.no_grad():
            # Pr√©diction
            raw_outputs = model(image_tensor)
            probabilities = torch.sigmoid(raw_outputs)
            probabilities_np = probabilities.cpu().numpy().squeeze()
            
            # Stockage des r√©sultats
            results[model_name] = {
                'probabilities': probabilities_np,
                'pathologies': pathologies,
                'positive_findings': [(pathologies[i], prob) for i, prob in enumerate(probabilities_np) if prob > 0.5]
            }
    
    return results

def create_comparative_analysis(case_results, case_info):
    """
    Cr√©e une analyse comparative des r√©sultats multi-mod√®les
    """
    print(f"\nüìä ANALYSE COMPARATIVE : {case_info['description']}")
    print(f"-" * 70)
    
    # Pathologies attendues
    expected = case_info['expected_findings']
    print(f"üéØ Pathologies attendues : {', '.join(expected)}")
    print(f"üè• Contexte clinique : {case_info['clinical_context']}")
    
    # Analyse par mod√®le
    print(f"\nüìã R√©sultats par mod√®le :")
    
    all_model_results = {}
    
    for model_name, results in case_results.items():
        model_info = models_info.get(model_name, {'name': model_name})
        print(f"\n   üß† {model_info['name']} :")
        
        positive_findings = results['positive_findings']
        
        if positive_findings:
            print(f"      üö® D√©tections positives ({len(positive_findings)}) :")
            for pathology, prob in sorted(positive_findings, key=lambda x: x[1], reverse=True):
                print(f"        ‚Ä¢ {pathology}: {prob:.3f} ({prob*100:.1f}%)")
        else:
            print(f"      ‚úÖ Aucune pathologie d√©tect√©e (seuil 50%)")
        
        # V√©rification des pathologies attendues
        expected_detected = []
        for exp_pathology in expected:
            for detected_pathology, prob in positive_findings:
                if exp_pathology.lower() in detected_pathology.lower() or \
                   detected_pathology.lower() in exp_pathology.lower():
                    expected_detected.append((exp_pathology, detected_pathology, prob))
        
        if expected_detected:
            print(f"      ‚úÖ Pathologies attendues d√©tect√©es : {len(expected_detected)}")
            for exp, det, prob in expected_detected:
                print(f"        ‚úì {exp} ‚Üí {det} ({prob:.3f})")
        else:
            print(f"      ‚ö†Ô∏è Pathologies attendues non d√©tect√©es")
        
        # Stockage pour analyse globale
        all_model_results[model_name] = results
    
    return all_model_results

# Analyse de tous les cas avec tous les mod√®les
print("\nüöÄ D√©marrage de l'analyse comparative multi-mod√®les...")
print(f"üìä {len(clinical_cases)} cas √ó {len(models_dict)} mod√®les = {len(clinical_cases) * len(models_dict)} analyses")

comprehensive_results = {}

if models_dict:  # V√©rifier qu'il y a au moins un mod√®le
    # R√©cup√©ration des pathologies du premier mod√®le
    first_model = list(models_dict.values())[0]
    pathologies = first_model.pathologies
    
    for case_name, case_data in clinical_cases.items():
        print(f"\nüîÑ Analyse du cas : {case_name}...")
        
        # Preprocessing
        processed_tensor, processed_display = preprocess_for_models(case_data['image'])
        
        # Analyse avec tous les mod√®les
        case_results = analyze_with_all_models(processed_tensor, models_dict, pathologies)
        
        # Analyse comparative
        comparative_analysis = create_comparative_analysis(case_results, case_data)
        
        # Stockage
        comprehensive_results[case_name] = {
            'case_info': case_data,
            'processed_image': processed_display,
            'model_results': comparative_analysis
        }
        
        print(f"   ‚úÖ Cas {case_name} analys√©")

else:
    print("‚ùå Aucun mod√®le disponible pour l'analyse")
    print("üí° V√©rifiez la connexion internet et r√©essayez le chargement")

print(f"\n‚úÖ Analyse multi-mod√®les termin√©e !")
print(f"üìä {len(comprehensive_results)} cas analys√©s")
print(f"üìö Prochaine √©tape : Visualisation comparative avanc√©e")

## üìä Visualisation Comparative Avanc√©e

Cr√©ons des visualisations sophistiqu√©es pour analyser les performances :

In [None]:
print("üìä VISUALISATION COMPARATIVE AVANC√âE")
print("=" * 40)

def create_advanced_comparative_visualization(comprehensive_results):
    """
    Cr√©e des visualisations avanc√©es pour comparaison multi-mod√®les
    """
    
    if not comprehensive_results:
        print("‚ùå Pas de r√©sultats √† visualiser")
        return
    
    # 1. Vue d'ensemble : Heatmap des pr√©dictions
    print("\nüìà Cr√©ation de la heatmap des pr√©dictions...")
    
    # Collecte des donn√©es pour la heatmap
    model_names = list(models_dict.keys()) if models_dict else []
    case_names = list(comprehensive_results.keys())
    
    if model_names and case_names:
        # Cr√©ation de la figure globale
        fig = plt.figure(figsize=(20, 16))
        gs = fig.add_gridspec(4, 3, height_ratios=[1, 1, 1, 0.3], hspace=0.4, wspace=0.3)
        
        # Heatmap globale des d√©tections positives
        ax_heatmap = fig.add_subplot(gs[0, :])
        
        detection_matrix = np.zeros((len(case_names), len(model_names)))
        
        for i, case_name in enumerate(case_names):
            case_results = comprehensive_results[case_name]['model_results']
            for j, model_name in enumerate(model_names):
                if model_name in case_results:
                    positive_count = len(case_results[model_name]['positive_findings'])
                    detection_matrix[i, j] = positive_count
        
        # Cr√©ation de la heatmap
        im = ax_heatmap.imshow(detection_matrix, cmap='Reds', aspect='auto')
        
        # Configuration des axes
        ax_heatmap.set_xticks(range(len(model_names)))
        ax_heatmap.set_xticklabels([models_info.get(name, {'name': name})['name'] for name in model_names], 
                                  rotation=45, ha='right')
        ax_heatmap.set_yticks(range(len(case_names)))
        ax_heatmap.set_yticklabels([case.replace('_', ' ').title() for case in case_names])
        
        # Ajout des valeurs dans les cellules
        for i in range(len(case_names)):
            for j in range(len(model_names)):
                text = ax_heatmap.text(j, i, int(detection_matrix[i, j]),
                                     ha="center", va="center", color="white", fontweight='bold')
        
        ax_heatmap.set_title("Nombre de D√©tections Positives par Cas et Mod√®le", 
                           fontsize=14, fontweight='bold')
        
        # Colorbar
        cbar = plt.colorbar(im, ax=ax_heatmap, orientation='horizontal', pad=0.1)
        cbar.set_label('Nombre de pathologies d√©tect√©es', fontsize=12)
        
        # 2. Analyse de consensus par cas
        for idx, (case_name, case_data) in enumerate(list(comprehensive_results.items())[:6]):
            row = (idx // 3) + 1
            col = idx % 3
            ax_case = fig.add_subplot(gs[row, col])
            
            # Collecte des probabilit√©s max par mod√®le
            model_max_probs = []
            model_labels = []
            
            for model_name in model_names:
                if model_name in case_data['model_results']:
                    probs = case_data['model_results'][model_name]['probabilities']
                    max_prob = np.max(probs)
                    model_max_probs.append(max_prob)
                    model_labels.append(models_info.get(model_name, {'name': model_name})['name'][:8])
            
            # Graphique en barres
            colors = ['skyblue', 'lightcoral', 'lightgreen', 'gold'][:len(model_max_probs)]
            bars = ax_case.bar(range(len(model_max_probs)), model_max_probs, 
                              color=colors, alpha=0.7, edgecolor='black')
            
            ax_case.set_xticks(range(len(model_labels)))
            ax_case.set_xticklabels(model_labels, rotation=45, ha='right', fontsize=10)
            ax_case.set_ylabel('Prob. Max', fontsize=10)
            ax_case.set_title(f"{case_name.replace('_', ' ').title()}", fontsize=11, fontweight='bold')
            ax_case.axhline(y=0.5, color='red', linestyle='--', alpha=0.7, linewidth=1)
            ax_case.set_ylim(0, 1)
            ax_case.grid(True, alpha=0.3)
            
            # Valeurs sur les barres
            for bar, prob in zip(bars, model_max_probs):
                height = bar.get_height()
                ax_case.text(bar.get_x() + bar.get_width()/2., height + 0.01,
                           f'{prob:.2f}', ha='center', va='bottom', fontsize=9)
        
        # 3. L√©gende et statistiques globales
        ax_legend = fig.add_subplot(gs[3, :])
        ax_legend.axis('off')
        
        # Calcul de statistiques globales
        total_analyses = len(case_names) * len(model_names)
        total_detections = np.sum(detection_matrix)
        avg_detections_per_case = np.mean(np.sum(detection_matrix, axis=1))
        avg_detections_per_model = np.mean(np.sum(detection_matrix, axis=0))
        
        stats_text = f"""
üìä STATISTIQUES GLOBALES DE L'ANALYSE COMPARATIVE

‚Ä¢ Analyses totales effectu√©es : {total_analyses}
‚Ä¢ D√©tections positives totales : {int(total_detections)}
‚Ä¢ D√©tections moyennes par cas : {avg_detections_per_case:.1f}
‚Ä¢ D√©tections moyennes par mod√®le : {avg_detections_per_model:.1f}

üìà INTERPR√âTATION :
‚Ä¢ Ligne rouge : Seuil de d√©tection (50%)
‚Ä¢ Heatmap : Plus rouge = Plus de d√©tections
‚Ä¢ Consensus √©lev√© = M√™me diagnostic par plusieurs mod√®les
‚Ä¢ Divergences = Incertitudes diagnostiques
        """
        
        ax_legend.text(0.05, 0.95, stats_text, transform=ax_legend.transAxes,
                      fontsize=11, verticalalignment='top', fontfamily='monospace',
                      bbox=dict(boxstyle="round,pad=0.5", facecolor="lightgray", alpha=0.8))
        
        plt.suptitle("Analyse Comparative Multi-Mod√®les TorchXRayVision", 
                    fontsize=18, fontweight='bold', y=0.98)
        plt.tight_layout()
        plt.show()
        
        print("‚úÖ Visualisation comparative cr√©√©e !")
    else:
        print("‚ùå Donn√©es insuffisantes pour la visualisation")

# Cr√©ation des visualisations avanc√©es
print("\nüé® G√©n√©ration des visualisations avanc√©es...")

if comprehensive_results:
    create_advanced_comparative_visualization(comprehensive_results)
else:
    print("‚ö†Ô∏è Aucun r√©sultat √† visualiser")
    print("üí° Assurez-vous que les mod√®les sont charg√©s correctement")

print("\nüéØ Visualisation comparative termin√©e !")

## üìà Analyse de Performance D√©taill√©e

Analysons en d√©tail les performances de chaque mod√®le :

In [None]:
print("üìà ANALYSE DE PERFORMANCE D√âTAILL√âE")
print("=" * 40)

def calculate_model_performance_metrics(comprehensive_results):
    """
    Calcule des m√©triques de performance d√©taill√©es pour chaque mod√®le
    """
    
    if not comprehensive_results or not models_dict:
        print("‚ùå Pas de donn√©es pour calculer les m√©triques")
        return {}
    
    print("\nüìä Calcul des m√©triques de performance...")
    
    model_names = list(models_dict.keys())
    performance_metrics = {}
    
    for model_name in model_names:
        print(f"\nüîç Analyse du mod√®le : {models_info.get(model_name, {'name': model_name})['name']}")
        
        # Collecte des donn√©es pour ce mod√®le
        all_probabilities = []
        all_expected_pathologies = []
        all_detected_pathologies = []
        case_performances = []
        
        for case_name, case_data in comprehensive_results.items():
            if model_name in case_data['model_results']:
                model_result = case_data['model_results'][model_name]
                expected = case_data['case_info']['expected_findings']
                detected = [finding[0] for finding in model_result['positive_findings']]
                
                # Stockage pour analyse globale
                all_probabilities.extend(model_result['probabilities'])
                all_expected_pathologies.extend(expected)
                all_detected_pathologies.extend(detected)
                
                # Performance par cas
                case_performance = {
                    'case_name': case_name,
                    'expected': expected,
                    'detected': detected,
                    'max_probability': np.max(model_result['probabilities']),
                    'num_detections': len(detected),
                    'true_positives': 0,
                    'false_positives': len(detected),
                    'false_negatives': len(expected)
                }
                
                # Calcul des vrais positifs (correspondances approximatives)
                for exp_path in expected:
                    for det_path in detected:
                        if (exp_path.lower() in det_path.lower() or 
                            det_path.lower() in exp_path.lower() or
                            any(word in det_path.lower() for word in exp_path.lower().split())):
                            case_performance['true_positives'] += 1
                            case_performance['false_positives'] -= 1
                            case_performance['false_negatives'] -= 1
                            break
                
                case_performances.append(case_performance)
        
        # Calcul des m√©triques globales
        total_tp = sum(cp['true_positives'] for cp in case_performances)
        total_fp = sum(max(0, cp['false_positives']) for cp in case_performances)
        total_fn = sum(max(0, cp['false_negatives']) for cp in case_performances)
        
        # M√©triques calcul√©es
        precision = total_tp / (total_tp + total_fp) if (total_tp + total_fp) > 0 else 0
        recall = total_tp / (total_tp + total_fn) if (total_tp + total_fn) > 0 else 0
        f1_score = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0
        
        # Statistiques des probabilit√©s
        prob_stats = {
            'mean': np.mean(all_probabilities),
            'std': np.std(all_probabilities),
            'max': np.max(all_probabilities),
            'min': np.min(all_probabilities),
            'median': np.median(all_probabilities)
        }
        
        # Stockage des m√©triques
        performance_metrics[model_name] = {
            'precision': precision,
            'recall': recall,
            'f1_score': f1_score,
            'true_positives': total_tp,
            'false_positives': total_fp,
            'false_negatives': total_fn,
            'probability_stats': prob_stats,
            'case_performances': case_performances,
            'total_cases': len(case_performances),
            'avg_detections_per_case': np.mean([len(cp['detected']) for cp in case_performances])
        }
        
        print(f"   ‚úÖ M√©triques calcul√©es :")
        print(f"      ‚Ä¢ Pr√©cision : {precision:.3f}")
        print(f"      ‚Ä¢ Rappel : {recall:.3f}")
        print(f"      ‚Ä¢ F1-Score : {f1_score:.3f}")
        print(f"      ‚Ä¢ D√©tections/cas : {np.mean([len(cp['detected']) for cp in case_performances]):.1f}")
    
    return performance_metrics

def create_performance_comparison_chart(performance_metrics):
    """
    Cr√©e un graphique de comparaison des performances
    """
    if not performance_metrics:
        return
    
    print("\nüìä Cr√©ation du graphique de comparaison des performances...")
    
    # Donn√©es pour les graphiques
    model_names = list(performance_metrics.keys())
    model_labels = [models_info.get(name, {'name': name})['name'] for name in model_names]
    
    precisions = [performance_metrics[name]['precision'] for name in model_names]
    recalls = [performance_metrics[name]['recall'] for name in model_names]
    f1_scores = [performance_metrics[name]['f1_score'] for name in model_names]
    avg_detections = [performance_metrics[name]['avg_detections_per_case'] for name in model_names]
    
    # Cr√©ation de la figure
    fig, axes = plt.subplots(2, 2, figsize=(16, 12))
    fig.suptitle('Comparaison des Performances - Mod√®les TorchXRayVision', 
                fontsize=16, fontweight='bold')
    
    # 1. Pr√©cision, Rappel, F1-Score
    ax1 = axes[0, 0]
    x_pos = np.arange(len(model_labels))
    width = 0.25
    
    bars1 = ax1.bar(x_pos - width, precisions, width, label='Pr√©cision', 
                    color='lightblue', alpha=0.8, edgecolor='black')
    bars2 = ax1.bar(x_pos, recalls, width, label='Rappel', 
                    color='lightgreen', alpha=0.8, edgecolor='black')
    bars3 = ax1.bar(x_pos + width, f1_scores, width, label='F1-Score', 
                    color='lightcoral', alpha=0.8, edgecolor='black')
    
    ax1.set_xticks(x_pos)
    ax1.set_xticklabels(model_labels, rotation=45, ha='right')
    ax1.set_ylabel('Score', fontweight='bold')
    ax1.set_title('M√©triques de Classification', fontweight='bold')
    ax1.legend()
    ax1.grid(True, alpha=0.3)
    ax1.set_ylim(0, 1)
    
    # Ajout des valeurs sur les barres
    for bars in [bars1, bars2, bars3]:
        for bar in bars:
            height = bar.get_height()
            ax1.text(bar.get_x() + bar.get_width()/2., height + 0.01,
                    f'{height:.2f}', ha='center', va='bottom', fontsize=9)
    
    # 2. D√©tections moyennes par cas
    ax2 = axes[0, 1]
    bars_det = ax2.bar(model_labels, avg_detections, 
                      color='gold', alpha=0.8, edgecolor='black')
    ax2.set_title('D√©tections Moyennes par Cas', fontweight='bold')
    ax2.set_ylabel('Nombre de D√©tections', fontweight='bold')
    ax2.tick_params(axis='x', rotation=45)
    ax2.grid(True, alpha=0.3)
    
    for bar in bars_det:
        height = bar.get_height()
        ax2.text(bar.get_x() + bar.get_width()/2., height + 0.05,
                f'{height:.1f}', ha='center', va='bottom', fontsize=10, fontweight='bold')
    
    # 3. Distribution des probabilit√©s
    ax3 = axes[1, 0]
    
    for i, model_name in enumerate(model_names):
        prob_stats = performance_metrics[model_name]['probability_stats']
        
        # Cr√©ation d'un violinplot simplifi√©
        model_label = model_labels[i]
        mean_prob = prob_stats['mean']
        std_prob = prob_stats['std']
        
        ax3.bar(i, mean_prob, yerr=std_prob, capsize=5, 
               color=f'C{i}', alpha=0.7, edgecolor='black',
               label=f'{model_label}: Œº={mean_prob:.3f}')
    
    ax3.set_xticks(range(len(model_labels)))
    ax3.set_xticklabels(model_labels, rotation=45, ha='right')
    ax3.set_ylabel('Probabilit√© Moyenne', fontweight='bold')
    ax3.set_title('Distribution des Probabilit√©s', fontweight='bold')
    ax3.axhline(y=0.5, color='red', linestyle='--', alpha=0.7, label='Seuil 50%')
    ax3.grid(True, alpha=0.3)
    ax3.legend()
    
    # 4. Matrice de confusion agr√©g√©e
    ax4 = axes[1, 1]
    
    # Calcul de la matrice de confusion moyenn√©e
    confusion_data = []
    confusion_labels = []
    
    for model_name in model_names:
        metrics = performance_metrics[model_name]
        tp = metrics['true_positives']
        fp = metrics['false_positives']
        fn = metrics['false_negatives']
        
        confusion_data.append([tp, fp, fn])
        confusion_labels.append(models_info.get(model_name, {'name': model_name})['name'])
    
    confusion_array = np.array(confusion_data)
    
    # Heatmap de la confusion
    im = ax4.imshow(confusion_array.T, cmap='Blues', aspect='auto')
    
    ax4.set_xticks(range(len(confusion_labels)))
    ax4.set_xticklabels(confusion_labels, rotation=45, ha='right')
    ax4.set_yticks(range(3))
    ax4.set_yticklabels(['Vrais Positifs', 'Faux Positifs', 'Faux N√©gatifs'])
    ax4.set_title('Matrice de Confusion Agr√©g√©e', fontweight='bold')
    
    # Ajout des valeurs
    for i in range(len(confusion_labels)):
        for j in range(3):
            text = ax4.text(i, j, int(confusion_array[i, j]),
                           ha="center", va="center", color="white", fontweight='bold')
    
    plt.tight_layout()
    plt.show()
    
    print("‚úÖ Graphiques de comparaison cr√©√©s !")

# Calcul et visualisation des performances
if comprehensive_results:
    print("\nüßÆ Calcul des m√©triques de performance...")
    performance_metrics = calculate_model_performance_metrics(comprehensive_results)
    
    if performance_metrics:
        create_performance_comparison_chart(performance_metrics)
        
        # R√©sum√© textuel des performances
        print(f"\nüìã R√âSUM√â COMPARATIF DES PERFORMANCES")
        print(f"=" * 50)
        
        best_precision = max(performance_metrics.items(), key=lambda x: x[1]['precision'])
        best_recall = max(performance_metrics.items(), key=lambda x: x[1]['recall'])
        best_f1 = max(performance_metrics.items(), key=lambda x: x[1]['f1_score'])
        
        print(f"ü•á Meilleure pr√©cision : {models_info.get(best_precision[0], {'name': best_precision[0]})['name']} ({best_precision[1]['precision']:.3f})")
        print(f"ü•á Meilleur rappel : {models_info.get(best_recall[0], {'name': best_recall[0]})['name']} ({best_recall[1]['recall']:.3f})")
        print(f"ü•á Meilleur F1-Score : {models_info.get(best_f1[0], {'name': best_f1[0]})['name']} ({best_f1[1]['f1_score']:.3f})")
        
        print(f"\nüí° Recommandations cliniques :")
        print(f"‚Ä¢ Mod√®le le plus pr√©cis pour √©viter faux positifs : {models_info.get(best_precision[0], {'name': best_precision[0]})['name']}")
        print(f"‚Ä¢ Mod√®le le plus sensible pour d√©pistage : {models_info.get(best_recall[0], {'name': best_recall[0]})['name']}")
        print(f"‚Ä¢ Mod√®le le plus √©quilibr√© : {models_info.get(best_f1[0], {'name': best_f1[0]})['name']}")
    
else:
    print("‚ö†Ô∏è Aucune donn√©e de performance disponible")

print("\nüéØ Analyse de performance d√©taill√©e termin√©e !")

## üíæ Sauvegarde et Rapport Complet

In [None]:
print("üíæ G√âN√âRATION DU RAPPORT COMPLET")
print("=" * 35)

def generate_comprehensive_report(comprehensive_results, performance_metrics, session_dir):
    """
    G√©n√®re un rapport complet de l'analyse multi-mod√®les
    """
    timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
    
    print(f"\nüìù G√©n√©ration du rapport complet...")
    
    # Rapport textuel d√©taill√©
    report_path = f"{session_dir}rapport_comparatif_{timestamp}.txt"
    
    report_content = f"""
RAPPORT COMPARATIF MULTI-MOD√àLES TORCHXRAYVISION
===============================================

Date d'analyse : {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}
Tutorial : 2 - Classification Avanc√©e
Version TorchXRayVision : {xrv.__version__}
Dispositif d'analyse : {device}

MOD√àLES ANALYS√âS :
=================
"""
    
    for model_name, info in models_info.items():
        if model_name in models_dict:
            report_content += f"""
‚Ä¢ {info['name']} :
  - Architecture : {info['architecture']}
  - Donn√©es d'entra√Ænement : {info['training_data']}
  - Sp√©cialit√© : {info['specialty']}
  - Param√®tres : {info['params']:,}
"""
    
    report_content += f"""

CAS CLINIQUES ANALYS√âS :
=======================
"""
    
    for case_name, case_data in clinical_cases.items():
        expected = case_data['expected_findings']
        context = case_data['clinical_context']
        description = case_data['description']
        
        report_content += f"""
‚Ä¢ {case_name.upper()} :
  - Description : {description}
  - Pathologies attendues : {', '.join(expected)}
  - Contexte clinique : {context}
"""
    
    if performance_metrics:
        report_content += f"""

PERFORMANCES COMPARATIVES :
==========================
"""
        
        for model_name, metrics in performance_metrics.items():
            model_info = models_info.get(model_name, {'name': model_name})
            report_content += f"""
‚Ä¢ {model_info['name']} :
  - Pr√©cision : {metrics['precision']:.3f}
  - Rappel (Sensibilit√©) : {metrics['recall']:.3f}
  - F1-Score : {metrics['f1_score']:.3f}
  - Vrais positifs : {metrics['true_positives']}
  - Faux positifs : {metrics['false_positives']}
  - Faux n√©gatifs : {metrics['false_negatives']}
  - D√©tections/cas : {metrics['avg_detections_per_case']:.1f}
  - Probabilit√© moyenne : {metrics['probability_stats']['mean']:.3f}
"""
    
    report_content += f"""

ANALYSE D√âTAILL√âE PAR CAS :
===========================
"""
    
    for case_name, case_result in comprehensive_results.items():
        case_info = case_result['case_info']
        model_results = case_result['model_results']
        
        report_content += f"""
CAS : {case_name.upper()}
Description : {case_info['description']}
Contexte : {case_info['clinical_context']}
Pathologies attendues : {', '.join(case_info['expected_findings'])}

R√©sultats par mod√®le :
"""
        
        for model_name, result in model_results.items():
            model_info = models_info.get(model_name, {'name': model_name})
            positive_findings = result['positive_findings']
            max_prob = np.max(result['probabilities'])
            
            report_content += f"""
  ‚Ä¢ {model_info['name']} :
    - Probabilit√© maximale : {max_prob:.3f}
    - D√©tections positives : {len(positive_findings)}
"""
            
            if positive_findings:
                report_content += "    - Pathologies d√©tect√©es :\n"
                for pathology, prob in sorted(positive_findings, key=lambda x: x[1], reverse=True):
                    report_content += f"      ‚Ä¢ {pathology}: {prob:.3f}\n"
            else:
                report_content += "    - Aucune pathologie d√©tect√©e (seuil 50%)\n"
    
    # Conclusions et recommandations
    if performance_metrics:
        best_models = {
            'precision': max(performance_metrics.items(), key=lambda x: x[1]['precision']),
            'recall': max(performance_metrics.items(), key=lambda x: x[1]['recall']),
            'f1': max(performance_metrics.items(), key=lambda x: x[1]['f1_score'])
        }
        
        report_content += f"""

CONCLUSIONS ET RECOMMANDATIONS CLINIQUES :
==========================================

PERFORMANCES OPTIMALES :
‚Ä¢ Meilleure pr√©cision : {models_info.get(best_models['precision'][0], {'name': best_models['precision'][0]})['name']} ({best_models['precision'][1]['precision']:.3f})
‚Ä¢ Meilleure sensibilit√© : {models_info.get(best_models['recall'][0], {'name': best_models['recall'][0]})['name']} ({best_models['recall'][1]['recall']:.3f})
‚Ä¢ Meilleur √©quilibre : {models_info.get(best_models['f1'][0], {'name': best_models['f1'][0]})['name']} ({best_models['f1'][1]['f1_score']:.3f})

RECOMMANDATIONS D'USAGE CLINIQUE :

1. D√âPISTAGE (priorit√© √† la sensibilit√©) :
   ‚Üí Utiliser {models_info.get(best_models['recall'][0], {'name': best_models['recall'][0]})['name']}
   ‚Üí Minimise les faux n√©gatifs
   ‚Üí Adapt√© pour screening de masse

2. DIAGNOSTIC DE CONFIRMATION (priorit√© √† la pr√©cision) :
   ‚Üí Utiliser {models_info.get(best_models['precision'][0], {'name': best_models['precision'][0]})['name']}
   ‚Üí Minimise les faux positifs
   ‚Üí Adapt√© pour √©viter sur-traitement

3. USAGE G√âN√âRAL (√©quilibre optimal) :
   ‚Üí Utiliser {models_info.get(best_models['f1'][0], {'name': best_models['f1'][0]})['name']}
   ‚Üí Bon compromis sensibilit√©/sp√©cificit√©
   ‚Üí Adapt√© pour consultation courante

LIMITES ET CONSID√âRATIONS :

‚Ä¢ Validation sur donn√©es synth√©tiques - Performance r√©elle peut diff√©rer
‚Ä¢ Variabilit√© inter-observateur non prise en compte
‚Ä¢ N√©cessit√© de validation sur cohortes cliniques r√©elles
‚Ä¢ L'IA reste un outil d'aide - D√©cision finale au m√©decin
‚Ä¢ Formation continue n√©cessaire sur ces technologies

PERSPECTIVES D'AM√âLIORATION :

‚Ä¢ Ensembles de mod√®les pour robustesse accrue
‚Ä¢ Sp√©cialisation par type de pathologie
‚Ä¢ Int√©gration de donn√©es cliniques contextuelles
‚Ä¢ D√©veloppement de mod√®les explicables (XAI)

---
Rapport g√©n√©r√© par TorchXRayVision Tutorial 2
Classification Avanc√©e pour Formation M√©dicale
"""
    
    # Sauvegarde du rapport
    with open(report_path, 'w', encoding='utf-8') as f:
        f.write(report_content)
    
    print(f"   ‚úÖ Rapport textuel : {report_path}")
    
    # Sauvegarde des donn√©es JSON
    json_path = f"{session_dir}donnees_comparatives_{timestamp}.json"
    
    json_data = {
        'timestamp': timestamp,
        'models_info': models_info,
        'performance_metrics': performance_metrics if performance_metrics else {},
        'clinical_cases': {name: {'description': info['description'], 
                                 'expected_findings': info['expected_findings']}
                          for name, info in clinical_cases.items()}
    }
    
    with open(json_path, 'w', encoding='utf-8') as f:
        json.dump(json_data, f, indent=2, ensure_ascii=False)
    
    print(f"   ‚úÖ Donn√©es JSON : {json_path}")
    
    return report_path, json_path

# G√©n√©ration du rapport final
if comprehensive_results:
    print("\nüìä Finalisation de l'analyse comparative...")
    
    report_files = generate_comprehensive_report(
        comprehensive_results, 
        performance_metrics if 'performance_metrics' in locals() else {},
        session_dir
    )
    
    print(f"\n‚úÖ Rapport complet g√©n√©r√© !")
    print(f"\nüìÅ Fichiers sauvegard√©s dans : {session_dir}")
    print(f"üìÑ Types de fichiers cr√©√©s :")
    print(f"   ‚Ä¢ Rapport textuel d√©taill√©")
    print(f"   ‚Ä¢ Donn√©es JSON pour r√©utilisation")
    print(f"   ‚Ä¢ Visualisations PNG (si g√©n√©r√©es)")
    
else:
    print("‚ö†Ô∏è Aucun r√©sultat √† sauvegarder")

print(f"\nüéØ Tutorial 2 - Classification Avanc√©e Termin√© !")

## üéâ Conclusion du Tutorial 2

### üèÜ F√©licitations !

Vous ma√Ætrisez maintenant la **classification avanc√©e multi-mod√®les** avec TorchXRayVision !

### ‚úÖ Comp√©tences Acquises :

#### **Techniques Avanc√©es :**
1. **üß† Chargement multi-mod√®les** : DenseNet121 sur diff√©rents datasets
2. **üñºÔ∏è Cr√©ation de cas complexes** : Simulation de pathologies vari√©es
3. **üî¨ Analyse comparative** : Performance crois√©e des mod√®les
4. **üìä M√©triques avanc√©es** : Pr√©cision, Rappel, F1-Score
5. **üìà Visualisations sophistiqu√©es** : Heatmaps, graphiques comparatifs
6. **üìã Rapports automatis√©s** : Documentation compl√®te

#### **Cliniques :**
- **Choix de mod√®le** selon contexte clinique
- **Interpr√©tation multi-pathologies**
- **Gestion de l'incertitude** diagnostique
- **Recommandations th√©rapeutiques** bas√©es sur l'IA

### üìä R√©sultats de Performance

Vous avez appris √† √©valuer :
- **Sensibilit√© vs Sp√©cificit√©** selon l'usage clinique
- **Consensus entre mod√®les** pour robustesse
- **Trade-offs cliniques** entre faux positifs/n√©gatifs
- **Adaptation contextuelle** des seuils de d√©tection

### üè• Applications Cliniques Ma√Ætris√©es

#### **D√©pistage de Masse :**
- Priorit√© √† la **sensibilit√© √©lev√©e**
- Mod√®les optimis√©s pour **d√©tection pr√©coce**
- Workflow de **triage automatique**

#### **Diagnostic de Confirmation :**
- Priorit√© √† la **sp√©cificit√© √©lev√©e**
- √âvitement des **sur-diagnostics**
- **Pr√©cision maximale** requise

#### **Consultation G√©n√©rale :**
- **√âquilibre optimal** sensibilit√©/sp√©cificit√©
- **Polyvalence** diagnostique
- **Facilit√© d'interpr√©tation**

### üéØ Prochaines √âtapes

#### **Tutorial 3** : Segmentation Anatomique
- D√©limitation automatique des structures
- Mesures morphom√©triques pr√©cises
- Analyse quantitative des organes
- Applications en radioth√©rapie

#### **Tutorial 4** : D√©tection et Localisation
- Localisation pr√©cise des pathologies
- Cartes d'activation (Grad-CAM)
- Explicabilit√© de l'IA (XAI)
- Confiance et incertitude

#### **Tutorial 5** : Comparaison Multi-Architectures
- ResNet vs DenseNet vs Vision Transformers
- Architectures √©mergentes
- Optimisation des performances
- Benchmarking complet

### üí° Points Cl√©s √† Retenir

#### **Choix du Mod√®le :**
- **Dataset d'entra√Ænement** d√©termine la sp√©cialisation
- **Taille du mod√®le** vs **vitesse d'inf√©rence**
- **Contexte clinique** guide la s√©lection

#### **Interpr√©tation des R√©sultats :**
- **Probabilit√©s ‚â† Certitudes** diagnostiques
- **Consensus multi-mod√®les** = Robustesse accrue
- **Corr√©lation clinique** toujours n√©cessaire

#### **Limitations √† Consid√©rer :**
- **Biais des datasets** d'entra√Ænement
- **Variabilit√© des pr√©sentations** cliniques
- **√âvolution technologique** rapide

### üåü Message Final

La **classification multi-mod√®les** repr√©sente l'√©tat actuel de l'art en IA radiologique. Vous ma√Ætrisez maintenant les outils qui transforment la pratique m√©dicale mondiale.

#### **Votre Expertise :**
- **Analyse comparative** rigoureuse
- **Choix √©clair√©s** de mod√®les
- **Interpr√©tation clinique** avanc√©e
- **Documentation professionnelle**

**Pr√™t pour le Tutorial 3 ?** La segmentation anatomique vous attend ! üî¨‚ú®

---

*Vous ma√Ætrisez d√©sormais les fondamentaux de la classification IA en radiologie !*