# Fine-tuning FunctionGemma pour Home Assistant

Ce notebook permet d'entraîner **FunctionGemma-270m-it** sur Google Colab avec des optimisations avancées.

**Améliorations:**
- Métriques personnalisées (précision function calls, entity accuracy)
- Scheduler cosine avec warmup
- Early stopping intelligent
- TensorBoard logging
- LoRA rank optimisé

**Format one-step (simplifié):**
1. User envoie une requête + liste des entités disponibles
2. Model appelle directement l'action avec la bonne entité

**Instructions:**
1. Activez le GPU: Runtime → Change runtime type → GPU
2. Exécutez les cellules dans l'ordre

## 1. Installation des dépendances

In [None]:
# Installation des dépendances
!pip install -q transformers peft accelerate datasets bitsandbytes huggingface_hub tensorboard

In [None]:
import torch
import gc

print(f"PyTorch version: {torch.__version__}")
print(f"GPU disponible: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    gpu_name = torch.cuda.get_device_name(0)
    vram_gb = torch.cuda.get_device_properties(0).total_memory / 1024**3
    print(f"GPU: {gpu_name}")
    print(f"VRAM: {vram_gb:.1f} GB")
    
    # Recommandations automatiques basées sur le GPU
    if "A100" in gpu_name:
        print("\n✓ A100 détectée - Configuration optimale disponible")
        RECOMMENDED_BATCH = 16
        RECOMMENDED_GRAD_ACCUM = 1
    elif "V100" in gpu_name or vram_gb >= 16:
        print("\n✓ GPU 16GB+ - Bonne configuration disponible")
        RECOMMENDED_BATCH = 8
        RECOMMENDED_GRAD_ACCUM = 2
    elif "T4" in gpu_name or vram_gb >= 12:
        print("\n⚠ T4/12GB - Configuration conservative recommandée")
        RECOMMENDED_BATCH = 4
        RECOMMENDED_GRAD_ACCUM = 4
    else:
        print("\n⚠ GPU limitée - Configuration minimale")
        RECOMMENDED_BATCH = 2
        RECOMMENDED_GRAD_ACCUM = 8
    
    print(f"   Batch recommandé: {RECOMMENDED_BATCH}")
    print(f"   Gradient accumulation: {RECOMMENDED_GRAD_ACCUM}")
    print(f"   Effective batch size: {RECOMMENDED_BATCH * RECOMMENDED_GRAD_ACCUM}")

## 2. Configuration Hugging Face

FunctionGemma est un modèle gated:
1. Accepter les conditions sur https://huggingface.co/google/functiongemma-270m-it
2. Créer un token sur https://huggingface.co/settings/tokens

In [None]:
from huggingface_hub import login
from google.colab import userdata

try:
    hf_token = userdata.get('HF_TOKEN')
    login(token=hf_token)
    print("✓ Connecté via secret Colab")
except:
    login()

## 3. Upload du dataset

In [None]:
from google.colab import files
import os

os.makedirs("data", exist_ok=True)

print("Uploadez train.jsonl et val.jsonl")
uploaded = files.upload()

for filename in uploaded.keys():
    os.rename(filename, f"data/{filename}")
    print(f"  → data/{filename}")

In [None]:
import json

def count_lines(filepath):
    with open(filepath, 'r') as f:
        return sum(1 for _ in f)

def analyze_dataset(filepath):
    """Analyse la distribution du dataset."""
    stats = {
        "total": 0,
        "actions": {},
        "negative": 0,
    }
    
    with open(filepath, 'r') as f:
        for line in f:
            stats["total"] += 1
            example = json.loads(line)
            text = example.get('text', '')
            
            # Détecter les actions
            for action in ['turn_on', 'turn_off', 'set_temperature', 'set_hvac_mode', 
                          'open_cover', 'close_cover', 'lock', 'unlock', 'activate']:
                if action in text:
                    stats["actions"][action] = stats["actions"].get(action, 0) + 1
            
            # Exemples négatifs (error.*)
            if 'error.' in text or 'clarification_needed' in text:
                stats["negative"] += 1
    
    return stats

train_count = count_lines("data/train.jsonl")
val_count = count_lines("data/val.jsonl")

print(f"Dataset:")
print(f"  Train: {train_count} exemples")
print(f"  Validation: {val_count} exemples")
print(f"  Ratio val: {val_count/(train_count+val_count)*100:.1f}%")

# Analyse détaillée
print("\nAnalyse du dataset d'entraînement:")
train_stats = analyze_dataset("data/train.jsonl")
print(f"  Exemples négatifs: {train_stats['negative']} ({train_stats['negative']/train_stats['total']*100:.1f}%)")
print(f"  Actions:")
for action, count in sorted(train_stats['actions'].items(), key=lambda x: -x[1]):
    print(f"    {action}: {count}")

# Aperçu
with open("data/train.jsonl", 'r') as f:
    example = json.loads(f.readline())
    print(f"\nExemple:")
    print(example['text'][:500] + "..." if len(example['text']) > 500 else example['text'])

## 4. Configuration

### Hyperparamètres optimisés

| Paramètre | Valeur | Justification |
|-----------|--------|---------------|
| LoRA rank | 64 | Meilleure capacité d'apprentissage |
| LoRA alpha | 128 | Ratio alpha/r = 2 (standard) |
| Learning rate | 1e-4 | Optimal pour LoRA fine-tuning |
| Epochs | 5 | Balance qualité/temps |
| Scheduler | Cosine | Meilleure convergence |
| Early stopping | 3 | Évite le surapprentissage |

In [None]:
# Configuration principale
CONFIG = {
    "model_name": "google/functiongemma-270m-it",
    "max_length": 512,
    
    # LoRA - Augmenté pour meilleure capacité
    "lora_r": 64,           # Augmenté de 32 à 64
    "lora_alpha": 128,      # Ratio alpha/r = 2
    "lora_dropout": 0.05,
    "lora_target_modules": [
        "q_proj", "k_proj", "v_proj", "o_proj",
        "gate_proj", "up_proj", "down_proj"
    ],
    
    # Entraînement - Ajuster selon GPU (voir cellule 1)
    "batch_size": RECOMMENDED_BATCH if 'RECOMMENDED_BATCH' in dir() else 8,
    "gradient_accumulation_steps": RECOMMENDED_GRAD_ACCUM if 'RECOMMENDED_GRAD_ACCUM' in dir() else 2,
    "learning_rate": 1e-4,
    "num_epochs": 5,            # Augmenté pour meilleure convergence
    "warmup_ratio": 0.1,
    "weight_decay": 0.01,
    
    # Scheduler
    "lr_scheduler_type": "cosine",  # Nouveau: scheduler cosine
    
    # Early stopping
    "early_stopping_patience": 3,   # Nouveau: arrêt après 3 eval sans amélioration
    "early_stopping_threshold": 0.01,
    
    # Sauvegarde
    "output_dir": "./output",
    "save_steps": 50,
    "logging_steps": 10,
    "eval_steps": 50,
}

# Calcul de l'effective batch size
effective_batch = CONFIG["batch_size"] * CONFIG["gradient_accumulation_steps"]

print("Configuration:")
print(f"  Model: {CONFIG['model_name']}")
print(f"  LoRA rank: {CONFIG['lora_r']} (alpha: {CONFIG['lora_alpha']})")
print(f"  Batch size: {CONFIG['batch_size']} × {CONFIG['gradient_accumulation_steps']} = {effective_batch} effective")
print(f"  Learning rate: {CONFIG['learning_rate']} ({CONFIG['lr_scheduler_type']} scheduler)")
print(f"  Epochs: {CONFIG['num_epochs']}")
print(f"  Early stopping: patience={CONFIG['early_stopping_patience']}")

## 5. Chargement du modèle

In [None]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import LoraConfig, get_peft_model, TaskType

print(f"Chargement de {CONFIG['model_name']}...")

model = AutoModelForCausalLM.from_pretrained(
    CONFIG["model_name"],
    torch_dtype=torch.bfloat16,
    device_map="auto",
    trust_remote_code=True,
)

tokenizer = AutoTokenizer.from_pretrained(
    CONFIG["model_name"],
    trust_remote_code=True,
)

if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token
    model.config.pad_token_id = tokenizer.eos_token_id

print(f"✓ Modèle chargé! Paramètres: {model.num_parameters():,}")

In [None]:
# Configuration LoRA optimisée
lora_config = LoraConfig(
    r=CONFIG["lora_r"],
    lora_alpha=CONFIG["lora_alpha"],
    lora_dropout=CONFIG["lora_dropout"],
    target_modules=CONFIG["lora_target_modules"],
    bias="none",
    task_type=TaskType.CAUSAL_LM,
)

model = get_peft_model(model, lora_config)
model.print_trainable_parameters()

# IMPORTANT: Activer input_require_grads AVANT gradient checkpointing
# Cela est nécessaire pour PEFT/LoRA avec gradient checkpointing
model.enable_input_require_grads()

# Activer gradient checkpointing pour économiser la mémoire
# use_reentrant=False est requis pour PyTorch 2.x avec PEFT
model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False})
print("✓ Gradient checkpointing activé (use_reentrant=False)")

