In [None]:
from transformers import (
    AutoTokenizer, 
    AutoModelForTokenClassification, 
    Trainer, 
    TrainingArguments, 
    EarlyStoppingCallback
)
import numpy as np
import os
import torch
from sklearn.model_selection import train_test_split
from transformers import EvalPrediction
import json
from datetime import datetime
from typing import List, Dict, Tuple
from dataclasses import dataclass
from torch.utils.data import Dataset
from sklearn.metrics import precision_recall_fscore_support, accuracy_score, confusion_matrix
import numpy as np
from transformers import TrainerCallback
import os
import json
import numpy as np
from typing import Dict

class MetricsCallback(TrainerCallback):
    """
    Callback to save detailed metrics after each evaluation.
    """
    def __init__(self, output_dir: str, id_to_label: Dict[int, str]):
        self.output_dir = output_dir
        self.id_to_label = id_to_label
        self.best_metrics = None
        self.best_step = None
        
    def on_evaluate(self, args, state, control, metrics, **kwargs):
        """Called after each evaluation."""
        # Create epoch-specific directory
        epoch_dir = os.path.join(self.output_dir, f"epoch_{state.epoch}")
        os.makedirs(epoch_dir, exist_ok=True)
        
        # Save detailed metrics for this epoch
        detailed_metrics = {
            "epoch": state.epoch,
            "step": state.global_step,
            "overall_propaganda_metrics": {
                "precision": metrics.get("eval_propaganda_precision", 0),
                "recall": metrics.get("eval_propaganda_recall", 0),
                "f1": metrics.get("eval_propaganda_f1", 0)
            },
            "per_class_metrics": {}
        }
        
        # Add per-class metrics
        for i in range(len(self.id_to_label)):
            class_name = self.id_to_label[i]
            detailed_metrics["per_class_metrics"][class_name] = {
                "precision": metrics["eval_per_class_precision"][i],
                "recall": metrics["eval_per_class_recall"][i],
                "f1": metrics["eval_per_class_f1"][i],
                "support": metrics["eval_support"][i]
            }
        
        # Track best metrics
        current_f1 = metrics.get("eval_propaganda_f1", 0)
        if self.best_metrics is None or current_f1 > self.best_metrics["overall_propaganda_metrics"]["f1"]:
            self.best_metrics = detailed_metrics
            self.best_step = state.global_step
            
            # Save best metrics separately
            best_metrics_path = os.path.join(self.output_dir, "best_metrics.json")
            with open(best_metrics_path, 'w', encoding='utf8') as f:
                json.dump({
                    "best_step": self.best_step,
                    "metrics": self.best_metrics
                }, f, ensure_ascii=False, indent=2)
        
        # Save epoch metrics
        metrics_path = os.path.join(epoch_dir, 'detailed_metrics.json')
        with open(metrics_path, 'w', encoding='utf8') as f:
            json.dump(detailed_metrics, f, ensure_ascii=False, indent=2)
        
        # Save confusion matrix if available
        if "eval_confusion_matrix" in metrics:
            confusion_path = os.path.join(epoch_dir, 'confusion_matrix.npy')
            np.save(confusion_path, np.array(metrics["eval_confusion_matrix"]))
            
        # Print summary of propaganda metrics
        print("\nPropaganda Detection Metrics:")
        print(f"Precision: {metrics.get('eval_propaganda_precision', 0):.4f}")
        print(f"Recall: {metrics.get('eval_propaganda_recall', 0):.4f}")
        print(f"F1: {metrics.get('eval_propaganda_f1', 0):.4f}")
        
        # Print per-class F1 scores for non-O classes
        print("\nPer-class F1 scores (excluding 'O'):")
        for i in range(1, len(self.id_to_label)):  # Skip 'O' class
            class_name = self.id_to_label[i]
            f1_score = metrics["eval_per_class_f1"][i]
            support = metrics["eval_support"][i]
            if support > 0:  # Only show classes that appear in the evaluation set
                print(f"{class_name}: {f1_score:.4f} (support: {support})")
    
    def on_train_end(self, args, state, control, **kwargs):
        """Called at the end of training - print best results."""
        if self.best_metrics is not None:
            print("\nBest Model Performance:")
            print(f"Step: {self.best_step}")
            best_f1 = self.best_metrics["overall_propaganda_metrics"]["f1"]
            print(f"Best Propaganda F1: {best_f1:.4f}")

