In [None]:
import pandas as pd
import anthropic
import os
import sys
sys.path.append('..')
from config import ANTHROPIC_API_KEY 
import openpyxl
import requests
import io 
from io import BytesIO
import base64
import sys.
from PIL import Image
from collections import defaultdict
import numpy as np

# Initialize Claude client
if not ANTHROPIC_API_KEY:
    raise ValueError("Please set ANTHROPIC_API_KEY environment variable")
client = anthropic.Anthropic(api_key=ANTHROPIC_API_KEY)

# System prompts
SYSTEM_PROMPT = """You must reply with NO explanations, NO headers, NO extra text.
Language: ENGLISH. Keep the output strictly in the required format.

You will receive ONE medical question with:
- Text fields: QuestionID, Question, options A..E (some may be null).
- Optionally ONE image immediately after the text (when provided).

Task: choose EXACTLY ONE correct option among A, B, C, D, E for the QuestionID.
You may consider the image when present.

STRICT output format:
QuestionID, LETTER

Example:
ES3341, B

RULES:
- Output EXACTLY ONE line in the exact format above.
- Do NOT repeat the instructions.
- Do NOT include the option text, ONLY the letter."""

SYSTEM_PROMPT_WITH_REASONING = """You will receive ONE medical question with:
- Text fields: QuestionID, Question, options A..E (some may be null).
- Optionally ONE image immediately after the text (when provided).

Task: 
1. Choose EXACTLY ONE correct option among A, B, C, D, E for the QuestionID.
2. Provide your reasoning for selecting this answer.

Output format:
QuestionID, LETTER
Reasoning: [Your detailed reasoning here]

Example:
ES3341, B
Reasoning: The image shows characteristic signs of..."""

USER_LEAD = "Below is a single record. Use ONLY the relevant information."

def build_content_like_main_script(question_id: str, question_text: str, options: dict) -> list:
    """Build content in the same format as the main script"""
    content = []
    
    # Start with the lead text
    lines = [USER_LEAD, f"QuestionID: {question_id}"]
    
    # Add question text
    if question_text:
        lines.append(f"Question: {question_text}")
    
    # Add options A through E
    for label in ["A", "B", "C", "D", "E"]:
        option_text = options.get(label, "")
        lines.append(f"{label}) {option_text if option_text else ''}")
    
    # Create text content block
    content.append({
        "type": "text",
        "text": "\n".join(lines)
    })
    
    return content

def call_claude_with_image(content, use_reasoning=False):
    """Call Claude API and return parsed response"""
    try:
        response = client.messages.create(
            model="claude-sonnet-4-5-20250929",
            max_tokens=500 if use_reasoning else 10,
            system=SYSTEM_PROMPT_WITH_REASONING if use_reasoning else SYSTEM_PROMPT,
            messages=[{
                "role": "user",
                "content": content
            }]
        )
        
        # Extract Claude's answer
        text = ""
        for block in response.content:
            if block.type == "text":
                text += block.text
        text = text.strip()
        
        return text
        
    except Exception as e:
        print(f"Error calling Claude API: {e}")
        return None

def parse_claude_response(text):
    """Parse Claude's response to extract answer letter"""
    import re
    LINE_RX = re.compile(r'^\s*([^,]+)\s*,\s*([A-Ea-e])\s*$', re.UNICODE)
    
    if text:
        first_line = text.splitlines()[0] if text else ""
        m = LINE_RX.match(first_line)
        if m:
            qid_out = m.group(1).strip()
            letter = m.group(2).upper().strip()
            return letter
        else:
            return "PARSE_ERROR"
    else:
        return "NO_RESPONSE"