## 6. Préparation du dataset

In [None]:
from datasets import load_dataset

dataset = load_dataset(
    "json",
    data_files={
        "train": "data/train.jsonl",
        "validation": "data/val.jsonl",
    }
)

print(f"Train: {len(dataset['train'])} | Val: {len(dataset['validation'])}")

In [None]:
def tokenize_function(examples):
    tokenized = tokenizer(
        examples["text"],
        padding="max_length",
        truncation=True,
        max_length=CONFIG["max_length"],
    )
    # Pour le causal LM, les labels sont les mêmes que input_ids
    tokenized["labels"] = tokenized["input_ids"].copy()
    return tokenized

print("Tokenization...")
tokenized_dataset = dataset.map(
    tokenize_function,
    remove_columns=["text"],
    batched=True,
    desc="Tokenizing",
)
print("✓ Tokenization terminée")

## 7. Métriques personnalisées

Évaluation spécifique aux function calls:
- **Function Call Accuracy**: Le modèle appelle-t-il la bonne fonction?
- **Entity Accuracy**: Le modèle sélectionne-t-il la bonne entité?
- **Negative Detection**: Le modèle détecte-t-il les requêtes impossibles?

In [None]:
import re
import numpy as np
from transformers import EvalPrediction

def extract_function_call(text):
    """Extrait le nom de fonction et les paramètres d'un appel FunctionGemma.
    
    Format attendu: <start_function_call>call:func_name{param:value,...}<end_function_call>
    """
    # Pattern pour le format FunctionGemma
    match = re.search(r'call:([a-z_\.]+)\{([^}]*)\}', text)
    if match:
        func_name = match.group(1)
        params_str = match.group(2)
        
        # Extraire entity_id si présent
        # Format: entity_id:<escape>value<escape> ou entity_id:value
        entity_match = re.search(r'entity_id:(?:<escape>)?([^<,]+)(?:<escape>)?', params_str)
        entity_id = entity_match.group(1).strip() if entity_match else None
        
        return func_name, entity_id
    return None, None

