In [None]:
import time
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from langchain.chat_models import ChatOpenAI
from langchain.schema import HumanMessage
import os

# Add your OpenAI API key here
openai_api_key = os.environ.get("OPENAI_API_KEY")

def load_local_model(model_name):
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16, device_map="auto")
    return model, tokenizer

def generate_sql_with_local_model(model, tokenizer, prompt):
    # Define device: CUDA if available, else CPU
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    # Move model to the correct device
    model.to(device)
    
    # Tokenize the input and move to the same device as the model
    inputs = tokenizer(prompt, return_tensors="pt").to(device)
    
    # Generate output
    outputs = model.generate(**inputs, max_new_tokens=100, pad_token_id=tokenizer.eos_token_id)
    response = tokenizer.decode(outputs[0], skip_special_tokens=True)
    return response

def generate_sql_with_openai_model(prompt):
    openai_model = ChatOpenAI(model="gpt-4o", api_key=openai_api_key)  # Include API key here
    response = openai_model([HumanMessage(content=prompt)])  # Wrap the prompt in HumanMessage
    return response.content  # Access the content of the response message

def compare_models(models, prompt):
    results = {}
    for model_name, (model, tokenizer) in models.items():
        print(f"Testing model: {model_name}")
        start_time = time.time()
        
        # Generate SQL query
        response = generate_sql_with_local_model(model, tokenizer, prompt)
        
        # Calculate response time
        response_time = time.time() - start_time
        results[model_name] = {
            "response": response,
            "response_time": response_time
        }
        print(f"Response from {model_name}: {response}")
        print(f"Response time: {response_time:.2f} seconds\n")

    # Test OpenAI model
    print("Testing OpenAI model")
    start_time = time.time()
    openai_response = generate_sql_with_openai_model(prompt)
    openai_response_time = time.time() - start_time
    results["OpenAI"] = {
        "response": openai_response,
        "response_time": openai_response_time
    }
    print(f"Response from OpenAI model: {openai_response}")
    print(f"Response time: {openai_response_time:.2f} seconds\n")
    
    return results

# Define your models and prompt
local_models = {
    "NumbersStation/nsql-llama-2-7B": load_local_model("NumbersStation/nsql-llama-2-7B"),
    "defog/llama-3-sqlcoder-8B": load_local_model("defog/llama-3-sqlcoder-8B"),
    "defog/sqlcoder-7b-2": load_local_model("defog/sqlcoder-7b-2"),
}

# Define the prompt you want to test
test_prompt = """
Write an SQL query to find the top 5 products with the highest sales in each category.
"""

# Run comparison
results = compare_models(local_models, test_prompt)




Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]



Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]



Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Testing model: NumbersStation/nsql-llama-2-7B
Response from NumbersStation/nsql-llama-2-7B: 
Write an SQL query to find the top 5 products with the highest sales in each category.
SELECT * FROM Products ORDER BY Sales DESC LIMIT 5
Response time: 3.12 seconds

Testing model: defog/llama-3-sqlcoder-8B
Response from defog/llama-3-sqlcoder-8B: 
Write an SQL query to find the top 5 products with the highest sales in each category.
SELECT p.product_name, p.category, SUM(o.quantity) AS total_sales FROM Products p JOIN Orders o ON p.product_id = o.product_id GROUP BY p.product_name, p.category ORDER BY total_sales DESC, p.category ASC LIMIT 5;
Response time: 102.63 seconds

Testing model: defog/sqlcoder-7b-2
Response from defog/sqlcoder-7b-2: 
Write an SQL query to find the top 5 products with the highest sales in each category.

 SELECT p.Category, p.Name, SUM(ps.Quantity) AS TotalQuantity FROM Products p JOIN ProductSales ps ON p.ID = ps.ProductID GROUP BY p.Category, p.Name ORDER BY TotalQu

Here's an analysis of how each model performed based on the response content and response time:

### 1. **NumbersStation/nsql-llama-2-7B**
   - **Generated SQL**: `SELECT * FROM Products ORDER BY Sales DESC LIMIT 5`
   - **Response Time**: 3.12 seconds
   - **Analysis**: 
     - The query is incorrect because it simply retrieves the top 5 products across all categories rather than finding the top 5 products *within each category*. 
     - This is a basic query that lacks the necessary grouping by category and doesn't meet the prompt requirements.

