In [None]:
import os
import json
import openai
import requests
import pandas as pd
import time
from dotenv import load_dotenv
from typing import Dict, List, Any, Optional

# LLM user simulator config

# Load .env from the input_handling_extraction directory
dotenv_path = os.path.join('input_handling_extraction', '.env')
load_dotenv(dotenv_path=dotenv_path)

pd.set_option('display.max_colwidth', None)
pd.set_option('display.width', 200)

try:
    api_key = os.environ["OPENROUTER_API_KEY"]
    print("Loaded OpenRouter API key.")
except KeyError:
    api_key = None
    print("ERROR: OPENROUTER_API_KEY not found. Check your .env in 'input_handling_extraction'.")

if api_key:
    user_llm_client = openai.OpenAI(
        base_url="https://openrouter.ai/api/v1",
        api_key=api_key
    )
    print("Connected to OpenRouter LLM.")
else:
    user_llm_client = None

USER_LLM_MODEL = "google/gemini-2.5-flash-lite-preview-06-17"
CHAT_ENDPOINT = "http://127.0.0.1:8000/chat"

Loaded OpenRouter API key.
Connected to OpenRouter LLM.


In [None]:
# LLM-driven user

def get_llm_user_response(ground_truth: Dict[str, Any], conversation_history: List[Dict[str, str]], persona: str) -> str:
    """
    Ask an LLM to produce a natural user reply based on the goal, persona, and conversation so far.
    """
    if not user_llm_client:
        return "Error: User LLM client not initialized."

    last_bot_message = conversation_history[-1]['content'] if conversation_history else "Hello, how can I help you book your flight today?"

    # Persona-specific behavior
    if persona == "direct_simple":
        persona_instruction = (
            "You’re simulating a direct, straightforward user. "
            "Share exactly one requested piece of flight info at a time from the data below. "
            "Answer concisely to the bot’s last question, and only add more if asked, or if the question is broad and one more detail is the obvious next step."
        )
    elif persona == "comprehensive":
        persona_instruction = (
            "You’re simulating a comprehensive user. "
            "Try to provide as much of your flight info as you can early on. "
            "Then respond with whatever’s left or clarify as needed. "
            "If you can anticipate what the bot needs next, include it proactively."
        )
    elif persona == "ambiguous":
        persona_instruction = (
            "You’re simulating an ambiguous user. "
            "You’ll give vague info first when the question allows (e.g., ‘next month’, ‘a few people’, ‘around noon’). "
            "Only get specific when the bot asks directly. "
            "If no exact date is in the data, use general phrasing like ‘sometime in 2025’."
        )
    elif persona == "error_prone":
        persona_instruction = (
            "You’re simulating an error-prone user. "
            "You might hesitate, change your mind, or rephrase something so the bot has to confirm. "
            "Eventually give the correct info from the data below when asked directly."
        )
    else:
        persona_instruction = "You’re a helpful user simulator. Provide flight info using the data below."

    system_prompt = f"""
    {persona_instruction}

    Required flight info (ground truth):
    {json.dumps(ground_truth, indent=2)}

    The bot’s latest message:
    "{last_bot_message}"

    Instructions:
    1) Given your persona, read the bot’s last message and your data.
    2) Reply with a concise, natural message that helps the booking move forward.
    3) Don’t invent facts. Stick to the data or the persona’s guidelines.
    4) Don’t send an empty reply. If you’re done, say: "That's all for now, thank you."
    """

    messages_for_llm = [{"role": "system", "content": system_prompt}]

    try:
        response = user_llm_client.chat.completions.create(
            model=USER_LLM_MODEL,
            messages=messages_for_llm,
            temperature=0.7,
            max_tokens=100,
        )
        user_response = response.choices[0].message.content or "yes"
        return user_response.strip()
    except Exception as e:
        print(f"Error calling user LLM: {e}")
        return "An error occurred."