def compute_metrics(eval_pred: EvalPrediction):
    """Calcule les métriques personnalisées."""
    predictions, labels = eval_pred
    
    # Décoder les prédictions
    if isinstance(predictions, tuple):
        predictions = predictions[0]
    
    # Pour la perplexité, on calcule la loss moyenne
    # Note: Les métriques de function call nécessitent une génération complète
    # qui est faite séparément dans l'évaluation détaillée
    
    # Calculer la perplexité à partir des logits
    shift_logits = predictions[..., :-1, :]
    shift_labels = labels[..., 1:]
    
    # Masquer les tokens de padding (-100)
    mask = shift_labels != -100
    
    if mask.sum() > 0:
        # Calculer la cross-entropy
        from torch.nn import CrossEntropyLoss
        loss_fct = CrossEntropyLoss(reduction='none')
        
        flat_logits = torch.tensor(shift_logits).view(-1, shift_logits.shape[-1])
        flat_labels = torch.tensor(shift_labels).view(-1)
        
        losses = loss_fct(flat_logits, flat_labels)
        masked_losses = losses * mask.view(-1).float()
        
        avg_loss = masked_losses.sum() / mask.sum()
        perplexity = torch.exp(avg_loss).item()
    else:
        perplexity = float('inf')
    
    return {
        "perplexity": perplexity,
    }

