# 4. Global Demo & Model Comparison (All 3 Models)

Ce notebook est la **Comparaison Ultime**.

Il teste en parall√®le :
1.  **Baseline** (Frozen BERT + LogReg)
2.  **Fine-Tuned** (BERT entra√Æn√©)
3.  **Few-Shot** (LLM Llama-3 via Groq) üÜï

**Note** : Assurez-vous d'avoir votre cl√© API Groq (variable d'environnement `GROQ_API_KEY`).

In [None]:
import os
import pandas as pd
import numpy as np
import torch
import joblib
import json
import random
import re
from transformers import AutoTokenizer, AutoModel, AutoModelForSequenceClassification
from sklearn.metrics import accuracy_score
from IPython.display import display, HTML
from openai import OpenAI

# Config
VAL_PATH = "../data/twitter_val_clean.csv"
TRAIN_PATH = "../data/twitter_train_clean.csv"
BASELINE_PATH = "../models/baseline/baseline_model.joblib"
FINETUNED_PATH = "../models/bert_finetuned"
GROQ_MODEL = "llama-3.1-8b-instant"

LABEL_MAP = {0: "Negative", 1: "Neutral", 2: "Positive", 3: "Irrelevant"}
INV_LABELS = {v: k for k, v in LABEL_MAP.items()}

# Chargement Dataset Validation
print("‚è≥ Chargement Dataset Validation...")
val_df = pd.read_csv(VAL_PATH)
print(f"‚úÖ {len(val_df)} tweets disponibles.")

## 1. Chargement des Mod√®les Classiques

In [None]:
# 1. Baseline
print("‚è≥ Chargement Baseline...")
bl_clf = joblib.load(BASELINE_PATH)
bl_tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
bl_bert = AutoModel.from_pretrained("bert-base-uncased")
bl_bert.eval()

# 2. Fine-Tuned
print("‚è≥ Chargement Fine-Tuned...")
ft_tokenizer = AutoTokenizer.from_pretrained(FINETUNED_PATH)
ft_model = AutoModelForSequenceClassification.from_pretrained(FINETUNED_PATH)
ft_model.eval()

print("‚úÖ Mod√®les classiques charg√©s.")

## 2. Configuration Few-Shot (LLM)

In [None]:
# Config de la cl√© API
# os.environ["GROQ_API_KEY"] = "VOTRE_CLE_ICI" # D√©commentez pour tester en local si besoin

if "GROQ_API_KEY" not in os.environ:
    print("‚ö†Ô∏è  ATTENTION : Pas de cl√© API. Le Few-Shot ne marchera pas.")
    print("D√©finissez la variable d'environnement GROQ_API_KEY.")
else:
    print("‚úÖ Cl√© API d√©tect√©e.")

client = OpenAI(
    api_key=os.environ.get("GROQ_API_KEY"),
    base_url="https://api.groq.com/openai/v1"
)

# Pr√©paration des exemples (K=1)
train_df = pd.read_csv(TRAIN_PATH)
fewshot_examples = []
for lbl in sorted(train_df["label"].unique()):
    subset = train_df[train_df["label"] == lbl].sample(n=1, random_state=42)
    for _, row in subset.iterrows():
        fewshot_examples.append((row["clean_text"], LABEL_MAP[int(lbl)]))
random.shuffle(fewshot_examples)
print("‚úÖ Exemples Few-Shot charg√©s (K=1).")

## 3. Fonctions de Pr√©diction

In [None]:
def get_baseline_pred(texts):
    inputs = bl_tokenizer(texts, return_tensors="pt", padding=True, truncation=True, max_length=128)
    with torch.no_grad():
        outputs = bl_bert(**inputs)
    embs = outputs.last_hidden_state[:, 0, :].numpy()
    return bl_clf.predict(embs)

def get_finetuned_pred(texts):
    inputs = ft_tokenizer(texts, return_tensors="pt", padding=True, truncation=True, max_length=128)
    with torch.no_grad():
        outputs = ft_model(**inputs)
    probs = torch.nn.functional.softmax(outputs.logits, dim=1)
    return torch.argmax(probs, dim=1).numpy()