In [None]:
# Normalize values to compare them more reliably
def _normalize_value_for_comparison(value: Any) -> Any:
    if value is None:
        return None
    if isinstance(value, bool):
        return value
    if isinstance(value, (int, float)):
        return float(value)
    if isinstance(value, str):
        normalized_str = value.strip()
        if normalized_str.lower() in ["", "null", "none"]:
            return None
        elif normalized_str.lower() == 'true':
            return True
        elif normalized_str.lower() == 'false':
            return False
        else:
            try:
                return float(normalized_str)
            except ValueError:
                return normalized_str.lower()
    return value

def run_llm_driven_test(record: Dict[str, Any], test_index: int, persona: str, print_conversation: bool = True) -> Dict[str, Any]:
    """
    Drive a conversation between the LLM-simulated user and the flight bot.
    Returns the final extracted info plus evaluation metrics.
    """
    session = requests.Session()
    session_id = None
    final_flight_info = {}
    conversation_history = []
    last_bot_messages_for_repetition = []
    REPETITION_THRESHOLD = 3
    MAX_TURNS = 9
    username = f"testuser_{test_index}_{persona}"

    fields_expected_from_extractor = [
        "departure_city", "arrival_city", "departure_date", "passengers"
    ]

    # Prepare ground truth in the same shape the extractor outputs for a fair comparison
    def _get_expected_state_for_extractor_comparison(gt_record: Dict[str, Any]) -> Dict[str, Any]:
        expected_state = {}
        for key in fields_expected_from_extractor:
            value = gt_record.get(key)
            if key == "passengers" and (value is None or value == 0):
                total_passengers = (gt_record.get("adult_passengers", 0) or 0) + \
                                   (gt_record.get("child_passengers", 0) or 0) + \
                                   (gt_record.get("infant_passengers", 0) or 0)
                expected_state[key] = total_passengers if total_passengers > 0 else None
            else:
                expected_state[key] = value
        return {k: _normalize_value_for_comparison(v) for k, v in expected_state.items()}

    # Check whether the extracted info matches the ground truth for all evaluated fields
    def _is_info_complete_and_correct(current_extracted_info: Dict[str, Any], ground_truth_for_comparison: Dict[str, Any]) -> bool:
        for key in fields_expected_from_extractor:
            expected_val = ground_truth_for_comparison.get(key)
            extracted_val = _normalize_value_for_comparison(current_extracted_info.get(key))
            if expected_val != extracted_val:
                return False
        return True

    try:
        requests.post(f"{CHAT_ENDPOINT.replace('/chat', '')}/users/register", json={"username": username, "email": f"{username}@example.com"})
        login_response = requests.post(f"{CHAT_ENDPOINT.replace('/chat', '')}/users/login", params={"username": username})
        login_response.raise_for_status()
        if print_conversation:
            print(f"Signed in as '{username}'.")
    except requests.RequestException as e:
        error_msg = f"Error setting up user '{username}': {e}"
        if print_conversation:
            print(error_msg)
        return {"error": error_msg, "metrics": {"precision": 0, "recall": 0, "f1_score": 0, "conversation_length": 0, "true_positives": 0, "false_positives": 0, "false_negatives": 0, "expected_for_evaluation": _get_expected_state_for_extractor_comparison(record), "extracted_for_evaluation": {}}, "final_flight_info": {}}

    try:
        response = session.post(CHAT_ENDPOINT, json={"role": "user", "content": "Hi!", "session_id": None, "username": username}, timeout=60)
        response.raise_for_status()
        data = response.json()
        session_id = data["session_id"]
        bot_response = data['response']
        final_flight_info = data.get("flight_info", {})
        if print_conversation:
            print(f"BOT: {bot_response}")
        conversation_history.append({"role": "assistant", "content": bot_response})
        last_bot_messages_for_repetition.append(bot_response)
    except requests.RequestException as e:
        error_msg = f"Error starting session for user {username}: {e}"
        if print_conversation:
            print(error_msg)
        return {"error": error_msg, "metrics": {"precision": 0, "recall": 0, "f1_score": 0, "conversation_length": 0, "true_positives": 0, "false_positives": 0, "false_negatives": 0, "expected_for_evaluation": _get_expected_state_for_extractor_comparison(record), "extracted_for_evaluation": {}}, "final_flight_info": {}}

    expected_ground_truth_for_extractor_comparison = _get_expected_state_for_extractor_comparison(record)

    for turn in range(MAX_TURNS):
        # Stop if the bot repeats itself N times
        if len(last_bot_messages_for_repetition) == REPETITION_THRESHOLD and \
           all(msg == last_bot_messages_for_repetition[0] for msg in last_bot_messages_for_repetition):
            if print_conversation:
                print(f"Bot repeated the same message {REPETITION_THRESHOLD} times, ending the conversation.")
            break

        # Stop if the bot explicitly confirms a booking
        final_booking_confirmation_phrases = [
            "booking confirmed", "your booking is complete", "your e-ticket and itinerary will be sent"
        ]
        if conversation_history and any(phrase.lower() in conversation_history[-1]['content'].lower() for phrase in final_booking_confirmation_phrases):
            if print_conversation:
                print("Booking appears confirmed—ending the test.")
            break

        # Stop if we have the core info and the bot starts taking action
        essential_fields_for_action_trigger = fields_expected_from_extractor
        is_essential_info_present = all(_normalize_value_for_comparison(final_flight_info.get(key)) is not None for key in essential_fields_for_action_trigger)

        bot_action_trigger_phrases = [
            "I'll now search for", "searching for flights", "pull up the results",
            "check for available flights", "here’s your complete booking request",
            "Here's your confirmed booking summary", "proceed with booking this flight",
            "here’s what I’ve found for your trip", "here are some simulated example options"
        ]

        if is_essential_info_present and \
           conversation_history and \
           any(phrase.lower() in conversation_history[-1]['content'].lower() for phrase in bot_action_trigger_phrases):
            if print_conversation:
                print("Bot has enough info and is moving to the next step—ending the test.")
            break

        # Next user turn
        user_response = get_llm_user_response(record, conversation_history, persona)

        if "error" in user_response.lower() or "an error occurred." in user_response.lower():
            if print_conversation:
                print(f"User (LLM) couldn't respond: {user_response}")
            break

        if print_conversation:
            print(f"USER (LLM): {user_response}")
        conversation_history.append({"role": "user", "content": user_response})

        try:
            response = session.post(CHAT_ENDPOINT, json={"role": "user", "content": user_response, "session_id": session_id, "username": username}, timeout=60)
            response.raise_for_status()
            data = response.json()
            bot_response = data['response']
            final_flight_info = data.get("flight_info", {})
            if print_conversation:
                print(f"BOT: {bot_response}")
            conversation_history.append({"role": "assistant", "content": bot_response})

            last_bot_messages_for_repetition.append(bot_response)
            if len(last_bot_messages_for_repetition) > REPETITION_THRESHOLD:
                last_bot_messages_for_repetition.pop(0)
        except requests.RequestException as e:
            error_msg = f"Error during conversation for user {username}: {e}"
            if print_conversation:
                print(error_msg)
            return {"error": error_msg, "metrics": {"precision": 0, "recall": 0, "f1_score": 0, "conversation_length": len(conversation_history), "true_positives": 0, "false_positives": 0, "false_negatives": 0, "expected_for_evaluation": expected_ground_truth_for_extractor_comparison, "extracted_for_evaluation": {}}, "final_flight_info": final_flight_info}

        time.sleep(1)

    if print_conversation:
        print(f"\nConversation for {username} ended after {len(conversation_history)} messages.")

    metrics = calculate_slot_metrics(expected_ground_truth_for_extractor_comparison, final_flight_info, fields_expected_from_extractor, len(conversation_history))

    return {"final_flight_info": final_flight_info, "metrics": metrics}

