In [7]:
%run ../utils/init_env.py

In [None]:
import requests
import config
import json
import random
import re
import os
from typing import List, Dict
from concurrent.futures import ThreadPoolExecutor

# Configuration
API_URL = config.get_model_config()['api_url']
print(API_URL)

http://Adrians-Mac-mini.local:12345/v1/chat/completions


In [9]:
# List of models available
# llama-3.2-1b,
# phi-4, (14b)
# deepseek-r1-distill-qwen-32b,

In [None]:
def normalize_sql(sql: str) -> str:
    sql = re.sub(r"```sql|```|<\|eot_id\|>|<\|im_end\|>", "", sql, flags=re.IGNORECASE).strip()
    sql = sql.replace(";", "")
    sql = re.sub(r"\s+", " ", sql).strip()
    return sql

def send_prompt(message: str) -> str:
    """
    Send a single prompt to LMStudio with an optional system prompt and return the response
    """
    options = {
        'model': 'phi-4',
        'temperature': 0.2,
    }

    headers = config.get_model_config()['headers']()
    
    payload = {
        'model': options['model'],
        'messages': [{"role": "user", "content": message}],
        'temperature': options['temperature'],
    }
    
    try:
        response = requests.post(API_URL, headers=headers, json=payload)
        response.raise_for_status()
        return response.json()["choices"][0]["message"]["content"].strip()
    except requests.exceptions.RequestException as e:
        print(f"Error connecting to LMStudio: {e}")
        return None

def evaluate_single_prompt(prompt: str, expected: str) -> Dict:
    actual_response = send_prompt(prompt)

    if actual_response is None:
        return {
            "prompt": prompt,
            "expected": expected,
            "actual": None,
            "match": False
        }

    actual_response = normalize_sql(actual_response)
    is_match = actual_response.lower() == expected.lower()

    return {
        "prompt": prompt,
        "expected": expected,
        "actual": actual_response,
        "match": is_match
    }

def evaluate_prompts(prompts: List[str], expected_responses: List[str]) -> Dict:
    if len(prompts) != len(expected_responses):
        raise ValueError("Number of prompts must match number of expected responses")

    results = {
        "total_prompts": len(prompts),
        "successful_matches": 0,
        "responses": []
    }

    with ThreadPoolExecutor(max_workers=4) as executor:
        all_results = list(executor.map(
            evaluate_single_prompt,
            prompts,
            expected_responses
        ))

    for res in all_results:
        results["responses"].append(res)
        if res["match"]:
            results["successful_matches"] += 1

    results["accuracy"] = results["successful_matches"] / results["total_prompts"] * 100
    return results

In [31]:
with open(os.path.join(config.DATA_DIR, "prompts.json"), 'r', encoding='utf-8') as f:
    prompts = json.load(f)

random.seed(42)
sample_size = 100
sampled_test_data = random.sample(prompts, sample_size)
# sampled_test_data = prompts

test_prompts = [sample['prompt'] for sample in sampled_test_data]
expected_responses = [sample['gold_sql'] for sample in sampled_test_data]

In [32]:
results = evaluate_prompts(test_prompts, expected_responses)

In [33]:
predicted_sql = [e['actual'] for e in results['responses']]
gold_sql = [e['expected'] for e in results['responses']]

In [34]:
# Save output
with open(config.PREDICTED_SQL, "w", encoding="utf-8") as f:
    for query in predicted_sql:
        f.write(query + "\n")

with open(config.GOLD_SQL, "w", encoding="utf-8") as f:
    for query, entry in zip(gold_sql, sampled_test_data):
        f.write(f'{query}\t{entry['db']}\n')

In [None]:
# # Display results
# print(f"Evaluation Results:")
# print(f"Total Prompts: {results['total_prompts']}")
# print(f"Successful Matches: {results['successful_matches']}")
# print(f"Accuracy: {results['accuracy']:.2f}%")
# print("\nDetailed Results:")

# for result in results['responses']:
#     print(f"\nPrompt: {json.dumps(result['prompt'])}")
#     print(f"Expected: {result['expected']}")
#     print(f"Actual: {result['actual']}")
#     print(f"Match: {result['match']}")
#     print("-" * 80)