def process_single_repetition(df, worksheet, picture_link_col, fake_image_path, repetition_num):
    """Process all questions for one repetition"""
    print(f"\n{'#'*70}")
    print(f"### STARTING REPETITION {repetition_num} ###")
    print(f"{'#'*70}\n")
    
    results = []
    skipped_questions = []

    for index, row in df.iterrows():
        question_id = row['questionID']
        question_text = row['question_text']
        correct_answer = row['correct_option']
        excel_row = index + 2
        
        # Check if picture_link column exists and get its value
        picture_link = row.get('picture_link', 'N/A')
        has_image = False
        real_image = None
        
        # Try to load real image if available
        if picture_link != 'N/A' and pd.notna(picture_link):
            cell = worksheet.cell(row=excel_row, column=picture_link_col)
            if cell.hyperlink and cell.hyperlink.target:
                try:
                    url = cell.hyperlink.target
                    file_id = url.split("/d/")[1].split("/")[0]
                    download_url = f"https://drive.google.com/uc?export=download&id={file_id}"
                    response = requests.get(download_url, timeout=30)
                    img = Image.open(BytesIO(response.content))
                    
                    # Convert RGBA to RGB
                    if img.mode in ('RGBA', 'LA'):
                        background = Image.new('RGB', img.size, (255, 255, 255))
                        background.paste(img, mask=img.split()[-1])
                        img = background
                    
                    real_image = img
                    has_image = True
                    print(f"[Rep {repetition_num}] Loaded image for {question_id}")
                except Exception as e:
                    print(f"[Rep {repetition_num}] Failed to load image for {question_id}: {e}")
        
        # Skip questions with invalid correct_option
        if pd.isna(correct_answer) or correct_answer is None:
            print(f"[Rep {repetition_num}] Skipping {question_id}: correct_option is NaN/None")
            skipped_questions.append({
                "repetition": repetition_num,
                "question_id": question_id,
                "reason": "correct_option is NaN/None"
            })
            continue

        # Convert to string and handle blanks
        correct_answer_str = str(correct_answer).strip().upper()

        if not correct_answer_str or correct_answer_str not in ['A', 'B', 'C', 'D', 'E']:
            print(f"[Rep {repetition_num}] Skipping {question_id}: Invalid correct_option '{correct_answer}'")
            skipped_questions.append({
                "repetition": repetition_num,
                "question_id": question_id,
                "reason": f"Invalid correct_option: {correct_answer}"
            })
            continue
        
        # Build options dictionary
        options = {
            "A": row['option_a'],
            "B": row['option_b'], 
            "C": row['option_c'],
            "D": row['option_d'],
            "E": row['option_e']
        }
        
        # Initialize result dictionary
        result = {
            "repetition": repetition_num,
            "question_id": question_id,
            "question": question_text,
            "correct_answer": correct_answer.upper(),
            "has_image": has_image
        }
        
        # If question has image, process with both real and fake images
        if has_image and real_image is not None:
            print(f"\n[Rep {repetition_num}] Processing question with IMAGE: {question_id}")
            
            # Build base content
            content_base = build_content_like_main_script(question_id, question_text, options)
            
            # ==== REAL IMAGE ====
            content_real = content_base.copy()
            
            # Add real image
            buffer = BytesIO()
            real_image.save(buffer, format='JPEG')
            image_data_real = base64.b64encode(buffer.getvalue()).decode('utf-8')
            
            content_real.append({
                "type": "image",
                "source": {
                    "type": "base64",
                    "media_type": "image/jpeg",
                    "data": image_data_real
                }
            })
            
            # Call Claude with real image and reasoning
            response_real = call_claude_with_image(content_real, use_reasoning=True)
            
            if response_real:
                claude_answer_real = parse_claude_response(response_real)
                result["claude_answer_real"] = claude_answer_real
                result["claude_response_real"] = response_real
                result["is_correct_real"] = (correct_answer.upper() == claude_answer_real)
            else:
                result["claude_answer_real"] = "API_ERROR"
                result["claude_response_real"] = "API_ERROR"
                result["is_correct_real"] = False
            
            # ==== FAKE IMAGE ====
            content_fake = content_base.copy()
            
            # Load and add fake image
            try:
                fake_img = Image.open(fake_image_path)
                if fake_img.mode in ('RGBA', 'LA'):
                    background = Image.new('RGB', fake_img.size, (255, 255, 255))
                    background.paste(fake_img, mask=fake_img.split()[-1])
                    fake_img = background
                
                buffer_fake = BytesIO()
                fake_img.save(buffer_fake, format='JPEG')
                image_data_fake = base64.b64encode(buffer_fake.getvalue()).decode('utf-8')
                
                content_fake.append({
                    "type": "image",
                    "source": {
                        "type": "base64",
                        "media_type": "image/jpeg",
                        "data": image_data_fake
                    }
                })
                
                # Call Claude with fake image and reasoning
                response_fake = call_claude_with_image(content_fake, use_reasoning=True)
                
                if response_fake:
                    claude_answer_fake = parse_claude_response(response_fake)
                    result["claude_answer_fake"] = claude_answer_fake
                    result["claude_response_fake"] = response_fake
                    result["is_correct_fake"] = (correct_answer.upper() == claude_answer_fake)
                else:
                    result["claude_answer_fake"] = "API_ERROR"
                    result["claude_response_fake"] = "API_ERROR"
                    result["is_correct_fake"] = False
                    
            except Exception as e:
                print(f"[Rep {repetition_num}] Error loading fake image: {e}")
                result["claude_answer_fake"] = "FAKE_IMAGE_ERROR"
                result["claude_response_fake"] = f"Error: {e}"
                result["is_correct_fake"] = False
            
            print(f"[Rep {repetition_num}] {question_id}: Correct={correct_answer.upper()}, "
                  f"Real={result.get('claude_answer_real', 'N/A')}{'✓' if result.get('is_correct_real', False) else '✗'}, "
                  f"Fake={result.get('claude_answer_fake', 'N/A')}{'✓' if result.get('is_correct_fake', False) else '✗'}")
            
        else:
            # No image - process normally
            print(f"[Rep {repetition_num}] Processing question WITHOUT image: {question_id}")
            content = build_content_like_main_script(question_id, question_text, options)
            
            response_text = call_claude_with_image(content, use_reasoning=False)
            
            if response_text:
                claude_answer = parse_claude_response(response_text)
                result["claude_answer"] = claude_answer
                result["is_correct"] = (correct_answer.upper() == claude_answer)
            else:
                result["claude_answer"] = "API_ERROR"
                result["is_correct"] = False
        
        results.append(result)
    
    return results, skipped_questions