def calculate_slot_metrics(expected_data: Dict[str, Any], extracted_data: Dict[str, Any], fields_to_evaluate: List[str], conversation_length: int) -> Dict[str, Any]:
    """
    Compute slot-filling precision, recall, and F1.
    """
    true_positives = 0
    false_positives = 0
    false_negatives = 0

    extracted_data_normalized_for_metrics = {k: _normalize_value_for_comparison(extracted_data.get(k)) for k in fields_to_evaluate}

    for field in fields_to_evaluate:
        expected_val = expected_data.get(field)
        extracted_val = extracted_data_normalized_for_metrics.get(field)

        if expected_val is not None:
            if extracted_val == expected_val:
                true_positives += 1
            else:
                false_negatives += 1
        else:
            if extracted_val is not None:
                false_positives += 1

    precision = true_positives / (true_positives + false_positives) if (true_positives + false_positives) > 0 else 0.0
    recall = true_positives / (true_positives + false_negatives) if (true_positives + false_negatives) > 0 else 0.0
    f1_score = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0.0

    return {
        "precision": round(precision, 4),
        "recall": round(recall, 4),
        "f1_score": round(f1_score, 4),
        "true_positives": true_positives,
        "false_positives": false_positives,
        "false_negatives": false_negatives,
        "conversation_length": conversation_length,
        "expected_for_evaluation": expected_data,
        "extracted_for_evaluation": extracted_data_normalized_for_metrics
    }


