In [None]:
import pandas as pd
from datasets import Dataset, DatasetDict, load_dataset
from transformers import (
    AutoTokenizer,
    AutoModel,
    AutoModelForSequenceClassification,
    TrainingArguments,
    Trainer,
    EarlyStoppingCallback
)
import wandb
import torch
import torch.nn.functional as F
from torch import nn
from sklearn.metrics import confusion_matrix, precision_score, recall_score, f1_score, accuracy_score, precision_recall_fscore_support
from sklearn.utils.class_weight import compute_class_weight
import numpy as np
import zipfile
import os
import re
from scipy.special import softmax
import math
import gc

# --- Configuration Générale ---
BASE_MODEL_NAME = "xlm-roberta-base"
TRAIN_CSV = "data/train_data_SMM4H_2025_Task_1.csv"
DEV_CSV = "data/dev_data_SMM4H_2025_Task_1.csv"
# ATTENTION: max_length=256 peut être gourmand en VRAM (8Go limite). Réduire à 192 ou 128 si OOM.
MAX_LENGTH = 256

# --- Configuration Phase 1: SimCSE Pré-entraînement ---
DO_SIMCSE_PRETRAINING = True
SIMCSE_OUTPUT_DIR = "results_simcse_xlmr_base"
SIMCSE_NUM_EPOCHS = 1
# ATTENTION: batch_size=16 double en mémoire (2 passes SimCSE). Réduire à 12 ou 8 si OOM sur 8Go VRAM.
SIMCSE_BATCH_SIZE = 8
SIMCSE_LEARNING_RATE = 3e-5
SIMCSE_TEMP = 0.05
SIMCSE_LOGGING_STEPS = 50
# NOUVEAU: Option pour le pooling SimCSE
SIMCSE_USE_MEAN_POOLING = True # Mettre à False pour utiliser le CLS token pooling
# NOUVEAU: Ratio de warmup pour le learning rate
SIMCSE_WARMUP_RATIO = 0.06 # ~6% des steps totaux

# --- Configuration Phase 2: Classification Fine-tuning ---
CLASSIFICATION_OUTPUT_DIR = "results_classifier_finetuned_on_simcse"
CLASSIFICATION_NUM_EPOCHS = 8
# ATTENTION: batch_size=8 * grad_accum=4 => effectif 32. Peut être lourd. Réduire si OOM.
CLASSIFICATION_BATCH_SIZE = 4
CLASSIFICATION_GRAD_ACCUM_STEPS = 8
CLASSIFICATION_LEARNING_RATE = 2e-5
CLASSIFICATION_EARLY_STOPPING_PATIENCE = 3
CLASSIFICATION_LOGGING_STEPS = 50
# NOUVEAU: Ratio de warmup pour le learning rate
CLASSIFICATION_WARMUP_RATIO = 0.06 # ~6% des steps totaux

# --- Initialisation WandB Globale ---
WANDB_PROJECT_NAME = "ade-classification-simcse-finetune"

# --- Fonction de Nettoyage de Texte (inchangée) ---
def clean_text(text):
    if not isinstance(text, str): return ""
    text = re.sub(r'@[\w_]+', '[USER_MENTION]', text)
    text = text.replace('<user>', '[USER_MENTION]')
    text = text.replace('<tuser>', '[USER_MENTION]')
    text = text.replace('<url>', '[URL]')
    text = text.replace('<email>', '[EMAIL]')
    text = text.replace('HTTPURL________________', '[URL]')
    text = re.sub(r'\s+', ' ', text).strip()
    return text

# --- Fonction Mean Pooling (pour SimCSE optionnel) ---
def mean_pooling(hidden_state, attention_mask):
    """Applique le mean pooling en ignorant les tokens de padding."""
    input_mask_expanded = attention_mask.unsqueeze(-1).expand(hidden_state.size()).float()
    sum_embeddings = torch.sum(hidden_state * input_mask_expanded, 1)
    sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9) # Évite division par zéro
    return sum_embeddings / sum_mask

# --- 1. Chargement et Préparation Initiale des Données ---
print("--- Initial Data Loading and Cleaning ---")
try:
    train_df_full = pd.read_csv(TRAIN_CSV).dropna(subset=['text'])
    dev_df_full = pd.read_csv(DEV_CSV).dropna(subset=['text'])
    print(f"Loaded {len(train_df_full)} train and {len(dev_df_full)} dev examples.")
except FileNotFoundError as e:
    print(f"Error loading CSV files: {e}")
    exit()

train_df_full['text'] = train_df_full['text'].apply(clean_text)
dev_df_full['text'] = dev_df_full['text'].apply(clean_text)
print("Text cleaning complete.")

# --- PHASE 1: SIMCSE PRE-TRAINING ---

# Définition du Trainer Personnalisé pour SimCSE (modifié)
class SimCSETrainer(Trainer):
    def __init__(self, *args, temperature=0.05, use_mean_pooling=False, **kwargs):
        super().__init__(*args, **kwargs)
        self.temperature = temperature
        self.use_mean_pooling = use_mean_pooling # Stocker l'option de pooling

    def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
        input_ids = inputs.get("input_ids")
        attention_mask = inputs.get("attention_mask")

        # Deux passes avec dropout (implicite dans le modèle en mode train)
        outputs1 = model(input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=True)
        outputs2 = model(input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=True)

        # --- MODIFIÉ: Choix entre Mean Pooling et CLS Pooling ---
        if self.use_mean_pooling:
            pooler_output1 = mean_pooling(outputs1.last_hidden_state, attention_mask)
            pooler_output2 = mean_pooling(outputs2.last_hidden_state, attention_mask)
        else: # Utiliser CLS token
            pooler_output1 = outputs1.last_hidden_state[:, 0]
            pooler_output2 = outputs2.last_hidden_state[:, 0]
        # ------------------------------------------------------

        # Concaténer et normaliser les embeddings
        embeddings = torch.cat([pooler_output1, pooler_output2], dim=0)
        embeddings = F.normalize(embeddings, p=2, dim=1) # L2 Normalization

        # Calculer la similarité cosinus
        cos_sim = torch.mm(embeddings, embeddings.t()) # Shape: (2*batch_size, 2*batch_size)

        # Masquer la diagonale (chaque embedding avec lui-même)
        batch_size = pooler_output1.size(0)
        mask_diag = torch.eye(2 * batch_size, device=embeddings.device, dtype=torch.bool)
        cos_sim = cos_sim.masked_fill(mask_diag, -9e15) # Remplacer par un très petit nombre

        # Appliquer la température
        cos_sim = cos_sim / self.temperature

        # Créer les labels pour InfoNCE loss
        labels = torch.arange(batch_size, device=embeddings.device)
        labels_z1 = labels + batch_size # Indices des z2 correspondants
        labels_z2 = labels             # Indices des z1 correspondants

        # Extraire les logits pour chaque partie
        logits_z1 = cos_sim[:batch_size, :] # Logits pour les embeddings de la 1ère passe
        logits_z2 = cos_sim[batch_size:, :] # Logits pour les embeddings de la 2ème passe

        # Calculer la perte CrossEntropy
        loss_fct = nn.CrossEntropyLoss()
        loss_z1 = loss_fct(logits_z1, labels_z1)
        loss_z2 = loss_fct(logits_z2, labels_z2)

        # Perte finale = moyenne des deux
        loss = (loss_z1 + loss_z2) / 2

        return (loss, {"embeddings1": pooler_output1, "embeddings2": pooler_output2}) if return_outputs else loss

model_load_path = BASE_MODEL_NAME # Default path if SimCSE is skipped

if DO_SIMCSE_PRETRAINING:
    print("\n--- Phase 1: Starting SimCSE Pre-training ---")
    run_name_simcse = f"simcse_{'meanpool_' if SIMCSE_USE_MEAN_POOLING else 'cls_'}{BASE_MODEL_NAME}"
    try:
        wandb.init(project=WANDB_PROJECT_NAME, name=run_name_simcse, reinit=True)
        wandb.config.update({ # Log config SimCSE
            "simcse_model": BASE_MODEL_NAME,
            "simcse_epochs": SIMCSE_NUM_EPOCHS,
            "simcse_batch_size": SIMCSE_BATCH_SIZE,
            "simcse_lr": SIMCSE_LEARNING_RATE,
            "simcse_temp": SIMCSE_TEMP,
            "simcse_pooling": "mean" if SIMCSE_USE_MEAN_POOLING else "cls",
            "simcse_warmup_ratio": SIMCSE_WARMUP_RATIO,
            "max_length": MAX_LENGTH
        })
    except Exception as e:
        print(f"WandB initialization failed for SimCSE phase: {e}")
        print("Proceeding without WandB logging for this phase.")


    # Préparer le dataset SimCSE (train + dev)
    all_texts_df = pd.concat([train_df_full[['text']], dev_df_full[['text']]], ignore_index=True)
    simcse_dataset = Dataset.from_pandas(all_texts_df)
    print(f"Created SimCSE dataset with {len(simcse_dataset)} examples.")

    # Tokenizer (Charger ici pour la phase 1)
    print(f"Loading tokenizer {BASE_MODEL_NAME} for SimCSE phase...")
    tokenizer_simcse = AutoTokenizer.from_pretrained(BASE_MODEL_NAME) # Utiliser une variable spécifique

    def tokenize_simcse(examples):
        # Utiliser tokenizer_simcse défini dans cette portée
        return tokenizer_simcse(examples["text"], padding="max_length", truncation=True, max_length=MAX_LENGTH)

    tokenized_simcse_dataset = simcse_dataset.map(tokenize_simcse, batched=True, remove_columns=["text"], num_proc=1) # Utiliser plus de procs si possible
    tokenized_simcse_dataset.set_format("torch")
    print("SimCSE dataset tokenized.")

    # Charger le modèle AutoModel (sans tête)
    simcse_model = AutoModel.from_pretrained(BASE_MODEL_NAME)
    print("SimCSE base model loaded.")

    # Arguments d'entraînement SimCSE (avec warmup)
    simcse_training_args = TrainingArguments(
        output_dir=SIMCSE_OUTPUT_DIR,
        num_train_epochs=SIMCSE_NUM_EPOCHS,
        per_device_train_batch_size=SIMCSE_BATCH_SIZE,
        learning_rate=SIMCSE_LEARNING_RATE,
        weight_decay=0.01,
        logging_dir=f'{SIMCSE_OUTPUT_DIR}/logs',
        logging_steps=SIMCSE_LOGGING_STEPS,
        save_strategy="epoch",
        report_to="wandb" if wandb.run is not None else "none", # Conditionner le report
        fp16=torch.cuda.is_available(), # INDISPENSABLE sur 8Go VRAM
        warmup_ratio=SIMCSE_WARMUP_RATIO, # NOUVEAU: Ajout du warmup
    )

    # Instancier le SimCSE Trainer (avec l'option pooling)
    simcse_trainer = SimCSETrainer(
        model=simcse_model,
        args=simcse_training_args,
        train_dataset=tokenized_simcse_dataset,
        tokenizer=tokenizer_simcse, # Passer le tokenizer spécifique
        temperature=SIMCSE_TEMP,
        use_mean_pooling=SIMCSE_USE_MEAN_POOLING # NOUVEAU: Passer l'option
    )

    # Lancer l'entraînement
    print("Starting SimCSE training...")
    simcse_trainer.train()
    print("SimCSE training finished.")

    # Sauvegarder
    simcse_trainer.save_model(SIMCSE_OUTPUT_DIR)
    tokenizer_simcse.save_pretrained(SIMCSE_OUTPUT_DIR) # Sauver le tokenizer utilisé
    print(f"SimCSE pre-trained model and tokenizer saved to {SIMCSE_OUTPUT_DIR}")

    # Nettoyer
    model_load_path = SIMCSE_OUTPUT_DIR # Mettre à jour le chemin pour la phase 2
    del simcse_model, simcse_trainer, tokenized_simcse_dataset, simcse_dataset, all_texts_df, tokenizer_simcse
    gc.collect()
    torch.cuda.empty_cache()
    if wandb.run is not None: wandb.finish()
    print("--- Phase 1: SimCSE Pre-training Complete ---")

else:
    print("\n--- Phase 1: Skipping SimCSE Pre-training ---")
    # model_load_path reste BASE_MODEL_NAME (défini plus haut)


# --- PHASE 2: CLASSIFICATION FINE-TUNING ---
print("\n--- Phase 2: Starting Classification Fine-tuning ---")
run_name_classify = f"classify_ft_on_{'simcse' if DO_SIMCSE_PRETRAINING else 'base'}_{BASE_MODEL_NAME}"
try:
    wandb.init(project=WANDB_PROJECT_NAME, name=run_name_classify, reinit=True)
    wandb.config.update({ # Log config Classification
        "base_model_for_ft": model_load_path,
        "classify_epochs": CLASSIFICATION_NUM_EPOCHS,
        "classify_batch_size": CLASSIFICATION_BATCH_SIZE,
        "classify_grad_accum": CLASSIFICATION_GRAD_ACCUM_STEPS,
        "classify_effective_batch": CLASSIFICATION_BATCH_SIZE * CLASSIFICATION_GRAD_ACCUM_STEPS,
        "classify_lr": CLASSIFICATION_LEARNING_RATE,
        "classify_warmup_ratio": CLASSIFICATION_WARMUP_RATIO,
        "classify_early_stopping": CLASSIFICATION_EARLY_STOPPING_PATIENCE,
        "max_length": MAX_LENGTH
    })
except Exception as e:
        print(f"WandB initialization failed for Classification phase: {e}")
        print("Proceeding without WandB logging for this phase.")


# Préparer les datasets classification
train_dataset_cls = Dataset.from_pandas(train_df_full)
dev_dataset_cls = Dataset.from_pandas(dev_df_full)
dataset_dict_cls = DatasetDict({'train': train_dataset_cls, 'validation': dev_dataset_cls})
print("Classification datasets created.")

# --- CORRECTIF: Charger systématiquement le tokenizer pour la Phase 2 ---
# Charger le tokenizer correspondant au modèle que nous allons fine-tuner
# (soit celui de SimCSE si Phase 1 a tourné, soit celui du modèle de base)
print(f"Loading tokenizer for classification phase from: {model_load_path}")
tokenizer = AutoTokenizer.from_pretrained(model_load_path)
# ----------------------------------------------------------------------

def tokenize_classification(examples):
    # Utilise le 'tokenizer' défini juste au-dessus
    return tokenizer(examples["text"], padding="max_length", truncation=True, max_length=MAX_LENGTH)

tokenized_datasets_cls = dataset_dict_cls.map(tokenize_classification, batched=True, remove_columns=["text", "id", "file_name", "origin", "language", "split", "type"], num_proc=1)
tokenized_datasets_cls.set_format("torch")
tokenized_datasets_cls = tokenized_datasets_cls.rename_column("label", "labels")
print("Classification datasets tokenized.")

# Calculer poids de classe
print("Computing class weights for classification...")
labels_train_cls = train_df_full['label'].values
class_weights_tensor_cls = None
unique_labels_cls = np.unique(labels_train_cls)
num_distinct_labels = len(unique_labels_cls)
print(f"Detected {num_distinct_labels} distinct labels in training data: {unique_labels_cls}")

if num_distinct_labels > 1:
    class_weights_cls = compute_class_weight(class_weight='balanced', classes=unique_labels_cls, y=labels_train_cls)
    # Ensure weights are ordered according to label index (0, 1, ...)
    ordered_weights_dict = {label: weight for label, weight in zip(unique_labels_cls, class_weights_cls)}
    # Utiliser num_distinct_labels pour déterminer la taille du tenseur
    ordered_weights_cls = np.array([ordered_weights_dict.get(i, 0) for i in unique_labels_cls]) # Assigner poids aux labels existants

    class_weights_tensor_cls = torch.tensor(ordered_weights_cls, dtype=torch.float).to("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Class Weights (for classes {unique_labels_cls}): {class_weights_tensor_cls.cpu().numpy()}")
    if wandb.run: wandb.config.update({"class_weights": class_weights_tensor_cls.cpu().numpy().tolist()})
else:
    print("Warning: Only one class found in training data. Cannot compute class weights.")


# Trainer Personnalisé Classification avec Poids
class WeightedClassificationTrainer(Trainer):
    def __init__(self, *args, class_weights=None, **kwargs):
         super().__init__(*args, **kwargs)
         # Déplacer les poids sur le bon device une seule fois si possible
         self.class_weights = class_weights.to(self.args.device) if class_weights is not None else None

    def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
        labels = inputs.get("labels")
        outputs = model(**inputs)
        logits = outputs.get("logits")

        # Utiliser les poids stockés et déjà sur le bon device
        if self.class_weights is not None:
            loss_fct = torch.nn.CrossEntropyLoss(weight=self.class_weights)
        else:
            loss_fct = torch.nn.CrossEntropyLoss() # No weights

        loss = loss_fct(logits.view(-1, self.model.config.num_labels), labels.view(-1))
        return (loss, outputs) if return_outputs else loss

