In [39]:
"""
RQS_eval.py

Reads a specified JSON file and iterates through each object's 'turns' array to find a turn where 'isuser' is true and its following turn.

Usage:
    python RQS_eval.py outputs/llm_responses.json outputs/llm_responses_rqs.json

Arguments:
    --file_path: Path to the JSON file.
    --output_path: Path where the modified JSON file will be saved.
"""

import json
from tqdm import tqdm

def ask_ai(db_name, question, answer_pred, answer_gold):
    type_ai = "unanswerable"
    rqs_ai = 10
    return type_ai, rqs_ai

def process_turns(file_path, output_path):

    # Open and read the JSON file
    with open(file_path, 'r') as file:
        data = json.load(file)
    
    # Iterate over each object in the file
    for entry in tqdm(data):
        # Check if the 'turns' key exists
        print("__________________")
        if 'turns' in entry:
            turns = entry['turns']
            db_name = entry['db_name']
            length = len(turns)
            # Iterate over each turn in 'turns'
            for i in range(length):
                # Check if the current turn is a user turn
                if turns[i].get('isuser', False):
                    print(i//2)
                    # Output the current user's turn
                    print(turns[i].get('type', ''), "--", turns[i].get('text', ''))
                    # Check and process the next turn (if it exists)
                    if i + 1 < length:
                        next_turn = turns[i + 1]
                        predict_text = next_turn.get('predict', '')
                        # Find the positions of SELECT and the semicolon
                        select_pos = predict_text.upper().find('SELECT')
                        colon_pos = predict_text.find(';', select_pos)
                        if select_pos != -1 and colon_pos != -1:
                            predict_sql = predict_text[select_pos:colon_pos].replace('\n',' ')
                        elif select_pos != -1:
                            predict_sql = predict_text[select_pos:].replace('\n',' ')
                        else:
                            predict_sql = ""
                        # Store the result in a new field 'predict_sql'
                        next_turn['predict_sql'] = predict_sql
                        # Calculate the ratio of the extracted SQL to the entire predict field
                        if len(predict_text) == 0:
                            ratio = 0
                        else:
                            ratio = len(predict_sql) / len(predict_text)
                        if predict_sql != "" and ratio >= 0.5:
                                next_turn['predict_type'] = 'answerable'
                                if turns[i].get('type', '') == 'answerable':
                                    next_turn['RQS'] = "N/A"
                                else:
                                    next_turn['RQS'] = 0
                        else:
                            next_turn['predict_type'] = 'not answerable'
                            # Ask LLM, Get categorized and RQS scored based on database, questions, answers, gold answer
                            type_ai, rqs_ai = ask_ai(db_name,turns[i].get('text', ''),predict_text,next_turn.get('text', '')) 
                            next_turn['predict_type'] = type_ai
                            next_turn['RQS'] = rqs_ai
                        print("Next Turn predict_sql:", predict_sql)
                        print("Predict Type:", next_turn['predict_type'])
                    else:
                        print("Next Turn does not exist.")
    
    # Save the modified data to a new JSON file
    with open(output_path, 'w') as outfile:
        json.dump(data, outfile, indent=4)

# Example usage
input_file_path = 'outputs/gemini-1-Copy1.5-flash-llm.json'
output_file_path = 'outputs/rqs_gemini-1-Copy1.5-flash-llm.json.json'
process_turns(input_file_path, output_file_path)


100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 148/148 [00:00<00:00, 8221.73it/s]

__________________
0
answerable -- How many students in the "Student" table are over the age of 20?
Next Turn predict_sql: SELECT COUNT(*) FROM Student WHERE Age > 20
Predict Type: answerable
1
answerable -- How many of those students own pets?
Next Turn predict_sql: SELECT COUNT(DISTINCT S.StuID) FROM Student AS S JOIN Has_Pet AS HP ON S.StuID = HP.StuID WHERE S.Age > 20
Predict Type: answerable
__________________
0
answerable -- Can you list all the continents?
Next Turn predict_sql: SELECT Continent FROM continents
Predict Type: answerable
1
answerable -- Okay, now can you list the number of countries belonging to each continent, including the continent id as well?
Next Turn predict_sql: SELECT c.ContId, c.Continent, COUNT(DISTINCT co.CountryId) AS NumberOfCountries FROM continents c JOIN countries co ON c.ContId = co.Continent GROUP BY c.ContId, c.Continent ORDER BY c.ContId
Predict Type: answerable
2
improper -- Thanks!
Next Turn predict_sql: 
Predict Type: unanswerable
__________


