In [3]:
import time 
import torch
import psycopg2
from transformers import AutoTokenizer, AutoModelForCausalLM
from langchain.chat_models import ChatOpenAI
from langchain.schema import HumanMessage
import os
import pandas as pd
from dotenv import load_dotenv
load_dotenv()

True

### Test database connection

In [4]:
def test_db_connection():
    """Test database connection and print table information."""
    try:
        conn = psycopg2.connect(
            dbname="test_db_val",
            user="postgres",
            password=os.environ.get("POSTGRES_PASSWORD"),
            host="localhost"
        )
        cur = conn.cursor()
        
        # Get counts for relevant tables
        tables = [
            'professional_details',
            'professional_locations',
            'languages_spoken_by_professional',
            'professional_availabilities',
            'professional_areas_of_interest',
            'age_group_professional_works_with',
            'type_of_funding_accepted'
        ]
        
        print("\nTable row counts:")
        for table in tables:
            try:
                cur.execute(f"SELECT COUNT(*) FROM {table}")
                count = cur.fetchone()[0]
                print(f"{table}: {count} rows")
            except Exception as e:
                print(f"Error querying {table}: {e}")
        
        # Test a simple join query
        print("\nTesting sample join query:")
        cur.execute("""
            SELECT 
                pd.persona, 
                COUNT(*) as count
            FROM professional_details pd
            GROUP BY pd.persona;
        """)
        
        results = cur.fetchall()
        print("\nProfessionals by type:")
        for persona, count in results:
            print(f"{persona}: {count}")
            
        conn.close()
        print("\nDatabase connection test successful!")
        return True
        
    except Exception as e:
        print(f"Database connection failed: {e}")
        return False

if __name__ == "__main__":
    import os
    from dotenv import load_dotenv
    import psycopg2
    
    load_dotenv()
    test_db_connection()


Table row counts:
professional_details: 7018 rows
professional_locations: 9782 rows
languages_spoken_by_professional: 6177 rows
professional_availabilities: 1507 rows
professional_areas_of_interest: 6164 rows
age_group_professional_works_with: 6164 rows
type_of_funding_accepted: 6164 rows

Testing sample join query:

Professionals by type:
Occupational Therapist: 2696
Speech pathologist: 2615
Psychologist: 1707

Database connection test successful!


In [5]:
# Add your OpenAI API key here
openai_api_key = os.environ.get("OPENAI_API_KEY")

# Updated schema context matching actual database
schema_context = """
You are a SQL expert. Generate PostgreSQL queries for a healthcare professional database with the following schema:

Tables and their key columns:
1. professional_details
   - professional_id (PK, varchar)
   - first_name (varchar)
   - last_name (varchar)
   - persona (varchar)
   - email (varchar)
   - phone_number (varchar)
   - additional_phone_numbers (text)
   - summary (text)
   - about (text)
   - telepractice_offered (boolean)
   - mobile_services_offered (boolean)
   - mobile_services_areas (text)

2. professional_locations
   - id (PK, integer)
   - professional_id (FK, varchar)
   - location (geometry)
   - city_suburb (text)
   - state_region (text)
   - postcode (text)
   - latitude (double precision)
   - longitude (double precision)
   - formatted_address (text)

3. languages_spoken_by_professional
   - professional_id (FK, varchar)
   - Mandarin (boolean)
   - English (boolean)
   - [plus many other language columns, all boolean]

4. professional_availabilities
   - id (PK, integer)
   - professional_id (FK, varchar)
   - availability_notes (text)
   - immediate (boolean)
   - within_1_month (boolean)
   - within_3_months (boolean)
   - within_6_months (boolean)
   - within_12_months (boolean)
   - more_than_12_months (boolean)
   - last_updated (timestamp with time zone)

5. professional_areas_of_interest
   - professional_id (FK, varchar)
   - Autism (boolean)
   - Disability (boolean)
   - Mental_Health (boolean)
   - Anxiety (boolean)
   - Depression (boolean)
   - [plus many other specialty columns, all boolean]

6. age_group_professional_works_with
   - professional_id (FK, varchar)
   - infants (boolean)
   - pre_school_children (boolean)
   - school_aged_children (boolean)
   - adolescents (boolean)
   - adults (boolean)
   - aged (boolean)

7. type_of_funding_accepted
   - professional_id (FK, varchar)
   - ndis_registered (boolean)
   - ndis_non_registered (boolean)
   - medicare (boolean)
   - private_health_insurance (boolean)
   - department_of_veterans_affairs (boolean)
   - workers_compensation (boolean)

Important Notes:
- All tables use professional_id as their link to professional_details
- Location queries can use either:
  * direct latitude/longitude comparison
  * OR PostGIS functions like ST_DWithin(ST_SetSRID(ST_MakePoint(long, lat), 4326), location, distance_meters)
- For name searches, use the similarity() function
- Check availability using multiple time window columns (immediate, within_1_month, etc.)
- Always include relevant contact and location details in results where appropriate
"""