@dataclass
class TokenClassificationConfig:
    model_name: str = 'microsoft/mdeberta-v3-base'
    max_length: int = 512
    stride: int = 128
    num_labels: int = 24
    output_dir: str = '/home/lgiordano/LUCA/checkthat_GITHUB/models/sliding_window'
    
    def __post_init__(self):
        self.date_time = datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
        self.full_output_dir = os.path.join(self.output_dir, self.date_time)

class TokenClassificationDataset(Dataset):
    def __init__(self, tokenized_inputs: Dict[str, List], labels: List[List[int]]):
        self.encodings = tokenized_inputs
        self.labels = labels

    def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
        item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
        item["labels"] = torch.tensor(self.labels[idx])
        return item

    def __len__(self) -> int:
        return len(self.labels)

def encode_tags(tags: List[Dict], 
                token_offsets: List[Tuple[int, int]], 
                label_to_id: Dict[str, int], 
                max_length: int) -> List[int]:
    """Encode tags for token classification."""
    token_labels = ["O"] * len(token_offsets)
    
    for annotation in tags:
        label = annotation["tag"]
        start = annotation["start"]
        end = annotation["end"]
        
        # Find all tokens that overlap with the annotation span
        for idx, (token_start, token_end) in enumerate(token_offsets):
            if token_start < end and token_end > start:
                token_labels[idx] = label
    
    # Convert string labels to ids
    token_labels = [label_to_id[label] for label in token_labels]
    
    # Pad with -100 (ignored in loss calculation)
    if len(token_labels) < max_length:
        token_labels += [-100] * (max_length - len(token_labels))
    
    return token_labels

def preprocess_data(texts: List[str], 
                    annotations: List[Dict], 
                    tokenizer, 
                    config: TokenClassificationConfig,
                    label_to_id: Dict[str, int]) -> Tuple[Dict[str, List], List[List[int]]]:
    """Preprocess texts and annotations into model inputs."""
    input_ids = []
    attention_masks = []
    labels = []

    for text, tags in zip(texts, annotations):
        text_length = len(text)
        start_idx = 0

        while start_idx < text_length:
            # Tokenize text chunk with overlap
            encoded_chunk = tokenizer(
                text[start_idx:start_idx + config.max_length],
                padding="max_length",
                truncation=True,
                max_length=config.max_length,
                return_offsets_mapping=True
            )

            input_ids.append(encoded_chunk["input_ids"])
            attention_masks.append(encoded_chunk["attention_mask"])
            token_offsets = encoded_chunk.pop("offset_mapping")
            
            chunk_labels = encode_tags(
                tags, 
                token_offsets, 
                label_to_id, 
                config.max_length
            )
            labels.append(chunk_labels)

            start_idx += config.max_length - config.stride

    return {"input_ids": input_ids, "attention_mask": attention_masks}, labels

def compute_metrics(pred: EvalPrediction) -> Dict[str, float]:
    """
    Compute metrics for token classification, handling class imbalance 
    and reporting per-class metrics.
    """
    labels = pred.label_ids.flatten()
    preds = np.argmax(pred.predictions, axis=2).flatten()
    
    # Filter out padding tokens (-100)
    mask = labels != -100
    labels = labels[mask]
    preds = preds[mask]

    # Calculate metrics excluding the 'O' class for a more realistic evaluation
    non_o_mask = labels != 0  # Assuming 0 is the ID for 'O' class
    propaganda_labels = labels[non_o_mask]
    propaganda_preds = preds[non_o_mask]

    # Calculate metrics only for propaganda classes
    propaganda_precision, propaganda_recall, propaganda_f1, _ = precision_recall_fscore_support(
        propaganda_labels, 
        propaganda_preds, 
        average="micro",
        labels=list(range(1, 24))  # Exclude 'O' class
    )

    # Calculate per-class metrics
    per_class_precision, per_class_recall, per_class_f1, support = precision_recall_fscore_support(
        labels, 
        preds, 
        labels=list(range(24)),  # Include all classes
        zero_division=0
    )

    results = {
        "propaganda_precision": propaganda_precision,
        "propaganda_recall": propaganda_recall,
        "propaganda_f1": propaganda_f1,
        "per_class_precision": per_class_precision.tolist(),
        "per_class_recall": per_class_recall.tolist(),
        "per_class_f1": per_class_f1.tolist(),
        "support": support.tolist()
    }

    # Add confusion matrix
    confusion = confusion_matrix(
        labels, 
        preds, 
        labels=list(range(24))
    )
    results["confusion_matrix"] = confusion.tolist()

    return results

