# TP 06 : Fine-tuning GPT-2 pour g√©n√©rer des Pok√©mon

**Objectif** : Fine-tuner GPT-2 fran√ßais pour qu'il g√©n√®re des descriptions de Pok√©mon

**Dur√©e d'entra√Ænement** : ~20 minutes

> **Pendant l'entra√Ænement**, ouvrez le notebook **TP-06-Exploration.ipynb** pour explorer le dataset et comprendre les techniques utilis√©es.

---

## 1. Installation et imports

In [None]:
# Installation des d√©pendances (Colab)
!pip install -q transformers datasets accelerate

In [None]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments, Trainer, TrainerCallback
from datasets import load_dataset
import matplotlib.pyplot as plt
import warnings
warnings.filterwarnings('ignore')

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Device: {device}")

if not torch.cuda.is_available():
    print("\n‚ö†Ô∏è  GPU non disponible ! L'entra√Ænement sera tr√®s lent.")
    print("   Sur Colab : Runtime > Change runtime type > GPU")

---

## 2. Configuration

Ces param√®tres sont pr√©-optimis√©s. Vous pouvez les modifier pour exp√©rimenter.

In [None]:
# ‚ïî‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïó
# ‚ïë                       CONFIGURATION                              ‚ïë
# ‚ïö‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïù

MODEL_SIZE = "base"           # "small" (rapide) ou "base" (meilleur)
NUM_EPOCHS = 10               # Nombre d'epochs
MAX_LENGTH = 256              # Longueur max (tokens)
LEARNING_RATE = 5e-5          # Taux d'apprentissage

# Techniques d'optimisation (recommand√© : laisser √† True)
ADD_POKEMON_TOKENS = True     # Ajouter les noms au vocabulaire
FREEZE_LOWER_LAYERS = True    # Figer 50% des couches basses

print("Configuration charg√©e !")

---

## 3. Chargement des donn√©es et du mod√®le

In [None]:
# Charger les noms de Pok√©mon
print("Chargement des noms de Pok√©mon...")
pokemon_names_ds = load_dataset("chris-lmd/pokemon-names-fr")
POKEMON_NAMES = [item["name"] for item in pokemon_names_ds["train"]]
print(f"  {len(POKEMON_NAMES)} noms charg√©s")

# Charger le dataset Pokepedia
print("\nChargement du dataset Pokepedia...")
dataset = load_dataset("chris-lmd/pokepedia-fr")
print(f"  {len(dataset['train']):,} articles au total")

# Filtrer pour ne garder que les vrais articles Pok√©mon
pokemon_names_set = set(name.lower() for name in POKEMON_NAMES)
train_dataset = dataset['train'].filter(
    lambda x: x.get('title', '').lower() in pokemon_names_set
)
print(f"  {len(train_dataset):,} articles Pok√©mon retenus")

In [None]:
# Charger le mod√®le GPT-2 fran√ßais
model_name = f"asi/gpt-fr-cased-{MODEL_SIZE}"
print(f"Chargement de {model_name}...")

tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)

tokenizer.pad_token = tokenizer.eos_token
model.config.pad_token_id = tokenizer.pad_token_id

print(f"\n‚úÖ Mod√®le charg√© !")
print(f"   Param√®tres : {sum(p.numel() for p in model.parameters()):,}")
print(f"   Couches : {model.config.n_layer}")

---

## 4. Techniques d'optimisation

In [None]:
# ‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê
# TECHNIQUE 1 : Ajouter les tokens Pok√©mon au vocabulaire
# ‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê

if ADD_POKEMON_TOKENS:
    print("Ajout des tokens Pok√©mon...")
    
    # Trouver le token de r√©f√©rence pour l'initialisation
    for ref in ["Pok√©mon", "Pokemon", "animal"]:
        tokens = tokenizer.encode(ref, add_special_tokens=False)
        if len(tokens) == 1:
            reference_id = tokens[0]
            print(f"  Token de r√©f√©rence : '{ref}'")
            break
    
    # Tokens √† ajouter (ceux qui ne sont pas d√©j√† uniques)
    new_tokens = [name for name in POKEMON_NAMES 
                  if len(tokenizer.encode(name, add_special_tokens=False)) > 1]
    
    # Sauvegarder l'embedding de r√©f√©rence
    with torch.no_grad():
        ref_embedding = model.transformer.wte.weight[reference_id].clone()
    
    # Ajouter et initialiser
    num_added = tokenizer.add_tokens(new_tokens)
    old_size = model.transformer.wte.weight.shape[0]
    model.resize_token_embeddings(len(tokenizer), mean_resizing=False)
    
    with torch.no_grad():
        for i in range(num_added):
            noise = torch.randn_like(ref_embedding) * 0.01
            model.transformer.wte.weight[old_size + i] = ref_embedding + noise
    
    print(f"  ‚úÖ {num_added} tokens ajout√©s")

In [None]:
# ‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê
# TECHNIQUE 2 : Figer les couches basses
# ‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê

if FREEZE_LOWER_LAYERS:
    print("Freezing des couches basses...")
    
    num_to_freeze = model.config.n_layer // 2
    
    # Figer embeddings
    for param in model.transformer.wte.parameters():
        param.requires_grad = False
    for param in model.transformer.wpe.parameters():
        param.requires_grad = False
    
    # Figer les N premi√®res couches
    for i in range(num_to_freeze):
        for param in model.transformer.h[i].parameters():
            param.requires_grad = False
    
    trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
    total = sum(p.numel() for p in model.parameters())
    print(f"  Couches fig√©es : {num_to_freeze}/{model.config.n_layer}")
    print(f"  ‚úÖ Param√®tres entra√Ænables : {trainable:,} ({100*trainable/total:.0f}%)")

In [None]:
# D√©placer sur GPU
model = model.to(device)
print(f"Mod√®le sur : {device}")

---

## 5. Pr√©paration des donn√©es

In [None]:
# Tokenization
def tokenize_function(examples):
    return tokenizer(
        examples['content'],
        truncation=True,
        max_length=MAX_LENGTH,
        padding='max_length'
    )

print("Tokenization...")
tokenized_dataset = train_dataset.map(
    tokenize_function,
    batched=True,
    remove_columns=train_dataset.column_names
)

# Ajouter les labels (pour le language modeling)
tokenized_dataset = tokenized_dataset.map(
    lambda x: {'labels': x['input_ids'].copy()},
    batched=True
)

print(f"‚úÖ {len(tokenized_dataset)} exemples pr√™ts")

---

## 6. Entra√Ænement

**‚è±Ô∏è Dur√©e estim√©e : ~20 minutes**

> Pendant ce temps, ouvrez **TP-06-Exploration.ipynb** pour explorer le dataset et les techniques !

In [None]:
# Callback pour tracker la loss
class LossCallback(TrainerCallback):
    def __init__(self):
        self.losses = []
        self.steps = []
    
    def on_log(self, args, state, control, logs=None, **kwargs):
        if logs and "loss" in logs:
            self.losses.append(logs["loss"])
            self.steps.append(state.global_step)

loss_callback = LossCallback()

# Configuration du trainer
batch_size = 2 if MODEL_SIZE == "base" else 4
grad_accum = 8 if MODEL_SIZE == "base" else 4

training_args = TrainingArguments(
    output_dir="./gpt2-pokemon",
    overwrite_output_dir=True,
    num_train_epochs=NUM_EPOCHS,
    per_device_train_batch_size=batch_size,
    gradient_accumulation_steps=grad_accum,
    learning_rate=LEARNING_RATE,
    warmup_steps=100,
    weight_decay=0.01,
    logging_steps=10,
    save_steps=500,
    save_total_limit=2,
    fp16=torch.cuda.is_available(),
    report_to="none",
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_dataset,
    callbacks=[loss_callback],
)

print("‚úÖ Pr√™t pour l'entra√Ænement !")