def load_local_model(model_name_or_path):
    tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
    model = AutoModelForCausalLM.from_pretrained(model_name_or_path, torch_dtype=torch.float16)
    return model, tokenizer

def generate_sql_with_local_model(model, tokenizer, prompt):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    
    # Combine schema context with prompt
    full_prompt = f"{schema_context}\n\nWrite a PostgreSQL query to: {prompt}\n\nReturn ONLY the SQL query, no explanations or markdown."
    
    try:
        inputs = tokenizer(full_prompt, return_tensors="pt").to(device)
        
        with torch.no_grad():
            outputs = model.generate(
                **inputs,
                max_new_tokens=500,
                pad_token_id=tokenizer.eos_token_id,
                temperature=0.7,  # Add some randomness
                top_p=0.9,       # Nucleus sampling
                do_sample=True   # Enable sampling
            )
        
        response = tokenizer.decode(outputs[0], skip_special_tokens=True)
        
        # Extract just the SQL query part
        response = response.split(full_prompt)[-1].strip()
        
        # Clean up the response
        response = response.replace('```sql', '').replace('```', '').strip()
        
        # Verify it looks like SQL
        if not any(keyword in response.lower() for keyword in ['select', 'from', 'where']):
            return "SELECT 'Error: No valid SQL generated'"
            
        model.to("cpu")
        torch.cuda.empty_cache()
        return response
        
    except Exception as e:
        return f"SELECT 'Error: {str(e)}'"

def generate_sql_with_openai_model(prompt):
    openai_model = ChatOpenAI(model="gpt-4o", api_key=openai_api_key)
    full_prompt = f"{schema_context}\n\nWrite a PostgreSQL query to: {prompt}\n\nReturn ONLY the SQL query, no explanations."
    response = openai_model([HumanMessage(content=full_prompt)])
    return response.content

def evaluate_query(query, expected_elements):
    """Evaluate a SQL query by checking syntax and expected elements."""
    if not query or not isinstance(query, str):
        return 0, ["✗ No valid query generated"]
        
    score = 0
    comments = []
    
    try:
        # Clean the query - remove any markdown code blocks if present
        query = query.strip()
        query = query.replace('```sql', '').replace('```', '')
        query = query.lower()
        
        # Check if it looks like SQL
        if not any(keyword in query for keyword in ['select', 'from', 'where']):
            return 0, ["✗ Response does not appear to be SQL"]
            
        # Check for expected elements
        total_elements = len(expected_elements)
        for element in expected_elements:
            if element.lower() in query:
                score += 1
                comments.append(f"✓ Contains {element}")
            else:
                comments.append(f"✗ Missing {element}")
        
        # Try to validate SQL syntax if sqlparse is available
        try:
            import sqlparse
            parsed = sqlparse.parse(query)
            if parsed:
                comments.append("✓ Valid SQL syntax")
            else:
                comments.append("✗ Invalid SQL syntax")
        except ImportError:
            comments.append("⚠ SQL syntax validation skipped (sqlparse not available)")
            
        return score / total_elements, comments
        
    except Exception as e:
        return 0, [f"✗ Error evaluating query: {str(e)}"]