def save_detailed_metrics(metrics: Dict[str, float], 
                         output_dir: str,
                         id_to_label: Dict[int, str]):
    """Save detailed metrics with per-class breakdown."""
    os.makedirs(output_dir, exist_ok=True)
    
    # Create detailed report
    detailed_metrics = {
        "overall_propaganda_metrics": {
            "precision": metrics["propaganda_precision"],
            "recall": metrics["propaganda_recall"],
            "f1": metrics["propaganda_f1"]
        },
        "per_class_metrics": {}
    }

    # Add per-class metrics
    for i in range(len(id_to_label)):
        class_name = id_to_label[i]
        detailed_metrics["per_class_metrics"][class_name] = {
            "precision": metrics["per_class_precision"][i],
            "recall": metrics["per_class_recall"][i],
            "f1": metrics["per_class_f1"][i],
            "support": metrics["support"][i]
        }

    # Save metrics
    metrics_path = os.path.join(output_dir, 'detailed_results.json')
    with open(metrics_path, 'w', encoding='utf8') as f:
        json.dump(detailed_metrics, f, ensure_ascii=False, indent=2)

    # Save confusion matrix separately
    confusion_path = os.path.join(output_dir, 'confusion_matrix.npy')
    np.save(confusion_path, np.array(metrics["confusion_matrix"]))

def main():
    # Configuration
    config = TokenClassificationConfig()
    
    # Load data
    with open('/home/lgiordano/LUCA/checkthat_GITHUB/data/formatted/train.json', 'r', encoding='utf8') as f:
        dataset = json.load(f)

    texts = [item['text'] for item in dataset]
    annotations = [item['annotations'] for item in dataset]

    # Initialize model and tokenizer
    tokenizer = AutoTokenizer.from_pretrained(config.model_name)
    model = AutoModelForTokenClassification.from_pretrained(
        config.model_name, 
        num_labels=config.num_labels
    )

    # Prepare label mappings
    label_to_id = {
    "O": 0, "Appeal_to_Authority":1, "Appeal_to_Popularity":2,"Appeal_to_Values":3,"Appeal_to_Fear-Prejudice":4,"Flag_Waving":5,"Causal_Oversimplification":6,
               "False_Dilemma-No_Choice":7,"Consequential_Oversimplification":8,"Straw_Man":9,"Red_Herring":10,"Whataboutism":11,"Slogans":12,"Appeal_to_Time":13,
               "Conversation_Killer":14,"Loaded_Language":15,"Repetition":16,"Exaggeration-Minimisation":17,"Obfuscation-Vagueness-Confusion":18,"Name_Calling-Labeling":19,
               "Doubt":20,"Guilt_by_Association":21,"Appeal_to_Hypocrisy":22,"Questioning_the_Reputation":23
}
    
    id_to_label = {v: k for k, v in label_to_id.items()}


    # Split data
    texts_train, texts_val, annotations_train, annotations_val = train_test_split(
        texts, annotations, test_size=0.2, random_state=42
    )

    # Preprocess data
    train_inputs, train_labels = preprocess_data(
        texts_train, annotations_train, tokenizer, config, label_to_id
    )
    val_inputs, val_labels = preprocess_data(
        texts_val, annotations_val, tokenizer, config, label_to_id
    )

    # Create datasets
    train_dataset = TokenClassificationDataset(train_inputs, train_labels)
    val_dataset = TokenClassificationDataset(val_inputs, val_labels)

    # Training setup
    training_args = TrainingArguments(
        output_dir=config.full_output_dir,
        save_total_limit=2,
        save_strategy='epoch',
        load_best_model_at_end=True,
        save_only_model=True,
        metric_for_best_model='propaganda_f1',
        logging_strategy='epoch',
        evaluation_strategy="epoch",
        learning_rate=5e-5,
        optim='adamw_torch',
        num_train_epochs=10,
        weight_decay=0.01,
    )

    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=val_dataset,
        callbacks=[
            EarlyStoppingCallback(early_stopping_patience=2),
            MetricsCallback(output_dir=config.full_output_dir, id_to_label=id_to_label)
        ],
        compute_metrics=compute_metrics
    )

    # Train and save metrics
    trainer.train()
    
if __name__ == "__main__":
    main()