# Test evaluation task suite for AI data designer

Setup API key to run tests below

In [1]:
import sys
import os

# Get the current working directory of the notebook
notebook_dir = os.getcwd()

# Define the root directory and add it to the path
root_dir = os.path.abspath(os.path.join(notebook_dir, '..', '..', '..'))
sys.path.insert(0, root_dir)

# set environment variable 'GRETEL_PROD_API_KEY' from https://console.gretel.ai/users/me/key
os.environ['GRETEL_PROD_API_KEY'] = 'grtude0e9cb184406dcdcd14a9cd05667ee0a2890fd889a1631bdf6c1db1cca1c41c'

### Evaluate Synthetic Dataset

In [2]:
# Reload packages if you've made changes to the evaluation.py file.
# Alternatively you can restart the kernel to pick up changes

from importlib import reload
import evaluation
reload(evaluation)

import pandas as pd
from pprint import pprint
from navigator_helpers.llms.llm_suite import GretelLLMSuite
from evaluation import BaseEvaluationTaskSuite, NL2SQLEvaluationTaskSuite
from datasets import load_dataset

  from .autonotebook import tqdm as notebook_tqdm


In [11]:
datasets_dict = {
    "synthetic_text_to_sql": "gretelai/synthetic_text_to_sql",
    "gsm8k": "openai/gsm8k",
    "synthetic_gsm8k": "gretelai/synthetic-gsm8k-reflection-405b",
    "xlcost_text_to_code": "codeparrot/xlcost-text-to-code",
    "python_github_code": "angie-chen55/python-github-code"
}

# Prompt user to select a dataset
print("Available datasets:")
for key in datasets_dict.keys():
    print(f" - {key}")

selected_dataset = input("\nEnter the name of the dataset to load: ").strip()

# Load the selected dataset
if selected_dataset in datasets_dict:
    dataset_path = datasets_dict[selected_dataset]
    dataset = load_dataset(dataset_path, split="train")
    
    # Optionally, select a subset and convert to pandas DataFrame
    dataset_1000 = dataset.select(range(1000))
    dataset_1000_pd = dataset_1000.to_pandas()
    
    print(f"Loaded dataset '{selected_dataset}' successfully!")
else:
    print("Error: Dataset not found. Please enter a valid dataset name.")

Available datasets:
 - synthetic_text_to_sql
 - gsm8k
 - synthetic_gsm8k
 - xlcost_text_to_code
 - python_github_code


You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this dataset from the next major release of `datasets`.


Loaded dataset 'xlcost_text_to_code' successfully!


In [12]:
llm_suite = GretelLLMSuite()

2024-10-08 19:33:49.683 - INFO - 🦜 Initializing LLM suite
2024-10-08 19:33:49.685 - INFO - 📖 Natural language LLM: gretelai-mistral-nemo-2407
2024-10-08 19:33:49.686 - INFO - 💻 Code LLM: gretelai-mistral-nemo-2407
2024-10-08 19:33:49.687 - INFO - ⚖️ Judge LLM: gretelai-mistral-nemo-2407


In [13]:
results_1 = BaseEvaluationTaskSuite(llm_suite, dataset_1000_pd).row_uniqueness()
pprint(results_1)