def compare_models(models, test_cases):
    results = {}
    results_data = []
    
    for model_name, (model, tokenizer) in models.items():
        print(f"\nTesting model: {model_name}")
        
        for test_case in test_cases:
            print(f"\nRunning test case: {test_case['name']}")
            start_time = time.time()
            
            try:
                response = generate_sql_with_local_model(model, tokenizer, test_case['prompt'])
                response_time = time.time() - start_time
                score, comments = evaluate_query(response, test_case['expected_elements'])
                
                results_data.append({
                    'Model': model_name,
                    'Test Case': test_case['name'],
                    'Response Time': f"{response_time:.2f}s",
                    'Score': f"{score:.2%}",
                    'Comments': '\n'.join(comments),
                    'Generated Query': response
                })
                
                print(f"Generated SQL:\n{response}")
                print(f"Score: {score:.2%}")
                print(f"Response time: {response_time:.2f} seconds")
                print("Comments:\n" + '\n'.join(comments))
            except Exception as e:
                print(f"Error testing model {model_name} on case {test_case['name']}: {str(e)}")

    # Test OpenAI model
    print("\nTesting OpenAI model")
    for test_case in test_cases:
        print(f"\nRunning test case: {test_case['name']}")
        start_time = time.time()
        
        try:
            openai_response = generate_sql_with_openai_model(test_case['prompt'])
            openai_response_time = time.time() - start_time
            score, comments = evaluate_query(openai_response, test_case['expected_elements'])
            
            results_data.append({
                'Model': 'OpenAI GPT-4',
                'Test Case': test_case['name'],
                'Response Time': f"{openai_response_time:.2f}s",
                'Score': f"{score:.2%}",
                'Comments': '\n'.join(comments),
                'Generated Query': openai_response
            })
            
            print(f"Generated SQL:\n{openai_response}")
            print(f"Score: {score:.2%}")
            print(f"Response time: {openai_response_time:.2f} seconds")
            print("Comments:\n" + '\n'.join(comments))
        except Exception as e:
            print(f"Error testing OpenAI model on case {test_case['name']}: {str(e)}")

    # Convert results to DataFrame and calculate summary
    results_df = pd.DataFrame(results_data)
    results_df['Score'] = results_df['Score'].str.rstrip('%').astype(float) / 100

    # Display summary
    print("\nResults Summary:")
    summary_df = results_df.pivot_table(
        index='Model',
        values='Score',
        columns='Test Case',
        aggfunc='first'
    ).round(2)

    print("\nAverage Scores by Model:")
    print(summary_df.mean(axis=1).sort_values(ascending=False))

    # Save detailed results
    results_df.to_csv('sql_model_comparison_results.csv', index=False)
    
    return results_df