def compute_aggregated_statistics(all_results):
    """Compute aggregated statistics with confidence intervals"""
    
    # Group results by question_id
    question_results = defaultdict(lambda: {
        'correct_answer': None,
        'has_image': False,
        'real_correct': [],
        'fake_correct': [],
        'no_image_correct': []
    })
    
    for result in all_results:
        qid = result['question_id']
        question_results[qid]['correct_answer'] = result['correct_answer']
        question_results[qid]['has_image'] = result.get('has_image', False)
        
        if result.get('has_image', False):
            question_results[qid]['real_correct'].append(result.get('is_correct_real', False))
            question_results[qid]['fake_correct'].append(result.get('is_correct_fake', False))
        else:
            question_results[qid]['no_image_correct'].append(result.get('is_correct', False))
    
    # Calculate statistics for each question
    aggregated = []
    
    for qid, data in question_results.items():
        correct = data['correct_answer']
        
        if data['has_image']:
            real_correct = data['real_correct']
            fake_correct = data['fake_correct']
            n = len(real_correct)
            
            real_correct_count = sum(real_correct)
            fake_correct_count = sum(fake_correct)
            
            real_accuracy = real_correct_count / n if n > 0 else 0
            fake_accuracy = fake_correct_count / n if n > 0 else 0
            
            # Wilson score confidence interval (95%)
            from scipy import stats
            if n > 0:
                real_ci = stats.binom.interval(0.95, n, real_accuracy) if real_accuracy > 0 else (0, 0)
                fake_ci = stats.binom.interval(0.95, n, fake_accuracy) if fake_accuracy > 0 else (0, 0)
                real_ci_lower = real_ci[0] / n
                real_ci_upper = real_ci[1] / n
                fake_ci_lower = fake_ci[0] / n
                fake_ci_upper = fake_ci[1] / n
            else:
                real_ci_lower = real_ci_upper = 0
                fake_ci_lower = fake_ci_upper = 0
            
            aggregated.append({
                'question_id': qid,
                'correct_answer': correct,
                'has_image': True,
                'num_repetitions': n,
                # Real image stats
                'real_correct_count': real_correct_count,
                'real_accuracy': real_accuracy,
                'real_ci_lower': real_ci_lower,
                'real_ci_upper': real_ci_upper,
                # Fake image stats
                'fake_correct_count': fake_correct_count,
                'fake_accuracy': fake_accuracy,
                'fake_ci_lower': fake_ci_lower,
                'fake_ci_upper': fake_ci_upper
            })
        else:
            no_image_correct = data['no_image_correct']
            n = len(no_image_correct)
            correct_count = sum(no_image_correct)
            accuracy = correct_count / n if n > 0 else 0
            
            # Wilson score confidence interval (95%)
            from scipy import stats
            if n > 0:
                ci = stats.binom.interval(0.95, n, accuracy) if accuracy > 0 else (0, 0)
                ci_lower = ci[0] / n
                ci_upper = ci[1] / n
            else:
                ci_lower = ci_upper = 0
            
            aggregated.append({
                'question_id': qid,
                'correct_answer': correct,
                'has_image': False,
                'num_repetitions': n,
                'correct_count': correct_count,
                'accuracy': accuracy,
                'ci_lower': ci_lower,
                'ci_upper': ci_upper
            })
    
    return pd.DataFrame(aggregated)

