In [5]:
%run ../utils/init_env.py

In [6]:
import requests
import config
import json
import random
import re
import os
from typing import List, Dict

# Configuration
LMSTUDIO_API_URL = config.get_model_config()['api_url']
print(LMSTUDIO_API_URL)

http://1.1.1.1:52415/v1/chat/completions


In [7]:

config.get_model_config()['default_model']

'llama-3.2-1b'

In [8]:
def normalize_sql(sql: str) -> str:
    sql = re.sub(r"```sql|```|<\|eot_id\|>|<\|im_end\|>", "", sql, flags=re.IGNORECASE).strip()
    sql = sql.replace(";", "")
    sql = re.sub(r"\s+", " ", sql).strip()
    return sql

def send_prompt_to_lmstudio(messages: list, max_tokens: int = 500) -> str:
    """
    Send a single prompt to LMStudio with an optional system prompt and return the response
    """
    headers = config.get_model_config()['headers']()
    
    # payload = {
    #     "prompt": full_prompt,
    #     "max_tokens": max_tokens,
    #     "temperature": 0.7,  # Adjust as needed
    #     "top_p": 1.0
    # }
    payload = {
        # 'model': 'llama-3.2-1b',
        'model': 'phi-4',
        'messages': messages,
        'temperature': '0.1',
    }
    
    try:
        response = requests.post(LMSTUDIO_API_URL, headers=headers, json=payload)
        response.raise_for_status()
        return response.json()["choices"][0]["message"]["content"].strip()
    except requests.exceptions.RequestException as e:
        print(f"Error connecting to LMStudio: {e}")
        return None

def evaluate_prompts(prompts: list[str], expected_responses: list[str]) -> dict:
    """
    Evaluate multiple prompts with an optional system prompt and compare with expected responses
    """
    results = {
        "total_prompts": len(prompts),
        "successful_matches": 0,
        "responses": []
    }
    
    if len(prompts) != len(expected_responses):
        raise ValueError("Number of prompts must match number of expected responses")
    
    for i, (prompt, expected) in enumerate(zip(prompts, expected_responses)):
        # Send prompt to LLM with system prompt
        actual_response = send_prompt_to_lmstudio([prompt])
        
        if actual_response is None:
            continue

        actual_response = normalize_sql(actual_response)
        
        # Simple comparison (case-insensitive exact match)
        is_match = actual_response.lower() == expected.lower()
        if is_match:
            results["successful_matches"] += 1
            
        # Store detailed results
        results["responses"].append({
            "prompt": prompt,
            "expected": expected,
            "actual": actual_response,
            "match": is_match
        })
    
    # Calculate accuracy
    results["accuracy"] = results["successful_matches"] / results["total_prompts"] * 100
    return results

In [9]:
with open(os.path.join(config.PROMPTS_JSON), 'r', encoding='utf-8') as f:
    prompts = json.load(f)

random.seed(42)
sample_size = 5
sampled_test_data = random.sample(prompts, sample_size)

test_prompts = [{'role': 'user', 'content': sample['prompt']} for sample in sampled_test_data]
expected_responses = [sample['gold_sql'] for sample in sampled_test_data]

results = evaluate_prompts(test_prompts, expected_responses)

In [10]:
# results

In [11]:
predicted_sql = [e['actual'] for e in results['responses']]
gold_sql = [e['expected'] for e in results['responses']]

In [12]:
# Save output
with open(config.PREDICTED_SQL, "w", encoding="utf-8") as f:
    for query in predicted_sql:
        f.write(query + "\n")

with open(config.GOLD_SQL, "w", encoding="utf-8") as f:
    for query, entry in zip(gold_sql, sampled_test_data):
        f.write(f'{query}\t{entry['db']}\n')

In [13]:
# # Display results
# print(f"Evaluation Results:")
# print(f"Total Prompts: {results['total_prompts']}")
# print(f"Successful Matches: {results['successful_matches']}")
# print(f"Accuracy: {results['accuracy']:.2f}%")
# print("\nDetailed Results:")

# for result in results['responses']:
#     print(f"\nPrompt: {json.dumps(result['prompt'])}")
#     print(f"Expected: {result['expected']}")
#     print(f"Actual: {result['actual']}")
#     print(f"Match: {result['match']}")
#     print("-" * 80)