In [72]:
%run ../utils/init_env.py

In [83]:
import requests
import config
import json
import random
import re
from typing import List, Dict

# Configuration
LMSTUDIO_API_URL = config.get_model_config()['api_url']
print(LMSTUDIO_API_URL)

http://127.0.0.1:52415/v1/chat/completions


In [74]:

config.get_model_config()['default_model']

'llama-3.2-1b'

In [100]:
def normalize_sql(sql: str) -> str:
    sql = re.sub(r"```sql|```|<\|eot_id\|>", "", sql, flags=re.IGNORECASE).strip()
    sql = sql.replace(";", "")
    sql = re.sub(r"\s+", " ", sql).strip()
    return sql

def send_prompt_to_lmstudio(messages: list, max_tokens: int = 500) -> str:
    """
    Send a single prompt to LMStudio with an optional system prompt and return the response
    """
    headers = config.get_model_config()['headers']()
    
    # payload = {
    #     "prompt": full_prompt,
    #     "max_tokens": max_tokens,
    #     "temperature": 0.7,  # Adjust as needed
    #     "top_p": 1.0
    # }
    payload = {
        'model': 'llama-3.2-1b',
        # 'model': 'phi-4',
        'messages': messages,
        'temperature': '0.7',
        "max_tokens": max_tokens,
    }
    
    try:
        response = requests.post(LMSTUDIO_API_URL, headers=headers, json=payload)
        response.raise_for_status()
        return response.json()["choices"][0]["message"]["content"].strip()
    except requests.exceptions.RequestException as e:
        print(f"Error connecting to LMStudio: {e}")
        return None

def evaluate_prompts(prompts: List[str], expected_responses: List[str]) -> Dict:
    """
    Evaluate multiple prompts with an optional system prompt and compare with expected responses
    """
    results = {
        "total_prompts": len(prompts),
        "successful_matches": 0,
        "responses": []
    }
    
    if len(prompts) != len(expected_responses):
        raise ValueError("Number of prompts must match number of expected responses")
    
    for i, (prompt, expected) in enumerate(zip(prompts, expected_responses)):
        # Send prompt to LLM with system prompt
        actual_response = send_prompt_to_lmstudio([prompt])
        
        if actual_response is None:
            continue

        actual_response = normalize_sql(actual_response)
        
        # Simple comparison (case-insensitive exact match)
        is_match = actual_response.lower() == expected.lower()
        if is_match:
            results["successful_matches"] += 1
            
        # Store detailed results
        results["responses"].append({
            "prompt": prompt,
            "expected": expected,
            "actual": actual_response,
            "match": is_match
        })
    
    # Calculate accuracy
    results["accuracy"] = results["successful_matches"] / results["total_prompts"] * 100
    return results

In [101]:
with open(os.path.join(config.PREPROCESSED_JSON), 'r', encoding='utf-8') as f:
    gold_sql_queries = [e['gold_sql'] for e in json.load(f)]

with open(os.path.join(config.PROMPTS_JSON), 'r', encoding='utf-8') as f:
    prompts = [{'role': 'user', 'content': prompt} for prompt in json.load(f)]

random.seed(42)
sample_size = 5
sampled_test_data = random.sample(list(zip(prompts, gold_sql_queries)), sample_size)

test_prompts = [sample[0] for sample in sampled_test_data]
expected_responses = [sample[1] for sample in sampled_test_data]

# # Run evaluation
# # results = evaluate_prompts(test_prompts, expected_responses)
# for prompt, exp_response in sampled_test_data:
#     print(f'EXPECTED:  {exp_response}')
#     print(f'PREDICTED: {send_prompt_to_lmstudio([prompt])}')
#     print()

results = evaluate_prompts(test_prompts, expected_responses)

In [102]:
# Display results
print(f"Evaluation Results:")
print(f"Total Prompts: {results['total_prompts']}")
print(f"Successful Matches: {results['successful_matches']}")
print(f"Accuracy: {results['accuracy']:.2f}%")
print("\nDetailed Results:")

for result in results['responses']:
    print(f"\nPrompt: {result['prompt']}")
    print(f"Expected: {result['expected']}")
    print(f"Actual: {result['actual']}")
    print(f"Match: {result['match']}")
    print("-" * 80)

Evaluation Results:
Total Prompts: 5
Successful Matches: 1
Accuracy: 20.00%

Detailed Results:

Prompt: {'role': 'user', 'content': '### Answer the question by SQLite SQL query only and with no explanation. You must minimize SQL execution time while ensuring correctness.\n### Sqlite SQL tables, with their properties:\n#\n# Students(student_id, bio_data, student_details)\n# Transcripts(transcript_id, student_id, date_of_transcript, transcript_details)\n# Behaviour_Monitoring(behaviour_monitoring_id, student_id, behaviour_monitoring_details)\n# Addresses(address_id, address_details)\n# Ref_Event_Types(event_type_code, event_type_description)\n# Ref_Achievement_Type(achievement_type_code, achievement_type_description)\n# Ref_Address_Types(address_type_code, address_type_description)\n# Ref_Detention_Type(detention_type_code, detention_type_description)\n# Student_Events(event_id, event_type_code, student_id, event_date, other_details)\n# Teachers(teacher_id, teacher_details)\n# Student_Lo