# ==== MAIN EXECUTION ====

# Configuration
NUM_REPETITIONS = 10  # Change this to desired number of repetitions
FAKE_IMAGE_PATH = "../data/Fake_Image_path/image.png"

# Load your Excel data
df = pd.read_excel("../data/subset_with_images.xlsx", sheet_name="SSM_Q_ITA")
picture_link_col = df.columns.get_loc('picture_link') + 1
workbook = openpyxl.load_workbook("../data/subset_with_images.xlsx")
worksheet = workbook["SSM_Q_ITA"]

# Store all results across repetitions
all_results = []
all_skipped = []

# Run multiple repetitions
for rep in range(1, NUM_REPETITIONS + 1):
    results, skipped = process_single_repetition(
        df, worksheet, picture_link_col, FAKE_IMAGE_PATH, rep
    )
    all_results.extend(results)
    all_skipped.extend(skipped)

# Save individual repetition results
all_results_df = pd.DataFrame(all_results)
all_results_df.to_csv("results/all_repetitions_detailed.csv", index=False)
print(f"\n✓ Saved all repetitions to all_repetitions_detailed.csv")

# Compute and save aggregated statistics
aggregated_df = compute_aggregated_statistics(all_results)
aggregated_df.to_csv("results/aggregated_statistics.csv", index=False)
print(f"✓ Saved aggregated statistics to aggregated_statistics.csv")

# Save skipped questions
if all_skipped:
    skipped_df = pd.DataFrame(all_skipped)
    skipped_df.to_csv("results/skipped_questions.csv", index=False)
    print(f"✓ Saved skipped questions to skipped_questions.csv")

# Print summary statistics
print(f"\n{'='*70}")
print(f"=== OVERALL SUMMARY (across {NUM_REPETITIONS} repetitions) ===")
print(f"{'='*70}")

total_questions = len(aggregated_df)
questions_with_images = len(aggregated_df[aggregated_df['has_image'] == True])
questions_without_images = total_questions - questions_with_images

print(f"\nTotal unique questions: {total_questions}")
print(f"Questions with images: {questions_with_images}")
print(f"Questions without images: {questions_without_images}")
print(f"Total evaluations: {len(all_results)}")