In [None]:
# R√©sum√©
print("‚ïî" + "‚ïê" * 50 + "‚ïó")
print("‚ïë" + " R√âSUM√â ".center(50) + "‚ïë")
print("‚ï†" + "‚ïê" * 50 + "‚ï£")
print(f"‚ïë  Mod√®le : {model_name:<38} ‚ïë")
print(f"‚ïë  Dataset : {len(train_dataset):,} articles Pok√©mon{'':<20} ‚ïë")
print(f"‚ïë  Epochs : {NUM_EPOCHS}{'':<38} ‚ïë")
print(f"‚ïë  Tokens ajout√©s : {ADD_POKEMON_TOKENS}{'':<30} ‚ïë")
print(f"‚ïë  Couches fig√©es : {FREEZE_LOWER_LAYERS}{'':<30} ‚ïë")
print("‚ïö" + "‚ïê" * 50 + "‚ïù")

In [None]:
# üöÄ LANCER L'ENTRA√éNEMENT
print("üöÄ Fine-tuning en cours...")
print("\n‚è±Ô∏è  Allez explorer TP-06-Exploration.ipynb pendant ce temps !\n")

trainer.train()

print("\n‚úÖ Fine-tuning termin√© !")

In [None]:
# üìä Courbe de la loss
plt.figure(figsize=(10, 4))
plt.plot(loss_callback.steps, loss_callback.losses, 'b-', alpha=0.7)
plt.xlabel('Steps')
plt.ylabel('Loss')
plt.title('Training Loss')
plt.grid(True, alpha=0.3)
plt.show()

print(f"Loss initiale : {loss_callback.losses[0]:.3f}")
print(f"Loss finale   : {loss_callback.losses[-1]:.3f}")
print(f"R√©duction     : {(1 - loss_callback.losses[-1]/loss_callback.losses[0])*100:.0f}%")

In [None]:
# Sauvegarder le mod√®le
trainer.save_model("./gpt2-pokemon-final")
tokenizer.save_pretrained("./gpt2-pokemon-final")
print("‚úÖ Mod√®le sauvegard√© dans ./gpt2-pokemon-final")

---

## 7. Test de g√©n√©ration

In [None]:
def generate(prompt, temperature=0.7, max_length=150):
    """G√©n√®re du texte √† partir d'un prompt."""
    inputs = tokenizer(prompt, return_tensors="pt").to(device)
    
    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_length=max_length,
            temperature=temperature,
            top_k=50,
            top_p=0.9,
            do_sample=True,
            pad_token_id=tokenizer.pad_token_id,
            repetition_penalty=1.2,
        )
    
    return tokenizer.decode(outputs[0], skip_special_tokens=True)

In [None]:
# Test : Pikachu
prompt = "Pikachu est un Pok√©mon de type"
print(f"Prompt: {prompt}")
print(f"\nG√©n√©ration:")
print(generate(prompt, temperature=0.5))

In [None]:
# Test : Dracaufeu
prompt = "Dracaufeu est un Pok√©mon de type"
print(f"Prompt: {prompt}")
print(f"\nG√©n√©ration:")
print(generate(prompt, temperature=0.5))

In [None]:
# Test : Pok√©mon invent√©
prompt = "Aqualis est un Pok√©mon de type Eau. Il"
print(f"Prompt: {prompt}")
print(f"\nG√©n√©ration:")
print(generate(prompt, temperature=0.7))

In [None]:
# üéÆ Testez vos propres prompts !
mon_prompt = "Flamador est un Pok√©mon l√©gendaire de type Feu et Dragon. Ce"

print(f"Prompt: {mon_prompt}")
print(f"\nG√©n√©ration:")
print(generate(mon_prompt, temperature=0.7))

---

## 8. Pour aller plus loin

**Exp√©rimentations sugg√©r√©es :**
- Changer `temperature` : 0.3 (conservateur) ‚Üí 1.0 (cr√©atif)
- Essayer diff√©rents prompts
- Comparer avec `MODEL_SIZE = "small"`

**Voir aussi :**
- `TP-06-Exploration.ipynb` : Comprendre les techniques en d√©tail