# Charger le modèle pour Classification
print(f"Loading model for classification from: {model_load_path}")
classification_model = AutoModelForSequenceClassification.from_pretrained(
    model_load_path,
    num_labels=num_distinct_labels, # Utiliser le nombre détecté
    ignore_mismatched_sizes=True # Crucial si chargement depuis AutoModel (SimCSE)
)
print("Classification model loaded.")

# Fonction compute_metrics
def compute_metrics_cls(pred):
    labels = pred.label_ids
    preds = pred.predictions.argmax(-1)
    metric_labels = unique_labels_cls # Utiliser les labels détectés
    if num_distinct_labels == 2:
        # Calcul spécifique pour binaire (Pos = 1)
        precision, recall, f1, _ = precision_recall_fscore_support(labels, preds, average=None, labels=[0, 1], zero_division=0)
        acc = accuracy_score(labels, preds)
        metrics = {
            'accuracy': acc,
            'f1_pos': f1[1] if len(f1) > 1 else 0,
            'precision_pos': precision[1] if len(precision) > 1 else 0,
            'recall_pos': recall[1] if len(recall) > 1 else 0,
            'f1_neg': f1[0] if len(f1) > 0 else 0,
        }
    else:
        # Calcul macro/weighted pour multiclasse
        precision_macro, recall_macro, f1_macro, _ = precision_recall_fscore_support(labels, preds, average='macro', labels=metric_labels, zero_division=0)
        precision_weighted, recall_weighted, f1_weighted, _ = precision_recall_fscore_support(labels, preds, average='weighted', labels=metric_labels, zero_division=0)
        acc = accuracy_score(labels, preds)
        metrics = {
            'accuracy': acc,
            'f1_macro': f1_macro,
            'precision_macro': precision_macro,
            'recall_macro': recall_macro,
            'f1_weighted': f1_weighted,
        }
        # Optionnel: ajouter f1 par classe si besoin
        # _, _, f1_per_class, _ = precision_recall_fscore_support(labels, preds, average=None, labels=metric_labels, zero_division=0)
        # for i, label in enumerate(metric_labels):
        #     metrics[f'f1_class_{label}'] = f1_per_class[i]

    return metrics

# Arguments d'entraînement Classification (avec warmup)
classification_training_args = TrainingArguments(
    output_dir=CLASSIFICATION_OUTPUT_DIR,
    num_train_epochs=CLASSIFICATION_NUM_EPOCHS,
    per_device_train_batch_size=CLASSIFICATION_BATCH_SIZE,
    per_device_eval_batch_size=CLASSIFICATION_BATCH_SIZE * 2,
    gradient_accumulation_steps=CLASSIFICATION_GRAD_ACCUM_STEPS,
    learning_rate=CLASSIFICATION_LEARNING_RATE,
    weight_decay=0.01,
    evaluation_strategy="epoch",
    save_strategy="epoch",
    load_best_model_at_end=True,
    # Choisir la métrique pertinente pour load_best_model_at_end
    metric_for_best_model="f1_pos" if num_distinct_labels == 2 else "f1_macro",
    greater_is_better=True,
    logging_dir=f'{CLASSIFICATION_OUTPUT_DIR}/logs',
    logging_steps=CLASSIFICATION_LOGGING_STEPS,
    report_to="wandb" if wandb.run is not None else "none", # Conditionner le report
    fp16=torch.cuda.is_available(), # INDISPENSABLE sur 8Go VRAM
    warmup_ratio=CLASSIFICATION_WARMUP_RATIO,
    save_total_limit=2,
)

# Instancier le Trainer Classification (passer les poids ici)
classification_trainer = WeightedClassificationTrainer(
    model=classification_model,
    args=classification_training_args,
    train_dataset=tokenized_datasets_cls["train"],
    eval_dataset=tokenized_datasets_cls["validation"],
    tokenizer=tokenizer, # Utiliser le tokenizer chargé pour la phase 2
    compute_metrics=compute_metrics_cls,
    callbacks=[EarlyStoppingCallback(early_stopping_patience=CLASSIFICATION_EARLY_STOPPING_PATIENCE)],
    class_weights=class_weights_tensor_cls # Passer le tenseur de poids
)
print("Classification Trainer configured.")

# Lancer le fine-tuning
print("Starting classification fine-tuning...")
classification_trainer.train()
print("Classification fine-tuning finished.")

# Sauvegarder le meilleur modèle explicitement
best_model_path = os.path.join(CLASSIFICATION_OUTPUT_DIR, "best_model")
classification_trainer.save_model(best_model_path)
tokenizer.save_pretrained(best_model_path) # Sauver le tokenizer avec le meilleur modèle
print(f"Best classification model and tokenizer saved to {best_model_path}")

# --- Évaluation Détaillée et Soumission (Utilise le meilleur modèle chargé) ---
print("\n--- Detailed Evaluation and Submission File Generation ---")
print("\nGenerating predictions for Threshold Adjustment and Detailed Metrics...")

predictions_output = classification_trainer.predict(tokenized_datasets_cls["validation"])
logits = predictions_output.predictions
true_labels = predictions_output.label_ids

if logits.shape[-1] != num_distinct_labels:
    print(f"Error: Logits shape {logits.shape} unexpected for {num_distinct_labels} labels.")
    exit()

probabilities = None
predicted_labels_final = None
best_threshold = None

if num_distinct_labels == 2:
    probabilities = softmax(logits, axis=-1)[:, 1] # Proba classe positive (index 1)
    print("\nFinding best threshold on validation set based on Overall F1-Positive...")
    best_f1 = -1
    best_threshold = 0.5 # Default
    thresholds = np.arange(0.1, 0.91, 0.01)
    f1_scores_thresh = []
    for threshold in thresholds:
        predicted_labels_thresh = (probabilities >= threshold).astype(int)
        precision_thresh, recall_thresh, f1_thresh, _ = precision_recall_fscore_support(
            true_labels, predicted_labels_thresh, average='binary', pos_label=1, zero_division=0)
        f1_scores_thresh.append(f1_thresh)
        if f1_thresh > best_f1:
            best_f1 = f1_thresh
            best_threshold = threshold

    print(f"\nBest threshold found: {best_threshold:.2f} with Overall F1-Pos: {best_f1:.4f}")
    if wandb.run: wandb.log({"eval/best_threshold": best_threshold, "eval/best_val_f1_at_threshold": best_f1})
    predicted_labels_final = (probabilities >= best_threshold).astype(int)
else:
    print("Multi-class classification detected (>2). Using argmax for final predictions.")
    predicted_labels_final = logits.argmax(-1)
    # best_threshold reste None

# Préparer le DataFrame pour l'évaluation détaillée
dev_df_eval = dev_df_full.reset_index(drop=True)
if len(dev_df_eval) == len(predicted_labels_final):
    dev_df_eval['predicted_label'] = predicted_labels_final
    if probabilities is not None:
         dev_df_eval['probability_positive'] = probabilities
else:
    print(f"Error: Length mismatch between dev_df {len(dev_df_eval)} and predictions {len(predicted_labels_final)}!")
    exit()

dev_df_eval["language"] = dev_df_full["id"].apply(lambda x: str(x).split("_")[0] if isinstance(x, str) and "_" in x else "unknown")

languages = sorted(dev_df_eval['language'].unique())
language_f1_scores_pos = [] # Pour Macro F1 binaire
wandb_logs_eval = {}
print(f"\n--- Detailed Evaluation on Development Set (Final Predictions) ---")
if best_threshold is not None:
     print(f"--- (Using Threshold = {best_threshold:.2f}) ---")

for lang in languages:
    if lang == "unknown": continue
    lang_mask = dev_df_eval['language'] == lang
    y_true_lang = dev_df_eval.loc[lang_mask, 'label'].tolist()
    y_pred_lang_final = dev_df_eval.loc[lang_mask, 'predicted_label'].tolist()
    if len(y_true_lang) == 0: continue

    # Utiliser les labels détectés pour le calcul des métriques
    metric_labels = unique_labels_cls
    precision_lang, recall_lang, f1_lang, support_lang = precision_recall_fscore_support(
        y_true_lang, y_pred_lang_final, average=None, labels=metric_labels, zero_division=0)
    accuracy_lang = accuracy_score(y_true_lang, y_pred_lang_final)

    print(f"\nMetrics for language: {lang.upper()} (Support: {dict(zip(metric_labels, support_lang))})")
    # Clé WandB dynamique basée sur seuil/argmax
    wandb_key_prefix = f"eval/{lang}" + ("/thresh" if best_threshold is not None else "/argmax")

    if num_distinct_labels == 2:
        f1_pos_lang = f1_lang[1] # Index 1 correspond au label 1 (Positif)
        language_f1_scores_pos.append(f1_pos_lang)
        print(f"  Precision (Pos/1): {precision_lang[1]:.4f}")
        print(f"  Recall    (Pos/1): {recall_lang[1]:.4f}")
        print(f"  F1        (Pos/1): {f1_pos_lang:.4f}")
        print(f"  Accuracy:          {accuracy_lang:.4f}")
        wandb_logs_eval[f"{wandb_key_prefix}/precision_pos"] = precision_lang[1]
        wandb_logs_eval[f"{wandb_key_prefix}/recall_pos"] = recall_lang[1]
        wandb_logs_eval[f"{wandb_key_prefix}/f1_pos"] = f1_pos_lang
        wandb_logs_eval[f"{wandb_key_prefix}/accuracy"] = accuracy_lang
    else:
        f1_macro_lang = np.mean(f1_lang) # F1 macro simple
        print(f"  F1-Macro:         {f1_macro_lang:.4f}")
        print(f"  Accuracy:         {accuracy_lang:.4f}")
        wandb_logs_eval[f"{wandb_key_prefix}/f1_macro"] = f1_macro_lang
        wandb_logs_eval[f"{wandb_key_prefix}/accuracy"] = accuracy_lang
        # Logguer F1 par classe si besoin
        for i, label in enumerate(metric_labels):
             wandb_logs_eval[f"{wandb_key_prefix}/f1_class_{label}"] = f1_lang[i]


# Calcul des métriques globales finales
cm_overall_final = confusion_matrix(true_labels, predicted_labels_final, labels=unique_labels_cls)
overall_accuracy_final = accuracy_score(true_labels, predicted_labels_final)
wandb_key_prefix_overall = "eval/overall" + ("/thresh" if best_threshold is not None else "/argmax")

print(f"\n--- Overall Evaluation Summary (Final Predictions) ---")
if num_distinct_labels == 2:
    # Assurer que cm a bien 4 éléments (cas binaire)
    if cm_overall_final.size == 4:
      tn, fp, fn, tp = cm_overall_final.ravel()
    else: # Gérer cas où une classe n'est pas prédite/présente dans l'éval
      tn, fp, fn, tp = 0, 0, 0, 0
      print("Warning: Confusion matrix size indicates potential missing classes in evaluation.")
      # Logique pour reconstruire TN/FP/FN/TP si nécessaire basée sur les labels uniques
      if 0 in unique_labels_cls and 1 in unique_labels_cls:
          tn = cm_overall_final[0, 0]
          fp = cm_overall_final[0, 1]
          fn = cm_overall_final[1, 0]
          tp = cm_overall_final[1, 1]

    overall_precision_pos = tp / (tp + fp) if (tp + fp) > 0 else 0
    overall_recall_pos = tp / (tp + fn) if (tp + fn) > 0 else 0
    overall_f1_pos = 2 * (overall_precision_pos * overall_recall_pos) / (overall_precision_pos + overall_recall_pos) if (overall_precision_pos + overall_recall_pos) > 0 else 0
    macro_f1_pos = np.mean(language_f1_scores_pos) if language_f1_scores_pos else 0 # Moyenne des F1-pos par langue

    print(f"Overall F1-score (Positive Class): {overall_f1_pos:.4f}  <-- Primary Metric")
    print(f"Macro F1-score (Pos Class / Lang): {macro_f1_pos:.4f}")
    print(f"Overall Precision (Positive Class):{overall_precision_pos:.4f}")
    print(f"Overall Recall (Positive Class):   {overall_recall_pos:.4f}")
    print(f"Overall Accuracy:                  {overall_accuracy_final:.4f}")
    print("\nOverall Confusion Matrix (Final):")
    print(f"Predicted Labels: {unique_labels_cls}")
    print(f"True Labels")
    print(cm_overall_final)

    wandb_logs_eval[f"{wandb_key_prefix_overall}/f1_pos"] = overall_f1_pos
    wandb_logs_eval[f"{wandb_key_prefix_overall}/macro_f1_pos_lang"] = macro_f1_pos
    wandb_logs_eval[f"{wandb_key_prefix_overall}/precision_pos"] = overall_precision_pos
    wandb_logs_eval[f"{wandb_key_prefix_overall}/recall_pos"] = overall_recall_pos
    wandb_logs_eval[f"{wandb_key_prefix_overall}/accuracy"] = overall_accuracy_final
    if wandb.run: wandb_logs_eval[f"{wandb_key_prefix_overall}/conf_matrix"] = wandb.Table(data=cm_overall_final.tolist(), columns=[f"Pred_{l}" for l in unique_labels_cls], rows=[f"True_{l}" for l in unique_labels_cls])

else: # Métriques globales pour multiclasse
    overall_prec_macro, overall_recall_macro, overall_f1_macro, _ = precision_recall_fscore_support(true_labels, predicted_labels_final, average='macro', labels=unique_labels_cls, zero_division=0)
    overall_prec_weighted, overall_recall_weighted, overall_f1_weighted, _ = precision_recall_fscore_support(true_labels, predicted_labels_final, average='weighted', labels=unique_labels_cls, zero_division=0)
    print(f"Overall Accuracy:     {overall_accuracy_final:.4f}")
    print(f"Overall F1 (Macro):   {overall_f1_macro:.4f}")
    print(f"Overall F1 (Weighted):{overall_f1_weighted:.4f}")
    print("\nOverall Confusion Matrix (Final):")
    print(f"Predicted Labels: {unique_labels_cls}")
    print(f"True Labels")
    print(cm_overall_final)
    wandb_logs_eval[f"{wandb_key_prefix_overall}/accuracy"] = overall_accuracy_final
    wandb_logs_eval[f"{wandb_key_prefix_overall}/f1_macro"] = overall_f1_macro
    wandb_logs_eval[f"{wandb_key_prefix_overall}/f1_weighted"] = overall_f1_weighted
    if wandb.run: wandb_logs_eval[f"{wandb_key_prefix_overall}/conf_matrix"] = wandb.Table(data=cm_overall_final.tolist(), columns=[f"Pred_{l}" for l in unique_labels_cls], rows=[f"True_{l}" for l in unique_labels_cls])

# Log all detailed eval metrics
if wandb.run: wandb.log(wandb_logs_eval)

# --- Sauvegarde du fichier de soumission ---
print("\nSaving predictions for submission...")
os.makedirs(CLASSIFICATION_OUTPUT_DIR, exist_ok=True)
submission_df = dev_df_eval[['id', 'predicted_label']]
suffix = "simcse_finetuned" if DO_SIMCSE_PRETRAINING else "base_finetuned"
thresh_suffix = f"_thresh{best_threshold:.2f}" if best_threshold is not None else "_argmax"
csv_filename = f"predictions_task1_{suffix}{thresh_suffix}.csv"
zip_filename = f"submission_task1_{suffix}{thresh_suffix}.zip"
csv_path = os.path.join(CLASSIFICATION_OUTPUT_DIR, csv_filename)
zip_path = os.path.join(CLASSIFICATION_OUTPUT_DIR, zip_filename)

submission_df.to_csv(csv_path, index=False)
print(f"Predictions saved to {csv_path}")

try:
    with zipfile.ZipFile(zip_path, mode="w", compression=zipfile.ZIP_DEFLATED) as zf:
        zf.write(csv_path, arcname=csv_filename)
    print(f"{csv_filename} has been zipped into {zip_path}")
except Exception as e:
    print(f"Error zipping the file: {e}")


if wandb.run is not None and wandb.run.step > 0: # Check if wandb was used and logged something
    wandb.finish()
print("\nScript finished.")


--- Initial Data Loading and Cleaning ---
Loaded 31187 train and 4625 dev examples.
Text cleaning complete.

--- Phase 1: Starting SimCSE Pre-training ---


wandb: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
wandb: Currently logged in as: caron-olivier-80 (caron-olivier-80-universit-paris-dauphine-psl) to https://api.wandb.ai. Use `wandb login --relogin` to force relogin


Created SimCSE dataset with 35812 examples.
Loading tokenizer xlm-roberta-base for SimCSE phase...


Map:   0%|          | 0/35812 [00:00<?, ? examples/s]

SimCSE dataset tokenized.
SimCSE base model loaded.


  super().__init__(*args, **kwargs)


Starting SimCSE training...