if questions_with_images > 0:
    img_df = aggregated_df[aggregated_df['has_image'] == True]
    
    avg_real_accuracy = img_df['real_accuracy'].mean() * 100
    avg_fake_accuracy = img_df['fake_accuracy'].mean() * 100
    std_real_accuracy = img_df['real_accuracy'].std() * 100
    std_fake_accuracy = img_df['fake_accuracy'].std() * 100
    
    # Overall confidence interval using all data points
    total_real_correct = img_df['real_correct_count'].sum()
    total_real_trials = img_df['num_repetitions'].sum()
    total_fake_correct = img_df['fake_correct_count'].sum()
    total_fake_trials = img_df['num_repetitions'].sum()
    
    from scipy import stats
    real_overall_ci = stats.binom.interval(0.95, total_real_trials, total_real_correct/total_real_trials)
    fake_overall_ci = stats.binom.interval(0.95, total_fake_trials, total_fake_correct/total_fake_trials)
    
    print(f"\n--- REAL Images ---")
    print(f"Overall accuracy: {avg_real_accuracy:.1f}% ± {std_real_accuracy:.1f}%")
    print(f"95% CI: [{real_overall_ci[0]/total_real_trials*100:.1f}%, {real_overall_ci[1]/total_real_trials*100:.1f}%]")
    print(f"Questions always correct: {len(img_df[img_df['real_accuracy'] == 1.0])}/{questions_with_images}")
    print(f"Questions never correct: {len(img_df[img_df['real_accuracy'] == 0.0])}/{questions_with_images}")
    
    print(f"\n--- FAKE Images ---")
    print(f"Overall accuracy: {avg_fake_accuracy:.1f}% ± {std_fake_accuracy:.1f}%")
    print(f"95% CI: [{fake_overall_ci[0]/total_fake_trials*100:.1f}%, {fake_overall_ci[1]/total_fake_trials*100:.1f}%]")
    print(f"Questions always correct: {len(img_df[img_df['fake_accuracy'] == 1.0])}/{questions_with_images}")
    print(f"Questions never correct: {len(img_df[img_df['fake_accuracy'] == 0.0])}/{questions_with_images}")

if questions_without_images > 0:
    no_img_df = aggregated_df[aggregated_df['has_image'] == False]
    
    avg_accuracy = no_img_df['accuracy'].mean() * 100
    std_accuracy = no_img_df['accuracy'].std() * 100
    
    # Overall confidence interval
    total_correct = no_img_df['correct_count'].sum()
    total_trials = no_img_df['num_repetitions'].sum()
    
    from scipy import stats
    overall_ci = stats.binom.interval(0.95, total_trials, total_correct/total_trials)
    
    print(f"\n--- WITHOUT Images ---")
    print(f"Overall accuracy: {avg_accuracy:.1f}% ± {std_accuracy:.1f}%")
    print(f"95% CI: [{overall_ci[0]/total_trials*100:.1f}%, {overall_ci[1]/total_trials*100:.1f}%]")
    print(f"Questions always correct: {len(no_img_df[no_img_df['accuracy'] == 1.0])}/{questions_without_images}")
    print(f"Questions never correct: {len(no_img_df[no_img_df['accuracy'] == 0.0])}/{questions_without_images}")

print(f"\n{'='*70}")
print("Analysis complete!")
print(f"{'='*70}\n")


######################################################################
### STARTING REPETITION 1 ###
######################################################################

[Rep 1] Loaded image for IT0006

[Rep 1] Processing question with IMAGE: IT0006
[Rep 1] IT0006: Correct=C, Real=C✓, Fake=D✗
[Rep 1] Loaded image for IT0007

[Rep 1] Processing question with IMAGE: IT0007
[Rep 1] IT0007: Correct=C, Real=C✓, Fake=D✗
[Rep 1] Loaded image for IT0031

[Rep 1] Processing question with IMAGE: IT0031
[Rep 1] IT0031: Correct=A, Real=A✓, Fake=A✓
[Rep 1] Loaded image for IT0032

[Rep 1] Processing question with IMAGE: IT0032
[Rep 1] IT0032: Correct=B, Real=B✓, Fake=B✓
[Rep 1] Loaded image for IT0053

[Rep 1] Processing question with IMAGE: IT0053
[Rep 1] IT0053: Correct=C, Real=C✓, Fake=C✓
[Rep 1] Loaded image for IT0054

[Rep 1] Processing question with IMAGE: IT0054
[Rep 1] IT0054: Correct=A, Real=A✓, Fake=A✓
[Rep 1] Loaded image for IT0063

[Rep 1] Processing question with IMAGE: IT0063
[R