In [None]:
# Single test runner

TEST_INDEX = 0
SELECTED_PERSONA = "comprehensive"  # options: "direct_simple", "comprehensive", "ambiguous", "error_prone"

if 'ground_truth_data' not in globals() or not ground_truth_data:
    try:
        with open("flight_test_data.json", 'r') as f:
            ground_truth_data = json.load(f)
        print(f"Loaded {len(ground_truth_data)} records from 'flight_test_data.json'.")
    except Exception as e:
        print(f"ERROR: Could not load test data: {e}")
        ground_truth_data = []

if not user_llm_client:
    print("Can't run: LLM client isn't configured (missing API key).")
elif not ground_truth_data:
    print("Can't run: Ground truth data isn't loaded.")
elif TEST_INDEX >= len(ground_truth_data):
    print(f"Can't run: test index {TEST_INDEX} is out of bounds (dataset has {len(ground_truth_data)} records).")
else:
    record = ground_truth_data[TEST_INDEX]

    print(f"\n\nRunning single LLM-driven test (index {TEST_INDEX}, persona: {SELECTED_PERSONA})\n")
    print("Ground truth for this test")
    print(json.dumps(record, indent=2))
    print("\nStarting conversation\n")

    test_result = run_llm_driven_test(record, TEST_INDEX, SELECTED_PERSONA, print_conversation=True)

    if "error" in test_result:
        print(f"Result: Error - {test_result['error']}")
        print(f"Metrics at time of failure: {json.dumps(test_result.get('metrics', {}), indent=2)}")
    else:
        metrics = test_result["metrics"]
        print(f"\nFinal test analysis (index {TEST_INDEX}, persona: {SELECTED_PERSONA})\n")

        expected_data_for_comparison = metrics['expected_for_evaluation']
        extracted_data_for_comparison = metrics['extracted_for_evaluation']
        is_perfect_match = (metrics['f1_score'] == 1.0 and metrics['false_positives'] == 0)
        verdict = 'PASS' if is_perfect_match else 'FAIL'

        print("\nExpected data (ground truth)")
        print(json.dumps(expected_data_for_comparison, indent=2))
        print("\nExtracted data (from conversation)")
        print(json.dumps(extracted_data_for_comparison, indent=2))
        print(f"\nResult: {verdict}")

        if not is_perfect_match:
            print("\nMismatches:")
            for key in metrics['expected_for_evaluation'].keys():
                expected_value = expected_data_for_comparison.get(key)
                extracted_value = extracted_data_for_comparison.get(key)
                if expected_value != extracted_value:
                    print(f"  - Field: '{key}' | Expected: {expected_value} | Extracted: {extracted_value}")

        print(f"\nMetrics:")
        print(f"  Precision: {metrics['precision']:.4f}")
        print(f"  Recall: {metrics['recall']:.4f}")
        print(f"  F1-Score: {metrics['f1_score']:.4f}")
        print(f"  Conversation Length (messages): {metrics['conversation_length']}")
        print(f"  True Positives: {metrics['true_positives']}")
        print(f"  False Positives: {metrics['false_positives']}")
        print(f"  False Negatives: {metrics['false_negatives']}")


In [None]:
# Batch test runner

# Runs a slice of test cases for multiple personas and writes results to CSV.

import os

TEST_CASE_RANGE_STR = "0:20"
PERSONAS_TO_TEST = ["direct_simple", "comprehensive", "ambiguous", "error_prone"]
RESULTS_CSV_FILE = "llm_test_results.csv"