### 2. **defog/llama-3-sqlcoder-8B**
   - **Generated SQL**:
     ```sql
     SELECT p.product_name, p.category, SUM(o.quantity) AS total_sales 
     FROM Products p 
     JOIN Orders o ON p.product_id = o.product_id 
     GROUP BY p.product_name, p.category 
     ORDER BY total_sales DESC, p.category ASC LIMIT 5;
     ```
   - **Response Time**: 102.63 seconds
   - **Analysis**:
     - This query is more complex, joining tables and summing sales per product within categories.
     - However, it still fails to partition by category and doesn't retrieve the top 5 products *within each category*.
     - Response time was significantly slower than the other models.

### 3. **defog/sqlcoder-7b-2**
   - **Generated SQL**:
     ```sql
     SELECT p.Category, p.Name, SUM(ps.Quantity) AS TotalQuantity 
     FROM Products p 
     JOIN ProductSales ps ON p.ID = ps.ProductID 
     GROUP BY p.Category, p.Name 
     ORDER BY TotalQuantity DESC LIMIT 5;
     ```
   - **Response Time**: 137.13 seconds
   - **Analysis**:
     - This query includes joins and aggregation by category and product name.
     - Like the previous model, it does not partition by category to retrieve the top 5 in each category, so it doesn’t fully meet the prompt requirements.
     - This model took the longest time to respond.

### 4. **OpenAI GPT-4 (via `ChatOpenAI`)**
   - **Generated SQL**:
     ```sql
     WITH RankedProducts AS (
         SELECT
             p.product_id,
             p.product_name,
             p.category_id,
             c.category_name,
             SUM(s.sales_amount) AS total_sales,
             RANK() OVER (PARTITION BY p.category_id ORDER BY SUM(s.sales_amount) DESC) AS sales_rank
         FROM
             Products p
         JOIN
             Sales s ON p.product_id = s.product_id
         JOIN
             Categories c ON p.category_id = c.category_id
         GROUP BY
             p.product_id, p.product_name, p.category_id, c.category_name
     )
     SELECT
         product_id,
         product_name,
         category_id,
         category_name,
         total_sales
     FROM
         RankedProducts
     WHERE
         sales_rank <= 5
     ORDER BY
         category_id,
         sales_rank;
     ```
   - **Response Time**: 9.11 seconds
   - **Analysis**:
     - This response meets the prompt's requirements accurately. It uses a Common Table Expression (CTE) with a `RANK()` window function to find the top 5 products in each category.
     - The query is both structurally and logically correct for the task.
     - The response time was significantly faster than the `defog` models, although slower than `nsql-llama-2-7B`.

### **Performance Summary**

| Model                            | Accuracy                               | Response Time (seconds) | Comment                                                                 |
|----------------------------------|----------------------------------------|--------------------------|-------------------------------------------------------------------------|
| **NumbersStation/nsql-llama-2-7B** | Incorrect SQL query                   | 3.12                     | Fast response but incorrect query.                                      |
| **defog/llama-3-sqlcoder-8B**      | Partially correct SQL query           | 102.63                   | Complex but incorrect query, very slow response.                        |
| **defog/sqlcoder-7b-2**            | Partially correct SQL query           | 137.13                   | Complex but incorrect query, slowest response.                          |
| **OpenAI GPT-4**                   | Correct SQL query with explanations   | 9.11                     | Accurate SQL generation with reasonable response time and explanation.  |

### Conclusion
- **OpenAI GPT-4** performed the best in terms of accuracy and query structure, meeting the requirements effectively and providing a detailed explanation. Its response time was reasonable, especially considering the additional explanation.
- **NumbersStation/nsql-llama-2-7B** was the fastest, but it generated an overly simplistic and incorrect query, which doesn’t fulfill the requirements.
- **defog models** (`llama-3-sqlcoder-8B` and `sqlcoder-7b-2`) generated partially correct queries but were very slow, taking over 100 seconds each.

**Recommendation**: For applications requiring accurate and complex SQL generation, OpenAI’s GPT-4 seems more reliable. If response time is more critical and accuracy can be sacrificed, `NumbersStation/nsql-llama-2-7B` might be a better choice.