In [None]:
# Dans Anaconda Prompt
conda activate hyperpri
conda install jupyter notebook ipykernel -y

# Enregistrer le kernel
python -m ipykernel install --user --name=hyperpri --display-name "Python (HyperPRI)"
```

---

### **√âtape 2 : Cr√©er des Notebooks Jupyter**

Voici comment organiser votre projet avec Jupyter :
```
HyperPRI/
‚îú‚îÄ‚îÄ notebooks/                  # üìì Tous les notebooks ici
‚îÇ   ‚îú‚îÄ‚îÄ 01_Setup_and_Test.ipynb
‚îÇ   ‚îú‚îÄ‚îÄ 02_Data_Exploration.ipynb
‚îÇ   ‚îú‚îÄ‚îÄ 03_Train_UNET.ipynb
‚îÇ   ‚îú‚îÄ‚îÄ 04_Train_SpectralUNET.ipynb
‚îÇ   ‚îú‚îÄ‚îÄ 05_Train_CubeNET.ipynb
‚îÇ   ‚îú‚îÄ‚îÄ 06_Evaluate_Models.ipynb
‚îÇ   ‚îî‚îÄ‚îÄ 07_Visualize_Results.ipynb
‚îÇ
‚îú‚îÄ‚îÄ src/                        # Code source (comme avant)
‚îú‚îÄ‚îÄ Datasets/                   # Donn√©es
‚îî‚îÄ‚îÄ checkpoints/                # Mod√®les sauvegard√©s

In [None]:
# ============================================================================
# CELLULE 1 : Imports et v√©rification environnement
# ============================================================================

import sys
import os

# Ajouter le projet au PYTHONPATH
project_root = os.path.abspath('..')  # Remonter d'un niveau depuis notebooks/
sys.path.insert(0, project_root)

print(f"‚úì Project root: {project_root}")

# V√©rifier imports
import torch
import pytorch_lightning as pl
import numpy as np
import matplotlib.pyplot as plt

