In [None]:
# Import the functions from /tools
from tools.api_request import request_gemini as request_llm
from tools.db_detail import db_getdesc
from tools.sql_execute import sqlite_execute as execute
from tqdm import tqdm
import json

def sql_evoke(query,db_name):
    result, execution_time ,executable = execute("datasets/cosql_dataset/database/"+db_name+"/"+db_name+".sqlite",query)
    return result 

def get_example(db_name):
    sql_query = "SELECT name FROM sqlite_master WHERE type='table';"
    result = sql_evoke(sql_query,db_name)
    column_example=""
    for table_name in result:
        column_example = column_example + table_name[0] + ":\n"
        sql_get_eg = "SELECT * FROM "+ table_name[0] +" LIMIT 3;"
        table_eg = sql_evoke(sql_get_eg,db_name)
        for table_data in table_eg:
            column_example = column_example + '('
            for column_data in table_data: 
                column_example = column_example + str(column_data) +','
            column_example = column_example[:-1] + ')\n'
    return column_example
    
def get_system(db_name): 
    # Get db schema prompt
    description = db_getdesc(db_name)
    column_example = get_example(db_name)
    question = "Database schema:\n" + description + "\nExamples for each table:"+ column_example + "\nBased on the provided information, if the user's question cannot be accurately answered with an SQL query, indicate whether the question is ambiguous or unanswerable and explain why. If the question is answerable, output only SQL query without additional content."
    return question

def process_json(input_file, output_file):
    with open(input_file, 'r') as f:
        data = json.load(f)

    for index1,item in enumerate(tqdm(data)):
        # Initialnize messages
        print("__________"+str(index1)+"___________")
        system_instruct = get_system(item['db_name'])
        messages = [{"role": "system", "content": system_instruct}]
        for index, turn in enumerate(item['turns']):
            if turn['isuser']:
                # update messages
                user_question = turn['text']
                print(str(index)+" type "+turn['type']+" Q: "+user_question)
                messages.append({"role": "user", "content": user_question})
                if index+1<len(item['turns']):
                    # llm input
                    print(messages)
                    llm_response = request_llm(messages)
                    # llm record
                    print("\nLLM Response:")
                    print(llm_response)
                    item['turns'][index+1]['predict'] = llm_response
                    # update messages
                    g_ans = ""
                    if item['turns'][index+1]['text']:
                        g_ans = item['turns'][index+1]['text']
                    else:
                        g_ans = item['turns'][index+1]['query']
                    messages.append({"role": "assistant", "content": g_ans})
                    
        with open(output_file, 'w') as f:
            json.dump(data, f, indent=4)

# 示例用法
input_file = 'MMSQL_test.json'
output_file = 'MMSQL_test_pred.json'
process_json(input_file, output_file)

  0%|                                                                                                                                                                                                                                                                                                  | 0/282 [00:00<?, ?it/s]

__________0___________
0 type unanswerable Q: What is the most popular car color?
[{'role': 'system', 'content': "Database schema:\ncontinents(ContId:cont id type:number PRIMARY KEY|Continent:continent type:text|)\ncountries(CountryId:country id type:number PRIMARY KEY|CountryName:country name type:text|Continent:continent type:number|)\ncar_makers(Id:id type:number PRIMARY KEY|Maker:maker type:text|FullName:full name type:text|Country:country type:text|)\nmodel_list(ModelId:model id type:number PRIMARY KEY|Maker:maker type:number|Model:model type:text|)\ncar_names(MakeId:make id type:number PRIMARY KEY|Model:model type:text|Make:make type:text|)\ncars_data(Id:id type:number PRIMARY KEY|MPG:mpg type:text|Cylinders:cylinders type:number|Edispl:edispl type:number|Horsepower:horsepower type:text|Weight:weight type:number|Accelerate:accelerate type:number|Year:year type:number|)\n\nExamples for each table:continents:\n(1,america)\n(2,europe)\n(3,asia)\ncountries:\n(1,usa,1)\n(2,germany,2)\

  0%|▉                                                                                                                                                                                                                                                                                       | 1/282 [01:32<7:11:39, 92.17s/it]


LLM Response:
You're welcome! Let me know if you have any other questions about the database. 😊 

__________1___________
0 type answerable Q: Hi!  Can you tell me how many unique template IDs there are?
[{'role': 'system', 'content': "Database schema:\ncontinents(ContId:cont id type:number PRIMARY KEY|Continent:continent type:text|)\ncountries(CountryId:country id type:number PRIMARY KEY|CountryName:country name type:text|Continent:continent type:number|)\ncar_makers(Id:id type:number PRIMARY KEY|Maker:maker type:text|FullName:full name type:text|Country:country type:text|)\nmodel_list(ModelId:model id type:number PRIMARY KEY|Maker:maker type:number|Model:model type:text|)\ncar_names(MakeId:make id type:number PRIMARY KEY|Model:model type:text|Make:make type:text|)\ncars_data(Id:id type:number PRIMARY KEY|MPG:mpg type:text|Cylinders:cylinders type:number|Edispl:edispl type:number|Horsepower:horsepower type:text|Weight:weight type:number|Accelerate:accelerate type:number|Year:year type