print("✓ Métriques personnalisées définies")

## 8. Configuration de l'entraînement

In [None]:
from transformers import (
    TrainingArguments, 
    Trainer, 
    DataCollatorForLanguageModeling,
    EarlyStoppingCallback,
)

training_args = TrainingArguments(
    output_dir=CONFIG["output_dir"],
    
    # Epochs et batch
    num_train_epochs=CONFIG["num_epochs"],
    per_device_train_batch_size=CONFIG["batch_size"],
    per_device_eval_batch_size=CONFIG["batch_size"],
    gradient_accumulation_steps=CONFIG["gradient_accumulation_steps"],
    
    # Optimisation
    learning_rate=CONFIG["learning_rate"],
    lr_scheduler_type=CONFIG["lr_scheduler_type"],  # Cosine scheduler
    warmup_ratio=CONFIG["warmup_ratio"],
    weight_decay=CONFIG["weight_decay"],
    max_grad_norm=1.0,
    
    # Logging
    logging_dir="./logs",
    logging_steps=CONFIG["logging_steps"],
    report_to=["tensorboard"],  # Activer TensorBoard
    
    # Évaluation
    eval_strategy="steps",
    eval_steps=CONFIG["eval_steps"],
    
    # Sauvegarde
    save_strategy="steps",
    save_steps=CONFIG["save_steps"],
    save_total_limit=3,
    load_best_model_at_end=True,
    metric_for_best_model="eval_loss",
    greater_is_better=False,
    
    # Performance
    bf16=True,
    dataloader_num_workers=2,
    # Gradient checkpointing déjà activé manuellement dans la cellule précédente
    # avec use_reentrant=False pour compatibilité PEFT/LoRA
    gradient_checkpointing=True,
    gradient_checkpointing_kwargs={"use_reentrant": False},
    
    # Misc
    remove_unused_columns=False,
    seed=42,
)

data_collator = DataCollatorForLanguageModeling(
    tokenizer=tokenizer, 
    mlm=False,
)

# Callbacks
callbacks = [
    EarlyStoppingCallback(
        early_stopping_patience=CONFIG["early_stopping_patience"],
        early_stopping_threshold=CONFIG["early_stopping_threshold"],
    )
]

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_dataset["train"],
    eval_dataset=tokenized_dataset["validation"],
    data_collator=data_collator,
    callbacks=callbacks,
)

print(f"\n{'='*50}")
print(f"Configuration d'entraînement:")
print(f"  Epochs: {CONFIG['num_epochs']}")
print(f"  Batch: {CONFIG['batch_size']} × {CONFIG['gradient_accumulation_steps']} = {effective_batch}")
print(f"  LR: {CONFIG['learning_rate']} ({CONFIG['lr_scheduler_type']})")
print(f"  Early stopping: patience={CONFIG['early_stopping_patience']}")
print(f"  TensorBoard: ./logs")
print(f"  Gradient checkpointing: use_reentrant=False")
print(f"{'='*50}")

In [None]:
# Lancer TensorBoard (optionnel)
%load_ext tensorboard
%tensorboard --logdir ./logs

## 9. Entraînement