Step,Training Loss
50,2.5152
100,0.7374
150,0.1244
200,0.0279
250,0.0172
300,0.0111
350,0.012
400,0.003
450,0.0059
500,0.0069


SimCSE training finished.
SimCSE pre-trained model and tokenizer saved to results_simcse_xlmr_base


0,1
train/epoch,▁▁▁▁▁▂▂▂▂▂▃▃▃▃▃▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇▇███
train/global_step,▁▁▁▁▁▂▂▃▃▃▃▃▄▄▄▄▄▄▄▄▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇▇███
train/grad_norm,█▁▃▁▁▁▁▁▁▁▁▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train/learning_rate,█████▇▇▇▇▇▆▆▆▆▆▅▅▅▅▅▄▄▄▄▄▄▃▃▃▃▃▃▃▃▃▂▂▂▁▁
train/loss,█▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
total_flos,4711181956153344.0
train/epoch,1.0
train/global_step,4477.0
train/grad_norm,0.01307
train/learning_rate,0.0
train/loss,0.0001
train_loss,0.0395
train_runtime,1185.0057
train_samples_per_second,30.221
train_steps_per_second,3.778


--- Phase 1: SimCSE Pre-training Complete ---

--- Phase 2: Starting Classification Fine-tuning ---


Classification datasets created.
Loading tokenizer for classification phase from: results_simcse_xlmr_base


Map:   0%|          | 0/31187 [00:00<?, ? examples/s]

Map:   0%|          | 0/4625 [00:00<?, ? examples/s]

Classification datasets tokenized.
Computing class weights for classification...
Detected 2 distinct labels in training data: [0 1]
Class Weights (for classes [0 1]): [0.5426468 6.3620973]
Loading model for classification from: results_simcse_xlmr_base


Some weights of XLMRobertaForSequenceClassification were not initialized from the model checkpoint at results_simcse_xlmr_base and are newly initialized: ['classifier.dense.bias', 'classifier.dense.weight', 'classifier.out_proj.bias', 'classifier.out_proj.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Classification model loaded.


  super().__init__(*args, **kwargs)


Classification Trainer configured.
Starting classification fine-tuning...


Epoch,Training Loss,Validation Loss,Accuracy,F1 Pos,Precision Pos,Recall Pos,F1 Neg
0,0.3075,0.394511,0.922595,0.536269,0.553476,0.520101,0.957773
1,0.1762,0.336146,0.886054,0.555274,0.418043,0.826633,0.934656
2,0.1807,0.344501,0.924541,0.636837,0.543517,0.768844,0.957896
3,0.1133,0.636081,0.946595,0.673712,0.710306,0.640704,0.970917
4,0.1244,0.789361,0.942054,0.663317,0.663317,0.663317,0.968299
5,0.0697,0.773144,0.941189,0.669903,0.647887,0.693467,0.967719
6,0.0323,0.846988,0.944649,0.688564,0.667453,0.711055,0.969625
7,0.0308,0.986162,0.944865,0.675985,0.683805,0.668342,0.969869


Classification fine-tuning finished.
Best classification model and tokenizer saved to results_classifier_finetuned_on_simcse\best_model

--- Detailed Evaluation and Submission File Generation ---

Generating predictions for Threshold Adjustment and Detailed Metrics...



Finding best threshold on validation set based on Overall F1-Positive...

Best threshold found: 0.11 with Overall F1-Pos: 0.6919

--- Detailed Evaluation on Development Set (Final Predictions) ---
--- (Using Threshold = 0.11) ---

Metrics for language: DE (Support: {0: 599, 1: 35})
  Precision (Pos/1): 0.5385
  Recall    (Pos/1): 0.6000
  F1        (Pos/1): 0.5676
  Accuracy:          0.9495

Metrics for language: EN (Support: {0: 841, 1: 61})
  Precision (Pos/1): 0.7391
  Recall    (Pos/1): 0.8361
  F1        (Pos/1): 0.7846
  Accuracy:          0.9690

Metrics for language: FR (Support: {0: 389, 1: 30})
  Precision (Pos/1): 0.5610
  Recall    (Pos/1): 0.7667
  F1        (Pos/1): 0.6479
  Accuracy:          0.9403

Metrics for language: RU (Support: {0: 2398, 1: 272})
  Precision (Pos/1): 0.6600
  Recall    (Pos/1): 0.7279
  F1        (Pos/1): 0.6923
  Accuracy:          0.9341

--- Overall Evaluation Summary (Final Predictions) ---
Overall F1-score (Positive Class): 0.6919  <-- Prim

0,1
eval/accuracy,▅▁▅█▇▇██
eval/best_threshold,▁
eval/best_val_f1_at_threshold,▁
eval/de/thresh/accuracy,▁
eval/de/thresh/f1_pos,▁
eval/de/thresh/precision_pos,▁
eval/de/thresh/recall_pos,▁
eval/en/thresh/accuracy,▁
eval/en/thresh/f1_pos,▁
eval/en/thresh/precision_pos,▁

0,1
eval/accuracy,0.94486
eval/best_threshold,0.11
eval/best_val_f1_at_threshold,0.69185
eval/de/thresh/accuracy,0.94953
eval/de/thresh/f1_pos,0.56757
eval/de/thresh/precision_pos,0.53846
eval/de/thresh/recall_pos,0.6
eval/en/thresh/accuracy,0.96896
eval/en/thresh/f1_pos,0.78462
eval/en/thresh/precision_pos,0.73913



Script finished.


## AUGMENTED DATA but same model

In [1]:
import pandas as pd
from datasets import Dataset, DatasetDict, load_dataset
from transformers import (
    AutoTokenizer,
    AutoModel,
    AutoModelForSequenceClassification,
    TrainingArguments,
    Trainer,
    EarlyStoppingCallback
)
import wandb
import torch
import torch.nn.functional as F
from torch import nn
from sklearn.metrics import confusion_matrix, precision_score, recall_score, f1_score, accuracy_score, precision_recall_fscore_support
from sklearn.utils.class_weight import compute_class_weight
import numpy as np
import zipfile
import os
import re
from scipy.special import softmax
import math
import gc

# --- Configuration Générale ---
BASE_MODEL_NAME = "xlm-roberta-base"
TRAIN_CSV = "data/train_data_augmented_no_text_clean.csv"
DEV_CSV = "data/dev_data_SMM4H_2025_Task_1.csv"
# ATTENTION: max_length=256 peut être gourmand en VRAM (8Go limite). Réduire à 192 ou 128 si OOM.
MAX_LENGTH = 256

# --- Configuration Phase 1: SimCSE Pré-entraînement ---
DO_SIMCSE_PRETRAINING = True
SIMCSE_OUTPUT_DIR = "results_augmented_simcse_xlmr_base"
SIMCSE_NUM_EPOCHS = 1
# ATTENTION: batch_size=16 double en mémoire (2 passes SimCSE). Réduire à 12 ou 8 si OOM sur 8Go VRAM.
SIMCSE_BATCH_SIZE = 8
SIMCSE_LEARNING_RATE = 3e-5
SIMCSE_TEMP = 0.05
SIMCSE_LOGGING_STEPS = 50
# NOUVEAU: Option pour le pooling SimCSE
SIMCSE_USE_MEAN_POOLING = True # Mettre à False pour utiliser le CLS token pooling
# NOUVEAU: Ratio de warmup pour le learning rate
SIMCSE_WARMUP_RATIO = 0.06 # ~6% des steps totaux

# --- Configuration Phase 2: Classification Fine-tuning ---
CLASSIFICATION_OUTPUT_DIR = "results_augmented_data_classifier_finetuned_on_simcse"
CLASSIFICATION_NUM_EPOCHS = 8
# ATTENTION: batch_size=8 * grad_accum=4 => effectif 32. Peut être lourd. Réduire si OOM.
CLASSIFICATION_BATCH_SIZE = 4
CLASSIFICATION_GRAD_ACCUM_STEPS = 8
CLASSIFICATION_LEARNING_RATE = 2e-5
CLASSIFICATION_EARLY_STOPPING_PATIENCE = 3
CLASSIFICATION_LOGGING_STEPS = 50
# NOUVEAU: Ratio de warmup pour le learning rate
CLASSIFICATION_WARMUP_RATIO = 0.06 # ~6% des steps totaux

# --- Initialisation WandB Globale ---
WANDB_PROJECT_NAME = "ade-classification-augmented-simcse-finetune"

# --- Fonction de Nettoyage de Texte (inchangée) ---
def clean_text(text):
    if not isinstance(text, str): return ""
    text = re.sub(r'@[\w_]+', '[USER_MENTION]', text)
    text = text.replace('<user>', '[USER_MENTION]')
    text = text.replace('<tuser>', '[USER_MENTION]')
    text = text.replace('<url>', '[URL]')
    text = text.replace('<email>', '[EMAIL]')
    text = text.replace('HTTPURL________________', '[URL]')
    text = re.sub(r'\s+', ' ', text).strip()
    return text

# --- Fonction Mean Pooling (pour SimCSE optionnel) ---
def mean_pooling(hidden_state, attention_mask):
    """Applique le mean pooling en ignorant les tokens de padding."""
    input_mask_expanded = attention_mask.unsqueeze(-1).expand(hidden_state.size()).float()
    sum_embeddings = torch.sum(hidden_state * input_mask_expanded, 1)
    sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9) # Évite division par zéro
    return sum_embeddings / sum_mask

# --- 1. Chargement et Préparation Initiale des Données ---
print("--- Initial Data Loading and Cleaning ---")
try:
    train_df_full = pd.read_csv(TRAIN_CSV).dropna(subset=['text'])
    dev_df_full = pd.read_csv(DEV_CSV).dropna(subset=['text'])
    print(f"Loaded {len(train_df_full)} train and {len(dev_df_full)} dev examples.")
except FileNotFoundError as e:
    print(f"Error loading CSV files: {e}")
    exit()

train_df_full['text'] = train_df_full['text'].apply(clean_text)
dev_df_full['text'] = dev_df_full['text'].apply(clean_text)
print("Text cleaning complete.")

# --- PHASE 1: SIMCSE PRE-TRAINING ---

# Définition du Trainer Personnalisé pour SimCSE (modifié)
class SimCSETrainer(Trainer):
    def __init__(self, *args, temperature=0.05, use_mean_pooling=False, **kwargs):
        super().__init__(*args, **kwargs)
        self.temperature = temperature
        self.use_mean_pooling = use_mean_pooling # Stocker l'option de pooling

    def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
        input_ids = inputs.get("input_ids")
        attention_mask = inputs.get("attention_mask")

        # Deux passes avec dropout (implicite dans le modèle en mode train)
        outputs1 = model(input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=True)
        outputs2 = model(input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=True)

        # --- MODIFIÉ: Choix entre Mean Pooling et CLS Pooling ---
        if self.use_mean_pooling:
            pooler_output1 = mean_pooling(outputs1.last_hidden_state, attention_mask)
            pooler_output2 = mean_pooling(outputs2.last_hidden_state, attention_mask)
        else: # Utiliser CLS token
            pooler_output1 = outputs1.last_hidden_state[:, 0]
            pooler_output2 = outputs2.last_hidden_state[:, 0]
        # ------------------------------------------------------

        # Concaténer et normaliser les embeddings
        embeddings = torch.cat([pooler_output1, pooler_output2], dim=0)
        embeddings = F.normalize(embeddings, p=2, dim=1) # L2 Normalization

        # Calculer la similarité cosinus
        cos_sim = torch.mm(embeddings, embeddings.t()) # Shape: (2*batch_size, 2*batch_size)

        # Masquer la diagonale (chaque embedding avec lui-même)
        batch_size = pooler_output1.size(0)
        mask_diag = torch.eye(2 * batch_size, device=embeddings.device, dtype=torch.bool)
        cos_sim = cos_sim.masked_fill(mask_diag, -9e15) # Remplacer par un très petit nombre

        # Appliquer la température
        cos_sim = cos_sim / self.temperature

        # Créer les labels pour InfoNCE loss
        labels = torch.arange(batch_size, device=embeddings.device)
        labels_z1 = labels + batch_size # Indices des z2 correspondants
        labels_z2 = labels             # Indices des z1 correspondants

        # Extraire les logits pour chaque partie
        logits_z1 = cos_sim[:batch_size, :] # Logits pour les embeddings de la 1ère passe
        logits_z2 = cos_sim[batch_size:, :] # Logits pour les embeddings de la 2ème passe

        # Calculer la perte CrossEntropy
        loss_fct = nn.CrossEntropyLoss()
        loss_z1 = loss_fct(logits_z1, labels_z1)
        loss_z2 = loss_fct(logits_z2, labels_z2)

        # Perte finale = moyenne des deux
        loss = (loss_z1 + loss_z2) / 2

        return (loss, {"embeddings1": pooler_output1, "embeddings2": pooler_output2}) if return_outputs else loss

model_load_path = BASE_MODEL_NAME # Default path if SimCSE is skipped

if DO_SIMCSE_PRETRAINING:
    print("\n--- Phase 1: Starting SimCSE Pre-training ---")
    run_name_simcse = f"simcse_{'meanpool_' if SIMCSE_USE_MEAN_POOLING else 'cls_'}{BASE_MODEL_NAME}"
    try:
        wandb.init(project=WANDB_PROJECT_NAME, name=run_name_simcse, reinit=True)
        wandb.config.update({ # Log config SimCSE
            "simcse_model": BASE_MODEL_NAME,
            "simcse_epochs": SIMCSE_NUM_EPOCHS,
            "simcse_batch_size": SIMCSE_BATCH_SIZE,
            "simcse_lr": SIMCSE_LEARNING_RATE,
            "simcse_temp": SIMCSE_TEMP,
            "simcse_pooling": "mean" if SIMCSE_USE_MEAN_POOLING else "cls",
            "simcse_warmup_ratio": SIMCSE_WARMUP_RATIO,
            "max_length": MAX_LENGTH
        })
    except Exception as e:
        print(f"WandB initialization failed for SimCSE phase: {e}")
        print("Proceeding without WandB logging for this phase.")


    # Préparer le dataset SimCSE (train + dev)
    all_texts_df = pd.concat([train_df_full[['text']], dev_df_full[['text']]], ignore_index=True)
    simcse_dataset = Dataset.from_pandas(all_texts_df)
    print(f"Created SimCSE dataset with {len(simcse_dataset)} examples.")

    # Tokenizer (Charger ici pour la phase 1)
    print(f"Loading tokenizer {BASE_MODEL_NAME} for SimCSE phase...")
    tokenizer_simcse = AutoTokenizer.from_pretrained(BASE_MODEL_NAME) # Utiliser une variable spécifique

    def tokenize_simcse(examples):
        # Utiliser tokenizer_simcse défini dans cette portée
        return tokenizer_simcse(examples["text"], padding="max_length", truncation=True, max_length=MAX_LENGTH)

    tokenized_simcse_dataset = simcse_dataset.map(tokenize_simcse, batched=True, remove_columns=["text"], num_proc=1) # Utiliser plus de procs si possible
    tokenized_simcse_dataset.set_format("torch")
    print("SimCSE dataset tokenized.")

    # Charger le modèle AutoModel (sans tête)
    simcse_model = AutoModel.from_pretrained(BASE_MODEL_NAME)
    print("SimCSE base model loaded.")

    # Arguments d'entraînement SimCSE (avec warmup)
    simcse_training_args = TrainingArguments(
        output_dir=SIMCSE_OUTPUT_DIR,
        num_train_epochs=SIMCSE_NUM_EPOCHS,
        per_device_train_batch_size=SIMCSE_BATCH_SIZE,
        learning_rate=SIMCSE_LEARNING_RATE,
        weight_decay=0.01,
        logging_dir=f'{SIMCSE_OUTPUT_DIR}/logs',
        logging_steps=SIMCSE_LOGGING_STEPS,
        save_strategy="epoch",
        report_to="wandb" if wandb.run is not None else "none", # Conditionner le report
        fp16=torch.cuda.is_available(), # INDISPENSABLE sur 8Go VRAM
        warmup_ratio=SIMCSE_WARMUP_RATIO, # NOUVEAU: Ajout du warmup
    )

    # Instancier le SimCSE Trainer (avec l'option pooling)
    simcse_trainer = SimCSETrainer(
        model=simcse_model,
        args=simcse_training_args,
        train_dataset=tokenized_simcse_dataset,
        tokenizer=tokenizer_simcse, # Passer le tokenizer spécifique
        temperature=SIMCSE_TEMP,
        use_mean_pooling=SIMCSE_USE_MEAN_POOLING # NOUVEAU: Passer l'option
    )

    # Lancer l'entraînement
    print("Starting SimCSE training...")
    simcse_trainer.train()
    print("SimCSE training finished.")

    # Sauvegarder
    simcse_trainer.save_model(SIMCSE_OUTPUT_DIR)
    tokenizer_simcse.save_pretrained(SIMCSE_OUTPUT_DIR) # Sauver le tokenizer utilisé
    print(f"SimCSE pre-trained model and tokenizer saved to {SIMCSE_OUTPUT_DIR}")

    # Nettoyer
    model_load_path = SIMCSE_OUTPUT_DIR # Mettre à jour le chemin pour la phase 2
    del simcse_model, simcse_trainer, tokenized_simcse_dataset, simcse_dataset, all_texts_df, tokenizer_simcse
    gc.collect()
    torch.cuda.empty_cache()
    if wandb.run is not None: wandb.finish()
    print("--- Phase 1: SimCSE Pre-training Complete ---")

