# Visualisation des Courbes d'Entra√Ænement

Ce notebook permet de visualiser les m√©triques d'entra√Ænement des mod√®les Graph Transformer.


In [None]:
import json
import os
import matplotlib.pyplot as plt
import numpy as np
from pathlib import Path

# Configuration
BASE_PATH = "data"  # Changez en "/content/drive/MyDrive/data" pour Colab
MODELS = {
    "GT (MSE)": f"{BASE_PATH}/GT/training_logs.json",
    "GT Contrast (Improved)": f"{BASE_PATH}/GT_Contrast/training_logs.json"
}

# Couleurs pour les graphiques
COLORS = {
    "GT (MSE)": "#1f77b4",
    "GT Contrast (Improved)": "#ff7f0e"
}


In [None]:
def load_logs(log_path):
    """Charge les logs d'entra√Ænement depuis un fichier JSON"""
    if not os.path.exists(log_path):
        print(f"‚ö†Ô∏è  Fichier non trouv√©: {log_path}")
        return None
    
    with open(log_path, 'r') as f:
        logs = json.load(f)
    return logs

# Charger tous les logs disponibles
all_logs = {}
for model_name, log_path in MODELS.items():
    logs = load_logs(log_path)
    if logs is not None:
        all_logs[model_name] = logs
        print(f"‚úÖ {model_name}: {len(logs['epochs'])} epochs charg√©s")
    else:
        print(f"‚ùå {model_name}: logs non disponibles")

if len(all_logs) == 0:
    print("\n‚ö†Ô∏è  Aucun log trouv√©. Assurez-vous d'avoir entra√Æn√© au moins un mod√®le.")


## 1. Courbes de Loss (Train vs Validation)


In [None]:
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 6))

# Training Loss
for model_name, logs in all_logs.items():
    epochs = [e["epoch"] for e in logs["epochs"]]
    train_losses = [e["train_loss"] for e in logs["epochs"]]
    
    ax1.plot(epochs, train_losses, 
             label=model_name, 
             color=COLORS.get(model_name, "gray"),
             linewidth=2,
             marker='o',
             markersize=4)

ax1.set_xlabel("Epoch", fontsize=12)
ax1.set_ylabel("Training Loss", fontsize=12)
ax1.set_title("Training Loss", fontsize=14, fontweight='bold')
ax1.legend(fontsize=10)
ax1.grid(True, alpha=0.3)

# Validation Loss
for model_name, logs in all_logs.items():
    epochs = [e["epoch"] for e in logs["epochs"] if e.get("val_loss", 0) > 0]
    val_losses = [e["val_loss"] for e in logs["epochs"] if e.get("val_loss", 0) > 0]
    
    if len(epochs) > 0:
        ax2.plot(epochs, val_losses, 
                 label=model_name, 
                 color=COLORS.get(model_name, "gray"),
                 linewidth=2,
                 marker='s',
                 markersize=4,
                 linestyle='--')

ax2.set_xlabel("Epoch", fontsize=12)
ax2.set_ylabel("Validation Loss", fontsize=12)
ax2.set_title("Validation Loss", fontsize=14, fontweight='bold')
ax2.legend(fontsize=10)
ax2.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

# Graphique combin√© Train vs Val pour d√©tecter l'overfitting
plt.figure(figsize=(12, 6))
for model_name, logs in all_logs.items():
    epochs = [e["epoch"] for e in logs["epochs"]]
    train_losses = [e["train_loss"] for e in logs["epochs"]]
    val_losses = [e.get("val_loss", None) for e in logs["epochs"]]
    
    # Filtrer les epochs o√π val_loss existe
    valid_epochs = [ep for ep, vl in zip(epochs, val_losses) if vl is not None and vl > 0]
    valid_val_losses = [vl for vl in val_losses if vl is not None and vl > 0]
    
    plt.plot(epochs, train_losses, 
             label=f"{model_name} (Train)", 
             color=COLORS.get(model_name, "gray"),
             linewidth=2,
             marker='o',
             markersize=4)
    
    if len(valid_epochs) > 0:
        plt.plot(valid_epochs, valid_val_losses, 
                 label=f"{model_name} (Val)", 
                 color=COLORS.get(model_name, "gray"),
                 linewidth=2,
                 marker='s',
                 markersize=4,
                 linestyle='--',
                 alpha=0.7)

plt.xlabel("Epoch", fontsize=12)
plt.ylabel("Loss", fontsize=12)
plt.title("Train vs Validation Loss (D√©tection d'Overfitting)", fontsize=14, fontweight='bold')
plt.legend(fontsize=10)
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()


## 2. M√©triques de Validation (MRR, R@1, R@5, R@10)


In [None]:
fig, axes = plt.subplots(2, 2, figsize=(15, 10))
axes = axes.flatten()

metrics = ["val_mrr", "val_r1", "val_r5", "val_r10"]
metric_names = ["MRR (Mean Reciprocal Rank)", "R@1 (Recall@1)", "R@5 (Recall@5)", "R@10 (Recall@10)"]

for idx, (metric, metric_name) in enumerate(zip(metrics, metric_names)):
    ax = axes[idx]
    
    for model_name, logs in all_logs.items():
        epochs = [e["epoch"] for e in logs["epochs"] if e.get(metric, 0) > 0]
        values = [e[metric] for e in logs["epochs"] if e.get(metric, 0) > 0]
        
        if len(epochs) > 0:
            ax.plot(epochs, values, 
                   label=model_name, 
                   color=COLORS.get(model_name, "gray"),
                   linewidth=2,
                   marker='o',
                   markersize=4)
    
    ax.set_xlabel("Epoch", fontsize=10)
    ax.set_ylabel(metric_name, fontsize=10)
    ax.set_title(metric_name, fontsize=11, fontweight='bold')
    ax.legend(fontsize=9)
    ax.grid(True, alpha=0.3)
    ax.set_ylim([0, 1])

