In [None]:
"""### Enhanced Evaluation Model (Based on Code Pattern 2)"""
# ===================== Additional Imports =====================
import re
import numpy as np
import time
import os
import torch
import json
from tqdm import tqdm
from sklearn.metrics import precision_score, recall_score, f1_score
from nltk.translate.bleu_score import corpus_bleu, SmoothingFunction
from rouge import Rouge

# ===================== New Definitions =====================
# Define direct cause categories
DIRECT_CAUSES = [
    "Insufficient vehicle operation safety awareness",
    "Non-compliance with speed and load regulations",
    "Non-compliant driving behavior in high-risk areas",
    "Heavy vehicle operation standards",
    "Vehicle stability management",
    "Safe cargo handling procedures",
    "Non-standard unloading operations"
]

# Define indirect cause categories
INDIRECT_CAUSES = [
    "Inadequate on-site personnel safety management",
    "Failure in transportation compliance monitoring",
    "Insufficient driver training",
    "Corporate management failure in vehicle operations"
]

# Cause extraction function (regex-based)
def extract_causes(text):
    """Extract direct and indirect causes from output text using regex patterns"""
    direct_match = re.search(r"Direct Cause[:：]\s*(.+)", text)
    indirect_match = re.search(r"Indirect Cause[:：]\s*(.+)", text)
    
    direct = direct_match.group(1).strip() if direct_match else ""
    indirect = indirect_match.group(1).strip() if indirect_match else ""
    
    # Split multiple causes
    direct_list = [c.strip() for c in re.split(r"[,，]", direct) if c.strip()]
    indirect_list = [c.strip() for c in re.split(r"[,，]", indirect) if c.strip()]
    
    return direct_list, indirect_list

# Create multi-hot encoded vectors
def create_multihot(labels, class_list):
    """Generate multi-hot encoded vector for cause classification"""
    vector = np.zeros(len(class_list), dtype=int)
    for label in labels:
        if label in class_list:
            idx = class_list.index(label)
            vector[idx] = 1
    return vector

# ===================== Refactored Evaluation Function =====================
def calculate_metrics(true_outputs, pred_outputs):
    """Enhanced metric calculation for cause analysis system"""
    # Initialize data structures
    true_direct_vectors = []
    pred_direct_vectors = []
    true_indirect_vectors = []
    pred_indirect_vectors = []
    references = []
    hypotheses = []
    
    # Process each sample
    for true, pred in zip(true_outputs, pred_outputs):
        # Extract causes
        true_direct, true_indirect = extract_causes(true)
        pred_direct, pred_indirect = extract_causes(pred)
        
        # Create multi-hot vectors
        true_direct_vectors.append(create_multihot(true_direct, DIRECT_CAUSES))
        pred_direct_vectors.append(create_multihot(pred_direct, DIRECT_CAUSES))
        true_indirect_vectors.append(create_multihot(true_indirect, INDIRECT_CAUSES))
        pred_indirect_vectors.append(create_multihot(pred_indirect, INDIRECT_CAUSES))
        
        # Prepare text for NLP metrics
        references.append(true.split())
        hypotheses.append(pred.split())
    
    # Calculate classification metrics (macro average)
    metrics = {
        "direct_precision": precision_score(true_direct_vectors, pred_direct_vectors, average="macro", zero_division=0),
        "direct_recall": recall_score(true_direct_vectors, pred_direct_vectors, average="macro", zero_division=0),
        "direct_f1": f1_score(true_direct_vectors, pred_direct_vectors, average="macro", zero_division=0),
        "indirect_precision": precision_score(true_indirect_vectors, pred_indirect_vectors, average="macro", zero_division=0),
        "indirect_recall": recall_score(true_indirect_vectors, pred_indirect_vectors, average="macro", zero_division=0),
        "indirect_f1": f1_score(true_indirect_vectors, pred_indirect_vectors, average="macro", zero_division=0)
    }
    
    # Calculate BLEU (corpus-level)
    smoothie = SmoothingFunction().method1
    metrics["bleu"] = corpus_bleu(
        [[ref] for ref in references],
        hypotheses,
        smoothing_function=smoothie
    )
    
    # Calculate ROUGE (average)
    rouge = Rouge()
    rouge_scores = rouge.get_scores(
        [" ".join(h) for h in hypotheses],
        [" ".join(r) for r in references],
        avg=True
    )
    metrics.update({
        "rouge-1": rouge_scores["rouge-1"]["f"],
        "rouge-2": rouge_scores["rouge-2"]["f"],
        "rouge-l": rouge_scores["rouge-l"]["f"]
    })
    
    return metrics