In [None]:
print("Début de l'entraînement...")
print(f"  Train: {len(tokenized_dataset['train'])} exemples")
print(f"  Val: {len(tokenized_dataset['validation'])} exemples")
print()

train_result = trainer.train()

print("\n" + "="*50)
print("Entraînement terminé!")
print(f"  Training loss: {train_result.training_loss:.4f}")
print(f"  Steps: {train_result.global_step}")
print("="*50)

In [None]:
# Évaluation finale
print("Évaluation finale...")
eval_results = trainer.evaluate()

print(f"\nRésultats:")
print(f"  Eval loss: {eval_results['eval_loss']:.4f}")
print(f"  Perplexity: {np.exp(eval_results['eval_loss']):.2f}")

## 10. Évaluation détaillée des Function Calls

Test de la qualité des prédictions sur des exemples spécifiques.

In [None]:
import re
from collections import defaultdict

# Entités de test par domaine (simuler ce qui serait dans le prompt)
# IMPORTANT: Le modèle est entraîné avec TOUTES les entités de TOUS les domaines
TEST_ENTITIES = {
    "light": ["light.salon", "light.cuisine", "light.chambre", "light.bureau"],
    "switch": ["switch.prise_salon", "switch.prise_cuisine"],
    "climate": ["climate.thermostat_salon", "climate.thermostat_bureau"],
    "scene": ["scene.cinema", "scene.nuit", "scene.romantique"],
    "cover": ["cover.volets_salon", "cover.volets_chambre", "cover.volets_cuisine"],
    "fan": ["fan.ventilateur_salon", "fan.ventilateur_chambre"],
    "lock": ["lock.porte_entree", "lock.porte_garage"],
    "person": ["person.francis", "person.noemie"],
}

# Ordre des domaines (identique à dataset_generator.py)
DOMAIN_ORDER = ['light', 'switch', 'climate', 'scene', 'cover', 'fan', 'lock', 'person']

def build_all_entities_context():
    """Construit le contexte avec TOUTES les entités (comme en production)."""
    parts = []
    for domain in DOMAIN_ORDER:
        entities = TEST_ENTITIES.get(domain, [])
        if entities:
            entities_str = ", ".join(entities)
            parts.append(f"Entités {domain} disponibles: {entities_str}")
    return "\n".join(parts)

def generate_response_with_entities(query: str, max_tokens: int = 100):
    """Génère une réponse avec TOUTES les entités dans le prompt (format one-step)."""
    # Utiliser le contexte complet avec toutes les entités
    entities_context = build_all_entities_context()
    
    text = (
        f"<start_of_turn>user\n{query}\n\n"
        f"{entities_context}<end_of_turn>\n"
        f"<start_of_turn>model\n"
    )
    inputs = tokenizer(text, return_tensors="pt").to(model.device)
    
    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=max_tokens,
            do_sample=False,
            pad_token_id=tokenizer.eos_token_id,
        )
    
    response = tokenizer.decode(outputs[0], skip_special_tokens=False)
    
    if "<start_of_turn>model" in response:
        response = response.split("<start_of_turn>model")[-1]
    if "<end_of_turn>" in response:
        response = response.split("<end_of_turn>")[0]
    
    return response.strip()

def parse_function_call_full(response: str) -> dict:
    """Parse un appel de fonction FunctionGemma."""
    match = re.search(r'call:([a-z_\.]+)\{([^}]*)\}', response)
    if not match:
        return None
    
    func_name = match.group(1)
    params_str = match.group(2)
    
    params = {}
    for param in params_str.split(','):
        if ':' in param:
            key, value = param.split(':', 1)
            value = value.replace('<escape>', '').strip()
            params[key.strip()] = value
    
    return {"name": func_name, "params": params}