if 'ground_truth_data' not in globals() or not ground_truth_data:
    try:
        with open("flight_test_data.json", 'r') as f:
            ground_truth_data = json.load(f)
        print(f"Loaded {len(ground_truth_data)} records from 'flight_test_data.json'.")
    except Exception as e:
        print(f"ERROR: Could not load test data: {e}")
        ground_truth_data = []

if not user_llm_client:
    print("Can't run tests: LLM client isn't configured (missing API key).")
elif not ground_truth_data:
    print("Can't run tests: Ground truth data isn't loaded.")
else:
    # Load existing results if the CSV file exists
    if os.path.exists(RESULTS_CSV_FILE):
        try:
            existing_df = pd.read_csv(RESULTS_CSV_FILE)
            results_summary = existing_df.to_dict(orient='records')
            print(f"Loaded {len(results_summary)} previous results from '{RESULTS_CSV_FILE}'.")
        except pd.errors.EmptyDataError:
            print(f"Existing results file '{RESULTS_CSV_FILE}' is empty. Starting fresh.")
            results_summary = []
        except Exception as e:
            print(f"Error loading existing results from '{RESULTS_CSV_FILE}': {e}. Starting fresh.")
            results_summary = []
    else:
        results_summary = []
        print("No existing results found—starting fresh.")

    fields_expected_from_extractor = [
        "departure_city", "arrival_city", "departure_date", "passengers"
    ]

    # Determine the actual range of test indices to run
    if TEST_CASE_RANGE_STR.lower() == "all":
        start_index = 0
        end_index = len(ground_truth_data)
    else:
        try:
            parts = TEST_CASE_RANGE_STR.split(':')
            start_part = parts[0].strip()
            end_part = parts[1].strip()
            start_index = int(start_part) if start_part else 0
            end_index = int(end_part) if end_part else len(ground_truth_data)
            # Ensure indices are within bounds
            start_index = max(0, min(start_index, len(ground_truth_data)))
            end_index = max(0, min(end_index, len(ground_truth_data)))
            if start_index >= end_index:
                print(f"Invalid test case range '{TEST_CASE_RANGE_STR}'. No tests will run.")
                start_index = 0
                end_index = 0
        except Exception as e:
            print(f"Couldn't parse test case range '{TEST_CASE_RANGE_STR}': {e}. Running all tests instead.")
            start_index = 0
            end_index = len(ground_truth_data)

    # Filter ground_truth_data to the specified range
    test_cases_to_run = ground_truth_data[start_index:end_index]
    total_test_cases_in_range = len(test_cases_to_run)

    # Initialize overall metrics accumulators per persona
    overall_metrics_by_persona = {
        p: {'tp': 0, 'fp': 0, 'fn': 0, 'length': 0, 'tests_run': 0} for p in PERSONAS_TO_TEST
    }

    total_tests_run_overall = 0

    for persona_name in PERSONAS_TO_TEST:
        print(f"\nRunning tests for persona: {persona_name}")
        if total_test_cases_in_range == 0:
            print("  No test cases to run for this range.")
            continue

        for i, record in enumerate(test_cases_to_run):
            original_test_index = start_index + i

            # Check if this test (persona + index) already exists and isn't an error
            already_run = False
            for existing_result in results_summary:
                if (existing_result.get('Persona') == persona_name and
                    existing_result.get('Index') == original_test_index and
                    existing_result.get('Verdict') != 'ERROR'):
                    already_run = True
                    if 'TP' in existing_result:
                        overall_metrics_by_persona[persona_name]['tp'] += existing_result['TP']
                        overall_metrics_by_persona[persona_name]['fp'] += existing_result['FP']
                        overall_metrics_by_persona[persona_name]['fn'] += existing_result['FN']
                        overall_metrics_by_persona[persona_name]['length'] += existing_result['Length']
                        overall_metrics_by_persona[persona_name]['tests_run'] += 1
                        total_tests_run_overall += 1
                    break

            if already_run:
                print(f"  Skipping test case {i+1}/{total_test_cases_in_range} (original index: {original_test_index}) for {persona_name} (already completed).")
                continue

            print(f"  Running test case {i+1}/{total_test_cases_in_range} (original index: {original_test_index}) for {persona_name}...")
            test_result = run_llm_driven_test(record, original_test_index, persona_name, print_conversation=False)

            metrics = test_result.get('metrics', {})

            # Accumulate overall metrics for this persona (newly run tests)
            overall_metrics_by_persona[persona_name]['tp'] += metrics.get('true_positives', 0)
            overall_metrics_by_persona[persona_name]['fp'] += metrics.get('false_positives', 0)
            overall_metrics_by_persona[persona_name]['fn'] += metrics.get('false_negatives', 0)
            overall_metrics_by_persona[persona_name]['length'] += metrics.get('conversation_length', 0)
            overall_metrics_by_persona[persona_name]['tests_run'] += 1
            total_tests_run_overall += 1

            verdict = 'ERROR' if "error" in test_result else ('PASS' if (metrics.get('f1_score', 0) == 1.0 and metrics.get('false_positives', 0) == 0) else 'FAIL')

            result_row = {
                'Persona': persona_name,
                'Index': original_test_index,
                'Verdict': verdict,
                'Precision': metrics.get('precision', 0.0),
                'Recall': metrics.get('recall', 0.0),
                'F1-Score': metrics.get('f1_score', 0.0),
                'Length': metrics.get('conversation_length', 0),
                'TP': metrics.get('true_positives', 0),
                'FP': metrics.get('false_positives', 0),
                'FN': metrics.get('false_negatives', 0),
                'Expected': json.dumps(metrics.get('expected_for_evaluation', {})),
                'Extracted': json.dumps(metrics.get('extracted_for_evaluation', {}))
            }
            results_summary.append(result_row)

            # Save results after each test
            current_summary_df = pd.DataFrame(results_summary)
            current_summary_df.to_csv(RESULTS_CSV_FILE, index=False)

    print("\n\nBatch test summary")
    summary_df = pd.DataFrame(results_summary)
    pd.set_option('display.max_columns', None)
    pd.set_option('display.width', 1000)
    if not summary_df.empty:
        print(summary_df.to_string())
    else:
        print("No tests were run or loaded.")

    print(f"\nOverall performance across selected tests (total {total_tests_run_overall} runs):\n")
    if total_tests_run_overall == 0:
        print("  No data to calculate overall metrics.")
    else:
        for p_name, p_metrics in overall_metrics_by_persona.items():
            if p_metrics['tests_run'] > 0:
                overall_precision = p_metrics['tp'] / (p_metrics['tp'] + p_metrics['fp']) if (p_metrics['tp'] + p_metrics['fp']) > 0 else 0.0
                overall_recall = p_metrics['tp'] / (p_metrics['tp'] + p_metrics['fn']) if (p_metrics['tp'] + p_metrics['fn']) > 0 else 0.0
                overall_f1_score = 2 * (overall_precision * overall_recall) / (overall_precision + overall_recall) if (overall_precision + overall_recall) > 0 else 0.0
                avg_length = p_metrics['length'] / p_metrics['tests_run']

                persona_summary_df = summary_df[summary_df['Persona'] == p_name]
                persona_pass_count = persona_summary_df[persona_summary_df['Verdict'] == 'PASS'].shape[0]
                persona_fail_count = persona_summary_df[persona_summary_df['Verdict'] == 'FAIL'].shape[0]
                persona_error_count = persona_summary_df[persona_summary_df['Verdict'] == 'ERROR'].shape[0]

                print(f"Persona: {p_name}")
                print(f"  Total tests: {p_metrics['tests_run']}")
                print(f"  Passed: {persona_pass_count}")
                print(f"  Failed: {persona_fail_count}")
                print(f"  Errors: {persona_error_count}")
                print(f"  Precision: {overall_precision:.4f}")
                print(f"  Recall: {overall_recall:.4f}")
                print(f"  F1-Score: {overall_f1_score:.4f}")
                print(f"  Average conversation length: {avg_length:.2f} messages")
            else:
                print(f"Persona: {p_name}")
                print("  No tests run for this persona in the selected range.")
