In [None]:
"""
RQS_eval.py

Reads a specified JSON file and iterates through each object's 'turns' array to find a turn where 'isuser' is true and its following turn.

Usage:
    python RQS_eval.py outputs/llm_responses.json outputs/llm_responses_rqs.json

Arguments:
    --file_path: Path to the JSON file.
    --output_path: Path where the modified JSON file will be saved.
"""

from tools.api_request import request_gemini as request_llm
from tools.db_detail import db_getdesc
from tools.sql_execute import sqlite_execute as execute
import json
from tqdm import tqdm

def sql_evoke(query,db_name):
    result, execution_time ,executable = execute("datasets/cosql_dataset/database/"+db_name+"/"+db_name+".sqlite",query)
    return result 

def get_example(db_name):
    sql_query = "SELECT name FROM sqlite_master WHERE type='table';"
    result = sql_evoke(sql_query,db_name)
    column_example=""
    for table_name in result:
        column_example = column_example + table_name[0] + ":\n"
        sql_get_eg = "SELECT * FROM "+ table_name[0] +" LIMIT 3;"
        table_eg = sql_evoke(sql_get_eg,db_name)
        for table_data in table_eg:
            column_example = column_example + '('
            for column_data in table_data: 
                column_example = column_example + str(column_data) +','
            column_example = column_example[:-1] + ')\n'
    return column_example
    

def ask_ai(db_name, question, answer_pred, answer_gold):
    description = db_getdesc(db_name)
    column_example = get_example(db_name)
    template = """
{database_description}

{user_question}

{system_response}

{reference_answer}

Evaluate the quality of the system's response based on the following criteria. Assign 2 points directly if a criterion does not apply.
Relevance:
0 points: The response is completely irrelevant.
1 point: The response is partially relevant but misses key details.
2 points: The response is fully relevant and addresses the question adequately.
Clarity:
0 points: The response is incomprehensible.
1 point: The response is mostly clear with minor ambiguities.
2 points: The response is very clear and easy to understand.
Completeness:
0 points: The response does not address the question at all.
1 point: The response covers most aspects of the question but lacks some details.
2 points: The response thoroughly addresses all aspects of the question.
Accuracy:
0 points: The response contains factually incorrect information.
1 point: The response is partially accurate with some errors.
2 points: The response is completely accurate.
Utility:
0 points: The response does not meet the user's needs or explain the context of the question.
1 point: The response somewhat meets the user's needs and provides partial explanations.
2 points: The response excellently meets the user's needs and clearly explains the context or ambiguity of the question.
Task:
Classify the Response: Determine if the system response is 'improper'(Non-SQL based user questions), 'unanswerable'(unachievable under existing conditions), or 'ambiguous'(Lack of clarity).
Evaluate Each Criterion: Provide a detailed rationale for the score assigned to each criterion.
Calculate the Total Score: Sum the scores for all criteria.(10 points for a direct greeting alone)

Output Format:
{{
  "AnswerType": "",(text only)
  "Rationale": "",(text only)
  "Score": ""(An integer from 0 to 10)
}}
    """
    filled_template = template.format(
        database_description="Database Description:"+ "\nDatabase schema:\n" + description + "\nExamples for each table:"+ column_example,
        user_question="User Question:" + question,
        system_response="System Response:" + answer_pred,
        reference_answer="Reference Answer:" + answer_gold
    )
    
    # print(filled_template)

    messages = [{"role": "user", "content": filled_template}]
    max_attempts = 5
    attempt = 0
    while attempt < max_attempts:
        llm_response = request_llm(messages)
        print("LLM Response:", llm_response)
        select_pos = llm_response.find('{')
        colon_pos = llm_response.find('}', select_pos)
        if select_pos != -1 and colon_pos != -1:
            llm_response = llm_response[select_pos:colon_pos+1].replace('\n',' ')
            print("formatted json "+llm_response)
        try:
            response_data = json.loads(llm_response)
            type_ai = response_data.get("AnswerType", "")
            rqs_ai = response_data.get("Score", 0)
            rationale_ai = response_data.get("Rationale", "")
            
            # 检查返回的类型和分数是否符合预期
            if type_ai in ["improper", "unanswerable", "ambiguous"] and int(rqs_ai) >= 0 and int(rqs_ai) <= 10:
                return type_ai, rqs_ai, rationale_ai
            else:
                raise ValueError("Response type or score out of expected range.")
        except (json.JSONDecodeError, KeyError, ValueError, TypeError, Exception) as e:
            print("\033[91mRQS_eval.py::: Retry Reason: {}\033[0m".format(str(e)))  # 红色字体提示重试原因
            attempt += 1
    return "error", 0, "error"