def evaluate_function_calls(test_cases):
    """Évalue la précision des function calls (format one-step)."""
    results = {
        "total": 0,
        "correct_function": 0,
        "correct_entity": 0,
        "details": []
    }
    
    for test in test_cases:
        query = test["query"]
        expected_func = test["expected_function"]
        expected_entity = test.get("expected_entity")
        
        response = generate_response_with_entities(query)
        parsed = parse_function_call_full(response)
        
        func_name = parsed["name"] if parsed else None
        params = parsed["params"] if parsed else {}
        entity_id = params.get("entity_id")
        
        results["total"] += 1
        
        # Vérifier la fonction
        if func_name == expected_func:
            results["correct_function"] += 1
        
        # Vérifier l'entité
        if expected_entity and entity_id == expected_entity:
            results["correct_entity"] += 1
        
        results["details"].append({
            "query": query,
            "response": response,
            "function": func_name,
            "entity": entity_id,
            "expected_func": expected_func,
            "expected_entity": expected_entity,
        })
    
    return results

# Cas de test - format one-step (action directe avec TOUTES les entités)
test_cases = [
    {"query": "Allume la lumière du salon", "expected_function": "light.turn_on", "expected_entity": "light.salon"},
    {"query": "Éteins la lumière de la cuisine", "expected_function": "light.turn_off", "expected_entity": "light.cuisine"},
    {"query": "Mets le chauffage à 21 degrés", "expected_function": "climate.set_temperature", "expected_entity": "climate.thermostat_salon"},
    {"query": "Ferme les volets du salon", "expected_function": "cover.close_cover", "expected_entity": "cover.volets_salon"},
    {"query": "Active la scène cinéma", "expected_function": "scene.turn_on", "expected_entity": "scene.cinema"},
    {"query": "Verrouille la porte d'entrée", "expected_function": "lock.lock", "expected_entity": "lock.porte_entree"},
    {"query": "Allume le ventilateur du salon", "expected_function": "fan.turn_on", "expected_entity": "fan.ventilateur_salon"},
]

print("Évaluation des function calls (format one-step)...\n")
print("Contexte utilisé (toutes les entités):")
print("-" * 40)
print(build_all_entities_context())
print("-" * 40 + "\n")

results = evaluate_function_calls(test_cases)

print(f"Résultats ({results['total']} tests):")
print(f"  Fonction correcte: {results['correct_function']}/{results['total']} ({results['correct_function']/results['total']*100:.1f}%)")
print(f"  Entité correcte: {results['correct_entity']}/{results['total']} ({results['correct_entity']/results['total']*100:.1f}%)")

print("\nDétails:")
for detail in results["details"]:
    func_ok = "✓" if detail["function"] == detail["expected_func"] else "✗"
    entity_ok = "✓" if detail["entity"] == detail["expected_entity"] else "✗"
    print(f"  {func_ok}{entity_ok} {detail['query'][:35]}...")
    print(f"       → {detail['function']}(entity_id={detail['entity']})")
    if detail["function"] != detail["expected_func"]:
        print(f"       ⚠ Attendu: {detail['expected_func']}")

## 11. Sauvegarde

In [None]:
final_path = f"{CONFIG['output_dir']}/final"
trainer.save_model(final_path)
tokenizer.save_pretrained(final_path)

# Sauvegarder les métriques
import json
metrics = {
    "train_loss": train_result.training_loss,
    "eval_loss": eval_results['eval_loss'],
    "perplexity": float(np.exp(eval_results['eval_loss'])),
    "config": CONFIG,
    "function_call_accuracy": results['correct_function'] / results['total'] if results['total'] > 0 else 0,
}

with open(f"{final_path}/training_metrics.json", 'w') as f:
    json.dump(metrics, f, indent=2)

print(f"✓ Modèle sauvegardé: {final_path}")
print(f"✓ Métriques sauvegardées: {final_path}/training_metrics.json")

In [None]:
import shutil
shutil.make_archive("functiongemma-ha", 'zip', final_path)
files.download("functiongemma-ha.zip")
print("✓ Téléchargement du modèle...")

## 12. Test du modèle

**IMPORTANT:** Le format de test doit correspondre EXACTEMENT au format d'entraînement.

