In [None]:
# Enhanced Flan-T5 Training for Misinformation Detection
# Optimized for M3 Max in notebook environment
# Import the dataset class from the separate file
from dataset_utils import FeverousDataset, create_balanced_dataset
import json
import os
import numpy as np
import random
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt
from IPython.display import display, HTML

import torch
from torch.utils.data import DataLoader
from transformers import (
    AutoModelForSeq2SeqLM,
    AutoTokenizer,
    Seq2SeqTrainingArguments,
    Seq2SeqTrainer,
    DataCollatorForSeq2Seq,
)

# Define the oversample_minority_classes function directly in the script
def oversample_minority_classes(dataset, seed=42):
    """
    Oversample minority classes to match the majority class
    while keeping all original data
    """
    random.seed(seed)
    print("Oversampling minority classes...")

    # Group examples by verdict
    by_verdict = {"SUPPORTS": [], "REFUTES": [], "NOT ENOUGH INFO": []}
    for example in dataset.data:
        verdict = example["verdict"]
        if verdict in by_verdict:
            by_verdict[verdict].append(example)

    # Print class distribution before oversampling
    print("Class distribution before oversampling:")
    for verdict, examples in by_verdict.items():
        print(f"  {verdict}: {len(examples)} examples")

    # Find the majority class count
    max_count = max(len(examples) for examples in by_verdict.values())

    # Oversample minority classes
    balanced = []
    for verdict, examples in by_verdict.items():
        if len(examples) < max_count:
            # Oversample with replacement to match the majority class
            additional = random.choices(examples, k=max_count-len(examples))
            balanced.extend(examples + additional)
        else:
            balanced.extend(examples)

    # Shuffle the final balanced dataset
    random.shuffle(balanced)

    # Update the dataset
    dataset.data = balanced
    print(f"Created oversampled dataset with {len(balanced)} examples")

    return dataset

# Check if MPS (Metal Performance Shaders) is available for M3 Max
device = "mps" if torch.backends.mps.is_available() else "cpu"
print(f"Using device: {device}")

# Define metrics computation
def compute_metrics(eval_preds):
    preds, labels = eval_preds

    # Decode predictions and labels
    tokenizer = AutoTokenizer.from_pretrained(model_name)

    # Replace -100 with pad token id before decoding
    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)

    decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

    # Normalize predictions
    normalized_preds = []
    for pred in decoded_preds:
        pred = pred.strip().upper()
        if "SUPPORT" in pred:
            normalized_preds.append("SUPPORTED")
        elif "REFUT" in pred or "FALSE" in pred:
            normalized_preds.append("REFUTED")
        elif "NOT ENOUGH" in pred or "INSUFFICIENT" in pred or "UNKNOWN" in pred:
            normalized_preds.append("NOT ENOUGH INFORMATION")
        else:
            normalized_preds.append(pred)  # Keep as is if no match

    # Normalize labels
    normalized_labels = [label.strip().upper() for label in decoded_labels]

    # Calculate accuracy
    correct = sum(1 for p, l in zip(normalized_preds, normalized_labels) if p == l)
    accuracy = correct / len(normalized_preds)

    # Calculate per-class metrics
    classes = ["SUPPORTED", "REFUTED", "NOT ENOUGH INFORMATION"]
    per_class_metrics = {}

    for cls in classes:
        true_positives = sum(1 for p, l in zip(normalized_preds, normalized_labels) if p == cls and l == cls)
        false_positives = sum(1 for p, l in zip(normalized_preds, normalized_labels) if p == cls and l != cls)
        false_negatives = sum(1 for p, l in zip(normalized_preds, normalized_labels) if p != cls and l == cls)

        precision = true_positives / (true_positives + false_positives) if (true_positives + false_positives) > 0 else 0
        recall = true_positives / (true_positives + false_negatives) if (true_positives + false_negatives) > 0 else 0
        f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0

        per_class_metrics[cls] = {
            "precision": precision,
            "recall": recall,
            "f1": f1
        }

    # Calculate macro F1
    macro_f1 = sum(metrics["f1"] for metrics in per_class_metrics.values()) / len(per_class_metrics)

    return {
        "accuracy": accuracy,
        "macro_f1": macro_f1,
        **{f"{cls.lower()}_{metric}": value
           for cls, metrics in per_class_metrics.items()
           for metric, value in metrics.items()}
    }