In [6]:
# Updated test cases based on actual schema
test_cases = [
    {
        'name': 'Geolocation Search',
        'prompt': 'find speech therapists within 5000 meters of Northcote, Melbourne (coordinates: -37.7692917, 144.9990291). Include their names, contact details, and distance. Sort by distance.',
        'expected_elements': [
            'professional_details',
            'professional_locations',
            'latitude',
            'longitude',
            'persona = \'Speech pathologist\'',
            'JOIN',
            'ORDER BY'
        ]
    },
    {
        'name': 'Immediate Availability',
        'prompt': 'find therapists who are immediately available or within 1 month AND offer telepractice. Include their names, contact details, and availability notes.',
        'expected_elements': [
            'professional_details',
            'professional_availabilities',
            'immediate = TRUE',
            'within_1_month = TRUE',
            'telepractice_offered = TRUE',
            'availability_notes',
            'JOIN'
        ]
    },
    {
        'name': 'Language and Mental Health',
        'prompt': 'find all psychologists who speak Mandarin AND specialize in anxiety or depression. Include their contact details and areas of interest.',
        'expected_elements': [
            'professional_details',
            'languages_spoken_by_professional',
            'professional_areas_of_interest',
            'Mandarin = TRUE',
            'anxiety = TRUE',
            'depression = TRUE',
            'persona = \'Psychologist\'',
            'JOIN'
        ]
    },
    {
        'name': 'Complex Age and Funding',
        'prompt': 'find occupational therapists who work with school-aged children AND accept both NDIS and Medicare funding AND offer mobile services. Include their location and contact details.',
        'expected_elements': [
            'professional_details',
            'age_group_professional_works_with',
            'type_of_funding_accepted',
            'school_aged_children = TRUE',
            'ndis_registered = TRUE',
            'medicare = TRUE',
            'mobile_services_offered = TRUE',
            'JOIN'
        ]
    },
    {
        'name': 'Full Profile Search',
        'prompt': 'get the complete profile for therapist "John Smith", including their contact details, availability, languages spoken, areas of interest, age groups they work with, and funding types accepted.',
        'expected_elements': [
            'professional_details',
            'professional_availabilities',
            'languages_spoken_by_professional',
            'professional_areas_of_interest',
            'age_group_professional_works_with',
            'type_of_funding_accepted',
            'similarity',
            'LEFT JOIN'
        ]
    },
    {
        'name': 'Advanced Availability Search',
        'prompt': 'find speech pathologists in Melbourne who work with children, are available within 3 months, and either offer telepractice or mobile services. Sort by earliest availability.',
        'expected_elements': [
            'professional_details',
            'professional_locations',
            'professional_availabilities',
            'age_group_professional_works_with',
            'within_3_months = TRUE',
            'telepractice_offered = TRUE',
            'mobile_services_offered = TRUE',
            'ORDER BY'
        ]
    }
]

# Define your models
local_models = {
    "NumbersStation/nsql-llama-2-7B": load_local_model("NumbersStation/nsql-llama-2-7B"),
    "defog/llama-3-sqlcoder-8B": load_local_model("defog/llama-3-sqlcoder-8B"),
    "defog/sqlcoder-7b-2": load_local_model("defog/sqlcoder-7b-2"),
    "Llama-3.2-1B-Instruct": load_local_model("meta-llama/Llama-3.2-1B-Instruct"),
    "Llama-3.2-3B-Instruct": load_local_model("meta-llama/Llama-3.2-3B-Instruct")
}

# Run comparison and display results
results_df = compare_models(local_models, test_cases)

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]



Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]



Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]


Testing model: NumbersStation/nsql-llama-2-7B

Running test case: Geolocation Search
Generated SQL:
SELECT 'Error: No valid SQL generated'
Score: 0.00%
Response time: 3.18 seconds
Comments:
✗ Missing professional_details
✗ Missing professional_locations
✗ Missing latitude
✗ Missing longitude
✗ Missing persona = 'Speech pathologist'
✗ Missing JOIN
✗ Missing ORDER BY
✓ Valid SQL syntax

Running test case: Immediate Availability
Generated SQL:
SELECT 'Error: No valid SQL generated'
Score: 0.00%
Response time: 0.49 seconds
Comments:
✗ Missing professional_details
✗ Missing professional_availabilities
✗ Missing immediate = TRUE
✗ Missing within_1_month = TRUE
✗ Missing telepractice_offered = TRUE
✗ Missing availability_notes
✗ Missing JOIN
✓ Valid SQL syntax

Running test case: Language and Mental Health
Generated SQL:
SELECT 'Error: No valid SQL generated'
Score: 0.00%
Response time: 0.49 seconds
Comments:
✗ Missing professional_details
✗ Missing languages_spoken_by_professional
✗ Missing

  openai_model = ChatOpenAI(model="gpt-4o", api_key=openai_api_key)
  response = openai_model([HumanMessage(content=full_prompt)])