else:
    print("\n--- Phase 1: Skipping SimCSE Pre-training ---")
    # model_load_path reste BASE_MODEL_NAME (défini plus haut)


# --- PHASE 2: CLASSIFICATION FINE-TUNING ---
print("\n--- Phase 2: Starting Classification Fine-tuning ---")
run_name_classify = f"classify_ft_on_{'simcse' if DO_SIMCSE_PRETRAINING else 'base'}_{BASE_MODEL_NAME}"
try:
    wandb.init(project=WANDB_PROJECT_NAME, name=run_name_classify, reinit=True)
    wandb.config.update({ # Log config Classification
        "base_model_for_ft": model_load_path,
        "classify_epochs": CLASSIFICATION_NUM_EPOCHS,
        "classify_batch_size": CLASSIFICATION_BATCH_SIZE,
        "classify_grad_accum": CLASSIFICATION_GRAD_ACCUM_STEPS,
        "classify_effective_batch": CLASSIFICATION_BATCH_SIZE * CLASSIFICATION_GRAD_ACCUM_STEPS,
        "classify_lr": CLASSIFICATION_LEARNING_RATE,
        "classify_warmup_ratio": CLASSIFICATION_WARMUP_RATIO,
        "classify_early_stopping": CLASSIFICATION_EARLY_STOPPING_PATIENCE,
        "max_length": MAX_LENGTH
    })
except Exception as e:
        print(f"WandB initialization failed for Classification phase: {e}")
        print("Proceeding without WandB logging for this phase.")


# Préparer les datasets classification
train_dataset_cls = Dataset.from_pandas(train_df_full)
dev_dataset_cls = Dataset.from_pandas(dev_df_full)
dataset_dict_cls = DatasetDict({'train': train_dataset_cls, 'validation': dev_dataset_cls})
print("Classification datasets created.")

# --- CORRECTIF: Charger systématiquement le tokenizer pour la Phase 2 ---
# Charger le tokenizer correspondant au modèle que nous allons fine-tuner
# (soit celui de SimCSE si Phase 1 a tourné, soit celui du modèle de base)
print(f"Loading tokenizer for classification phase from: {model_load_path}")
tokenizer = AutoTokenizer.from_pretrained(model_load_path)
# ----------------------------------------------------------------------

def tokenize_classification(examples):
    # Utilise le 'tokenizer' défini juste au-dessus
    return tokenizer(examples["text"], padding="max_length", truncation=True, max_length=MAX_LENGTH)

tokenized_datasets_cls = dataset_dict_cls.map(tokenize_classification, batched=True, remove_columns=["text", "id", "file_name", "origin", "language", "split", "type"], num_proc=1)
tokenized_datasets_cls.set_format("torch")
tokenized_datasets_cls = tokenized_datasets_cls.rename_column("label", "labels")
print("Classification datasets tokenized.")

# Calculer poids de classe
print("Computing class weights for classification...")
labels_train_cls = train_df_full['label'].values
class_weights_tensor_cls = None
unique_labels_cls = np.unique(labels_train_cls)
num_distinct_labels = len(unique_labels_cls)
print(f"Detected {num_distinct_labels} distinct labels in training data: {unique_labels_cls}")

if num_distinct_labels > 1:
    class_weights_cls = compute_class_weight(class_weight='balanced', classes=unique_labels_cls, y=labels_train_cls)
    # Ensure weights are ordered according to label index (0, 1, ...)
    ordered_weights_dict = {label: weight for label, weight in zip(unique_labels_cls, class_weights_cls)}
    # Utiliser num_distinct_labels pour déterminer la taille du tenseur
    ordered_weights_cls = np.array([ordered_weights_dict.get(i, 0) for i in unique_labels_cls]) # Assigner poids aux labels existants

    class_weights_tensor_cls = torch.tensor(ordered_weights_cls, dtype=torch.float).to("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Class Weights (for classes {unique_labels_cls}): {class_weights_tensor_cls.cpu().numpy()}")
    if wandb.run: wandb.config.update({"class_weights": class_weights_tensor_cls.cpu().numpy().tolist()})
else:
    print("Warning: Only one class found in training data. Cannot compute class weights.")


# Trainer Personnalisé Classification avec Poids
class WeightedClassificationTrainer(Trainer):
    def __init__(self, *args, class_weights=None, **kwargs):
         super().__init__(*args, **kwargs)
         # Déplacer les poids sur le bon device une seule fois si possible
         self.class_weights = class_weights.to(self.args.device) if class_weights is not None else None

    def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
        labels = inputs.get("labels")
        outputs = model(**inputs)
        logits = outputs.get("logits")

        # Utiliser les poids stockés et déjà sur le bon device
        if self.class_weights is not None:
            loss_fct = torch.nn.CrossEntropyLoss(weight=self.class_weights)
        else:
            loss_fct = torch.nn.CrossEntropyLoss() # No weights

        loss = loss_fct(logits.view(-1, self.model.config.num_labels), labels.view(-1))
        return (loss, outputs) if return_outputs else loss

# Charger le modèle pour Classification
print(f"Loading model for classification from: {model_load_path}")
classification_model = AutoModelForSequenceClassification.from_pretrained(
    model_load_path,
    num_labels=num_distinct_labels, # Utiliser le nombre détecté
    ignore_mismatched_sizes=True # Crucial si chargement depuis AutoModel (SimCSE)
)
print("Classification model loaded.")

# Fonction compute_metrics
def compute_metrics_cls(pred):
    labels = pred.label_ids
    preds = pred.predictions.argmax(-1)
    metric_labels = unique_labels_cls # Utiliser les labels détectés
    if num_distinct_labels == 2:
        # Calcul spécifique pour binaire (Pos = 1)
        precision, recall, f1, _ = precision_recall_fscore_support(labels, preds, average=None, labels=[0, 1], zero_division=0)
        acc = accuracy_score(labels, preds)
        metrics = {
            'accuracy': acc,
            'f1_pos': f1[1] if len(f1) > 1 else 0,
            'precision_pos': precision[1] if len(precision) > 1 else 0,
            'recall_pos': recall[1] if len(recall) > 1 else 0,
            'f1_neg': f1[0] if len(f1) > 0 else 0,
        }
    else:
        # Calcul macro/weighted pour multiclasse
        precision_macro, recall_macro, f1_macro, _ = precision_recall_fscore_support(labels, preds, average='macro', labels=metric_labels, zero_division=0)
        precision_weighted, recall_weighted, f1_weighted, _ = precision_recall_fscore_support(labels, preds, average='weighted', labels=metric_labels, zero_division=0)
        acc = accuracy_score(labels, preds)
        metrics = {
            'accuracy': acc,
            'f1_macro': f1_macro,
            'precision_macro': precision_macro,
            'recall_macro': recall_macro,
            'f1_weighted': f1_weighted,
        }
        # Optionnel: ajouter f1 par classe si besoin
        # _, _, f1_per_class, _ = precision_recall_fscore_support(labels, preds, average=None, labels=metric_labels, zero_division=0)
        # for i, label in enumerate(metric_labels):
        #     metrics[f'f1_class_{label}'] = f1_per_class[i]

    return metrics

# Arguments d'entraînement Classification (avec warmup)
classification_training_args = TrainingArguments(
    output_dir=CLASSIFICATION_OUTPUT_DIR,
    num_train_epochs=CLASSIFICATION_NUM_EPOCHS,
    per_device_train_batch_size=CLASSIFICATION_BATCH_SIZE,
    per_device_eval_batch_size=CLASSIFICATION_BATCH_SIZE * 2,
    gradient_accumulation_steps=CLASSIFICATION_GRAD_ACCUM_STEPS,
    learning_rate=CLASSIFICATION_LEARNING_RATE,
    weight_decay=0.01,
    evaluation_strategy="epoch",
    save_strategy="epoch",
    load_best_model_at_end=True,
    # Choisir la métrique pertinente pour load_best_model_at_end
    metric_for_best_model="f1_pos" if num_distinct_labels == 2 else "f1_macro",
    greater_is_better=True,
    logging_dir=f'{CLASSIFICATION_OUTPUT_DIR}/logs',
    logging_steps=CLASSIFICATION_LOGGING_STEPS,
    report_to="wandb" if wandb.run is not None else "none", # Conditionner le report
    fp16=torch.cuda.is_available(), # INDISPENSABLE sur 8Go VRAM
    warmup_ratio=CLASSIFICATION_WARMUP_RATIO,
    save_total_limit=2,
)

# Instancier le Trainer Classification (passer les poids ici)
classification_trainer = WeightedClassificationTrainer(
    model=classification_model,
    args=classification_training_args,
    train_dataset=tokenized_datasets_cls["train"],
    eval_dataset=tokenized_datasets_cls["validation"],
    tokenizer=tokenizer, # Utiliser le tokenizer chargé pour la phase 2
    compute_metrics=compute_metrics_cls,
    callbacks=[EarlyStoppingCallback(early_stopping_patience=CLASSIFICATION_EARLY_STOPPING_PATIENCE)],
    class_weights=class_weights_tensor_cls # Passer le tenseur de poids
)
print("Classification Trainer configured.")

# Lancer le fine-tuning
print("Starting classification fine-tuning...")
classification_trainer.train()
print("Classification fine-tuning finished.")

# Sauvegarder le meilleur modèle explicitement
best_model_path = os.path.join(CLASSIFICATION_OUTPUT_DIR, "best_model")
classification_trainer.save_model(best_model_path)
tokenizer.save_pretrained(best_model_path) # Sauver le tokenizer avec le meilleur modèle
print(f"Best classification model and tokenizer saved to {best_model_path}")

# --- Évaluation Détaillée et Soumission (Utilise le meilleur modèle chargé) ---
print("\n--- Detailed Evaluation and Submission File Generation ---")
print("\nGenerating predictions for Threshold Adjustment and Detailed Metrics...")

predictions_output = classification_trainer.predict(tokenized_datasets_cls["validation"])
logits = predictions_output.predictions
true_labels = predictions_output.label_ids

if logits.shape[-1] != num_distinct_labels:
    print(f"Error: Logits shape {logits.shape} unexpected for {num_distinct_labels} labels.")
    exit()

probabilities = None
predicted_labels_final = None
best_threshold = None

if num_distinct_labels == 2:
    probabilities = softmax(logits, axis=-1)[:, 1] # Proba classe positive (index 1)
    print("\nFinding best threshold on validation set based on Overall F1-Positive...")
    best_f1 = -1
    best_threshold = 0.5 # Default
    thresholds = np.arange(0.1, 0.91, 0.01)
    f1_scores_thresh = []
    for threshold in thresholds:
        predicted_labels_thresh = (probabilities >= threshold).astype(int)
        precision_thresh, recall_thresh, f1_thresh, _ = precision_recall_fscore_support(
            true_labels, predicted_labels_thresh, average='binary', pos_label=1, zero_division=0)
        f1_scores_thresh.append(f1_thresh)
        if f1_thresh > best_f1:
            best_f1 = f1_thresh
            best_threshold = threshold

    print(f"\nBest threshold found: {best_threshold:.2f} with Overall F1-Pos: {best_f1:.4f}")
    if wandb.run: wandb.log({"eval/best_threshold": best_threshold, "eval/best_val_f1_at_threshold": best_f1})
    predicted_labels_final = (probabilities >= best_threshold).astype(int)
else:
    print("Multi-class classification detected (>2). Using argmax for final predictions.")
    predicted_labels_final = logits.argmax(-1)
    # best_threshold reste None

# Préparer le DataFrame pour l'évaluation détaillée
dev_df_eval = dev_df_full.reset_index(drop=True)
if len(dev_df_eval) == len(predicted_labels_final):
    dev_df_eval['predicted_label'] = predicted_labels_final
    if probabilities is not None:
         dev_df_eval['probability_positive'] = probabilities
else:
    print(f"Error: Length mismatch between dev_df {len(dev_df_eval)} and predictions {len(predicted_labels_final)}!")
    exit()

dev_df_eval["language"] = dev_df_full["id"].apply(lambda x: str(x).split("_")[0] if isinstance(x, str) and "_" in x else "unknown")

languages = sorted(dev_df_eval['language'].unique())
language_f1_scores_pos = [] # Pour Macro F1 binaire
wandb_logs_eval = {}
print(f"\n--- Detailed Evaluation on Development Set (Final Predictions) ---")
if best_threshold is not None:
     print(f"--- (Using Threshold = {best_threshold:.2f}) ---")

for lang in languages:
    if lang == "unknown": continue
    lang_mask = dev_df_eval['language'] == lang
    y_true_lang = dev_df_eval.loc[lang_mask, 'label'].tolist()
    y_pred_lang_final = dev_df_eval.loc[lang_mask, 'predicted_label'].tolist()
    if len(y_true_lang) == 0: continue

    # Utiliser les labels détectés pour le calcul des métriques
    metric_labels = unique_labels_cls
    precision_lang, recall_lang, f1_lang, support_lang = precision_recall_fscore_support(
        y_true_lang, y_pred_lang_final, average=None, labels=metric_labels, zero_division=0)
    accuracy_lang = accuracy_score(y_true_lang, y_pred_lang_final)

    print(f"\nMetrics for language: {lang.upper()} (Support: {dict(zip(metric_labels, support_lang))})")
    # Clé WandB dynamique basée sur seuil/argmax
    wandb_key_prefix = f"eval/{lang}" + ("/thresh" if best_threshold is not None else "/argmax")

    if num_distinct_labels == 2:
        f1_pos_lang = f1_lang[1] # Index 1 correspond au label 1 (Positif)
        language_f1_scores_pos.append(f1_pos_lang)
        print(f"  Precision (Pos/1): {precision_lang[1]:.4f}")
        print(f"  Recall    (Pos/1): {recall_lang[1]:.4f}")
        print(f"  F1        (Pos/1): {f1_pos_lang:.4f}")
        print(f"  Accuracy:          {accuracy_lang:.4f}")
        wandb_logs_eval[f"{wandb_key_prefix}/precision_pos"] = precision_lang[1]
        wandb_logs_eval[f"{wandb_key_prefix}/recall_pos"] = recall_lang[1]
        wandb_logs_eval[f"{wandb_key_prefix}/f1_pos"] = f1_pos_lang
        wandb_logs_eval[f"{wandb_key_prefix}/accuracy"] = accuracy_lang
    else:
        f1_macro_lang = np.mean(f1_lang) # F1 macro simple
        print(f"  F1-Macro:         {f1_macro_lang:.4f}")
        print(f"  Accuracy:         {accuracy_lang:.4f}")
        wandb_logs_eval[f"{wandb_key_prefix}/f1_macro"] = f1_macro_lang
        wandb_logs_eval[f"{wandb_key_prefix}/accuracy"] = accuracy_lang
        # Logguer F1 par classe si besoin
        for i, label in enumerate(metric_labels):
             wandb_logs_eval[f"{wandb_key_prefix}/f1_class_{label}"] = f1_lang[i]


# Calcul des métriques globales finales
cm_overall_final = confusion_matrix(true_labels, predicted_labels_final, labels=unique_labels_cls)
overall_accuracy_final = accuracy_score(true_labels, predicted_labels_final)
wandb_key_prefix_overall = "eval/overall" + ("/thresh" if best_threshold is not None else "/argmax")

