<a href="https://colab.research.google.com/github/naisofly/HalluShield/blob/main/compare_llm_hallucination_recall.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install transformers datasets pandas scikit-learn

In [None]:
# Import necessary libraries
import os
from datasets import load_dataset
import pandas as pd
from transformers import pipeline
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay

# ----------------------------
# GPU Setup in Google Colab
# ----------------------------

# Verify GPU availability
import torch
print("PyTorch version:", torch.__version__)
print("CUDA available:", torch.cuda.is_available())

# ----------------------------
# Hugging Face Authentication
# ----------------------------

# Hugging Face authentication - replace with your token
HF_TOKEN = "ADD_YOUR_HUGGINGFACE_TOKEN_HERE"  # Get from https://huggingface.co/settings/tokens
os.environ["HF_TOKEN"] = HF_TOKEN  # Set as environment variable


### 1: Load the MedHallu dataset

In [None]:
# The dataset contains medical questions, hallucinated answers, and ground truth answers.
ds = load_dataset("UTAustin-AIHealth/MedHallu", "pqa_labeled")
df = ds['train'].to_pandas()

### 2. Create new test dataset from Hard Hallucinations

In [None]:
# Focus on challenging cases where hallucinations are harder to detect
hard_hallucinations = df[df['Difficulty Level'] == 'hard']
print(f"Number of hard hallucination entries: {len(hard_hallucinations)}")

# Create balanced test set
num_samples = len(hard_hallucinations)
test_df = pd.concat([
    hard_hallucinations.sample(n=num_samples, random_state=42)
    .assign(answer=lambda x: x['Ground Truth'], label='non-hallucination'),
    hard_hallucinations.sample(n=num_samples, random_state=84)
    .assign(answer=lambda x: x['Hallucinated Answer'], label='hallucination')
]).sample(frac=1, random_state=126).reset_index(drop=True)

print("\nLabel counts in new dataset:")
print(test_df['label'].value_counts())

### 3: Initialize the LLMs to be evaluated

In [None]:
gemma_model = pipeline(
    "text-generation",
    model="google/gemma-2-2b-it",
    token=HF_TOKEN,
    device="cuda",  # Use GPU for inference
    torch_dtype=torch.float16,
    max_new_tokens=10,  # Increased to capture full responses
    do_sample=False,
    temperature=0.0,
    pad_token_id=2  # Gemma-specific requirement
)

### 4: Define system and user prompts for hallucination detection

In [None]:
# ----------------------------
# Define Prompt Templates and Batch Processing Function
# ----------------------------

system_prompt = """<bos>You are a medical hallucination detector.
Check if answers contain factual inaccuracies. Respond ONLY with 'Yes' or 'No'."""

def generate_prompt(row):
    return f"""<start_of_turn>user{system_prompt}
            Context: {row['Ground Truth']}
            Question: {row['Question']}
            Answer: {row['answer']}<end_of_turn>
            <start_of_turn>model"""

def parse_response(response):
    """Handle edge cases and partial responses"""
    response = response.lower().strip()

    # Handle empty responses
    if not response:
        return "unexpected_empty_response"

    # Capture any yes/no indication
    if any(x in response for x in ["yes", "no"]):
        return "yes" if "yes" in response else "no"

    return f"unexpected_{response[:20]}"  # Truncate long unexpected responses

def evaluate_model(test_df, model, batch_size=32):
    """
    Evaluate the model on the test dataset and debug response structure.
    """
    results = []
    for i in range(0, len(test_df), batch_size):
        batch = test_df.iloc[i:i + batch_size]
        prompts = [generate_prompt(row) for _, row in batch.iterrows()]

        # Generate responses
        responses = model(
            prompts,
            do_sample=False,
            temperature=0.0,
            return_full_text=False  # Get only generated text
        )

        # # Debug: Print the entire response object
        # print("\nDebugging responses:")
        # print(responses)

        for idx, (_, row) in enumerate(batch.iterrows()):
            # Extract raw response and validate structure
            if 'generated_text' in responses[idx][0]:
                raw_response = responses[idx][0]['generated_text'].strip()
            else:
                print(f"\nError: 'generated_text' key not found in response for prompt:\n{prompts[idx]}")
                raw_response = "INVALID RESPONSE"

            # Parse and validate response
            parsed_response = parse_response(raw_response)

            # # Debug: Print parsed response for verification
            # print(f"Raw Response: {raw_response}")
            # print(f"Parsed Response: {parsed_response}")

            model_response = 'Yes' if parsed_response == 'yes' else 'No'
            is_correct = (model_response == 'Yes') == (row['label'] == 'hallucination')

            results.append({
                "Question": row["Question"],
                "Answer": row["answer"],
                "Label": row["label"],
                "Model Response": model_response,
                "Raw Response": raw_response,
                "Correct": is_correct
            })

    return pd.DataFrame(results)

### 5. Evaluate Model on Hard Hallucinations

In [None]:
# ----------------------------
# Evaluate Model Using Batches
# ----------------------------

results_df = evaluate_model(test_df, gemma_model)

### 6. Calculate Recall Scores

In [None]:
# ----------------------------
# Performance Analysis and Results Saving
# ----------------------------

print("\n6. Calculating metrics...")
true_labels = results_df['Label'].map({'hallucination': 1, 'non-hallucination': 0})
predicted_labels = results_df['Model Response'].map({'Yes': 1, 'No': 0})

cm = confusion_matrix(true_labels, predicted_labels)
disp = ConfusionMatrixDisplay(cm, display_labels=['Non-Hallucination', 'Hallucination'])
disp.plot(cmap='Blues', values_format='d')
disp.ax_.set_title("Confusion Matrix\n(1=Hallucination, 0=Non-Hallucination)")

print("\nConfusion Matrix Breakdown:")
print(f"True Positives (TP): {cm[1,1]}")  # Correctly identified hallucinations
print(f"False Positives (FP): {cm[0,1]}") # Non-hallucinations flagged as hallucinations
print(f"False Negatives (FN): {cm[1,0]}") # Missed hallucinations
print(f"True Negatives (TN): {cm[0,0]}")  # Correctly identified non-hallucinations

precision = cm[1,1] / (cm[1,1] + cm[0,1]) if (cm[1,1] + cm[0,1]) > 0 else 0
recall = cm[1,1] / (cm[1,1] + cm[1,0]) if (cm[1,1] + cm[1,0]) > 0 else 0

print(f"\nPrecision: {precision:.2f} (How many flagged hallucinations were correct)")
print(f"Recall: {recall:.2f} (How many actual hallucinations were detected)")

# Save results
results_df.to_csv("hallucination_evaluation_results.csv", index=False)
print("\nResults saved to 'hallucination_evaluation_results.csv'")