Generated SQL:
```sql
SELECT 
    pd.first_name,
    pd.last_name,
    pd.email,
    pd.phone_number,
    ST_Distance(
        ST_SetSRID(ST_MakePoint(144.9990291, -37.7692917), 4326)::geography, 
        pl.location::geography
    ) AS distance
FROM 
    professional_details pd
JOIN 
    professional_locations pl ON pd.professional_id = pl.professional_id
WHERE 
    pd.persona = 'Speech Therapist' AND
    ST_DWithin(
        ST_SetSRID(ST_MakePoint(144.9990291, -37.7692917), 4326)::geography, 
        pl.location::geography, 
        5000
    )
ORDER BY 
    distance;
```
Score: 57.14%
Response time: 5.01 seconds
Comments:
✓ Contains professional_details
✓ Contains professional_locations
✗ Missing latitude
✗ Missing longitude
✗ Missing persona = 'Speech pathologist'
✓ Contains JOIN
✓ Contains ORDER BY
✓ Valid SQL syntax

Running test case: Immediate Availability
Generated SQL:
```sql
SELECT 
    pd.first_name,
    pd.last_name,
    pd.email,
    pd.phone_number,
    pa.availability_no

### Summary of Model Performance

| Model Name                       | Query Accuracy & Completeness                                                                                                              | Response Time (seconds) | Evaluation Summary                                 |
|----------------------------------|-------------------------------------------------------------------------------------------------------------------------------------------|--------------------------|---------------------------------------------------|
| **NumbersStation/nsql-llama-2-7B** | Incorrect; generated no meaningful SQL queries and failed to include required elements like `JOIN` and `latitude/longitude`.             | 3.18 - 7.66             | Fast but entirely inaccurate, no usable outputs. |
| **defog/llama-3-sqlcoder-8B**      | Highly accurate; produced valid SQL queries covering key elements such as `JOIN`, `latitude/longitude`, and filtering logic.              | 133.11 - 400.56         | Best accuracy; slower for complex queries.       |
| **defog/sqlcoder-7b-2**            | Incorrect; failed to generate meaningful queries, similar to `NumbersStation/nsql-llama-2-7B`.                                           | 15.46 - 25.84           | Inaccurate and missing critical SQL components.  |
| **Llama-3.2-1B-Instruct**          | Partially correct; included basic logic but often lacked essential filters and precise SQL structure.                                    | 84.50 - 412.20          | Moderate accuracy, much slower for large queries.|
| **Llama-3.2-3B-Instruct**          | Partially correct; handled complex cases better than `1B-Instruct`, with appropriate filtering and ordering but some missing elements.   | 168.09 - 1108.10        | Better completeness but slow.                   |
| **OpenAI GPT-4**                   | Highly accurate; queries included filtering, ordering, and all essential elements while maintaining speed.                              | 2.53 - 10.52            | Fastest and most reliable for accurate SQL.      |

### Evaluation Summary

1. **Best Accuracy**:
   - **defog/llama-3-sqlcoder-8B** and **OpenAI GPT-4** excelled in generating precise SQL queries. Both handled complex filtering, joins, and correct syntax, with GPT-4 providing slightly faster results.

2. **Fastest Model with Good Accuracy**:
   - **OpenAI GPT-4** balanced accuracy and response time, delivering highly accurate results in under 11 seconds.

3. **Incomplete or Incorrect Queries**:
   - **NumbersStation/nsql-llama-2-7B** and **defog/sqlcoder-7b-2** failed to generate meaningful queries, consistently omitting critical elements.

4. **Detailed but Slow**:
   - **Llama-3.2-3B-Instruct** provided detailed outputs but had slower response times, particularly for complex queries involving multiple filters and joins.

### Recommendations
- **For Best Accuracy and Speed**: **OpenAI GPT-4** is the top choice for its reliability and efficiency in generating SQL queries.
- **For Detailed and Accurate Results**: Use **defog/llama-3-sqlcoder-8B** for scenarios where slower response times are acceptable in exchange for high precision.
- **For Simplicity**: Avoid **NumbersStation/nsql-llama-2-7B** and **defog/sqlcoder-7b-2**, as they fail to meet basic query requirements.