print(f"\n--- Overall Evaluation Summary (Final Predictions) ---")
if num_distinct_labels == 2:
    # Assurer que cm a bien 4 éléments (cas binaire)
    if cm_overall_final.size == 4:
      tn, fp, fn, tp = cm_overall_final.ravel()
    else: # Gérer cas où une classe n'est pas prédite/présente dans l'éval
      tn, fp, fn, tp = 0, 0, 0, 0
      print("Warning: Confusion matrix size indicates potential missing classes in evaluation.")
      # Logique pour reconstruire TN/FP/FN/TP si nécessaire basée sur les labels uniques
      if 0 in unique_labels_cls and 1 in unique_labels_cls:
          tn = cm_overall_final[0, 0]
          fp = cm_overall_final[0, 1]
          fn = cm_overall_final[1, 0]
          tp = cm_overall_final[1, 1]

    overall_precision_pos = tp / (tp + fp) if (tp + fp) > 0 else 0
    overall_recall_pos = tp / (tp + fn) if (tp + fn) > 0 else 0
    overall_f1_pos = 2 * (overall_precision_pos * overall_recall_pos) / (overall_precision_pos + overall_recall_pos) if (overall_precision_pos + overall_recall_pos) > 0 else 0
    macro_f1_pos = np.mean(language_f1_scores_pos) if language_f1_scores_pos else 0 # Moyenne des F1-pos par langue

    print(f"Overall F1-score (Positive Class): {overall_f1_pos:.4f}  <-- Primary Metric")
    print(f"Macro F1-score (Pos Class / Lang): {macro_f1_pos:.4f}")
    print(f"Overall Precision (Positive Class):{overall_precision_pos:.4f}")
    print(f"Overall Recall (Positive Class):   {overall_recall_pos:.4f}")
    print(f"Overall Accuracy:                  {overall_accuracy_final:.4f}")
    print("\nOverall Confusion Matrix (Final):")
    print(f"Predicted Labels: {unique_labels_cls}")
    print(f"True Labels")
    print(cm_overall_final)

    wandb_logs_eval[f"{wandb_key_prefix_overall}/f1_pos"] = overall_f1_pos
    wandb_logs_eval[f"{wandb_key_prefix_overall}/macro_f1_pos_lang"] = macro_f1_pos
    wandb_logs_eval[f"{wandb_key_prefix_overall}/precision_pos"] = overall_precision_pos
    wandb_logs_eval[f"{wandb_key_prefix_overall}/recall_pos"] = overall_recall_pos
    wandb_logs_eval[f"{wandb_key_prefix_overall}/accuracy"] = overall_accuracy_final
    if wandb.run: wandb_logs_eval[f"{wandb_key_prefix_overall}/conf_matrix"] = wandb.Table(data=cm_overall_final.tolist(), columns=[f"Pred_{l}" for l in unique_labels_cls], rows=[f"True_{l}" for l in unique_labels_cls])

else: # Métriques globales pour multiclasse
    overall_prec_macro, overall_recall_macro, overall_f1_macro, _ = precision_recall_fscore_support(true_labels, predicted_labels_final, average='macro', labels=unique_labels_cls, zero_division=0)
    overall_prec_weighted, overall_recall_weighted, overall_f1_weighted, _ = precision_recall_fscore_support(true_labels, predicted_labels_final, average='weighted', labels=unique_labels_cls, zero_division=0)
    print(f"Overall Accuracy:     {overall_accuracy_final:.4f}")
    print(f"Overall F1 (Macro):   {overall_f1_macro:.4f}")
    print(f"Overall F1 (Weighted):{overall_f1_weighted:.4f}")
    print("\nOverall Confusion Matrix (Final):")
    print(f"Predicted Labels: {unique_labels_cls}")
    print(f"True Labels")
    print(cm_overall_final)
    wandb_logs_eval[f"{wandb_key_prefix_overall}/accuracy"] = overall_accuracy_final
    wandb_logs_eval[f"{wandb_key_prefix_overall}/f1_macro"] = overall_f1_macro
    wandb_logs_eval[f"{wandb_key_prefix_overall}/f1_weighted"] = overall_f1_weighted
    if wandb.run: wandb_logs_eval[f"{wandb_key_prefix_overall}/conf_matrix"] = wandb.Table(data=cm_overall_final.tolist(), columns=[f"Pred_{l}" for l in unique_labels_cls], rows=[f"True_{l}" for l in unique_labels_cls])

# Log all detailed eval metrics
if wandb.run: wandb.log(wandb_logs_eval)

# --- Sauvegarde du fichier de soumission ---
print("\nSaving predictions for submission...")
os.makedirs(CLASSIFICATION_OUTPUT_DIR, exist_ok=True)
submission_df = dev_df_eval[['id', 'predicted_label']]
suffix = "simcse_finetuned" if DO_SIMCSE_PRETRAINING else "base_finetuned"
thresh_suffix = f"_thresh{best_threshold:.2f}" if best_threshold is not None else "_argmax"
csv_filename = f"predictions_task1_{suffix}{thresh_suffix}.csv"
zip_filename = f"submission_task1_{suffix}{thresh_suffix}.zip"
csv_path = os.path.join(CLASSIFICATION_OUTPUT_DIR, csv_filename)
zip_path = os.path.join(CLASSIFICATION_OUTPUT_DIR, zip_filename)

submission_df.to_csv(csv_path, index=False)
print(f"Predictions saved to {csv_path}")

try:
    with zipfile.ZipFile(zip_path, mode="w", compression=zipfile.ZIP_DEFLATED) as zf:
        zf.write(csv_path, arcname=csv_filename)
    print(f"{csv_filename} has been zipped into {zip_path}")
except Exception as e:
    print(f"Error zipping the file: {e}")


if wandb.run is not None and wandb.run.step > 0: # Check if wandb was used and logged something
    wandb.finish()
print("\nScript finished.")


--- Initial Data Loading and Cleaning ---
Loaded 33482 train and 4625 dev examples.
Text cleaning complete.

--- Phase 1: Starting SimCSE Pre-training ---


wandb: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
wandb: Currently logged in as: caron-olivier-80 (caron-olivier-80-universit-paris-dauphine-psl) to https://api.wandb.ai. Use `wandb login --relogin` to force relogin


Created SimCSE dataset with 38107 examples.
Loading tokenizer xlm-roberta-base for SimCSE phase...


Map:   0%|          | 0/38107 [00:00<?, ? examples/s]

SimCSE dataset tokenized.
SimCSE base model loaded.


  super().__init__(*args, **kwargs)


Starting SimCSE training...




Step,Training Loss
50,2.5365
100,0.7395
150,0.1314
200,0.0325
250,0.0089
300,0.0153
350,0.0128
400,0.0011
450,0.0015
500,0.0012


SimCSE training finished.
SimCSE pre-trained model and tokenizer saved to results_augmented_simcse_xlmr_base


0,1
train/epoch,▁▁▁▁▁▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▆▇▇▇▇███
train/global_step,▁▁▁▂▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▇▇▇▇███
train/grad_norm,█▅▆▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train/learning_rate,▂███▇▇▇▇▆▆▆▆▆▅▅▅▅▅▅▅▄▄▄▄▄▄▄▄▃▃▃▃▃▃▂▂▂▂▁▁
train/loss,█▃▁▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
total_flos,5013096470544384.0
train/epoch,1.0
train/global_step,4764.0
train/grad_norm,0.00779
train/learning_rate,0.0
train/loss,0.0
train_loss,0.03722
train_runtime,1375.759
train_samples_per_second,27.699
train_steps_per_second,3.463


--- Phase 1: SimCSE Pre-training Complete ---

--- Phase 2: Starting Classification Fine-tuning ---


Classification datasets created.
Loading tokenizer for classification phase from: results_augmented_simcse_xlmr_base


Map:   0%|          | 0/33482 [00:00<?, ? examples/s]

Map:   0%|          | 0/4625 [00:00<?, ? examples/s]

Classification datasets tokenized.
Computing class weights for classification...
Detected 2 distinct labels in training data: [0 1]
Class Weights (for classes [0 1]): [0.5825793 3.5273914]
Loading model for classification from: results_augmented_simcse_xlmr_base


Some weights of XLMRobertaForSequenceClassification were not initialized from the model checkpoint at results_augmented_simcse_xlmr_base and are newly initialized: ['classifier.dense.bias', 'classifier.dense.weight', 'classifier.out_proj.bias', 'classifier.out_proj.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Classification model loaded.


  super().__init__(*args, **kwargs)


Classification Trainer configured.
Starting classification fine-tuning...


Epoch,Training Loss,Validation Loss,Accuracy,F1 Pos,Precision Pos,Recall Pos,F1 Neg
0,0.2506,0.384392,0.850811,0.48737,0.345992,0.824121,0.912702
1,0.2133,0.435832,0.929946,0.625866,0.57906,0.680905,0.961355
2,0.1212,0.356035,0.940108,0.666667,0.639723,0.69598,0.967098
3,0.0798,0.508999,0.939676,0.654275,0.645477,0.663317,0.966955
4,0.0653,0.521542,0.932108,0.647982,0.58502,0.726131,0.962431
5,0.0582,0.630815,0.946162,0.690683,0.683047,0.698492,0.970515
6,0.0256,0.778755,0.942703,0.679565,0.655012,0.70603,0.968539
7,0.0147,0.827851,0.940541,0.671446,0.640091,0.70603,0.967312


Classification fine-tuning finished.
Best classification model and tokenizer saved to results_augmented_data_classifier_finetuned_on_simcse\best_model

--- Detailed Evaluation and Submission File Generation ---

Generating predictions for Threshold Adjustment and Detailed Metrics...



Finding best threshold on validation set based on Overall F1-Positive...

Best threshold found: 0.52 with Overall F1-Pos: 0.6933

--- Detailed Evaluation on Development Set (Final Predictions) ---
--- (Using Threshold = 0.52) ---

Metrics for language: DE (Support: {0: 599, 1: 35})
  Precision (Pos/1): 0.5227
  Recall    (Pos/1): 0.6571
  F1        (Pos/1): 0.5823
  Accuracy:          0.9479

Metrics for language: EN (Support: {0: 841, 1: 61})
  Precision (Pos/1): 0.8065
  Recall    (Pos/1): 0.8197
  F1        (Pos/1): 0.8130
  Accuracy:          0.9745

Metrics for language: FR (Support: {0: 389, 1: 30})
  Precision (Pos/1): 0.4898
  Recall    (Pos/1): 0.8000
  F1        (Pos/1): 0.6076
  Accuracy:          0.9260

Metrics for language: RU (Support: {0: 2398, 1: 272})
  Precision (Pos/1): 0.7269
  Recall    (Pos/1): 0.6654
  F1        (Pos/1): 0.6948
  Accuracy:          0.9404

--- Overall Evaluation Summary (Final Predictions) ---
Overall F1-score (Positive Class): 0.6933  <-- Prim

0,1
eval/accuracy,▁▇██▇███
eval/best_threshold,▁
eval/best_val_f1_at_threshold,▁
eval/de/thresh/accuracy,▁
eval/de/thresh/f1_pos,▁
eval/de/thresh/precision_pos,▁
eval/de/thresh/recall_pos,▁
eval/en/thresh/accuracy,▁
eval/en/thresh/f1_pos,▁
eval/en/thresh/precision_pos,▁

0,1
eval/accuracy,0.94054
eval/best_threshold,0.52
eval/best_val_f1_at_threshold,0.69327
eval/de/thresh/accuracy,0.94795
eval/de/thresh/f1_pos,0.58228
eval/de/thresh/precision_pos,0.52273
eval/de/thresh/recall_pos,0.65714
eval/en/thresh/accuracy,0.9745
eval/en/thresh/f1_pos,0.81301
eval/en/thresh/precision_pos,0.80645



Script finished.


## NO SimSCE

In [1]:
import pandas as pd
from datasets import Dataset, DatasetDict
from transformers import (
    AutoTokenizer,
    AutoModelForSequenceClassification,
    TrainingArguments,
    Trainer,
    EarlyStoppingCallback # Importer le callback pour l'arrêt précoce
)
# Assurez-vous que wandb est initialisé si vous l'utilisez, sinon commentez/supprimez les lignes wandb.log
import wandb
wandb.init(project="ade-classification-xlmr") # Exemple d'initialisation

# CORRECTION ICI: Ajout de precision_recall_fscore_support
from sklearn.metrics import confusion_matrix, precision_score, recall_score, f1_score, accuracy_score, precision_recall_fscore_support
from sklearn.utils.class_weight import compute_class_weight
import numpy as np
import torch
import zipfile # Pour la sauvegarde finale
import os # Pour la sauvegarde finale

# --- Configuration ---
MODEL_NAME = "xlm-roberta-base"
TRAIN_CSV = "data/train_data_augmented_no_text_clean.csv" # Nouveau fichier d'entraînement
DEV_CSV = "data/dev_data_SMM4H_2025_Task_1.csv"
OUTPUT_DIR = "results_xlmr_augmented_longer_train" # Nouveau répertoire de sortie
NUM_EPOCHS = 8 # Augmenté le nombre d'époques
BATCH_SIZE = 8 # Gardé petit pour la mémoire GPU
LEARNING_RATE = 2e-5
GRADIENT_ACCUMULATION_STEPS = 4 # Accumuler les gradients (batch effectif = 8*4=32)
EARLY_STOPPING_PATIENCE = 3 # Patience pour l'arrêt précoce

# --- 1. Load Data ---
print("Loading data...")
try:
    train_df = pd.read_csv(TRAIN_CSV).dropna(subset=['text'])
    dev_df = pd.read_csv(DEV_CSV).dropna(subset=['text'])
except FileNotFoundError as e:
    print(f"Error loading CSV files: {e}")
    exit()

# Convert to Hugging Face Dataset format
train_dataset = Dataset.from_pandas(train_df)
dev_dataset = Dataset.from_pandas(dev_df)
dataset_dict = DatasetDict({'train': train_dataset, 'validation': dev_dataset})
print("Data loaded.")

# --- 2. Tokenization ---
print("Tokenizing data...")
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)

def tokenize_function(examples):
    return tokenizer(examples["text"], padding="max_length", truncation=True, max_length=256)

tokenized_datasets = dataset_dict.map(tokenize_function, batched=True)
print("Tokenization complete.")

# --- Clean up columns ---
print("Cleaning dataset columns...")
print("Columns before removal:", tokenized_datasets['train'].column_names)
columns_to_remove = ["text", "id", "file_name", "origin", "language", "split", "type"]
actual_columns_to_remove = [col for col in columns_to_remove if col in tokenized_datasets['train'].column_names]
print("Removing columns:", actual_columns_to_remove)
tokenized_datasets = tokenized_datasets.remove_columns(actual_columns_to_remove)
tokenized_datasets.set_format("torch")

# Rename 'label' to 'labels'
tokenized_datasets = tokenized_datasets.rename_column("label", "labels")
print("Columns after cleaning and rename:", tokenized_datasets['train'].column_names)

# --- 3. Compute Class Weights ---
print("Computing class weights...")
labels_train = train_df['label'].values
if len(np.unique(labels_train)) > 1: # Ensure there are at least two classes
    class_weights = compute_class_weight(class_weight='balanced', classes=np.unique(labels_train), y=labels_train)
    class_weights_tensor = torch.tensor(class_weights, dtype=torch.float).to("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Class Weights: {class_weights_tensor}")
else:
    print("Warning: Only one class found in training data. Cannot compute class weights.")
    class_weights_tensor = None # Handle this case in the loss function if needed

# --- Custom Trainer for Weighted Loss ---
class WeightedTrainer(Trainer):
    def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
        labels = inputs.get("labels")
        outputs = model(**inputs)
        logits = outputs.get("logits")
        # Use weights only if they were computed
        if class_weights_tensor is not None:
            loss_fct = torch.nn.CrossEntropyLoss(weight=class_weights_tensor)
        else:
            loss_fct = torch.nn.CrossEntropyLoss() # Default unweighted loss
        loss = loss_fct(logits.view(-1, self.model.config.num_labels), labels.view(-1))
        return (loss, outputs) if return_outputs else loss

# --- 4. Model & Metrics ---
print("Loading model...")
model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME, num_labels=2)
print("Model loaded.")

# Fonction compute_metrics qui utilise la fonction importée
def compute_metrics(pred):
    labels = pred.label_ids
    preds = pred.predictions.argmax(-1)
    # Utilisation de precision_recall_fscore_support
    precision, recall, f1, _ = precision_recall_fscore_support(labels, preds, average=None, labels=[0, 1], zero_division=0)
    acc = accuracy_score(labels, preds)
    return {
        'accuracy': acc,
        'f1_pos': f1[1], # F1 for class 1
        'precision_pos': precision[1],
        'recall_pos': recall[1],
        'f1_neg': f1[0], # F1 for class 0 (for info)
    }

# --- 5. Training Arguments ---
print("Setting training arguments...")
training_args = TrainingArguments(
    output_dir=OUTPUT_DIR,
    num_train_epochs=NUM_EPOCHS,                 # Augmenté
    per_device_train_batch_size=BATCH_SIZE,
    per_device_eval_batch_size=BATCH_SIZE*2,     # Peut souvent être plus grand pour l'évaluation
    gradient_accumulation_steps=GRADIENT_ACCUMULATION_STEPS, # Ajouté
    learning_rate=LEARNING_RATE,
    weight_decay=0.01,
    evaluation_strategy="epoch",
    save_strategy="epoch",
    load_best_model_at_end=True,                # Important avec Early Stopping
    metric_for_best_model="f1_pos",             # Métrique à surveiller
    greater_is_better=True,                     # On veut maximiser le F1
    # early_stopping_patience=EARLY_STOPPING_PATIENCE, # Activé via Callback
    logging_dir='./logs',
    logging_steps=50,
    report_to="wandb" if "wandb" in locals() else "none", # Log to wandb if initialized
    fp16=torch.cuda.is_available(),
    save_total_limit=2, # Garde seulement les 2 meilleurs checkpoints + le dernier
)