plt.tight_layout()
plt.show()


## 3. √âvolution du Learning Rate


In [None]:
plt.figure(figsize=(12, 6))

for model_name, logs in all_logs.items():
    epochs = [e["epoch"] for e in logs["epochs"]]
    learning_rates = [e["learning_rate"] for e in logs["epochs"]]
    
    plt.plot(epochs, learning_rates, 
             label=model_name, 
             color=COLORS.get(model_name, "gray"),
             linewidth=2,
             marker='s',
             markersize=4)

plt.xlabel("Epoch", fontsize=12)
plt.ylabel("Learning Rate", fontsize=12)
plt.title("√âvolution du Learning Rate", fontsize=14, fontweight='bold')
plt.legend(fontsize=10)
plt.grid(True, alpha=0.3)
plt.yscale('log')
plt.tight_layout()
plt.show()


## 4. Tableau R√©capitulatif


In [None]:
import pandas as pd

# Cr√©er un tableau r√©capitulatif
summary_data = []

for model_name, logs in all_logs.items():
    config = logs.get("config", {})
    best_mrr = logs.get("best_mrr", 0.0)
    
    # Derni√®res m√©triques
    if len(logs["epochs"]) > 0:
        last_epoch = logs["epochs"][-1]
        summary_data.append({
            "Mod√®le": model_name,
            "Epochs": config.get("epochs", "N/A"),
            "LR": config.get("lr", "N/A"),
            "Batch Size": config.get("batch_size", "N/A"),
            "Best MRR": f"{best_mrr:.4f}",
            "Last MRR": f"{last_epoch.get('val_mrr', 0.0):.4f}",
            "Last R@1": f"{last_epoch.get('val_r1', 0.0):.4f}",
            "Last R@5": f"{last_epoch.get('val_r5', 0.0):.4f}",
            "Last R@10": f"{last_epoch.get('val_r10', 0.0):.4f}",
            "Train Loss": f"{last_epoch.get('train_loss', 0.0):.4f}",
            "Val Loss": f"{last_epoch.get('val_loss', 0.0):.4f}"
        })

if summary_data:
    df = pd.DataFrame(summary_data)
    print("\nüìä R√©capitulatif des Mod√®les:\n")
    print(df.to_string(index=False))
else:
    print("Aucune donn√©e disponible.")


## 5. Comparaison C√¥te √† C√¥te (Loss vs MRR)


In [None]:
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 6))

# Loss
for model_name, logs in all_logs.items():
    epochs = [e["epoch"] for e in logs["epochs"]]
    train_losses = [e["train_loss"] for e in logs["epochs"]]
    
    ax1.plot(epochs, train_losses, 
            label=model_name, 
            color=COLORS.get(model_name, "gray"),
            linewidth=2,
            marker='o',
            markersize=4)

ax1.set_xlabel("Epoch", fontsize=12)
ax1.set_ylabel("Training Loss", fontsize=12)
ax1.set_title("Training Loss", fontsize=13, fontweight='bold')
ax1.legend(fontsize=10)
ax1.grid(True, alpha=0.3)

# MRR
for model_name, logs in all_logs.items():
    epochs = [e["epoch"] for e in logs["epochs"] if e.get("val_mrr", 0) > 0]
    mrr_values = [e["val_mrr"] for e in logs["epochs"] if e.get("val_mrr", 0) > 0]
    
    if len(epochs) > 0:
        ax2.plot(epochs, mrr_values, 
                label=model_name, 
                color=COLORS.get(model_name, "gray"),
                linewidth=2,
                marker='o',
                markersize=4)

ax2.set_xlabel("Epoch", fontsize=12)
ax2.set_ylabel("MRR", fontsize=12)
ax2.set_title("Validation MRR", fontsize=13, fontweight='bold')
ax2.legend(fontsize=10)
ax2.grid(True, alpha=0.3)
ax2.set_ylim([0, 1])

plt.tight_layout()
plt.show()


## 6. Analyse D√©taill√©e d'un Mod√®le

S√©lectionnez un mod√®le pour voir ses d√©tails complets.


In [None]:
# S√©lectionner un mod√®le (changez le nom si n√©cessaire)
selected_model = list(all_logs.keys())[0] if all_logs else None

if selected_model:
    logs = all_logs[selected_model]
    
    print(f"\nüìà Analyse d√©taill√©e: {selected_model}\n")
    print(f"Configuration:")
    for key, value in logs.get("config", {}).items():
        print(f"  - {key}: {value}")
    
    print(f"\nMeilleur MRR: {logs.get('best_mrr', 0.0):.4f}")
    
    # Afficher les 5 derni√®res epochs
    print(f"\nüìä 5 Derni√®res Epochs:")
    last_epochs = logs["epochs"][-5:]
    for epoch in last_epochs:
        print(f"\n  Epoch {epoch['epoch']}:")
        print(f"    Train Loss: {epoch['train_loss']:.4f}")
        if epoch.get('val_loss', 0) > 0:
            print(f"    Val Loss: {epoch['val_loss']:.4f}")
        print(f"    LR: {epoch['learning_rate']:.6f}")
        if epoch.get('val_mrr', 0) > 0:
            print(f"    MRR: {epoch['val_mrr']:.4f}")
            print(f"    R@1: {epoch['val_r1']:.4f}")
            print(f"    R@5: {epoch['val_r5']:.4f}")
            print(f"    R@10: {epoch['val_r10']:.4f}")
else:
    print("Aucun mod√®le disponible.")