def get_fewshot_pred_single(text):
    # Construit le prompt
    prompt = "You are a strict tweet sentiment classifier (Negative, Neutral, Positive, Irrelevant).\n"
    prompt += "Here are examples:\n"
    for ex_txt, ex_lbl in fewshot_examples:
        prompt += f'Tweet: "{ex_txt}"\nLabel: {ex_lbl}\n\n'
    prompt += f'Now classify:\nTweet: "{text}"\nReturn strictly JSON: {{"label": "..."}}'
    
    try:
        resp = client.chat.completions.create(
            model=GROQ_MODEL,
            messages=[{"role": "user", "content": prompt}],
            max_tokens=30, temperature=0.0
        )
        raw = resp.choices[0].message.content
        # Parse simple
        if "Negative" in raw: return 0
        if "Neutral" in raw: return 1
        if "Positive" in raw: return 2
        if "Irrelevant" in raw: return 3
        return 1 # Default
    except Exception as e:
        print(f"LLM Error: {e}")
        return 1

## 4. Test Comparatif sur N Tweets
Attention : Le Few-Shot est plus lent (appels API).

In [None]:
N_TEST = 20  # On reste raisonnable pour ne pas attendre trop

sample = val_df.sample(n=N_TEST, random_state=None).reset_index(drop=True)
texts = sample["clean_text"].astype(str).tolist()
labels = sample["label"].values

print(f"üöÄ Test sur {N_TEST} tweets...")

p_bl = get_baseline_pred(texts)
p_ft = get_finetuned_pred(texts)

print("   ... appel LLM (Few-Shot)...")
p_fs = [get_fewshot_pred_single(t) for t in texts]

print("‚úÖ Termin√©.")

# Calcul et Affichage des Accuracy
acc_bl = accuracy_score(labels, p_bl)
acc_ft = accuracy_score(labels, p_ft)
acc_fs = accuracy_score(labels, p_fs)

print(f"\nüìä Scores d'Accuracy sur {N_TEST} tweets al√©atoires :")
print(f"--------------------------------------")
print(f"ü§ñ Baseline   : {acc_bl*100:.2f}%")
print(f"üöÄ Fine-Tuned : {acc_ft*100:.2f}%")
print(f"üß† Few-Shot   : {acc_fs*100:.2f}%")
print(f"--------------------------------------")

# Tableau R√©sultats
results = pd.DataFrame({
    "Tweet": texts,
    "True": [LABEL_MAP[l] for l in labels],
    "Baseline": [LABEL_MAP[l] for l in p_bl],
    "FineTuned": [LABEL_MAP[l] for l in p_ft],
    "FewShot": [LABEL_MAP[l] for l in p_fs]
})

def color_cells(row):
    styles = [''] * len(row)
    # Columns indices: True=1, Baseline=2, FineTuned=3, FewShot=4
    truth = row['True']
    
    for i, col in enumerate(['Baseline', 'FineTuned', 'FewShot'], start=2):
        if row[col] == truth:
            styles[i] = 'background-color: #d4edda; color: #155724' # Valide
        else:
            styles[i] = 'background-color: #f8d7da; color: #721c24' # Erreur
    return styles

display(results.style.apply(color_cells, axis=1))

## 5. Test Manuel (3 Mod√®les)

In [None]:
def clean_text_manual(text):
    text = text.lower()
    text = re.sub(r"[^a-z\s]", " ", text)
    return re.sub(r"\s+", " ", text).strip()

user_text = input("Tweet √† tester (anglais) : ")

if user_text:
    clean = clean_text_manual(user_text)
    print(f"Texte : {clean}")
    
    val_bl = get_baseline_pred([clean])[0]
    val_ft = get_finetuned_pred([clean])[0]
    val_fs = get_fewshot_pred_single(clean)
    
    html = f"""
    <div style='padding:10px; border:1px solid #ccc; border-radius:8px;'>
        <h3>üîÆ Pr√©dictions :</h3>
        <ul>
            <li>ü§ñ <b>Baseline :</b> {LABEL_MAP[val_bl]}</li>
            <li>üöÄ <b>Fine-Tuned :</b> {LABEL_MAP[val_ft]}</li>
            <li>üß† <b>Few-Shot (LLM) :</b> {LABEL_MAP[val_fs]}</li>
        </ul>
    </div>
    """
    display(HTML(html))