# Function to visualize training results
def plot_training_results(trainer_history):
    train_loss = trainer_history.state.log_history

    # Extract training and evaluation metrics
    train_metrics = [x for x in train_loss if 'loss' in x and 'eval' not in x]
    eval_metrics = [x for x in train_loss if 'eval_loss' in x]

    # Plot training loss
    plt.figure(figsize=(15, 5))
    plt.subplot(1, 2, 1)
    plt.plot([x['step'] for x in train_metrics], [x['loss'] for x in train_metrics])
    plt.title('Training Loss')
    plt.xlabel('Step')
    plt.ylabel('Loss')

    # Plot evaluation metrics
    plt.subplot(1, 2, 2)
    plt.plot([x['step'] for x in eval_metrics], [x['eval_accuracy'] for x in eval_metrics], label='Accuracy')
    plt.plot([x['step'] for x in eval_metrics], [x['eval_macro_f1'] for x in eval_metrics], label='Macro F1')
    plt.title('Evaluation Metrics')
    plt.xlabel('Step')
    plt.ylabel('Score')
    plt.legend()

    plt.tight_layout()
    plt.show()

# Configure training parameters
model_name = "google/flan-t5-small"  # Options: google/flan-t5-small, google/flan-t5-base, google/flan-t5-large
data_dir = "feverous_prepared"  # Directory containing the prepared data
output_dir = "flan-t5-feverous"  # Output directory for model and checkpoints
max_length = 512  # Maximum sequence length

# Training hyperparameters optimized for M3 Max
batch_size = 16  # Adjust based on model size (16 for small, 8 for base, 4 for large)
gradient_accumulation_steps = 1  # Increase to 2 or 4 for larger models
learning_rate = 5e-5  # Learning rate
weight_decay = 0.01  # Weight decay for regularization
num_epochs = 3  # Number of training epochs
warmup_ratio = 0.1  # Ratio of total training steps to use for warmup

# Dataset options
max_train_samples = -1  # Use -1 for full dataset, or a smaller number for testing
max_eval_samples = -1  # Use -1 for full dataset, or a smaller number for testing
balance_dataset = True  # Whether to balance the dataset
max_per_class = 10000  # Maximum examples per class when balancing

# Performance and optimization
use_fp16 = False  # MPS on M3 Max doesn't support fp16 training
gradient_checkpointing = False  # Enable for larger models to save memory
num_workers = 4  # Number of workers for data loading
logging_steps = 100  # Number of steps between logging updates
seed = 42  # Random seed for reproducibility

# Set random seeds for reproducibility
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)

# Create output directory
os.makedirs(output_dir, exist_ok=True)

# Load tokenizer and model
print(f"Loading model: {model_name}")
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)

# Enable gradient checkpointing to save memory
if gradient_checkpointing:
    print("Enabling gradient checkpointing")
    model.gradient_checkpointing_enable()

# Prepare datasets
print(f"Loading datasets from {data_dir}")
train_dataset = FeverousDataset(
    os.path.join(data_dir, "feverous_train_prompts.json"),
    tokenizer,
    max_length=max_length
)

eval_dataset = FeverousDataset(
    os.path.join(data_dir, "feverous_dev_prompts.json"),
    tokenizer,
    max_length=max_length
)

print(f"Train dataset size: {len(train_dataset)}")
print(f"Eval dataset size: {len(eval_dataset)}")

# If using a subset for faster training/testing
if max_train_samples > 0:
    train_dataset.data = train_dataset.data[:max_train_samples]
    print(f"Using {len(train_dataset)} training examples")

if max_eval_samples > 0:
    eval_dataset.data = eval_dataset.data[:max_eval_samples]
    print(f"Using {len(eval_dataset)} evaluation examples")

