# Test evaluation task suite for AI data designer

Setup API key to run tests below

In [12]:
import sys
import os

# Get the current working directory of the notebook
notebook_dir = os.getcwd()

# Define the root directory and add it to the path
root_dir = os.path.abspath(os.path.join(notebook_dir, '..', '..', '..'))
sys.path.insert(0, root_dir)

# set environment variable 'GRETEL_PROD_API_KEY' from https://console.gretel.ai/users/me/key
os.environ['GRETEL_PROD_API_KEY'] = 'GRETEL_PROD_API_KEY'

### Evaluate Synthetic Dataset

In [19]:
# Reload packages if you've made changes to the evaluation.py file.
# Alternatively you can restart the kernel to pick up changes

from importlib import reload
import evaluation
reload(evaluation)

import pandas as pd
from pprint import pprint
from navigator_helpers.llms.llm_suite import GretelLLMSuite
from evaluation import BaseEvaluationTaskSuite, NL2SQLEvaluationTaskSuite
from datasets import load_dataset

In [15]:
## TODO: add all datasets to test in a dictionary
dataset = load_dataset("gretelai/synthetic_text_to_sql", split="train")
dataset_1000 = dataset.select(range(1000))
dataset_1000_pd = dataset_1000.to_pandas()

In [16]:
llm_suite = GretelLLMSuite()

2024-10-08 18:20:29.317 - INFO - 🦜 Initializing LLM suite
2024-10-08 18:20:29.319 - INFO - 📖 Natural language LLM: gretelai-mistral-nemo-2407
2024-10-08 18:20:29.320 - INFO - 💻 Code LLM: gretelai-mistral-nemo-2407
2024-10-08 18:20:29.321 - INFO - ⚖️ Judge LLM: gretelai-mistral-nemo-2407


In [20]:
results_1 = BaseEvaluationTaskSuite(llm_suite, dataset_1000_pd).row_uniqueness()
pprint(results_1)

{'non_semantically_unique_ids': [],
 'non_unique_ids': [],
 'percent_semantically_unique': 100.0,
 'percent_unique': 100.0}


In [22]:

results_2 = BaseEvaluationTaskSuite(llm_suite, dataset_1000_pd).feature_cardinality()
pprint(results_2)

{'domain': 0.6,
 'domain_description': 0.6,
 'id': 1.0,
 'sql': 1.0,
 'sql_complexity': 0.06,
 'sql_complexity_description': 0.06,
 'sql_context': 1.0,
 'sql_explanation': 1.0,
 'sql_prompt': 1.0,
 'sql_task_type': 0.04,
 'sql_task_type_description': 0.04}


In [23]:

results_3 = BaseEvaluationTaskSuite(llm_suite, dataset_1000_pd).feature_distribution()
pprint(results_3)