# --- 6. Trainer ---
trainer = WeightedTrainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_datasets["train"],
    eval_dataset=tokenized_datasets["validation"],
    tokenizer=tokenizer,
    compute_metrics=compute_metrics,
    # Ajout du Callback pour Early Stopping
    callbacks=[EarlyStoppingCallback(early_stopping_patience=EARLY_STOPPING_PATIENCE)]
)
print("Trainer configured.")

# --- 7. Train ---
print("Starting Training...")
train_result = trainer.train()
print("Training finished.")

# Log training metrics
# trainer.log_metrics("train", train_result.metrics) # Décommenter si besoin
# trainer.save_metrics("train", train_result.metrics) # Décommenter si besoin
# trainer.save_state() # Sauvegarde l'état du Trainer

# --- 8. Evaluate on Dev Set (using the best model loaded) ---
print("\nEvaluating on Development Set (Best Model)...")
eval_results = trainer.evaluate()
print("Evaluation Results:")
print(eval_results)
# trainer.log_metrics("eval", eval_results) # Décommenter si besoin
# trainer.save_metrics("eval", eval_results) # Décommenter si besoin

# --- 9. Detailed Evaluation and Submission File Generation ---
print("\nGenerating predictions and detailed metrics for Dev Set...")

# Generate predictions
predictions = trainer.predict(tokenized_datasets["validation"])
predicted_labels = predictions.predictions.argmax(-1)

# Add predictions to the dev dataframe (ensure index alignment if necessary)
# If dev_df was filtered by dropna, indices might not match directly. Resetting index helps.
dev_df_eval = dev_df.reset_index(drop=True)
# Check lengths match before assigning
if len(dev_df_eval) == len(predicted_labels):
    dev_df_eval['predicted_label'] = predicted_labels
else:
    print(f"Error: Length mismatch! Dev DF has {len(dev_df_eval)} rows, Predictions have {len(predicted_labels)} entries.")
    # Handle error appropriately, maybe skip detailed eval or investigate dropna impact
    exit()


# Re-extract language (assuming ID format 'lang_...') - Be careful if IDs differ
dev_df_eval["language"] = dev_df_eval["id"].apply(lambda x: str(x).split("_")[0] if isinstance(x, str) and "_" in x else "unknown")

# --- Calculate Per-Language and Overall Metrics ---
languages = sorted(dev_df_eval['language'].unique())
per_language_metrics = {}
language_f1_scores = []
all_true_labels_eval = []
all_pred_labels_eval = []

print("\n--- Detailed Evaluation on Development Set ---")
wandb_logs = {} # Collect logs for wandb

for lang in languages:
    if lang == "unknown": continue # Skip if language couldn't be extracted
    lang_mask = dev_df_eval['language'] == lang
    y_true_lang = dev_df_eval.loc[lang_mask, 'label']
    y_pred_lang = dev_df_eval.loc[lang_mask, 'predicted_label']

    if len(y_true_lang) == 0: continue # Skip if no data for this language

    all_true_labels_eval.extend(y_true_lang.tolist())
    all_pred_labels_eval.extend(y_pred_lang.tolist())

    # Utilisation de precision_recall_fscore_support (qui est maintenant importé)
    precision, recall, f1, _ = precision_recall_fscore_support(y_true_lang, y_pred_lang, average=None, labels=[0, 1], zero_division=0)
    accuracy_lang = accuracy_score(y_true_lang, y_pred_lang) # Renommé pour éviter conflit avec la fonction accuracy_score

    per_language_metrics[lang] = {'precision': precision[1], 'recall': recall[1], 'f1': f1[1], 'accuracy': accuracy_lang}
    language_f1_scores.append(f1[1]) # On stocke le F1 de la classe positive (1)

    print(f"\nMetrics for language: {lang.upper()}")
    print(f"  Precision-{lang} (Pos): {precision[1]:.4f}")
    print(f"  Recall-{lang}    (Pos): {recall[1]:.4f}")
    print(f"  F1-{lang}        (Pos): {f1[1]:.4f}")
    print(f"  Accuracy-{lang}:        {accuracy_lang:.4f}")

    # Prepare logs for wandb
    wandb_logs[f"{lang}/precision_pos"] = precision[1]
    wandb_logs[f"{lang}/recall_pos"] = recall[1]
    wandb_logs[f"{lang}/f1_pos"] = f1[1]
    wandb_logs[f"{lang}/accuracy"] = accuracy_lang


# Calculate Overall Metrics (using the full dev set lists)
cm_overall = confusion_matrix(all_true_labels_eval, all_pred_labels_eval, labels=[0, 1])
tn, fp, fn, tp = cm_overall.ravel() if cm_overall.size == 4 else (0, 0, 0, 0) # Handle cases with missing classes

overall_precision_pos = tp / (tp + fp) if (tp + fp) > 0 else 0
overall_recall_pos = tp / (tp + fn) if (tp + fn) > 0 else 0
# Overall F1 (Primary Metric for Positive Class)
overall_f1_pos = 2 * (overall_precision_pos * overall_recall_pos) / (overall_precision_pos + overall_recall_pos) if (overall_precision_pos + overall_recall_pos) > 0 else 0
# Macro F1 (Average of per-language F1s for Positive Class)
macro_f1_pos = np.mean(language_f1_scores) if language_f1_scores else 0
# Overall Accuracy
overall_accuracy = accuracy_score(all_true_labels_eval, all_pred_labels_eval)

print("\n--- Overall Evaluation Summary (Positive Class Focus) ---")
print(f"F1-score across all languages (Positive Class): {overall_f1_pos:.4f}  <-- Primary Metric")
print(f"Macro F1-score across all languages (Pos Class):{macro_f1_pos:.4f}")
print(f"Overall Precision (Positive Class):             {overall_precision_pos:.4f}")
print(f"Overall Recall (Positive Class):                {overall_recall_pos:.4f}")
print(f"Overall Accuracy across all languages:          {overall_accuracy:.4f}")

print("\nOverall Confusion Matrix (All Languages):")
print(f"[[TN={tn}  FP={fp}]")
print(f" [FN={fn}  TP={tp}]]")

# Add overall metrics to wandb logs
wandb_logs["overall/f1_pos"] = overall_f1_pos
wandb_logs["overall/macro_f1_pos"] = macro_f1_pos
wandb_logs["overall/precision_pos"] = overall_precision_pos
wandb_logs["overall/recall_pos"] = overall_recall_pos
wandb_logs["overall/accuracy"] = overall_accuracy
wandb_logs["overall/TP"] = tp
wandb_logs["overall/FP"] = fp
wandb_logs["overall/FN"] = fn
wandb_logs["overall/TN"] = tn

# Log Confusion Matrix to wandb (optional)
if "wandb" in locals() and tp+fp+fn+tn > 0:
     try:
        wandb_logs["confusion_matrix"] = wandb.plot.confusion_matrix(
             probs=None,
             y_true=all_true_labels_eval,
             preds=all_pred_labels_eval,
             class_names=["Negative (0)", "Positive (1)"]
         )
     except Exception as e:
         print(f"Could not log confusion matrix to wandb: {e}")


# Log all collected metrics to wandb (if initialized)
if "wandb" in locals():
    try:
        wandb.log(wandb_logs)
        print("Metrics logged to WandB.")
    except Exception as e:
        print(f"Could not log metrics to wandb: {e}")


# --- 10. Save Submission File ---
print("\nSaving predictions for submission...")
# Ensure results directory exists
os.makedirs(OUTPUT_DIR, exist_ok=True)

# Prepare submission dataframe
submission_df = dev_df_eval[['id', 'predicted_label']]

# Define CSV and ZIP paths
csv_filename = "predictions_task1.csv"
zip_filename = "submission_task1.zip"
csv_path = os.path.join(OUTPUT_DIR, csv_filename)
zip_path = os.path.join(OUTPUT_DIR, zip_filename)

# Save CSV
submission_df.to_csv(csv_path, index=False)
print(f"Predictions saved to {csv_path}")

# Zip the CSV file
with zipfile.ZipFile(zip_path, mode="w", compression=zipfile.ZIP_DEFLATED) as zf:
    zf.write(csv_path, arcname=csv_filename)
print(f"{csv_filename} has been zipped into {zip_path}")


print("\nScript finished.")




wandb: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
wandb: Currently logged in as: caron-olivier-80 (caron-olivier-80-universit-paris-dauphine-psl) to https://api.wandb.ai. Use `wandb login --relogin` to force relogin


Loading data...
Data loaded.
Tokenizing data...


Map:   0%|          | 0/33482 [00:00<?, ? examples/s]

Map:   0%|          | 0/4625 [00:00<?, ? examples/s]

Tokenization complete.
Cleaning dataset columns...
Columns before removal: ['id', 'text', 'label', 'file_name', 'origin', 'type', 'language', 'split', 'input_ids', 'attention_mask']
Removing columns: ['text', 'id', 'file_name', 'origin', 'language', 'split', 'type']
Columns after cleaning and rename: ['labels', 'input_ids', 'attention_mask']
Computing class weights...
Class Weights: tensor([0.5826, 3.5274], device='cuda:0')
Loading model...