print(f"‚úì PyTorch version: {torch.__version__}")
print(f"‚úì Lightning version: {pl.__version__}")
print(f"‚úì CUDA disponible: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"  - GPU: {torch.cuda.get_device_name(0)}")
    print(f"  - M√©moire: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")


# ============================================================================
# CELLULE 2 : Tester imports du projet
# ============================================================================

try:
    from src.Experiments.params_HyperPRI import CONFIG
    print("‚úì Configuration charg√©e")
    
    from src.Datasets.HyperPRI_Dataset import HyperPRIDataset
    print("‚úì Dataset charg√©")
    
    from src.Models.UNET import UNET
    from src.Models.SpectralUNET import SpectralUNETWrapper
    from src.Models.CubeNET import CubeNET
    print("‚úì Mod√®les charg√©s")
    
    from src.metrics import SegmentationMetrics
    print("‚úì M√©triques charg√©es")
    
    print("\nüéâ Tous les imports r√©ussis !")
    
except Exception as e:
    print(f"‚ùå Erreur d'import: {e}")
    print("\nSolution:")
    print("1. V√©rifier que tous les fichiers sont pr√©sents dans src/")
    print("2. V√©rifier PYTHONPATH")


# ============================================================================
# CELLULE 3 : Afficher configuration
# ============================================================================

CONFIG.print_config()


# ============================================================================
# CELLULE 4 : V√©rifier structure des donn√©es
# ============================================================================

import os

def check_data_structure():
    """V√©rifie que les donn√©es sont bien organis√©es."""
    
    errors = []
    warnings = []
    
    # V√©rifier dossiers principaux
    if not os.path.exists(CONFIG.data_root):
        errors.append(f"‚ùå Dossier donn√©es introuvable: {CONFIG.data_root}")
    else:
        print(f"‚úì Dossier donn√©es: {CONFIG.data_root}")
    
    # V√©rifier Peanut
    if not os.path.exists(CONFIG.peanut_dir):
        errors.append(f"‚ùå Dossier Peanut introuvable: {CONFIG.peanut_dir}")
    else:
        print(f"‚úì Dossier Peanut: {CONFIG.peanut_dir}")
        
        # Sous-dossiers
        hsi_dir = os.path.join(CONFIG.peanut_dir, 'hsi_files')
        rgb_dir = os.path.join(CONFIG.peanut_dir, 'rgb_files')
        mask_dir = os.path.join(CONFIG.peanut_dir, 'mask_files')
        
        for subdir, name in [(hsi_dir, 'HSI'), (rgb_dir, 'RGB'), (mask_dir, 'Masques')]:
            if os.path.exists(subdir):
                n_files = len([f for f in os.listdir(subdir) if not f.startswith('.')])
                print(f"  ‚úì {name}: {n_files} fichiers")
            else:
                errors.append(f"  ‚ùå Sous-dossier {name} introuvable: {subdir}")
    
    # V√©rifier splits
    if not os.path.exists(CONFIG.splits_dir):
        errors.append(f"‚ùå Dossier splits introuvable: {CONFIG.splits_dir}")
    else:
        splits_found = [f for f in os.listdir(CONFIG.splits_dir) if f.endswith('.json')]
        print(f"‚úì Splits: {len(splits_found)} fichiers trouv√©s")
        
        expected_splits = [f'split_{i}.json' for i in range(5)]
        for split_file in expected_splits:
            if split_file not in splits_found:
                warnings.append(f"  ‚ö†Ô∏è  Fichier manquant: {split_file}")
    
    # R√©sum√©
    print("\n" + "="*60)
    if errors:
        print("‚ùå ERREURS D√âTECT√âES:")
        for err in errors:
            print(err)
    else:
        print("‚úÖ Structure des donn√©es OK!")
    
    if warnings:
        print("\n‚ö†Ô∏è  AVERTISSEMENTS:")
        for warn in warnings:
            print(warn)
    print("="*60)

check_data_structure()


# ============================================================================
# CELLULE 5 : Tester chargement d'une image
# ============================================================================

from src.Datasets.data_utils import load_rgb_image, load_hsi_cube, load_mask

# Choisir une image de test
test_image_name = '20220624_box33'  # Adapter selon vos donn√©es

try:
    # Charger RGB
    rgb_path = os.path.join(CONFIG.peanut_dir, 'rgb_files', f'{test_image_name}.png')
    rgb = load_rgb_image(rgb_path)
    print(f"‚úì RGB charg√©: {rgb.shape}")
    
    # Charger HSI
    hsi_path = os.path.join(CONFIG.peanut_dir, 'hsi_files', f'{test_image_name}.hdr')
    cube = load_hsi_cube(hsi_path, CONFIG.hsi_lo, CONFIG.hsi_hi)
    print(f"‚úì HSI charg√©: {cube.shape}")
    
    # Charger masque
    mask_path = os.path.join(CONFIG.peanut_dir, 'mask_files', f'{test_image_name}.png')
    mask = load_mask(mask_path, binary=True)
    print(f"‚úì Masque charg√©: {mask.shape}")
    
    # Visualiser
    fig, axes = plt.subplots(1, 3, figsize=(15, 5))
    
    axes[0].imshow(rgb)
    axes[0].set_title('RGB')
    axes[0].axis('off')
    
    # HSI: afficher une bande
    axes[1].imshow(cube[:, :, 100], cmap='viridis')
    axes[1].set_title('HSI - Bande 100')
    axes[1].axis('off')
    
    axes[2].imshow(mask, cmap='gray')
    axes[2].set_title('Masque (Ground Truth)')
    axes[2].axis('off')
    
    plt.tight_layout()
    plt.show()
    
    print("\nüéâ Chargement et visualisation r√©ussis!")
    
except Exception as e:
    print(f"‚ùå Erreur: {e}")
    print("\nV√©rifier:")
    print(f"1. Le fichier existe: {test_image_name}")
    print(f"2. Les chemins sont corrects")

In [None]:
# ============================================================================
# CELLULE 1 : Setup
# ============================================================================

import sys
import os
sys.path.insert(0, os.path.abspath('..'))

import numpy as np
import matplotlib.pyplot as plt
from src.Experiments.params_HyperPRI import CONFIG
from src.Datasets.data_utils import *


# ============================================================================
# CELLULE 2 : Charger plusieurs images
# ============================================================================

import json

# Charger split 0
split_file = os.path.join(CONFIG.splits_dir, 'split_0.json')
with open(split_file, 'r') as f:
    split_data = json.load(f)

train_images = split_data['train']
val_images = split_data['val']

print(f"Images d'entra√Ænement: {len(train_images)}")
print(f"Images de validation: {len(val_images)}")
print(f"\nExemples train: {train_images[:5]}")
print(f"Exemples val: {val_images[:3]}")


# ============================================================================
# CELLULE 3 : Analyser distribution des pixels
# ============================================================================

from tqdm.notebook import tqdm  # Barre de progression pour notebook

root_pixel_counts = []
total_pixel_counts = []

print("Analyse des masques...")
for img_name in tqdm(train_images[:20]):  # Analyser 20 premi√®res images
    mask_path = os.path.join(CONFIG.peanut_dir, 'mask_files', f'{img_name}.png')
    
    if not os.path.exists(mask_path):
        continue
    
    mask = load_mask(mask_path, binary=True)
    
    n_root = (mask == 1).sum()
    n_total = mask.size
    
    root_pixel_counts.append(n_root)
    total_pixel_counts.append(n_total)

# Calculer statistiques
root_ratios = np.array(root_pixel_counts) / np.array(total_pixel_counts)

print(f"\nRatio pixels racines/total:")
print(f"  Moyenne: {root_ratios.mean():.3%}")
print(f"  Min: {root_ratios.min():.3%}")
print(f"  Max: {root_ratios.max():.3%}")

# Visualiser
plt.figure(figsize=(10, 4))
plt.hist(root_ratios * 100, bins=20, edgecolor='black')
plt.xlabel('% Pixels racines')
plt.ylabel('Nombre d\'images')
plt.title('Distribution du ratio racines/sol')
plt.grid(alpha=0.3)
plt.show()


# ============================================================================
# CELLULE 4 : Analyser signatures spectrales
# ============================================================================

# Charger une image
img_name = train_images[0]
hsi_path = os.path.join(CONFIG.peanut_dir, 'hsi_files', f'{img_name}.hdr')
mask_path = os.path.join(CONFIG.peanut_dir, 'mask_files', f'{img_name}.png')

cube = load_hsi_cube(hsi_path, CONFIG.hsi_lo, CONFIG.hsi_hi)
mask = load_mask(mask_path, binary=True)

# Extraire pixels racines et sol
root_pixels = cube[mask == 1]  # Shape: (N_root, 238)
soil_pixels = cube[mask == 0]  # Shape: (N_soil, 238)

print(f"Pixels racines: {root_pixels.shape[0]}")
print(f"Pixels sol: {soil_pixels.shape[0]}")

# √âchantillonner pour acc√©l√©rer
n_samples = 1000
root_sample = root_pixels[np.random.choice(len(root_pixels), min(n_samples, len(root_pixels)), replace=False)]
soil_sample = soil_pixels[np.random.choice(len(soil_pixels), min(n_samples, len(soil_pixels)), replace=False)]

# Calculer moyennes et √©carts-types
root_mean = root_sample.mean(axis=0)
root_std = root_sample.std(axis=0)
soil_mean = soil_sample.mean(axis=0)
soil_std = soil_sample.std(axis=0)

# Longueurs d'onde
wavelengths = np.linspace(450, 926, CONFIG.n_spectral_bands)

# Visualiser
plt.figure(figsize=(12, 5))

plt.subplot(1, 2, 1)
plt.plot(wavelengths, root_mean, label='Racines', color='green', linewidth=2)
plt.fill_between(wavelengths, root_mean - root_std, root_mean + root_std, alpha=0.3, color='green')
plt.plot(wavelengths, soil_mean, label='Sol', color='brown', linewidth=2)
plt.fill_between(wavelengths, soil_mean - soil_std, soil_mean + soil_std, alpha=0.3, color='brown')
plt.xlabel('Longueur d\'onde (nm)')
plt.ylabel('R√©flectance')
plt.title('Signatures spectrales moyennes')
plt.legend()
plt.grid(alpha=0.3)

plt.subplot(1, 2, 2)
# Fisher Discriminant Score
fds = compute_fisher_score(
    np.vstack([root_sample, soil_sample]),
    np.array([1]*len(root_sample) + [0]*len(soil_sample))
)
plt.plot(wavelengths, fds, color='purple', linewidth=2)
plt.xlabel('Longueur d\'onde (nm)')
plt.ylabel('Fisher Discriminant Score')
plt.title('S√©parabilit√© racines vs sol')
plt.grid(alpha=0.3)

plt.tight_layout()
plt.show()

print(f"\nBandes avec meilleur FDS:")
best_bands = np.argsort(fds)[-10:]
for band_idx in best_bands:
    print(f"  Bande {band_idx}: {wavelengths[band_idx]:.1f} nm, FDS={fds[band_idx]:.3f}")

In [None]:
# ============================================================================
# CELLULE 1 : Setup
# ============================================================================

import sys
import os
sys.path.insert(0, os.path.abspath('..'))

import torch
import pytorch_lightning as pl
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping

from src.Experiments.params_HyperPRI import CONFIG
from src.Datasets.HyperPRI_Dataset import create_dataloaders
from src.PLTrainer import SegmentationModule
from src.metrics import SegmentationMetrics

import numpy as np
import matplotlib.pyplot as plt

# Fixer seed
pl.seed_everything(CONFIG.seed)

print("‚úì Environnement configur√©")


# ============================================================================
# CELLULE 2 : Configuration du mod√®le
# ============================================================================

model_type = 'unet'
split_idx = 0  # Premier split pour test

print(f"Mod√®le: {model_type.upper()}")
print(f"Split: {split_idx}")

# Configuration
model_config = CONFIG.get_model_config(model_type)
optimizer_config = CONFIG.get_optimizer_config()

print("\nConfiguration mod√®le:")
for key, value in model_config.items():
    print(f"  {key}: {value}")


# ============================================================================
# CELLULE 3 : Cr√©er DataLoaders
# ============================================================================

split_file = os.path.join(CONFIG.splits_dir, f'split_{split_idx}.json')

train_loader, val_loader = create_dataloaders(
    data_dir=CONFIG.peanut_dir,
    split_file=split_file,
    mode='RGB',
    batch_size=2,  # Ajuster selon m√©moire GPU
    num_workers=2,  # R√©duire si probl√®mes
    hsi_lo=CONFIG.hsi_lo,
    hsi_hi=CONFIG.hsi_hi,
    normalize_hsi=CONFIG.normalize_hsi,
    augment_train=False
)

print(f"\n‚úì DataLoaders cr√©√©s:")
print(f"  Train batches: {len(train_loader)}")
print(f"  Val batches: {len(val_loader)}")

# Visualiser un batch
batch = next(iter(train_loader))
print(f"\nBatch shape:")
print(f"  Images: {batch['image'].shape}")
print(f"  Masks: {batch['mask'].shape}")


# ============================================================================
# CELLULE 4 : Cr√©er mod√®le
# ============================================================================

module = SegmentationModule(
    model_type=model_type,
    model_config=model_config,
    optimizer_config=optimizer_config,
    loss_type='bce'
)

num_params = sum(p.numel() for p in module.parameters() if p.requires_grad)
print(f"‚úì Mod√®le cr√©√©: {num_params:,} param√®tres")


# ============================================================================
# CELLULE 5 : Configuration callbacks
# ============================================================================

# Dossier checkpoints
checkpoint_dir = os.path.join(CONFIG.checkpoint_dir, model_type, f'split_{split_idx}')
os.makedirs(checkpoint_dir, exist_ok=True)

# ModelCheckpoint
checkpoint_callback = ModelCheckpoint(
    dirpath=checkpoint_dir,
    filename='best_model',
    monitor='val_loss',
    mode='min',
    save_top_k=1,
    verbose=True
)

# Early stopping
early_stop_callback = EarlyStopping(
    monitor='val_loss',
    patience=50,  # R√©duire pour tests rapides
    mode='min',
    verbose=True
)

print(f"‚úì Checkpoints seront sauvegard√©s dans: {checkpoint_dir}")


# ============================================================================
# CELLULE 6 : Configuration Trainer
# ============================================================================

# Logger TensorBoard
logger = TensorBoardLogger(
    save_dir=CONFIG.log_dir,
    name=model_type,
    version=f'split_{split_idx}_notebook'
)

# Trainer
trainer = pl.Trainer(
    max_epochs=100,  # R√©duire pour tests
    accelerator='auto',
    devices=1,
    callbacks=[checkpoint_callback, early_stop_callback],
    logger=logger,
    log_every_n_steps=10,
    enable_progress_bar=True,
    enable_model_summary=True
)

print("‚úì Trainer configur√©")


# ============================================================================
# CELLULE 7 : ENTRA√éNEMENT (cette cellule peut prendre du temps!)
# ============================================================================

print("\nüöÄ D√©but de l'entra√Ænement...\n")

trainer.fit(module, train_loader, val_loader)

print(f"\n‚úÖ Entra√Ænement termin√©!")
print(f"Epochs entra√Æn√©s: {trainer.current_epoch}")
print(f"Meilleur mod√®le: {checkpoint_callback.best_model_path}")


# ============================================================================
# CELLULE 8 : √âvaluation sur validation
# ============================================================================

# Charger meilleur mod√®le
best_module = SegmentationModule.load_from_checkpoint(
    checkpoint_callback.best_model_path
)
best_module.eval()

print("‚úì Meilleur mod√®le charg√©")

# √âvaluer
val_results = trainer.validate(best_module, val_loader, verbose=False)
val_metrics = val_results[0]

print("\nüìä R√âSULTATS VALIDATION:")
print(f"  Val Loss: {val_metrics['val_loss']:.4f}")
print(f"  Val DICE: {val_metrics['val_dice']:.4f}")
print(f"  Val IoU: {val_metrics['val_iou']:.4f}")
print(f"  Val AP: {val_metrics['val_ap']:.4f}")


# ============================================================================
# CELLULE 9 : Visualiser pr√©dictions
# ============================================================================

# Prendre quelques images de validation
best_module.eval()
best_module.to('cuda' if torch.cuda.is_available() else 'cpu')

# Prendre un batch
val_batch = next(iter(val_loader))
images = val_batch['image'].to(best_module.device)
masks_gt = val_batch['mask'].numpy()

# Pr√©dire
with torch.no_grad():
    logits = best_module(images)
    probs = torch.sigmoid(logits).squeeze(1).cpu().numpy()
    preds = (probs > best_module.best_threshold).astype(np.uint8)

# Visualiser
n_images = min(4, images.shape[0])
fig, axes = plt.subplots(n_images, 3, figsize=(15, 5*n_images))

if n_images == 1:
    axes = axes.reshape(1, -1)

for i in range(n_images):
    # RGB
    img_rgb = images[i].cpu().permute(1, 2, 0).numpy()
    axes[i, 0].imshow(img_rgb)
    axes[i, 0].set_title('Image RGB')
    axes[i, 0].axis('off')
    
    # Ground Truth
    axes[i, 1].imshow(masks_gt[i], cmap='gray')
    axes[i, 1].set_title('Ground Truth')
    axes[i, 1].axis('off')
    
    # Pr√©diction
    axes[i, 2].imshow(preds[i], cmap='gray')
    axes[i, 2].set_title(f'Pr√©diction (seuil={best_module.best_threshold:.2f})')
    axes[i, 2].axis('off')

plt.tight_layout()
plt.show()


# ============================================================================
# CELLULE 10 : Visualiser avec overlay
# ============================================================================

from matplotlib.colors import ListedColormap

# Cr√©er colormap pour overlay
colors_overlay = np.array([
    [0, 0, 0, 0],      # TN: Transparent (sol correct)
    [1, 0, 0, 0.5],    # FP: Rouge (sur-segmentation)
    [0, 0, 1, 0.5],    # FN: Bleu (sous-segmentation)
    [0, 1, 0, 0.5]     # TP: Vert (racines correctes)
])
cmap_overlay = ListedColormap(colors_overlay)

fig, axes = plt.subplots(2, 2, figsize=(12, 12))
axes = axes.ravel()

for i in range(min(4, n_images)):
    img_rgb = images[i].cpu().permute(1, 2, 0).numpy()
    gt = masks_gt[i]
    pred = preds[i]
    
    # Cr√©er overlay
    # 0: TN, 1: FP, 2: FN, 3: TP
    overlay = np.zeros_like(gt, dtype=np.uint8)
    overlay[(pred == 1) & (gt == 0)] = 1  # FP
    overlay[(pred == 0) & (gt == 1)] = 2  # FN
    overlay[(pred == 1) & (gt == 1)] = 3  # TP
    
    axes[i].imshow(img_rgb)
    axes[i].imshow(overlay, cmap=cmap_overlay, alpha=0.6, vmin=0, vmax=3)
    axes[i].set_title(f'Image {i+1}: Vert=TP, Rouge=FP, Bleu=FN')
    axes[i].axis('off')

plt.tight_layout()
plt.show()


# ============================================================================
# CELLULE 11 : Sauvegarder r√©sultats
# ============================================================================

results = {
    'model_type': model_type,
    'split': split_idx,
    'val_loss': float(val_metrics['val_loss']),
    'val_dice': float(val_metrics['val_dice']),
    'val_iou': float(val_metrics['val_iou']),
    'val_ap': float(val_metrics['val_ap']),
    'best_threshold': best_module.best_threshold,
    'epochs_trained': trainer.current_epoch,
    'checkpoint_path': checkpoint_callback.best_model_path
}

import json
results_file = os.path.join(CONFIG.results_dir, model_type, f'notebook_split{split_idx}_results.json')
os.makedirs(os.path.dirname(results_file), exist_ok=True)

with open(results_file, 'w') as f:
    json.dump(results, f, indent=2)

print(f"‚úì R√©sultats sauvegard√©s: {results_file}")
print("\nüéâ Entra√Ænement et √©valuation termin√©s!")

In [None]:
# ============================================================================
# DIFF√âRENCES PRINCIPALES PAR RAPPORT √Ä UNET
# ============================================================================

# CELLULE 2 : Changer model_type
model_type = 'cube'  # Au lieu de 'unet'

# CELLULE 3 : Changer mode DataLoader
train_loader, val_loader = create_dataloaders(
    data_dir=CONFIG.peanut_dir,
    split_file=split_file,
    mode='HSI',  # ‚¨ÖÔ∏è HSI au lieu de RGB
    batch_size=1,  # ‚¨ÖÔ∏è R√©duire √† 1 (HSI plus lourd)
    num_workers=2,
    hsi_lo=CONFIG.hsi_lo,
    hsi_hi=CONFIG.hsi_hi,
    normalize_hsi=True,  # ‚¨ÖÔ∏è Important
    augment_train=False
)

# CELLULE 9 : Visualiser HSI (pas RGB)
# Afficher une bande HSI au lieu de RGB
img_hsi = images[i].cpu().numpy()  # Shape: (238, H, W)
axes[i, 0].imshow(img_hsi[100], cmap='viridis')  # Afficher bande 100
axes[i, 0].set_title('HSI - Bande 100')

# Reste identique !

In [None]:
# ============================================================================
# CELLULE 1 : Charger r√©sultats de tous les mod√®les
# ============================================================================

import sys
import os
sys.path.insert(0, os.path.abspath('..'))

import json
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

from src.Experiments.params_HyperPRI import CONFIG

# Charger r√©sultats JSON
results_all = {}

for model_type in ['unet', 'spectral', 'cube']:
    results_file = os.path.join(CONFIG.results_dir, model_type, 'kfold_results.json')
    
    if os.path.exists(results_file):
        with open(results_file, 'r') as f:
            results_all[model_type] = json.load(f)
        print(f"‚úì {model_type.upper()}: {len(results_all[model_type])} r√©sultats")
    else:
        print(f"‚ö†Ô∏è  {model_type.upper()}: Pas de r√©sultats (fichier introuvable)")

# Convertir en DataFrame
records = []
for model_type, results in results_all.items():
    for r in results:
        records.append({
            'model': model_type.upper(),
            'split': r['split'],
            'dice': r['val_dice'],
            'iou': r['val_iou'],
            'ap': r['val_ap']
        })

df = pd.DataFrame(records)
print(f"\n‚úì DataFrame cr√©√©: {len(df)} enregistrements")
df.head()


# ============================================================================
# CELLULE 2 : Statistiques descriptives
# ============================================================================

print("="*60)
print("STATISTIQUES PAR MOD√àLE")
print("="*60)

summary = df.groupby('model').agg({
    'dice': ['mean', 'std'],
    'iou': ['mean', 'std'],
    'ap': ['mean', 'std']
}).round(4)

print(summary)


# ============================================================================
# CELLULE 3 : Visualiser comparaison boxplots
# ============================================================================

fig, axes = plt.subplots(1, 3, figsize=(18, 5))

metrics = ['dice', 'iou', 'ap']
titles = ['DICE Score', 'IoU (Racines)', 'Average Precision']

for ax, metric, title in zip(axes, metrics, titles):
    sns.boxplot(data=df, x='model', y=metric, ax=ax, palette='Set2')
    ax.set_title(title, fontsize=14, fontweight='bold')
    ax.set_xlabel('Mod√®le', fontsize=12)
    ax.set_ylabel('Score', fontsize=12)
    ax.grid(axis='y', alpha=0.3)

plt.tight_layout()
plt.savefig(os.path.join(CONFIG.results_dir, 'model_comparison.png'), dpi=300, bbox_inches='tight')
plt.show()

print("‚úì Figure sauvegard√©e: results/model_comparison.png")


# ============================================================================
# ============================================================================
# CELLULE 4 : Tableau comparatif (style article)
# ============================================================================

import numpy as np

# Calculer moyennes et √©carts-types par mod√®le
table_data = []

for model in ['UNET', 'SPECTRAL', 'CUBE']:
    if model in df['model'].values:
        model_df = df[df['model'] == model]
        
        dice_mean = model_df['dice'].mean()
        dice_std = model_df['dice'].std()
        
        iou_mean = model_df['iou'].mean()
        iou_std = model_df['iou'].std()
        
        ap_mean = model_df['ap'].mean()
        ap_std = model_df['ap'].std()
        
        table_data.append({
            'Model': model,
            'DICE': f"{dice_mean:.3f} ¬± {dice_std:.3f}",
            '+IOU': f"{iou_mean:.3f} ¬± {iou_std:.3f}",
            'AP': f"{ap_mean:.3f} ¬± {ap_std:.3f}"
        })

# Cr√©er DataFrame pour affichage
table_df = pd.DataFrame(table_data)

print("\n" + "="*70)
print("TABLEAU COMPARATIF (Format Article)")
print("="*70)
print(table_df.to_string(index=False))
print("="*70)

# Identifier meilleur mod√®le
best_dice = df.groupby('model')['dice'].mean().idxmax()
print(f"\nüèÜ Meilleur mod√®le (DICE): {best_dice}")


# ============================================================================
# CELLULE 5 : Analyse par split
# ============================================================================

fig, axes = plt.subplots(1, 3, figsize=(18, 5))

for ax, metric, title in zip(axes, metrics, titles):
    pivot = df.pivot(index='split', columns='model', values=metric)
    pivot.plot(kind='bar', ax=ax, width=0.8)
    
    ax.set_title(f'{title} par Split', fontsize=14, fontweight='bold')
    ax.set_xlabel('Split', fontsize=12)
    ax.set_ylabel('Score', fontsize=12)
    ax.legend(title='Mod√®le', fontsize=10)
    ax.grid(axis='y', alpha=0.3)
    ax.set_xticklabels(ax.get_xticklabels(), rotation=0)

plt.tight_layout()
plt.savefig(os.path.join(CONFIG.results_dir, 'model_comparison_by_split.png'), dpi=300)
plt.show()


# ============================================================================
# CELLULE 6 : Test statistique (Student t-test)
# ============================================================================

from scipy import stats

print("\n" + "="*70)
print("TESTS STATISTIQUES (t-test bilat√©ral non-appari√©)")
print("="*70)

models = df['model'].unique()

for i, model1 in enumerate(models):
    for model2 in models[i+1:]:
        print(f"\n{model1} vs {model2}:")
        
        for metric in ['dice', 'iou', 'ap']:
            data1 = df[df['model'] == model1][metric]
            data2 = df[df['model'] == model2][metric]
            
            t_stat, p_value = stats.ttest_ind(data1, data2, equal_var=False)
            
            significant = "‚úì Significatif" if p_value < 0.05 else "‚úó Non significatif"
            
            print(f"  {metric.upper():4s}: t={t_stat:+.3f}, p={p_value:.4f} {significant}")


# ============================================================================
# CELLULE 7 : Courbes d'apprentissage (si TensorBoard logs disponibles)
# ============================================================================

from tensorboard.backend.event_processing import event_accumulator
import glob

def load_tensorboard_logs(log_dir, tag):
    """Charge valeurs d'un tag depuis logs TensorBoard."""
    values = []
    steps = []
    
    event_files = glob.glob(os.path.join(log_dir, '**', 'events.out.tfevents.*'), recursive=True)
    
    for event_file in event_files:
        try:
            ea = event_accumulator.EventAccumulator(event_file)
            ea.Reload()
            
            if tag in ea.Tags()['scalars']:
                for event in ea.Scalars(tag):
                    steps.append(event.step)
                    values.append(event.value)
        except:
            pass
    
    return steps, values

# Charger courbes d'apprentissage
fig, axes = plt.subplots(1, 2, figsize=(15, 5))

for model_type in ['unet', 'spectral', 'cube']:
    log_dir = os.path.join(CONFIG.log_dir, model_type)
    
    if os.path.exists(log_dir):
        # Train loss
        steps, train_loss = load_tensorboard_logs(log_dir, 'train_loss_epoch')
        if train_loss:
            axes[0].plot(steps, train_loss, label=model_type.upper(), linewidth=2)
        
        # Val loss
        steps, val_loss = load_tensorboard_logs(log_dir, 'val_loss')
        if val_loss:
            axes[1].plot(steps, val_loss, label=model_type.upper(), linewidth=2)

axes[0].set_title('Training Loss', fontsize=14, fontweight='bold')
axes[0].set_xlabel('Epoch')
axes[0].set_ylabel('Loss')
axes[0].legend()
axes[0].grid(alpha=0.3)

axes[1].set_title('Validation Loss', fontsize=14, fontweight='bold')
axes[1].set_xlabel('Epoch')
axes[1].set_ylabel('Loss')
axes[1].legend()
axes[1].grid(alpha=0.3)

plt.tight_layout()
plt.savefig(os.path.join(CONFIG.results_dir, 'learning_curves.png'), dpi=300)
plt.show()

print("‚úì Courbes d'apprentissage sauvegard√©es")


# ============================================================================
# CELLULE 8 : Heatmap des performances
# ============================================================================

# Pivot pour heatmap
heatmap_data = df.pivot_table(
    values='dice',
    index='split',
    columns='model',
    aggfunc='mean'
)

plt.figure(figsize=(8, 6))
sns.heatmap(heatmap_data, annot=True, fmt='.3f', cmap='YlGnBu', 
            cbar_kws={'label': 'DICE Score'}, linewidths=0.5)
plt.title('DICE Score par Split et Mod√®le', fontsize=14, fontweight='bold')
plt.xlabel('Mod√®le', fontsize=12)
plt.ylabel('Split', fontsize=12)
plt.tight_layout()
plt.savefig(os.path.join(CONFIG.results_dir, 'heatmap_dice.png'), dpi=300)
plt.show()


# ============================================================================
# CELLULE 9 : Rapport final format√©
# ============================================================================

print("\n" + "="*80)
print(" "*25 + "RAPPORT FINAL - HyperPRI")
print("="*80)

print("\nüìä R√âSULTATS PAR MOD√àLE\n")
print(table_df.to_string(index=False))

print("\n\nüèÜ CLASSEMENT (selon DICE moyen)\n")
ranking = df.groupby('model')['dice'].mean().sort_values(ascending=False)
for i, (model, dice) in enumerate(ranking.items(), 1):
    medal = {1: 'ü•á', 2: 'ü•à', 3: 'ü•â'}.get(i, '  ')
    print(f"  {medal} {i}. {model:10s}: {dice:.4f}")

print("\n\nüìà AM√âLIORATION CubeNET vs UNET\n")
if 'CUBE' in df['model'].values and 'UNET' in df['model'].values:
    cube_dice = df[df['model'] == 'CUBE']['dice'].mean()
    unet_dice = df[df['model'] == 'UNET']['dice'].mean()
    improvement = ((cube_dice - unet_dice) / unet_dice) * 100
    
    print(f"  DICE UNET:    {unet_dice:.4f}")
    print(f"  DICE CubeNET: {cube_dice:.4f}")
    print(f"  Am√©lioration: +{improvement:.2f}%")

print("\n" + "="*80)
print("‚úÖ Analyse termin√©e!")
print("="*80)

In [None]:
# ============================================================================
# CELLULE 1 : Setup
# ============================================================================

import sys
import os
sys.path.insert(0, os.path.abspath('..'))

import torch
import numpy as np
import matplotlib.pyplot as plt

from src.Experiments.params_HyperPRI import CONFIG
from src.Datasets.HyperPRI_Dataset import HyperPRIDataset
from src.PLTrainer import SegmentationModule
from src.metrics import SegmentationMetrics

print("‚úì Imports r√©ussis")


# ============================================================================
# CELLULE 2 : Charger images de test (Box 40)
# ============================================================================

# Images de test d√©finies dans CONFIG
test_images = CONFIG.test_images  # ['20220815_box40', '20220824_box40']

print(f"Images de test: {test_images}")
print(f"  - {test_images[0]}: S√®che (dry)")
print(f"  - {test_images[1]}: Humide (wet)")


# ============================================================================
# CELLULE 3 : Cr√©er Dataset de test
# ============================================================================

from torch.utils.data import DataLoader

# Dataset RGB pour UNET
test_dataset_rgb = HyperPRIDataset(
    data_dir=CONFIG.peanut_dir,
    image_list=test_images,
    mode='RGB',
    normalize_hsi=True
)

test_loader_rgb = DataLoader(
    test_dataset_rgb,
    batch_size=1,
    shuffle=False,
    num_workers=0
)

# Dataset HSI pour CubeNET
test_dataset_hsi = HyperPRIDataset(
    data_dir=CONFIG.peanut_dir,
    image_list=test_images,
    mode='HSI',
    hsi_lo=CONFIG.hsi_lo,
    hsi_hi=CONFIG.hsi_hi,
    normalize_hsi=True
)

test_loader_hsi = DataLoader(
    test_dataset_hsi,
    batch_size=1,
    shuffle=False,
    num_workers=0
)

print(f"‚úì Datasets de test cr√©√©s")


# ============================================================================
# CELLULE 4 : Charger mod√®les entra√Æn√©s
# ============================================================================

def load_best_model(model_type, split_idx=0):
    """Charge le meilleur mod√®le pour un type et split donn√©s."""
    
    checkpoint_dir = os.path.join(
        CONFIG.checkpoint_dir,
        model_type,
        f'split_{split_idx}',
        'seed_0'  # Premier seed
    )
    
    checkpoint_path = os.path.join(checkpoint_dir, 'best_model.ckpt')
    
    if not os.path.exists(checkpoint_path):
        print(f"‚ö†Ô∏è  Checkpoint introuvable: {checkpoint_path}")
        return None
    
    print(f"Chargement: {checkpoint_path}")
    
    module = SegmentationModule.load_from_checkpoint(checkpoint_path)
    module.eval()
    
    return module

# Charger mod√®les
models = {}

print("\nChargement des mod√®les...\n")

# UNET
models['UNET'] = load_best_model('unet', split_idx=0)

# CubeNET
models['CUBE'] = load_best_model('cube', split_idx=0)

# SpectralUNET (optionnel)
# models['SPECTRAL'] = load_best_model('spectral', split_idx=0)

print("\n‚úì Mod√®les charg√©s")


# ============================================================================
# CELLULE 5 : √âvaluer sur images de test
# ============================================================================

metrics_calc = SegmentationMetrics()
device = 'cuda' if torch.cuda.is_available() else 'cpu'

results_test = {
    'UNET': {'dry': {}, 'wet': {}},
    'CUBE': {'dry': {}, 'wet': {}}
}

print("\n" + "="*70)
print("√âVALUATION SUR TEST SET")
print("="*70)

# UNET sur RGB
if models['UNET'] is not None:
    models['UNET'].to(device)
    
    for i, batch in enumerate(test_loader_rgb):
        condition = 'dry' if i == 0 else 'wet'
        img_name = test_images[i]
        
        images = batch['image'].to(device)
        masks = batch['mask']
        
        with torch.no_grad():
            logits = models['UNET'](images)
            probs = torch.sigmoid(logits).squeeze(1).cpu()
            preds = (probs > models['UNET'].best_threshold).long()
        
        # Calculer m√©triques
        all_metrics = metrics_calc.compute_all(preds, masks, probs)
        
        results_test['UNET'][condition] = {
            'image': img_name,
            'dice': all_metrics['dice'],
            'iou': all_metrics['iou_positive'],
            'ap': all_metrics['ap'],
            'acc': all_metrics['pixel_acc']
        }
        
        print(f"\nUNET - {condition.upper()} ({img_name}):")
        print(f"  DICE: {all_metrics['dice']:.4f}")
        print(f"  IoU:  {all_metrics['iou_positive']:.4f}")
        print(f"  AP:   {all_metrics['ap']:.4f}")
        print(f"  Acc:  {all_metrics['pixel_acc']:.4f}")

# CubeNET sur HSI
if models['CUBE'] is not None:
    models['CUBE'].to(device)
    
    for i, batch in enumerate(test_loader_hsi):
        condition = 'dry' if i == 0 else 'wet'
        img_name = test_images[i]
        
        images = batch['image'].to(device)
        masks = batch['mask']
        
        with torch.no_grad():
            logits = models['CUBE'](images)
            probs = torch.sigmoid(logits).squeeze(1).cpu()
            preds = (probs > models['CUBE'].best_threshold).long()
        
        # Calculer m√©triques
        all_metrics = metrics_calc.compute_all(preds, masks, probs)
        
        results_test['CUBE'][condition] = {
            'image': img_name,
            'dice': all_metrics['dice'],
            'iou': all_metrics['iou_positive'],
            'ap': all_metrics['ap'],
            'acc': all_metrics['pixel_acc']
        }
        
        print(f"\nCubeNET - {condition.upper()} ({img_name}):")
        print(f"  DICE: {all_metrics['dice']:.4f}")
        print(f"  IoU:  {all_metrics['iou_positive']:.4f}")
        print(f"  AP:   {all_metrics['ap']:.4f}")
        print(f"  Acc:  {all_metrics['pixel_acc']:.4f}")

print("\n" + "="*70)


# ============================================================================
# CELLULE 6 : Tableau comparatif Test Set
# ============================================================================

import pandas as pd

# Cr√©er DataFrame
test_records = []

for model in ['UNET', 'CUBE']:
    for condition in ['dry', 'wet']:
        if results_test[model][condition]:
            test_records.append({
                'Model': model,
                'Condition': condition.upper(),
                'DICE': f"{results_test[model][condition]['dice']:.4f}",
                'IoU': f"{results_test[model][condition]['iou']:.4f}",
                'AP': f"{results_test[model][condition]['ap']:.4f}",
                'Acc': f"{results_test[model][condition]['acc']:.4f}"
            })

test_df = pd.DataFrame(test_records)

print("\n" + "="*70)
print("TABLEAU R√âSULTATS TEST SET")
print("="*70)
print(test_df.to_string(index=False))
print("="*70)


# ============================================================================
# CELLULE 7 : Visualiser pr√©dictions Test Set
# ============================================================================

from matplotlib.colors import ListedColormap

# Colormap pour overlay
colors = np.array([
    [0, 0, 0, 0],      # TN
    [1, 0, 0, 0.6],    # FP - Rouge
    [0, 0, 1, 0.6],    # FN - Bleu
    [0, 1, 0, 0.6]     # TP - Vert
])
cmap_overlay = ListedColormap(colors)

# Fonction pour cr√©er overlay
def create_overlay(pred, gt):
    overlay = np.zeros_like(gt, dtype=np.uint8)
    overlay[(pred == 1) & (gt == 0)] = 1  # FP
    overlay[(pred == 0) & (gt == 1)] = 2  # FN
    overlay[(pred == 1) & (gt == 1)] = 3  # TP
    return overlay

# VISUALISATION
fig, axes = plt.subplots(2, 4, figsize=(20, 10))

conditions = ['dry', 'wet']
model_names = ['UNET', 'CUBE']

# UNET
models['UNET'].to(device)
for i, (batch, condition) in enumerate(zip(test_loader_rgb, conditions)):
    images = batch['image'].to(device)
    masks = batch['mask'].numpy()[0]
    
    with torch.no_grad():
        logits = models['UNET'](images)
        probs = torch.sigmoid(logits).squeeze().cpu().numpy()
        preds = (probs > models['UNET'].best_threshold).astype(np.uint8)
    
    # RGB
    img_rgb = images[0].cpu().permute(1, 2, 0).numpy()
    axes[i, 0].imshow(img_rgb)
    axes[i, 0].set_title(f'RGB - {condition.upper()}', fontsize=12, fontweight='bold')
    axes[i, 0].axis('off')
    
    # Ground Truth
    axes[i, 1].imshow(masks, cmap='gray')
    axes[i, 1].set_title('Ground Truth', fontsize=12, fontweight='bold')
    axes[i, 1].axis('off')
    
    # UNET Prediction
    overlay = create_overlay(preds, masks)
    axes[i, 2].imshow(img_rgb)
    axes[i, 2].imshow(overlay, cmap=cmap_overlay, alpha=0.7, vmin=0, vmax=3)
    dice_unet = results_test['UNET'][condition]['dice']
    axes[i, 2].set_title(f'UNET (DICE={dice_unet:.3f})', fontsize=12, fontweight='bold')
    axes[i, 2].axis('off')

# CubeNET
models['CUBE'].to(device)
for i, (batch, condition) in enumerate(zip(test_loader_hsi, conditions)):
    images_hsi = batch['image'].to(device)
    masks = batch['mask'].numpy()[0]
    
    with torch.no_grad():
        logits = models['CUBE'](images_hsi)
        probs = torch.sigmoid(logits).squeeze().cpu().numpy()
        preds = (probs > models['CUBE'].best_threshold).astype(np.uint8)
    
    # Afficher bande HSI
    img_hsi_band = images_hsi[0, 100].cpu().numpy()  # Bande 100
    axes[i, 3].imshow(img_hsi_band, cmap='viridis')
    
    # CubeNET Prediction overlay
    overlay = create_overlay(preds, masks)
    axes[i, 3].imshow(overlay, cmap=cmap_overlay, alpha=0.7, vmin=0, vmax=3)
    dice_cube = results_test['CUBE'][condition]['dice']
    axes[i, 3].set_title(f'CubeNET (DICE={dice_cube:.3f})', fontsize=12, fontweight='bold')
    axes[i, 3].axis('off')

# L√©gende
from matplotlib.patches import Patch
legend_elements = [
    Patch(facecolor='green', alpha=0.6, label='TP (True Positive)'),
    Patch(facecolor='red', alpha=0.6, label='FP (False Positive)'),
    Patch(facecolor='blue', alpha=0.6, label='FN (False Negative)')
]
fig.legend(handles=legend_elements, loc='lower center', ncol=3, fontsize=12)

plt.tight_layout(rect=[0, 0.03, 1, 1])
plt.savefig(os.path.join(CONFIG.results_dir, 'test_set_predictions.png'), dpi=300, bbox_inches='tight')
plt.show()

print("‚úì Visualisation sauvegard√©e: results/test_set_predictions.png")


# ============================================================================
# CELLULE 8 : Analyse d√©taill√©e IMAGE S√àCHE
# ============================================================================

print("\n" + "="*70)
print("ANALYSE D√âTAILL√âE - IMAGE S√àCHE (Condition difficile)")
print("="*70)

for model_name in ['UNET', 'CUBE']:
    if results_test[model_name]['dry']:
        r = results_test[model_name]['dry']
        print(f"\n{model_name}:")
        print(f"  DICE: {r['dice']:.4f}")
        print(f"  IoU:  {r['iou']:.4f}")
        print(f"  AP:   {r['ap']:.4f}")
        print(f"  Acc:  {r['acc']:.4f}")

# Comparaison
if results_test['UNET']['dry'] and results_test['CUBE']['dry']:
    dice_unet_dry = results_test['UNET']['dry']['dice']
    dice_cube_dry = results_test['CUBE']['dry']['dice']
    
    improvement = ((dice_cube_dry - dice_unet_dry) / dice_unet_dry) * 100
    
    print(f"\nüìä AM√âLIORATION CubeNET vs UNET (Image s√®che):")
    print(f"  Facteur: √ó{dice_cube_dry / dice_unet_dry:.2f}")
    print(f"  Am√©lioration: +{improvement:.1f}%")
    
    print("\nüí° INTERPR√âTATION:")
    if dice_cube_dry > dice_unet_dry:
        print("  ‚úì CubeNET surpasse UNET en conditions difficiles")
        print("  ‚úì Les informations spectrales HSI sont critiques pour")
        print("    diff√©rencier racines s√®ches du sol sec")
    else:
        print("  ‚ö†Ô∏è  R√©sultats inattendus - v√©rifier mod√®les")

print("\n" + "="*70)


# ============================================================================
# CELLULE 9 : Sauvegarder r√©sultats test
# ============================================================================

import json

# Sauvegarder en JSON
test_results_file = os.path.join(CONFIG.results_dir, 'test_set_results.json')

with open(test_results_file, 'w') as f:
    json.dump(results_test, f, indent=2)

print(f"\n‚úì R√©sultats test sauvegard√©s: {test_results_file}")
print("\nüéâ √âvaluation test set termin√©e!")