({'domain': {'agriculture': 1,
             'aquaculture': 1,
             'arts operations and management': 1,
             'automotive': 6,
             'beauty industry': 1,
             'biotechnology': 2,
             'chemicals': 2,
             'civil engineering': 1,
             'construction': 3,
             'cybersecurity': 3,
             'defense contractors': 1,
             'defense industry': 3,
             'defense operations': 4,
             'defense security': 3,
             'disability services': 1,
             'energy': 2,
             'entertainment industry': 1,
             'ethical fashion': 2,
             'finance': 1,
             'financial services': 1,
             'fine arts': 1,
             'food industry': 2,
             'forestry': 2,
             'gaming technology': 1,
             'government policy': 1,
             'government services': 1,
             'healthcare': 2,
             'hospitality technology': 2,
             'justice': 1,
 

In [24]:

results_4 = BaseEvaluationTaskSuite(llm_suite, dataset_1000_pd).num_words_per_record()
pprint(results_4)

{'average_words_per_record': 12.88,
 'word_counts_per_column': {'domain': 1.63,
                            'domain_description': 13.3,
                            'sql': 14.83,
                            'sql_complexity': 1.71,
                            'sql_complexity_description': 8.01,
                            'sql_context': 31.05,
                            'sql_explanation': 35.5,
                            'sql_prompt': 13.88,
                            'sql_task_type': 2.93,
                            'sql_task_type_description': 5.96}}


## Testing SQL Validation

In [4]:
sql_linter_results = pd.read_csv("/mnt/foundation-shared/nina_xu_gretel_ai/datasets/sqlqueries_1200_validated_092524.csv")

In [None]:
# Coalesce the results of the five dialects into a single column
sql_linter_results['is_valid_sql'] = sql_linter_results['is_valid_sqlite'].fillna(
    sql_linter_results['is_valid_mysql'].fillna(
        sql_linter_results['is_valid_postgresql'].fillna(
            sql_linter_results['is_valid_sqlserver'].fillna(
                sql_linter_results['is_valid_googlesql']
            )
        )
    )
)

sql_linter_results['error_msg'] = sql_linter_results['error_msg_sqlite'].fillna(
    sql_linter_results['error_msg_mysql'].fillna(
        sql_linter_results['error_msg_postgresql'].fillna(
            sql_linter_results['error_msg_sqlserver'].fillna(
                sql_linter_results['error_msg_googlesql']
            )
        )
    )
)

print(sql_linter_results.is_valid_sql.value_counts())
def is_type_error(error_msg):
    # These errors are because of misuse of data types, and they are dialect-specific
    # Because the prompt as of now is not aware of the dialect, it makes sense if the LLM cannot identify these errors
    import re
    error_msg = str(error_msg).lower()

    pattern = r'type "\w+" does not exis'
    # Check if the pattern exists in the string
    if re.search(pattern, error_msg):
        return True
    
    patterns = ["type not found", "cannot find data type", "error creating tables", "'CREATE VIEW' must be the first statement in a query batch",
                "login failed"]
    for pattern in patterns:
        if pattern.lower() in error_msg:
            return True

    return False
sql_linter_results['is_type_error'] = sql_linter_results.error_msg.apply(is_type_error)
p = sql_linter_results[sql_linter_results.is_valid_sql == True].sample(10)
n = sql_linter_results[(sql_linter_results.is_valid_sql == False) & (sql_linter_results.is_type_error == False)]#.sample(10)

sql_linter_results_10 = pd.concat([p, n])
sql_linter_results_10.shape


In [None]:
task_5 = NL2SQLEvaluationTaskSuite(
    llm_suite=llm_suite, dataset=sql_linter_results_10, code_lang="sql"
    )
results_5 = task_5.llm_as_a_critic_evaluation(
    instruction_col_name="Natural Language Prompt", code_col_name="SQL Query", context_col_name="Context"
)
table5 = task_5.output_dataset

# task_6 = BaseEvaluationTaskSuite(llm_suite, dataset_10_pd)
# results_6 = task_6.llm_as_a_critic_evaluation(
#     instruction_col_name="sql_prompt", code_col_name="sql"
# )

In [33]:

print(results_5)
# print(results_6)

# review specific records
# print(dataset_10_pd.loc[results_1['non_semantically_unique_ids']])

{'llm_as_a_critic_score': 3.5764705882352943}


In [34]:
table5['correctness_score'] = table5['scores'].apply(lambda x: x['correctness_score'])

In [35]:
table5[['Natural Language Prompt', 'SQL Query', 'Context', 'Dialect', 'is_valid_sql', 'scores', 'overall_score', 'correctness_score']]

Unnamed: 0,Natural Language Prompt,SQL Query,Context,Dialect,is_valid_sql,scores,overall_score,correctness_score
927,List all patients who have been admitted to th...,SELECT p.name FROM Patients p JOIN ( SELEC...,CREATE TABLE Patients ( patient_id VARCHAR...,MySQL,True,"{'relevance_score': 4, 'relevance_reason': 'Th...",3.6,4
52,What is the average latency for each node in t...,"SELECT node_name, AVG(latency) AS avg_latency ...",CREATE TABLE network_nodes (node_id SERIAL PRI...,PostgreSQL,True,"{'relevance_score': 4, 'relevance_reason': 'Th...",3.6,4
159,Find the names of patients who have had teleme...,SELECT p.Name FROM Patients p JOIN Consultatio...,CREATE TABLE Patients (PatientID INTEGER PRIMA...,SQLite,True,"{'relevance_score': 4, 'relevance_reason': 'Th...",3.6,4
715,What is the average broadband speed for each c...,"SELECT City, AVG(Speed) AS AverageSpeed FROM B...","CREATE TABLE BroadbandData (City VARCHAR(50), ...",SQL Server,True,"{'relevance_score': 4, 'relevance_reason': 'Th...",3.6,4
827,What is the total number of defective products...,SELECT COUNT(*) AS TotalDefective FROM Defects...,CREATE TABLE Products (ProductID INT PRIMARY K...,SQL Server,True,"{'relevance_score': 4, 'relevance_reason': 'Th...",3.6,4
1131,What is the average market value of homes owne...,SELECT AVG(MarketValue) as AvgMarketValue FROM...,"CREATE TABLE Owners ( OwnerID STRING, ...",GoogleSQL,True,"{'relevance_score': 4, 'relevance_reason': 'Th...",3.6,4
394,Show me the average rating given by customers ...,"SELECT p.name, AVG(r.rating) AS average_rating...","CREATE TABLE products (id SERIAL PRIMARY KEY, ...",PostgreSQL,True,"{'relevance_score': 4, 'relevance_reason': 'Th...",3.6,4
621,How many customer service tickets were resolve...,SELECT COUNT(*) AS ResolvedTickets FROM Custom...,CREATE TABLE CustomerServiceTickets (TicketID ...,SQL Server,True,"{'relevance_score': 4, 'relevance_reason': 'Th...",3.6,4
954,How many items were sold in each category last...,"SELECT c.name AS category_name, SUM(s.quantity...","CREATE TABLE sales (id SERIAL PRIMARY KEY, ite...",PostgreSQL,True,"{'relevance_score': 4, 'relevance_reason': 'Th...",3.6,4
502,Show me the list of products in the shopping c...,"SELECT p.product_name, p.price FROM products p...","CREATE TABLE users (user_id INT PRIMARY KEY, u...",MySQL,True,"{'relevance_score': 4, 'relevance_reason': 'Th...",3.4,4


In [36]:
# calculate the average correctness scores grouped by is_valid_sql
table5.groupby('is_valid_sql')['correctness_score'].mean()

is_valid_sql
False    3.97561
True     4.00000
Name: correctness_score, dtype: float64

In [38]:
table5.to_csv("/mnt/foundation-shared/nina_xu_gretel_ai/datasets/sql_linter_results_35.csv", index=False)

In [52]:
count = 0

In [95]:
indices = table5[table5.is_valid_sql == False].index
ind = indices[count]
print(f'ind = {ind}')
print('\n', table5['error_msg'].loc[ind])
print('\n', table5['Natural Language Prompt'].loc[ind])
print('\n', table5['Context'].loc[ind])
print('\n', table5['SQL Query'].loc[ind])
print('\n', table5['scores'].loc[ind])

count += 1


IndexError: index 41 is out of bounds for axis 0 with size 41

In [54]:
print(indices)

Index([  51,   77,  144,  155,  168,  169,  186,  212,  247,  261,  308,  398,
        409,  469,  528,  555,  605,  625,  627,  645,  719,  722,  726,  775,
        783,  846,  851,  869,  890,  893,  902, 1011, 1014, 1054, 1085, 1101,
       1114, 1116, 1124, 1149, 1172],
      dtype='int64')


In [39]:
# ind = 851
# ind = 893
# ind = 169
# print(table5['error_msg_mysql'].loc[ind])

# ind = 247
# ind = 212
# print(table5['error_msg_sqlite'].loc[ind])

# ind = 499
# ind = 1085
# ind = 1011
# print(table5['error_msg_googlesql'].loc[ind]) # Type not found

# ind = 876
# ind = 928
# ind = 985
# ind = 813
# ind = 885
# ind = 308
ind = 51
print(table5['error_msg_sqlserver'].loc[ind]) # Cannot find data type NUMBER

# ind = 298
# ind = 846
# print(table5['error_msg_postgresql'].loc[ind]) # type "number" does not exist


print('\n', table5['Natural Language Prompt'].loc[ind])
print('\n', table5['Context'].loc[ind])
print('\n', table5['SQL Query'].loc[ind])
print('\n', table5['scores'].loc[ind])


(pyodbc.ProgrammingError) ('42000', "[42000] [Microsoft][ODBC Driver 18 for SQL Server][SQL Server]'ADD_MONTHS' is not a recognized built-in function name. (195) (SQLExecDirectW)")
[SQL: SELECT r.review_text FROM reviews r JOIN orders o ON r.customer_id = o.customer_id JOIN products p ON o.product_id = p.product_id WHERE p.category = 'Electronics' AND o.purchase_date >= ADD_MONTHS(SYSDATE, -1);]
(Background on this error at: http://sqlalche.me/e/f405)

 Show me all reviews made by customers who purchased electronics in the last month.

 CREATE TABLE customers (     customer_id VARCHAR(255) PRIMARY KEY,     name VARCHAR(255),     email VARCHAR(255) );  CREATE TABLE orders (     order_id VARCHAR(255) PRIMARY KEY,     customer_id VARCHAR(255),     purchase_date DATE,     product_id VARCHAR(255),     FOREIGN KEY (customer_id) REFERENCES customers(customer_id) );  CREATE TABLE products (     product_id VARCHAR(255) PRIMARY KEY,     product_name VARCHAR(255),     category VARCHAR(255) );  CR

In [None]:
print(table5['scores'].loc[0])


In [None]:
table6 = task_6.output_dataset
print(table6['scores'].loc[0])