def process_turns(file_path, output_path):

    # Open and read the JSON file
    with open(file_path, 'r') as file:
        data = json.load(file)
    
    # Iterate over each object in the file
    for entry in tqdm(data):
        # Check if the 'turns' key exists
        print("__________________")
        if 'turns' in entry:
            turns = entry['turns']
            db_name = entry['db_name']
            length = len(turns)
            # Iterate over each turn in 'turns'
            for i in range(length):
                # Check if the current turn is a user turn
                if turns[i].get('isuser', False):
                    print(i//2)
                    # Output the current user's turn
                    print(turns[i].get('type', ''), "--", turns[i].get('text', ''))
                    # Check and process the next turn (if it exists)
                    if i + 1 < length:
                        next_turn = turns[i + 1]
                        predict_text = next_turn.get('predict', '')
                        # Find the positions of SELECT and the semicolon
                        select_pos = predict_text.upper().find('SELECT')
                        colon_pos = predict_text.find(';', select_pos)
                        if select_pos != -1 and colon_pos != -1:
                            predict_sql = predict_text[select_pos:colon_pos].replace('\n',' ')
                        elif select_pos != -1:
                            predict_sql = predict_text[select_pos:].replace('\n',' ')
                        else:
                            predict_sql = ""
                        # Store the result in a new field 'predict_sql'
                        next_turn['predict_sql'] = predict_sql
                        # Calculate the ratio of the extracted SQL to the entire predict field
                        if len(predict_text) == 0:
                            ratio = 0
                        else:
                            ratio = len(predict_sql) / len(predict_text)
                            
                        if predict_sql != "" and ratio >= 0.5:
                                next_turn['predict_type'] = 'answerable'
                                if turns[i].get('type', '') == 'answerable':
                                    next_turn['RQS'] = "N/A"
                                else:
                                    next_turn['RQS'] = 0
                        else:
                            next_turn['predict_type'] = 'not answerable'
                            # Ask LLM, Get categorized and RQS scored based on database, questions, answers, gold answer
                            type_ai, rqs_ai, rationale_ai = ask_ai(db_name,turns[i].get('text', ''),predict_text,next_turn.get('text', '')) 
                            next_turn['predict_type'] = type_ai
                            next_turn['RQS'] = rqs_ai
                            next_turn['RQS_Rationale'] = rationale_ai
                        print("Next Turn predict_sql:", predict_sql)
                        print("Predict Type:", next_turn['predict_type'])
                    else:
                        print("Next Turn does not exist.")
    
    # Save the modified data to a new JSON file
    with open(output_path, 'w') as outfile:
        json.dump(data, outfile, indent=4)

# Example usage
input_file_path = 'outputs/gemini-1-Copy1.5-flash-llm.json'
output_file_path = 'outputs/rqs_gemini-1-Copy1.5-flash-llm.json.json'
process_turns(input_file_path, output_file_path)


  0%|                                                                                                                                                                                                                                                                             | 0/148 [00:00<?, ?it/s]

__________________
0
answerable -- How many students in the "Student" table are over the age of 20?
Next Turn predict_sql: SELECT COUNT(*) FROM Student WHERE Age > 20
Predict Type: answerable
1
answerable -- How many of those students own pets?
Next Turn predict_sql: SELECT COUNT(DISTINCT S.StuID) FROM Student AS S JOIN Has_Pet AS HP ON S.StuID = HP.StuID WHERE S.Age > 20
Predict Type: answerable
__________________
0
answerable -- Can you list all the continents?
Next Turn predict_sql: SELECT Continent FROM continents
Predict Type: answerable
1
answerable -- Okay, now can you list the number of countries belonging to each continent, including the continent id as well?
Next Turn predict_sql: SELECT c.ContId, c.Continent, COUNT(DISTINCT co.CountryId) AS NumberOfCountries FROM continents c JOIN countries co ON c.ContId = co.Continent GROUP BY c.ContId, c.Continent ORDER BY c.ContId
Predict Type: answerable
2
improper -- Thanks!
request gemini-1.5-flash


  1%|███▌                                                                                                                                                                                                                                                                 | 2/148 [00:02<03:04,  1.26s/it]

[92mSuccess: Operation completed after 0 retries[0m
LLM Response: ```json
{
  "AnswerType": "improper",
  "Rationale": "The user's question is a simple greeting and does not require a database response. The system response is appropriate for a conversational context but irrelevant to the database information provided.",
  "Score": 2
}
``` 

formatted json {   "AnswerType": "improper",   "Rationale": "The user's question is a simple greeting and does not require a database response. The system response is appropriate for a conversational context but irrelevant to the database information provided.",   "Score": 2 }
Next Turn predict_sql: 
Predict Type: improper
__________________
0
unanswerable -- How many customers are there now?
Next Turn predict_sql: SELECT COUNT(*) FROM visitor
Predict Type: answerable
1
answerable -- How many visitors are ther?
Next Turn predict_sql: SELECT COUNT(DISTINCT ID) FROM visitor
Predict Type: answerable
2
answerable -- How many of them have membership le

  2%|█████▎                                                                                                                                                                                                                                                               | 3/148 [00:05<04:33,  1.88s/it]

[92mSuccess: Operation completed after 0 retries[0m
LLM Response: ```json
{
  "AnswerType": "improper",
  "Rationale": "The user question is a simple greeting and does not require a database response. The system response is appropriate for a conversational context but irrelevant to the database task.  The system response is clear, complete, and accurate, but it is not useful in the context of a database interaction.",
  "Score": 6
}
``` 

formatted json {   "AnswerType": "improper",   "Rationale": "The user question is a simple greeting and does not require a database response. The system response is appropriate for a conversational context but irrelevant to the database task.  The system response is clear, complete, and accurate, but it is not useful in the context of a database interaction.",   "Score": 6 }
Next Turn predict_sql: 
Predict Type: improper
__________________
0
ambiguous -- Which cause has the least enrollment?
Next Turn predict_sql: SELECT c.course_name FROM Courses c

  3%|███████                                                                                                                                                                                                                                                              | 4/148 [00:07<05:13,  2.18s/it]

[92mSuccess: Operation completed after 0 retries[0m
LLM Response: ```json
{
  "AnswerType": "improper",
  "Rationale": "The user's question is a simple 'Thanks!' which is not a question related to SQL or database operations. The system response is a polite acknowledgement, which is appropriate but not relevant to the database context.",
  "Score": 10
}
``` 

formatted json {   "AnswerType": "improper",   "Rationale": "The user's question is a simple 'Thanks!' which is not a question related to SQL or database operations. The system response is a polite acknowledgement, which is appropriate but not relevant to the database context.",   "Score": 10 }
Next Turn predict_sql: 
Predict Type: improper
__________________
0
answerable -- Can you show the birth date of player id 200002?
Next Turn predict_sql: SELECT birth_date FROM players WHERE player_id = 200002
Predict Type: answerable
1
answerable -- What is the first name of player id 200001?
Next Turn predict_sql: SELECT first_name FROM 

  3%|████████▊                                                                                                                                                                                                                                                            | 5/148 [00:11<06:41,  2.81s/it]

[92mSuccess: Operation completed after 0 retries[0m
LLM Response: {
  "AnswerType": "improper",
  "Rationale": "The user's question is a simple expression of gratitude and does not require a database response. The system response is appropriate and polite, but it is not relevant to the database context. The response is clear, complete, and accurate, but it lacks utility in the context of the database interaction.",
  "Score": 6
} 

formatted json {   "AnswerType": "improper",   "Rationale": "The user's question is a simple expression of gratitude and does not require a database response. The system response is appropriate and polite, but it is not relevant to the database context. The response is clear, complete, and accurate, but it lacks utility in the context of the database interaction.",   "Score": 6 }
Next Turn predict_sql: 
Predict Type: improper
__________________
0
answerable -- What is the average age of the losers from all matches?
request gemini-1.5-flash
[92mSuccess: Op