Some weights of XLMRobertaForSequenceClassification were not initialized from the model checkpoint at xlm-roberta-base and are newly initialized: ['classifier.dense.bias', 'classifier.dense.weight', 'classifier.out_proj.bias', 'classifier.out_proj.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Model loaded.
Setting training arguments...


  trainer = WeightedTrainer(


Trainer configured.
Starting Training...




Epoch,Training Loss,Validation Loss,Accuracy,F1 Pos,Precision Pos,Recall Pos,F1 Neg
1,0.3129,0.340643,0.917189,0.577729,0.514735,0.658291,0.954093
2,0.2455,0.422801,0.905514,0.588124,0.470588,0.78392,0.946636
3,0.1488,0.507462,0.94227,0.633745,0.697885,0.580402,0.968666
4,0.1074,0.477409,0.924757,0.640496,0.54386,0.778894,0.957981
5,0.0949,0.63399,0.940324,0.658416,0.64878,0.668342,0.967306
6,0.0413,0.87442,0.945946,0.671053,0.70442,0.640704,0.970554
7,0.0222,0.924574,0.941838,0.672351,0.652482,0.693467,0.968086


Training finished.

Evaluating on Development Set (Best Model)...


Evaluation Results:
{'eval_loss': 0.9245741367340088, 'eval_accuracy': 0.9418378378378378, 'eval_f1_pos': 0.6723507917174177, 'eval_precision_pos': 0.6524822695035462, 'eval_recall_pos': 0.6934673366834171, 'eval_f1_neg': 0.9680863684897378, 'eval_runtime': 10.1814, 'eval_samples_per_second': 454.259, 'eval_steps_per_second': 28.483, 'epoch': 7.992833253702819}

Generating predictions and detailed metrics for Dev Set...

--- Detailed Evaluation on Development Set ---

Metrics for language: DE
  Precision-de (Pos): 0.4889
  Recall-de    (Pos): 0.6286
  F1-de        (Pos): 0.5500
  Accuracy-de:        0.9432

Metrics for language: EN
  Precision-en (Pos): 0.7385
  Recall-en    (Pos): 0.7869
  F1-en        (Pos): 0.7619
  Accuracy-en:        0.9667

Metrics for language: FR
  Precision-fr (Pos): 0.4894
  Recall-fr    (Pos): 0.7667
  F1-fr        (Pos): 0.5974
  Accuracy-fr:        0.9260

Metrics for language: RU
  Precision-ru (Pos): 0.6880
  Recall-ru    (Pos): 0.6728
  F1-ru        (Po

## Augmented Data with LLM traduction + paraphrase + SIMSCE


In [1]:
import pandas as pd
from datasets import Dataset, DatasetDict, load_dataset
from transformers import (
    AutoTokenizer,
    AutoModel,
    AutoModelForSequenceClassification,
    TrainingArguments,
    Trainer,
    EarlyStoppingCallback
)
import wandb
import torch
import torch.nn.functional as F
from torch import nn
from sklearn.metrics import confusion_matrix, precision_score, recall_score, f1_score, accuracy_score, precision_recall_fscore_support
from sklearn.utils.class_weight import compute_class_weight
import numpy as np
import zipfile
import os
import re
from scipy.special import softmax
import math
import gc

# --- Configuration Générale ---
BASE_MODEL_NAME = "xlm-roberta-base"
TRAIN_CSV = "data/train_data_augmented_strict_prompt.csv" # Nouveau fichier d'entraînement
DEV_CSV = "data/dev_data_SMM4H_2025_Task_1.csv"
# ATTENTION: max_length=256 peut être gourmand en VRAM (8Go limite). Réduire à 192 ou 128 si OOM.
MAX_LENGTH = 196 # Longueur max pour le tokenization (réduit pour éviter OOM)  

# --- Configuration Phase 1: SimCSE Pré-entraînement ---
DO_SIMCSE_PRETRAINING = True
SIMCSE_OUTPUT_DIR = "results_LLMTRADPARAPHRASE_simcse_xlmr_base"
SIMCSE_NUM_EPOCHS = 1
# ATTENTION: batch_size=16 double en mémoire (2 passes SimCSE). Réduire à 12 ou 8 si OOM sur 8Go VRAM.
SIMCSE_BATCH_SIZE = 8
SIMCSE_LEARNING_RATE = 3e-5
SIMCSE_TEMP = 0.05
SIMCSE_LOGGING_STEPS = 50
# NOUVEAU: Option pour le pooling SimCSE
SIMCSE_USE_MEAN_POOLING = True # Mettre à False pour utiliser le CLS token pooling
# NOUVEAU: Ratio de warmup pour le learning rate
SIMCSE_WARMUP_RATIO = 0.06 # ~6% des steps totaux

# --- Configuration Phase 2: Classification Fine-tuning ---
CLASSIFICATION_OUTPUT_DIR = "results_augmented_data_classifier_finetuned_on_simcse"
CLASSIFICATION_NUM_EPOCHS = 8
# ATTENTION: batch_size=8 * grad_accum=4 => effectif 32. Peut être lourd. Réduire si OOM.
CLASSIFICATION_BATCH_SIZE = 4
CLASSIFICATION_GRAD_ACCUM_STEPS = 8
CLASSIFICATION_LEARNING_RATE = 2e-5
CLASSIFICATION_EARLY_STOPPING_PATIENCE = 3
CLASSIFICATION_LOGGING_STEPS = 50
# NOUVEAU: Ratio de warmup pour le learning rate
CLASSIFICATION_WARMUP_RATIO = 0.06 # ~6% des steps totaux

# --- Initialisation WandB Globale ---
WANDB_PROJECT_NAME = "ade-classification-LLMaugmented-simcse-finetune"

# --- Fonction de Nettoyage de Texte (inchangée) ---
def clean_text(text):
    if not isinstance(text, str): return ""
    text = re.sub(r'@[\w_]+', '[USER_MENTION]', text)
    text = text.replace('<user>', '[USER_MENTION]')
    text = text.replace('<tuser>', '[USER_MENTION]')
    text = text.replace('<url>', '[URL]')
    text = text.replace('<email>', '[EMAIL]')
    text = text.replace('HTTPURL________________', '[URL]')
    text = re.sub(r'\s+', ' ', text).strip()
    return text

# --- Fonction Mean Pooling (pour SimCSE optionnel) ---
def mean_pooling(hidden_state, attention_mask):
    """Applique le mean pooling en ignorant les tokens de padding."""
    input_mask_expanded = attention_mask.unsqueeze(-1).expand(hidden_state.size()).float()
    sum_embeddings = torch.sum(hidden_state * input_mask_expanded, 1)
    sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9) # Évite division par zéro
    return sum_embeddings / sum_mask

# --- 1. Chargement et Préparation Initiale des Données ---
print("--- Initial Data Loading and Cleaning ---")
try:
    train_df_full = pd.read_csv(TRAIN_CSV).dropna(subset=['text'])
    dev_df_full = pd.read_csv(DEV_CSV).dropna(subset=['text'])
    print(f"Loaded {len(train_df_full)} train and {len(dev_df_full)} dev examples.")
except FileNotFoundError as e:
    print(f"Error loading CSV files: {e}")
    exit()

train_df_full['text'] = train_df_full['text'].apply(clean_text)
dev_df_full['text'] = dev_df_full['text'].apply(clean_text)
print("Text cleaning complete.")

# --- PHASE 1: SIMCSE PRE-TRAINING ---

# Définition du Trainer Personnalisé pour SimCSE (modifié)
class SimCSETrainer(Trainer):
    def __init__(self, *args, temperature=0.05, use_mean_pooling=False, **kwargs):
        super().__init__(*args, **kwargs)
        self.temperature = temperature
        self.use_mean_pooling = use_mean_pooling # Stocker l'option de pooling

    def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
        input_ids = inputs.get("input_ids")
        attention_mask = inputs.get("attention_mask")

        # Deux passes avec dropout (implicite dans le modèle en mode train)
        outputs1 = model(input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=True)
        outputs2 = model(input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=True)

        # --- MODIFIÉ: Choix entre Mean Pooling et CLS Pooling ---
        if self.use_mean_pooling:
            pooler_output1 = mean_pooling(outputs1.last_hidden_state, attention_mask)
            pooler_output2 = mean_pooling(outputs2.last_hidden_state, attention_mask)
        else: # Utiliser CLS token
            pooler_output1 = outputs1.last_hidden_state[:, 0]
            pooler_output2 = outputs2.last_hidden_state[:, 0]
        # ------------------------------------------------------

        # Concaténer et normaliser les embeddings
        embeddings = torch.cat([pooler_output1, pooler_output2], dim=0)
        embeddings = F.normalize(embeddings, p=2, dim=1) # L2 Normalization

        # Calculer la similarité cosinus
        cos_sim = torch.mm(embeddings, embeddings.t()) # Shape: (2*batch_size, 2*batch_size)

        # Masquer la diagonale (chaque embedding avec lui-même)
        batch_size = pooler_output1.size(0)
        mask_diag = torch.eye(2 * batch_size, device=embeddings.device, dtype=torch.bool)
        cos_sim = cos_sim.masked_fill(mask_diag, -9e15) # Remplacer par un très petit nombre

        # Appliquer la température
        cos_sim = cos_sim / self.temperature

        # Créer les labels pour InfoNCE loss
        labels = torch.arange(batch_size, device=embeddings.device)
        labels_z1 = labels + batch_size # Indices des z2 correspondants
        labels_z2 = labels             # Indices des z1 correspondants

        # Extraire les logits pour chaque partie
        logits_z1 = cos_sim[:batch_size, :] # Logits pour les embeddings de la 1ère passe
        logits_z2 = cos_sim[batch_size:, :] # Logits pour les embeddings de la 2ème passe

        # Calculer la perte CrossEntropy
        loss_fct = nn.CrossEntropyLoss()
        loss_z1 = loss_fct(logits_z1, labels_z1)
        loss_z2 = loss_fct(logits_z2, labels_z2)

        # Perte finale = moyenne des deux
        loss = (loss_z1 + loss_z2) / 2

        return (loss, {"embeddings1": pooler_output1, "embeddings2": pooler_output2}) if return_outputs else loss

model_load_path = BASE_MODEL_NAME # Default path if SimCSE is skipped

if DO_SIMCSE_PRETRAINING:
    print("\n--- Phase 1: Starting SimCSE Pre-training ---")
    run_name_simcse = f"simcse_{'meanpool_' if SIMCSE_USE_MEAN_POOLING else 'cls_'}{BASE_MODEL_NAME}"
    try:
        wandb.init(project=WANDB_PROJECT_NAME, name=run_name_simcse, reinit=True)
        wandb.config.update({ # Log config SimCSE
            "simcse_model": BASE_MODEL_NAME,
            "simcse_epochs": SIMCSE_NUM_EPOCHS,
            "simcse_batch_size": SIMCSE_BATCH_SIZE,
            "simcse_lr": SIMCSE_LEARNING_RATE,
            "simcse_temp": SIMCSE_TEMP,
            "simcse_pooling": "mean" if SIMCSE_USE_MEAN_POOLING else "cls",
            "simcse_warmup_ratio": SIMCSE_WARMUP_RATIO,
            "max_length": MAX_LENGTH
        })
    except Exception as e:
        print(f"WandB initialization failed for SimCSE phase: {e}")
        print("Proceeding without WandB logging for this phase.")


    # Préparer le dataset SimCSE (train + dev)
    all_texts_df = pd.concat([train_df_full[['text']], dev_df_full[['text']]], ignore_index=True)
    simcse_dataset = Dataset.from_pandas(all_texts_df)
    print(f"Created SimCSE dataset with {len(simcse_dataset)} examples.")

    # Tokenizer (Charger ici pour la phase 1)
    print(f"Loading tokenizer {BASE_MODEL_NAME} for SimCSE phase...")
    tokenizer_simcse = AutoTokenizer.from_pretrained(BASE_MODEL_NAME) # Utiliser une variable spécifique

    def tokenize_simcse(examples):
        # Utiliser tokenizer_simcse défini dans cette portée
        return tokenizer_simcse(examples["text"], padding="max_length", truncation=True, max_length=MAX_LENGTH)

    tokenized_simcse_dataset = simcse_dataset.map(tokenize_simcse, batched=True, remove_columns=["text"], num_proc=1) # Utiliser plus de procs si possible
    tokenized_simcse_dataset.set_format("torch")
    print("SimCSE dataset tokenized.")

    # Charger le modèle AutoModel (sans tête)
    simcse_model = AutoModel.from_pretrained(BASE_MODEL_NAME)
    print("SimCSE base model loaded.")

    # Arguments d'entraînement SimCSE (avec warmup)
    simcse_training_args = TrainingArguments(
        output_dir=SIMCSE_OUTPUT_DIR,
        num_train_epochs=SIMCSE_NUM_EPOCHS,
        per_device_train_batch_size=SIMCSE_BATCH_SIZE,
        learning_rate=SIMCSE_LEARNING_RATE,
        weight_decay=0.01,
        logging_dir=f'{SIMCSE_OUTPUT_DIR}/logs',
        logging_steps=SIMCSE_LOGGING_STEPS,
        save_strategy="epoch",
        report_to="wandb" if wandb.run is not None else "none", # Conditionner le report
        fp16=torch.cuda.is_available(), # INDISPENSABLE sur 8Go VRAM
        warmup_ratio=SIMCSE_WARMUP_RATIO, # NOUVEAU: Ajout du warmup
    )

    # Instancier le SimCSE Trainer (avec l'option pooling)
    simcse_trainer = SimCSETrainer(
        model=simcse_model,
        args=simcse_training_args,
        train_dataset=tokenized_simcse_dataset,
        tokenizer=tokenizer_simcse, # Passer le tokenizer spécifique
        temperature=SIMCSE_TEMP,
        use_mean_pooling=SIMCSE_USE_MEAN_POOLING # NOUVEAU: Passer l'option
    )

    # Lancer l'entraînement
    print("Starting SimCSE training...")
    simcse_trainer.train()
    print("SimCSE training finished.")

    # Sauvegarder
    simcse_trainer.save_model(SIMCSE_OUTPUT_DIR)
    tokenizer_simcse.save_pretrained(SIMCSE_OUTPUT_DIR) # Sauver le tokenizer utilisé
    print(f"SimCSE pre-trained model and tokenizer saved to {SIMCSE_OUTPUT_DIR}")

    # Nettoyer
    model_load_path = SIMCSE_OUTPUT_DIR # Mettre à jour le chemin pour la phase 2
    del simcse_model, simcse_trainer, tokenized_simcse_dataset, simcse_dataset, all_texts_df, tokenizer_simcse
    gc.collect()
    torch.cuda.empty_cache()
    if wandb.run is not None: wandb.finish()
    print("--- Phase 1: SimCSE Pre-training Complete ---")

else:
    print("\n--- Phase 1: Skipping SimCSE Pre-training ---")
    # model_load_path reste BASE_MODEL_NAME (défini plus haut)


# --- PHASE 2: CLASSIFICATION FINE-TUNING ---
print("\n--- Phase 2: Starting Classification Fine-tuning ---")
run_name_classify = f"classify_ft_on_{'simcse' if DO_SIMCSE_PRETRAINING else 'base'}_{BASE_MODEL_NAME}"
try:
    wandb.init(project=WANDB_PROJECT_NAME, name=run_name_classify, reinit=True)
    wandb.config.update({ # Log config Classification
        "base_model_for_ft": model_load_path,
        "classify_epochs": CLASSIFICATION_NUM_EPOCHS,
        "classify_batch_size": CLASSIFICATION_BATCH_SIZE,
        "classify_grad_accum": CLASSIFICATION_GRAD_ACCUM_STEPS,
        "classify_effective_batch": CLASSIFICATION_BATCH_SIZE * CLASSIFICATION_GRAD_ACCUM_STEPS,
        "classify_lr": CLASSIFICATION_LEARNING_RATE,
        "classify_warmup_ratio": CLASSIFICATION_WARMUP_RATIO,
        "classify_early_stopping": CLASSIFICATION_EARLY_STOPPING_PATIENCE,
        "max_length": MAX_LENGTH
    })
except Exception as e:
        print(f"WandB initialization failed for Classification phase: {e}")
        print("Proceeding without WandB logging for this phase.")


# Préparer les datasets classification
train_dataset_cls = Dataset.from_pandas(train_df_full)
dev_dataset_cls = Dataset.from_pandas(dev_df_full)
dataset_dict_cls = DatasetDict({'train': train_dataset_cls, 'validation': dev_dataset_cls})
print("Classification datasets created.")

# --- CORRECTIF: Charger systématiquement le tokenizer pour la Phase 2 ---
# Charger le tokenizer correspondant au modèle que nous allons fine-tuner
# (soit celui de SimCSE si Phase 1 a tourné, soit celui du modèle de base)
print(f"Loading tokenizer for classification phase from: {model_load_path}")
tokenizer = AutoTokenizer.from_pretrained(model_load_path)
# ----------------------------------------------------------------------

def tokenize_classification(examples):
    # Utilise le 'tokenizer' défini juste au-dessus
    return tokenizer(examples["text"], padding="max_length", truncation=True, max_length=MAX_LENGTH)

tokenized_datasets_cls = dataset_dict_cls.map(tokenize_classification, batched=True, remove_columns=["text", "id", "file_name", "origin", "language", "split", "type"], num_proc=1)
tokenized_datasets_cls.set_format("torch")
tokenized_datasets_cls = tokenized_datasets_cls.rename_column("label", "labels")
print("Classification datasets tokenized.")

# Calculer poids de classe
print("Computing class weights for classification...")
labels_train_cls = train_df_full['label'].values
class_weights_tensor_cls = None
unique_labels_cls = np.unique(labels_train_cls)
num_distinct_labels = len(unique_labels_cls)
print(f"Detected {num_distinct_labels} distinct labels in training data: {unique_labels_cls}")

if num_distinct_labels > 1:
    class_weights_cls = compute_class_weight(class_weight='balanced', classes=unique_labels_cls, y=labels_train_cls)
    # Ensure weights are ordered according to label index (0, 1, ...)
    ordered_weights_dict = {label: weight for label, weight in zip(unique_labels_cls, class_weights_cls)}
    # Utiliser num_distinct_labels pour déterminer la taille du tenseur
    ordered_weights_cls = np.array([ordered_weights_dict.get(i, 0) for i in unique_labels_cls]) # Assigner poids aux labels existants

    class_weights_tensor_cls = torch.tensor(ordered_weights_cls, dtype=torch.float).to("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Class Weights (for classes {unique_labels_cls}): {class_weights_tensor_cls.cpu().numpy()}")
    if wandb.run: wandb.config.update({"class_weights": class_weights_tensor_cls.cpu().numpy().tolist()})
else:
    print("Warning: Only one class found in training data. Cannot compute class weights.")


# Trainer Personnalisé Classification avec Poids
class WeightedClassificationTrainer(Trainer):
    def __init__(self, *args, class_weights=None, **kwargs):
         super().__init__(*args, **kwargs)
         # Déplacer les poids sur le bon device une seule fois si possible
         self.class_weights = class_weights.to(self.args.device) if class_weights is not None else None

    def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
        labels = inputs.get("labels")
        outputs = model(**inputs)
        logits = outputs.get("logits")

        # Utiliser les poids stockés et déjà sur le bon device
        if self.class_weights is not None:
            loss_fct = torch.nn.CrossEntropyLoss(weight=self.class_weights)
        else:
            loss_fct = torch.nn.CrossEntropyLoss() # No weights

        loss = loss_fct(logits.view(-1, self.model.config.num_labels), labels.view(-1))
        return (loss, outputs) if return_outputs else loss

# Charger le modèle pour Classification
print(f"Loading model for classification from: {model_load_path}")
classification_model = AutoModelForSequenceClassification.from_pretrained(
    model_load_path,
    num_labels=num_distinct_labels, # Utiliser le nombre détecté
    ignore_mismatched_sizes=True # Crucial si chargement depuis AutoModel (SimCSE)
)
print("Classification model loaded.")

# Fonction compute_metrics
def compute_metrics_cls(pred):
    labels = pred.label_ids
    preds = pred.predictions.argmax(-1)
    metric_labels = unique_labels_cls # Utiliser les labels détectés
    if num_distinct_labels == 2:
        # Calcul spécifique pour binaire (Pos = 1)
        precision, recall, f1, _ = precision_recall_fscore_support(labels, preds, average=None, labels=[0, 1], zero_division=0)
        acc = accuracy_score(labels, preds)
        metrics = {
            'accuracy': acc,
            'f1_pos': f1[1] if len(f1) > 1 else 0,
            'precision_pos': precision[1] if len(precision) > 1 else 0,
            'recall_pos': recall[1] if len(recall) > 1 else 0,
            'f1_neg': f1[0] if len(f1) > 0 else 0,
        }
    else:
        # Calcul macro/weighted pour multiclasse
        precision_macro, recall_macro, f1_macro, _ = precision_recall_fscore_support(labels, preds, average='macro', labels=metric_labels, zero_division=0)
        precision_weighted, recall_weighted, f1_weighted, _ = precision_recall_fscore_support(labels, preds, average='weighted', labels=metric_labels, zero_division=0)
        acc = accuracy_score(labels, preds)
        metrics = {
            'accuracy': acc,
            'f1_macro': f1_macro,
            'precision_macro': precision_macro,
            'recall_macro': recall_macro,
            'f1_weighted': f1_weighted,
        }
        # Optionnel: ajouter f1 par classe si besoin
        # _, _, f1_per_class, _ = precision_recall_fscore_support(labels, preds, average=None, labels=metric_labels, zero_division=0)
        # for i, label in enumerate(metric_labels):
        #     metrics[f'f1_class_{label}'] = f1_per_class[i]

    return metrics

# Arguments d'entraînement Classification (avec warmup)
classification_training_args = TrainingArguments(
    output_dir=CLASSIFICATION_OUTPUT_DIR,
    num_train_epochs=CLASSIFICATION_NUM_EPOCHS,
    per_device_train_batch_size=CLASSIFICATION_BATCH_SIZE,
    per_device_eval_batch_size=CLASSIFICATION_BATCH_SIZE * 2,
    gradient_accumulation_steps=CLASSIFICATION_GRAD_ACCUM_STEPS,
    learning_rate=CLASSIFICATION_LEARNING_RATE,
    weight_decay=0.01,
    evaluation_strategy="epoch",
    save_strategy="epoch",
    load_best_model_at_end=True,
    # Choisir la métrique pertinente pour load_best_model_at_end
    metric_for_best_model="f1_pos" if num_distinct_labels == 2 else "f1_macro",
    greater_is_better=True,
    logging_dir=f'{CLASSIFICATION_OUTPUT_DIR}/logs',
    logging_steps=CLASSIFICATION_LOGGING_STEPS,
    report_to="wandb" if wandb.run is not None else "none", # Conditionner le report
    fp16=torch.cuda.is_available(), # INDISPENSABLE sur 8Go VRAM
    warmup_ratio=CLASSIFICATION_WARMUP_RATIO,
    save_total_limit=2,
)

# Instancier le Trainer Classification (passer les poids ici)
classification_trainer = WeightedClassificationTrainer(
    model=classification_model,
    args=classification_training_args,
    train_dataset=tokenized_datasets_cls["train"],
    eval_dataset=tokenized_datasets_cls["validation"],
    tokenizer=tokenizer, # Utiliser le tokenizer chargé pour la phase 2
    compute_metrics=compute_metrics_cls,
    callbacks=[EarlyStoppingCallback(early_stopping_patience=CLASSIFICATION_EARLY_STOPPING_PATIENCE)],
    class_weights=class_weights_tensor_cls # Passer le tenseur de poids
)
print("Classification Trainer configured.")

# Lancer le fine-tuning
print("Starting classification fine-tuning...")
classification_trainer.train()
print("Classification fine-tuning finished.")

# Sauvegarder le meilleur modèle explicitement
best_model_path = os.path.join(CLASSIFICATION_OUTPUT_DIR, "best_model")
classification_trainer.save_model(best_model_path)
tokenizer.save_pretrained(best_model_path) # Sauver le tokenizer avec le meilleur modèle
print(f"Best classification model and tokenizer saved to {best_model_path}")

# --- Évaluation Détaillée et Soumission (Utilise le meilleur modèle chargé) ---
print("\n--- Detailed Evaluation and Submission File Generation ---")
print("\nGenerating predictions for Threshold Adjustment and Detailed Metrics...")

predictions_output = classification_trainer.predict(tokenized_datasets_cls["validation"])
logits = predictions_output.predictions
true_labels = predictions_output.label_ids

if logits.shape[-1] != num_distinct_labels:
    print(f"Error: Logits shape {logits.shape} unexpected for {num_distinct_labels} labels.")
    exit()

probabilities = None
predicted_labels_final = None
best_threshold = None

if num_distinct_labels == 2:
    probabilities = softmax(logits, axis=-1)[:, 1] # Proba classe positive (index 1)
    print("\nFinding best threshold on validation set based on Overall F1-Positive...")
    best_f1 = -1
    best_threshold = 0.5 # Default
    thresholds = np.arange(0.1, 0.91, 0.01)
    f1_scores_thresh = []
    for threshold in thresholds:
        predicted_labels_thresh = (probabilities >= threshold).astype(int)
        precision_thresh, recall_thresh, f1_thresh, _ = precision_recall_fscore_support(
            true_labels, predicted_labels_thresh, average='binary', pos_label=1, zero_division=0)
        f1_scores_thresh.append(f1_thresh)
        if f1_thresh > best_f1:
            best_f1 = f1_thresh
            best_threshold = threshold

    print(f"\nBest threshold found: {best_threshold:.2f} with Overall F1-Pos: {best_f1:.4f}")
    if wandb.run: wandb.log({"eval/best_threshold": best_threshold, "eval/best_val_f1_at_threshold": best_f1})
    predicted_labels_final = (probabilities >= best_threshold).astype(int)
else:
    print("Multi-class classification detected (>2). Using argmax for final predictions.")
    predicted_labels_final = logits.argmax(-1)
    # best_threshold reste None

# Préparer le DataFrame pour l'évaluation détaillée
dev_df_eval = dev_df_full.reset_index(drop=True)
if len(dev_df_eval) == len(predicted_labels_final):
    dev_df_eval['predicted_label'] = predicted_labels_final
    if probabilities is not None:
         dev_df_eval['probability_positive'] = probabilities
else:
    print(f"Error: Length mismatch between dev_df {len(dev_df_eval)} and predictions {len(predicted_labels_final)}!")
    exit()

dev_df_eval["language"] = dev_df_full["id"].apply(lambda x: str(x).split("_")[0] if isinstance(x, str) and "_" in x else "unknown")

languages = sorted(dev_df_eval['language'].unique())
language_f1_scores_pos = [] # Pour Macro F1 binaire
wandb_logs_eval = {}
print(f"\n--- Detailed Evaluation on Development Set (Final Predictions) ---")
if best_threshold is not None:
     print(f"--- (Using Threshold = {best_threshold:.2f}) ---")

for lang in languages:
    if lang == "unknown": continue
    lang_mask = dev_df_eval['language'] == lang
    y_true_lang = dev_df_eval.loc[lang_mask, 'label'].tolist()
    y_pred_lang_final = dev_df_eval.loc[lang_mask, 'predicted_label'].tolist()
    if len(y_true_lang) == 0: continue

    # Utiliser les labels détectés pour le calcul des métriques
    metric_labels = unique_labels_cls
    precision_lang, recall_lang, f1_lang, support_lang = precision_recall_fscore_support(
        y_true_lang, y_pred_lang_final, average=None, labels=metric_labels, zero_division=0)
    accuracy_lang = accuracy_score(y_true_lang, y_pred_lang_final)

    print(f"\nMetrics for language: {lang.upper()} (Support: {dict(zip(metric_labels, support_lang))})")
    # Clé WandB dynamique basée sur seuil/argmax
    wandb_key_prefix = f"eval/{lang}" + ("/thresh" if best_threshold is not None else "/argmax")

    if num_distinct_labels == 2:
        f1_pos_lang = f1_lang[1] # Index 1 correspond au label 1 (Positif)
        language_f1_scores_pos.append(f1_pos_lang)
        print(f"  Precision (Pos/1): {precision_lang[1]:.4f}")
        print(f"  Recall    (Pos/1): {recall_lang[1]:.4f}")
        print(f"  F1        (Pos/1): {f1_pos_lang:.4f}")
        print(f"  Accuracy:          {accuracy_lang:.4f}")
        wandb_logs_eval[f"{wandb_key_prefix}/precision_pos"] = precision_lang[1]
        wandb_logs_eval[f"{wandb_key_prefix}/recall_pos"] = recall_lang[1]
        wandb_logs_eval[f"{wandb_key_prefix}/f1_pos"] = f1_pos_lang
        wandb_logs_eval[f"{wandb_key_prefix}/accuracy"] = accuracy_lang
    else:
        f1_macro_lang = np.mean(f1_lang) # F1 macro simple
        print(f"  F1-Macro:         {f1_macro_lang:.4f}")
        print(f"  Accuracy:         {accuracy_lang:.4f}")
        wandb_logs_eval[f"{wandb_key_prefix}/f1_macro"] = f1_macro_lang
        wandb_logs_eval[f"{wandb_key_prefix}/accuracy"] = accuracy_lang
        # Logguer F1 par classe si besoin
        for i, label in enumerate(metric_labels):
             wandb_logs_eval[f"{wandb_key_prefix}/f1_class_{label}"] = f1_lang[i]


# Calcul des métriques globales finales
cm_overall_final = confusion_matrix(true_labels, predicted_labels_final, labels=unique_labels_cls)
overall_accuracy_final = accuracy_score(true_labels, predicted_labels_final)
wandb_key_prefix_overall = "eval/overall" + ("/thresh" if best_threshold is not None else "/argmax")

print(f"\n--- Overall Evaluation Summary (Final Predictions) ---")
if num_distinct_labels == 2:
    # Assurer que cm a bien 4 éléments (cas binaire)
    if cm_overall_final.size == 4:
      tn, fp, fn, tp = cm_overall_final.ravel()
    else: # Gérer cas où une classe n'est pas prédite/présente dans l'éval
      tn, fp, fn, tp = 0, 0, 0, 0
      print("Warning: Confusion matrix size indicates potential missing classes in evaluation.")
      # Logique pour reconstruire TN/FP/FN/TP si nécessaire basée sur les labels uniques
      if 0 in unique_labels_cls and 1 in unique_labels_cls:
          tn = cm_overall_final[0, 0]
          fp = cm_overall_final[0, 1]
          fn = cm_overall_final[1, 0]
          tp = cm_overall_final[1, 1]

    overall_precision_pos = tp / (tp + fp) if (tp + fp) > 0 else 0
    overall_recall_pos = tp / (tp + fn) if (tp + fn) > 0 else 0
    overall_f1_pos = 2 * (overall_precision_pos * overall_recall_pos) / (overall_precision_pos + overall_recall_pos) if (overall_precision_pos + overall_recall_pos) > 0 else 0
    macro_f1_pos = np.mean(language_f1_scores_pos) if language_f1_scores_pos else 0 # Moyenne des F1-pos par langue

    print(f"Overall F1-score (Positive Class): {overall_f1_pos:.4f}  <-- Primary Metric")
    print(f"Macro F1-score (Pos Class / Lang): {macro_f1_pos:.4f}")
    print(f"Overall Precision (Positive Class):{overall_precision_pos:.4f}")
    print(f"Overall Recall (Positive Class):   {overall_recall_pos:.4f}")
    print(f"Overall Accuracy:                  {overall_accuracy_final:.4f}")
    print("\nOverall Confusion Matrix (Final):")
    print(f"Predicted Labels: {unique_labels_cls}")
    print(f"True Labels")
    print(cm_overall_final)

    wandb_logs_eval[f"{wandb_key_prefix_overall}/f1_pos"] = overall_f1_pos
    wandb_logs_eval[f"{wandb_key_prefix_overall}/macro_f1_pos_lang"] = macro_f1_pos
    wandb_logs_eval[f"{wandb_key_prefix_overall}/precision_pos"] = overall_precision_pos
    wandb_logs_eval[f"{wandb_key_prefix_overall}/recall_pos"] = overall_recall_pos
    wandb_logs_eval[f"{wandb_key_prefix_overall}/accuracy"] = overall_accuracy_final
    if wandb.run: wandb_logs_eval[f"{wandb_key_prefix_overall}/conf_matrix"] = wandb.Table(data=cm_overall_final.tolist(), columns=[f"Pred_{l}" for l in unique_labels_cls], rows=[f"True_{l}" for l in unique_labels_cls])

else: # Métriques globales pour multiclasse
    overall_prec_macro, overall_recall_macro, overall_f1_macro, _ = precision_recall_fscore_support(true_labels, predicted_labels_final, average='macro', labels=unique_labels_cls, zero_division=0)
    overall_prec_weighted, overall_recall_weighted, overall_f1_weighted, _ = precision_recall_fscore_support(true_labels, predicted_labels_final, average='weighted', labels=unique_labels_cls, zero_division=0)
    print(f"Overall Accuracy:     {overall_accuracy_final:.4f}")
    print(f"Overall F1 (Macro):   {overall_f1_macro:.4f}")
    print(f"Overall F1 (Weighted):{overall_f1_weighted:.4f}")
    print("\nOverall Confusion Matrix (Final):")
    print(f"Predicted Labels: {unique_labels_cls}")
    print(f"True Labels")
    print(cm_overall_final)
    wandb_logs_eval[f"{wandb_key_prefix_overall}/accuracy"] = overall_accuracy_final
    wandb_logs_eval[f"{wandb_key_prefix_overall}/f1_macro"] = overall_f1_macro
    wandb_logs_eval[f"{wandb_key_prefix_overall}/f1_weighted"] = overall_f1_weighted
    if wandb.run: wandb_logs_eval[f"{wandb_key_prefix_overall}/conf_matrix"] = wandb.Table(data=cm_overall_final.tolist(), columns=[f"Pred_{l}" for l in unique_labels_cls], rows=[f"True_{l}" for l in unique_labels_cls])

# Log all detailed eval metrics
if wandb.run: wandb.log(wandb_logs_eval)

# --- Sauvegarde du fichier de soumission ---
print("\nSaving predictions for submission...")
os.makedirs(CLASSIFICATION_OUTPUT_DIR, exist_ok=True)
submission_df = dev_df_eval[['id', 'predicted_label']]
suffix = "simcse_finetuned" if DO_SIMCSE_PRETRAINING else "base_finetuned"
thresh_suffix = f"_thresh{best_threshold:.2f}" if best_threshold is not None else "_argmax"
csv_filename = f"predictions_task1_{suffix}{thresh_suffix}.csv"
zip_filename = f"submission_task1_{suffix}{thresh_suffix}.zip"
csv_path = os.path.join(CLASSIFICATION_OUTPUT_DIR, csv_filename)
zip_path = os.path.join(CLASSIFICATION_OUTPUT_DIR, zip_filename)

submission_df.to_csv(csv_path, index=False)
print(f"Predictions saved to {csv_path}")

try:
    with zipfile.ZipFile(zip_path, mode="w", compression=zipfile.ZIP_DEFLATED) as zf:
        zf.write(csv_path, arcname=csv_filename)
    print(f"{csv_filename} has been zipped into {zip_path}")
except Exception as e:
    print(f"Error zipping the file: {e}")


if wandb.run is not None and wandb.run.step > 0: # Check if wandb was used and logged something
    wandb.finish()
print("\nScript finished.")


--- Initial Data Loading and Cleaning ---
Loaded 33482 train and 4625 dev examples.
Text cleaning complete.

--- Phase 1: Starting SimCSE Pre-training ---


wandb: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
wandb: Currently logged in as: caron-olivier-80 (caron-olivier-80-universit-paris-dauphine-psl) to https://api.wandb.ai. Use `wandb login --relogin` to force relogin


Created SimCSE dataset with 38107 examples.
Loading tokenizer xlm-roberta-base for SimCSE phase...


Map:   0%|          | 0/38107 [00:00<?, ? examples/s]

SimCSE dataset tokenized.
SimCSE base model loaded.


  super().__init__(*args, **kwargs)


Starting SimCSE training...




Step,Training Loss
50,2.5285
100,0.7105
150,0.1193
200,0.0314
250,0.0274
300,0.0147
350,0.0081
400,0.0053
450,0.0012
500,0.0004


SimCSE training finished.
SimCSE pre-trained model and tokenizer saved to results_LLMTRADPARAPHRASE_simcse_xlmr_base


0,1
train/epoch,▁▁▁▂▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▆▆▇▇▇▇▇▇█
train/global_step,▁▁▁▁▂▂▂▂▂▂▃▃▃▃▄▄▄▅▅▅▅▅▆▆▆▆▆▆▆▆▇▇▇▇▇▇▇███
train/grad_norm,▃█▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train/learning_rate,▂▃▅▆▇█████▇▇▇▇▇▆▆▆▆▅▅▅▅▅▄▄▄▄▃▃▃▃▂▂▂▂▁▁▁▁
train/loss,█▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
total_flos,3838151985260544.0
train/epoch,1.0
train/global_step,4764.0
train/grad_norm,0.01793
train/learning_rate,0.0
train/loss,0.0
train_loss,0.03691
train_runtime,974.7083
train_samples_per_second,39.096
train_steps_per_second,4.888


--- Phase 1: SimCSE Pre-training Complete ---

--- Phase 2: Starting Classification Fine-tuning ---


Classification datasets created.
Loading tokenizer for classification phase from: results_LLMTRADPARAPHRASE_simcse_xlmr_base


Map:   0%|          | 0/33482 [00:00<?, ? examples/s]

Map:   0%|          | 0/4625 [00:00<?, ? examples/s]

Classification datasets tokenized.
Computing class weights for classification...
Detected 2 distinct labels in training data: [0 1]
Class Weights (for classes [0 1]): [0.5825793 3.5273914]
Loading model for classification from: results_LLMTRADPARAPHRASE_simcse_xlmr_base


Some weights of XLMRobertaForSequenceClassification were not initialized from the model checkpoint at results_LLMTRADPARAPHRASE_simcse_xlmr_base and are newly initialized: ['classifier.dense.bias', 'classifier.dense.weight', 'classifier.out_proj.bias', 'classifier.out_proj.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Classification model loaded.


  super().__init__(*args, **kwargs)


Classification Trainer configured.
Starting classification fine-tuning...


Epoch,Training Loss,Validation Loss,Accuracy,F1 Pos,Precision Pos,Recall Pos,F1 Neg
0,0.2805,0.293364,0.914595,0.596527,0.502582,0.733668,0.952243
1,0.2009,0.484181,0.939459,0.559748,0.747899,0.447236,0.967495
2,0.1724,0.349853,0.932973,0.640371,0.594828,0.693467,0.963042
3,0.0869,0.675381,0.938162,0.619681,0.658192,0.585427,0.966345
4,0.0871,0.764224,0.940108,0.628188,0.674352,0.58794,0.967431
5,0.0485,0.905898,0.943351,0.633053,0.71519,0.567839,0.969306


Classification fine-tuning finished.
Best classification model and tokenizer saved to results_augmented_data_classifier_finetuned_on_simcse\best_model

--- Detailed Evaluation and Submission File Generation ---

Generating predictions for Threshold Adjustment and Detailed Metrics...



Finding best threshold on validation set based on Overall F1-Positive...

Best threshold found: 0.74 with Overall F1-Pos: 0.6591

--- Detailed Evaluation on Development Set (Final Predictions) ---
--- (Using Threshold = 0.74) ---

Metrics for language: DE (Support: {0: 599, 1: 35})
  Precision (Pos/1): 0.4043
  Recall    (Pos/1): 0.5429
  F1        (Pos/1): 0.4634
  Accuracy:          0.9306

Metrics for language: EN (Support: {0: 841, 1: 61})
  Precision (Pos/1): 0.7925
  Recall    (Pos/1): 0.6885
  F1        (Pos/1): 0.7368
  Accuracy:          0.9667

Metrics for language: FR (Support: {0: 389, 1: 30})
  Precision (Pos/1): 0.5000
  Recall    (Pos/1): 0.7000
  F1        (Pos/1): 0.5833
  Accuracy:          0.9284

Metrics for language: RU (Support: {0: 2398, 1: 272})
  Precision (Pos/1): 0.7016
  Recall    (Pos/1): 0.6654
  F1        (Pos/1): 0.6830
  Accuracy:          0.9371

--- Overall Evaluation Summary (Final Predictions) ---
Overall F1-score (Positive Class): 0.6591  <-- Prim

0,1
eval/accuracy,▁▇▅▇▇█
eval/best_threshold,▁
eval/best_val_f1_at_threshold,▁
eval/de/thresh/accuracy,▁
eval/de/thresh/f1_pos,▁
eval/de/thresh/precision_pos,▁
eval/de/thresh/recall_pos,▁
eval/en/thresh/accuracy,▁
eval/en/thresh/f1_pos,▁
eval/en/thresh/precision_pos,▁

0,1
eval/accuracy,0.94335
eval/best_threshold,0.74
eval/best_val_f1_at_threshold,0.65915
eval/de/thresh/accuracy,0.9306
eval/de/thresh/f1_pos,0.46341
eval/de/thresh/precision_pos,0.40426
eval/de/thresh/recall_pos,0.54286
eval/en/thresh/accuracy,0.96674
eval/en/thresh/f1_pos,0.73684
eval/en/thresh/precision_pos,0.79245



Script finished.