# Apply oversampling to balance the dataset if requested
if balance_dataset:
    train_dataset = oversample_minority_classes(
        train_dataset,
        seed=seed
    )

# Data collator
data_collator = DataCollatorForSeq2Seq(
    tokenizer=tokenizer,
    model=model,
    padding="max_length",
    max_length=max_length
)

# Calculate training steps for warmup
num_update_steps_per_epoch = len(train_dataset) // (batch_size * gradient_accumulation_steps)
max_train_steps = num_epochs * num_update_steps_per_epoch

# Calculate warmup steps
warmup_steps = int(warmup_ratio * max_train_steps)
print(f"Warmup steps: {warmup_steps} of {max_train_steps} total steps")

# Training arguments
training_args = Seq2SeqTrainingArguments(
    output_dir=output_dir,
    evaluation_strategy="epoch",
    save_strategy="epoch",
    learning_rate=learning_rate,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    gradient_accumulation_steps=gradient_accumulation_steps,
    weight_decay=weight_decay,
    save_total_limit=3,
    num_train_epochs=num_epochs,
    predict_with_generate=True,
    generation_max_length=10,
    fp16=use_fp16,  # MPS doesn't support fp16
    load_best_model_at_end=True,
    metric_for_best_model="macro_f1",
    greater_is_better=True,
    report_to="tensorboard",
    warmup_steps=warmup_steps,
    logging_steps=logging_steps,
    logging_dir=os.path.join(output_dir, "logs"),
    seed=seed,
    dataloader_num_workers=num_workers,  # We can keep multiple workers now
)

# Initialize Trainer
trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
)

# Train the model
print("Starting training...")
train_result = trainer.train()

# Plot training results
plot_training_results(trainer)

# Save the final model
trainer.save_model(os.path.join(output_dir, "final_model"))
tokenizer.save_pretrained(os.path.join(output_dir, "final_model"))

# Evaluate the model
print("Evaluating final model...")
results = trainer.evaluate()
print(results)

# Save evaluation results
with open(os.path.join(output_dir, "eval_results.json"), "w") as f:
    json.dump(results, f, indent=4)

# Display a table of evaluation results
display(HTML("<h3>Evaluation Results</h3>"))
results_table = "<table><tr><th>Metric</th><th>Value</th></tr>"
for metric, value in results.items():
    results_table += f"<tr><td>{metric}</td><td>{value:.4f}</td></tr>"
results_table += "</table>"
display(HTML(results_table))

# Inference function for testing the model
def predict_verdict(claim, evidence, model_path=os.path.join(output_dir, "final_model")):
    # Load the fine-tuned model and tokenizer
    model = AutoModelForSeq2SeqLM.from_pretrained(model_path)
    tokenizer = AutoTokenizer.from_pretrained(model_path)

    # Create the prompt
    prompt = f"Claim: {claim}\n\nEvidence: {evidence}\n\nIs the claim supported, refuted, or is there not enough information?"

    # Tokenize the prompt
    inputs = tokenizer(prompt, return_tensors="pt", max_length=max_length, truncation=True)

    # Generate prediction
    outputs = model.generate(
        inputs.input_ids,
        max_length=10,
        num_beams=4,
        early_stopping=True
    )

    # Decode the prediction
    prediction = tokenizer.decode(outputs[0], skip_special_tokens=True)

    # Normalize the prediction
    prediction = prediction.strip().upper()
    if "SUPPORT" in prediction:
        normalized_pred = "SUPPORTED"
    elif "REFUT" in prediction or "FALSE" in prediction:
        normalized_pred = "REFUTED"
    elif "NOT ENOUGH" in prediction or "INSUFFICIENT" in prediction or "UNKNOWN" in prediction:
        normalized_pred = "NOT ENOUGH INFORMATION"
    else:
        normalized_pred = prediction

    return normalized_pred

# Example usage:
# prediction = predict_verdict(
#     "Paris is the capital of France.",
#     "Paris is the capital and most populous city of France."
# )
# print(f"Prediction: {prediction}")