In [None]:
def test_model_onestep(query: str):
    """Test le modèle avec format one-step (TOUTES les entités dans le prompt)."""
    response = generate_response_with_entities(query)
    return response

# Tests avec différents domaines
test_queries = [
    "Allume la lumière du salon",
    "Éteins la lumière de la cuisine",
    "Mets le chauffage à 21 degrés",
    "Ferme les volets de la chambre",
    "Active la scène cinéma",
    "Verrouille la porte d'entrée",
    # Avec typos
    "alume la lumiere du salon",
    "etein la cuisine",
    # Québécois
    "Ferme la lumière du bureau",
    "Ouvre les lumières de la chambre",
]

print("Tests du modèle fine-tuné (format one-step):\n")
print("Contexte: toutes les entités sont dans le prompt")
print("=" * 50 + "\n")

for query in test_queries:
    print(f"User: {query}")
    response = test_model_onestep(query)
    parsed = parse_function_call_full(response)
    
    if parsed:
        func = parsed["name"]
        entity = parsed["params"].get("entity_id", "?")
        print(f"Model: {func}({entity})")
    else:
        print(f"Model: {response[:80]}...")
    print()

## 13. Tests supplémentaires

Tests avec différentes variations et cas limites.

In [None]:
# Tests de robustesse
print("=" * 50)
print("Tests de robustesse")
print("=" * 50 + "\n")

robustness_tests = [
    # Variations de formulation
    "Peux-tu allumer la lumière du salon ?",
    "Je voudrais éteindre la cuisine",
    "Mets-moi 22 degrés stp",
    
    # Québécois
    "Ferme la lumière du salon",
    "Ouvre les lumières de la chambre",
    
    # Fautes de frappe
    "alume le slaon",
    "etein tou",
    
    # Formulations naturelles
    "Il fait trop froid",
    "J'ai besoin de lumière dans le bureau",
    "Cache le soleil dans la chambre",
    
    # Scènes
    "Mets l'ambiance cinéma",
    "Mode nuit",
    
    # Personnes
    "Où est Francis ?",
    "Noémie est à la maison ?",
]

for query in robustness_tests:
    response = generate_response_with_entities(query)
    parsed = parse_function_call_full(response)
    
    if parsed:
        func = parsed["name"]
        entity = parsed["params"].get("entity_id", "?")
        print(f"✓ {query[:40]}")
        print(f"  → {func}({entity})")
    else:
        print(f"✗ {query[:40]}")
        print(f"  → Pas de function call: {response[:50]}...")
    print()

## 14. Résumé et prochaines étapes

### Métriques finales

In [None]:
print("="*60)
print("RÉSUMÉ DE L'ENTRAÎNEMENT")
print("="*60)
print(f"\nModèle: {CONFIG['model_name']}")
print(f"LoRA rank: {CONFIG['lora_r']} (alpha: {CONFIG['lora_alpha']})")
print(f"\nFormat: One-step (entités dans le prompt)")
print(f"\nDataset:")
print(f"  Train: {len(tokenized_dataset['train'])} exemples")
print(f"  Validation: {len(tokenized_dataset['validation'])} exemples")
print(f"\nEntraînement:")
print(f"  Epochs: {CONFIG['num_epochs']}")
print(f"  Learning rate: {CONFIG['learning_rate']} ({CONFIG['lr_scheduler_type']})")
print(f"  Batch size: {effective_batch} (effective)")
print(f"\nRésultats:")
print(f"  Training loss: {train_result.training_loss:.4f}")
print(f"  Eval loss: {eval_results['eval_loss']:.4f}")
print(f"  Perplexity: {np.exp(eval_results['eval_loss']):.2f}")
print(f"  Function accuracy: {results['correct_function']/results['total']*100:.1f}%")
print(f"  Entity accuracy: {results['correct_entity']/results['total']*100:.1f}%")
print(f"\nFichiers:")
print(f"  Modèle: {final_path}/")
print(f"  Métriques: {final_path}/training_metrics.json")
print(f"  Logs: ./logs/")
print("="*60)