{'non_semantically_unique_ids': [(2, 4),
                                 (6, 480),
                                 (6, 505),
                                 (6, 582),
                                 (6, 693),
                                 (6, 699),
                                 (6, 705),
                                 (32, 61),
                                 (32, 632),
                                 (32, 921),
                                 (32, 986),
                                 (37, 131),
                                 (61, 986),
                                 (73, 141),
                                 (73, 311),
                                 (73, 614),
                                 (73, 654),
                                 (73, 680),
                                 (73, 716),
                                 (73, 723),
                                 (73, 884),
                                 (77, 101),
                                 (79, 141)

In [14]:

results_2 = BaseEvaluationTaskSuite(llm_suite, dataset_1000_pd).feature_cardinality()
pprint(results_2)

{'code': 0.918, 'text': 0.806}


In [15]:

results_3 = BaseEvaluationTaskSuite(llm_suite, dataset_1000_pd).feature_distribution()
pprint(results_3)

({'code': {'A , B = 967 , 679 NEW_LINE if ( check ( A , B ) ) : NEW_LINE INDENT print ( " Yes " ) NEW_LINE DEDENT else : NEW_LINE INDENT print ( " No " ) NEW_LINE DEDENT': 1,
           'A = 10 ; B = 4 ; NEW_LINE': 1,
           'A = 10 NEW_LINE B = 3 NEW_LINE movesRequired ( A , B ) NEW_LINE': 1,
           'A = A - firstdigit * power NEW_LINE A = A * 10 + firstdigit NEW_LINE': 1,
           'A = [ 0 ] * N NEW_LINE': 1,
           'A = [ 2 , - 1 , 4 , - 5 ] NEW_LINE B = [ 4 , - 3 , 12 , 4 , - 3 ] NEW_LINE print ( maxPresum ( A , B ) ) NEW_LINE': 1,
           'A = [ 23 , 14 , 15 , 14 , 56 , 29 , 14 ] NEW_LINE': 1,
           'A [ i ] = A [ ind ] NEW_LINE': 1,
           'AM = ArithmeticMean ( A , B ) NEW_LINE HM = HarmonicMean ( A , B ) NEW_LINE': 1,
           'B = [ 1 for i in range ( N ) ] NEW_LINE': 1,
           'D = len ( digits ) NEW_LINE for i in range ( 1 , D + 1 , 1 ) : NEW_LINE': 1,
           'Day1 = s - Day2 NEW_LINE': 1,
           'GCD = gcd ( final_numerator , final_de

In [16]:

results_4 = BaseEvaluationTaskSuite(llm_suite, dataset_1000_pd).num_words_per_record()
pprint(results_4)

{'average_words_per_record': 12.3675,
 'word_counts_per_column': {'code': 17.219, 'text': 7.516}}


### Testing SQL Validation

In [4]:
sql_linter_results = pd.read_csv("/mnt/foundation-shared/nina_xu_gretel_ai/datasets/sqlqueries_1200_validated_092524.csv")

In [None]:
# Coalesce the results of the five dialects into a single column
sql_linter_results['is_valid_sql'] = sql_linter_results['is_valid_sqlite'].fillna(
    sql_linter_results['is_valid_mysql'].fillna(
        sql_linter_results['is_valid_postgresql'].fillna(
            sql_linter_results['is_valid_sqlserver'].fillna(
                sql_linter_results['is_valid_googlesql']
            )
        )
    )
)

sql_linter_results['error_msg'] = sql_linter_results['error_msg_sqlite'].fillna(
    sql_linter_results['error_msg_mysql'].fillna(
        sql_linter_results['error_msg_postgresql'].fillna(
            sql_linter_results['error_msg_sqlserver'].fillna(
                sql_linter_results['error_msg_googlesql']
            )
        )
    )
)

print(sql_linter_results.is_valid_sql.value_counts())
def is_type_error(error_msg):
    # These errors are because of misuse of data types, and they are dialect-specific
    # Because the prompt as of now is not aware of the dialect, it makes sense if the LLM cannot identify these errors
    import re
    error_msg = str(error_msg).lower()

    pattern = r'type "\w+" does not exis'
    # Check if the pattern exists in the string
    if re.search(pattern, error_msg):
        return True
    
    patterns = ["type not found", "cannot find data type", "error creating tables", "'CREATE VIEW' must be the first statement in a query batch",
                "login failed"]
    for pattern in patterns:
        if pattern.lower() in error_msg:
            return True

    return False
sql_linter_results['is_type_error'] = sql_linter_results.error_msg.apply(is_type_error)
p = sql_linter_results[sql_linter_results.is_valid_sql == True].sample(10)
n = sql_linter_results[(sql_linter_results.is_valid_sql == False) & (sql_linter_results.is_type_error == False)]#.sample(10)

sql_linter_results_10 = pd.concat([p, n])
sql_linter_results_10.shape


In [None]:
# Coalesce the results of the five dialects into a single column
sql_linter_results['is_valid_sql'] = sql_linter_results['is_valid_sqlite'].fillna(
    sql_linter_results['is_valid_mysql'].fillna(
        sql_linter_results['is_valid_postgresql'].fillna(
            sql_linter_results['is_valid_sqlserver'].fillna(
                sql_linter_results['is_valid_googlesql']
            )
        )
    )
)

sql_linter_results['error_msg'] = sql_linter_results['error_msg_sqlite'].fillna(
    sql_linter_results['error_msg_mysql'].fillna(
        sql_linter_results['error_msg_postgresql'].fillna(
            sql_linter_results['error_msg_sqlserver'].fillna(
                sql_linter_results['error_msg_googlesql']
            )
        )
    )
)

print(sql_linter_results.is_valid_sql.value_counts())
def is_type_error(error_msg):
    # These errors are because of misuse of data types, and they are dialect-specific
    # Because the prompt as of now is not aware of the dialect, it makes sense if the LLM cannot identify these errors
    import re
    error_msg = str(error_msg).lower()

    pattern = r'type "\w+" does not exis'
    # Check if the pattern exists in the string
    if re.search(pattern, error_msg):
        return True
    
    patterns = ["type not found", "cannot find data type", "error creating tables", "'CREATE VIEW' must be the first statement in a query batch",
                "login failed"]
    for pattern in patterns:
        if pattern.lower() in error_msg:
            return True

    return False
sql_linter_results['is_type_error'] = sql_linter_results.error_msg.apply(is_type_error)
p = sql_linter_results[sql_linter_results.is_valid_sql == True].sample(10)
n = sql_linter_results[(sql_linter_results.is_valid_sql == False) & (sql_linter_results.is_type_error == False)]#.sample(10)

sql_linter_results_10 = pd.concat([p, n])
sql_linter_results_10.shape


In [None]:
task_5 = NL2SQLEvaluationTaskSuite(
    llm_suite=llm_suite, dataset=sql_linter_results_10, code_lang="sql"
    )
results_5 = task_5.llm_as_a_critic_evaluation(
    instruction_col_name="Natural Language Prompt", code_col_name="SQL Query", context_col_name="Context"
)
table5 = task_5.output_dataset

# task_6 = BaseEvaluationTaskSuite(llm_suite, dataset_10_pd)
# results_6 = task_6.llm_as_a_critic_evaluation(
#     instruction_col_name="sql_prompt", code_col_name="sql"
# )

In [None]:

print(results_5)
# print(results_6)

# review specific records
# print(dataset_10_pd.loc[results_1['non_semantically_unique_ids']])

In [34]:
table5['correctness_score'] = table5['scores'].apply(lambda x: x['correctness_score'])

In [None]:
table5[['Natural Language Prompt', 'SQL Query', 'Context', 'Dialect', 'is_valid_sql', 'scores', 'overall_score', 'correctness_score']]

In [None]:
# calculate the average correctness scores grouped by is_valid_sql
table5.groupby('is_valid_sql')['correctness_score'].mean()

In [38]:
table5.to_csv("/mnt/foundation-shared/nina_xu_gretel_ai/datasets/sql_linter_results_35.csv", index=False)

In [52]:
count = 0

In [None]:
indices = table5[table5.is_valid_sql == False].index
ind = indices[count]
print(f'ind = {ind}')
print('\n', table5['error_msg'].loc[ind])
print('\n', table5['Natural Language Prompt'].loc[ind])
print('\n', table5['Context'].loc[ind])
print('\n', table5['SQL Query'].loc[ind])
print('\n', table5['scores'].loc[ind])

count += 1


In [None]:
print(indices)

In [None]:
# ind = 851
# ind = 893
# ind = 169
# print(table5['error_msg_mysql'].loc[ind])

# ind = 247
# ind = 212
# print(table5['error_msg_sqlite'].loc[ind])

# ind = 499
# ind = 1085
# ind = 1011
# print(table5['error_msg_googlesql'].loc[ind]) # Type not found

# ind = 876
# ind = 928
# ind = 985
# ind = 813
# ind = 885
# ind = 308
ind = 51
print(table5['error_msg_sqlserver'].loc[ind]) # Cannot find data type NUMBER

# ind = 298
# ind = 846
# print(table5['error_msg_postgresql'].loc[ind]) # type "number" does not exist


print('\n', table5['Natural Language Prompt'].loc[ind])
print('\n', table5['Context'].loc[ind])
print('\n', table5['SQL Query'].loc[ind])
print('\n', table5['scores'].loc[ind])


In [None]:
print(table5['scores'].loc[0])


In [None]:
table6 = task_6.output_dataset
print(table6['scores'].loc[0])