# ===================== Refactored Evaluation Workflow ===================== 
print("Starting model evaluation...")
start_time = time.time()

# Initialize variables (assume test_dataset and model are defined elsewhere)
test_inputs = test_dataset["input"]
true_outputs = test_dataset["output"]
pred_outputs = []

# Setup progress tracking
progress_bar = tqdm(total=len(test_inputs), desc="Generating predictions", unit="sample")

# Generation parameters
generation_kwargs = {
    "max_new_tokens": 2048,
    "temperature": 0.6,
    "top_p": 0.95,
    "top_k": 20,
    "do_sample": True,
}

device = "cuda" if torch.cuda.is_available() else "cpu"

# Prediction generation loop
for i, input_text in enumerate(test_inputs):
    try:
        # Construct prompt (Alpaca format)
        prompt = f"""Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.

### Instruction:
You are an accident causation expert. Based on the accident process in the input, infer the direct and indirect causes from the provided classification tables. Output format must be: Reasoning chain. Direct Cause: list of direct causes, Indirect Cause: list of indirect causes.

Direct Cause Classification Table:
{', '.join(DIRECT_CAUSES)}

Indirect Cause Classification Table:
{', '.join(INDIRECT_CAUSES)}

### Input:
{input_text}

### Response:
"""
        
        # Generate prediction
        inputs = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True, max_length=1024).to(device)
        outputs = model.generate(**inputs, **generation_kwargs)
        pred_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
        
        # Extract response section
        if "### Response:" in pred_text:
            pred_text = pred_text.split("### Response:")[1].strip()
        
        pred_outputs.append(pred_text)
    except Exception as e:
        print(f"Error generating prediction: {str(e)}")
        pred_outputs.append("")
    
    # Update progress
    progress_bar.update(1)
    if (i + 1) % 5 == 0 and torch.cuda.is_available():
        torch.cuda.empty_cache()

progress_bar.close()

# Calculate evaluation metrics
print("Computing evaluation metrics...")
metrics = calculate_metrics(true_outputs, pred_outputs)

# Print comprehensive results
print("\n===== Evaluation Results =====")
print(f"Direct Cause Precision: {metrics['direct_precision']:.4f}")
print(f"Direct Cause Recall: {metrics['direct_recall']:.4f}")
print(f"Direct Cause F1: {metrics['direct_f1']:.4f}")
print(f"Indirect Cause Precision: {metrics['indirect_precision']:.4f}")
print(f"Indirect Cause Recall: {metrics['indirect_recall']:.4f}")
print(f"Indirect Cause F1: {metrics['indirect_f1']:.4f}")
print(f"BLEU Score: {metrics['bleu']:.4f}")
print(f"ROUGE-1 F1: {metrics['rouge-1']:.4f}")
print(f"ROUGE-2 F1: {metrics['rouge-2']:.4f}")
print(f"ROUGE-L F1: {metrics['rouge-l']:.4f}")
print(f"Total Time: {time.time()-start_time:.2f} seconds")

# Save structured results
output_dir = "evaluation_results"
os.makedirs(output_dir, exist_ok=True)
results_path = os.path.join(output_dir, "advanced_evaluation_results.json")

results_data = {
    "metrics": metrics,
    "predictions": [
        {
            "input": inp, 
            "true_output": true, 
            "pred_output": pred,
            "true_direct": extract_causes(true)[0],
            "true_indirect": extract_causes(true)[1],
            "pred_direct": extract_causes(pred)[0],
            "pred_indirect": extract_causes(pred)[1]
        }
        for inp, true, pred in zip(test_inputs, true_outputs, pred_outputs)
    ]
}

with open(results_path, 'w', encoding='utf-8') as f:
    json.dump(results_data, f, ensure_ascii=False, indent=2)

print(f"Detailed evaluation results saved to: {results_path}")
