# Init

In [2]:
import sys
import os
import json
import time
import re

from tqdm import tqdm

cwd = os.getcwd()
os.chdir(cwd)
sys.path.append('tools')

import chat
import parse_data
import sql

# Craft Prompt

In [5]:
system_prompt = ""

prompt_dir = os.path.join(cwd, 'prompt')
version = 'v1.1.0'
fname = f'sql_generator-stage_1-{version}.md'
prompt_fpath = os.path.join(prompt_dir, fname)

with open(prompt_fpath, 'r') as f:
    prompt_template = ''.join(f.readlines())

def make_prompt(data: dict) -> str:

    prompt = prompt_template

    # 
    table_finder_res = data['table_finder']['stage_1'][0]['data_source'][0]
    try:
        del table_finder_res['question']
    except:
        pass
    table = table_finder_res['table']
    table_finder_res = json.dumps(table_finder_res, ensure_ascii=False, indent=2)
    reg_p = re.compile('<Database and Table>')
    prompt = re.sub(reg_p, table_finder_res, prompt)

    # 
    table_fname = f'{table}-with_table_name.md'
    table_dir = os.path.join(cwd, 'data' + os.sep + 'table-column')
    table_fpath = os.path.join(table_dir, table_fname)
    with open(table_fpath,'r') as f:
        table_schema = ''.join(f.readlines())
    reg_p = re.compile('<Table-Column Schema>')
    prompt = re.sub(reg_p, table_schema, prompt)

    # 
    if data['ner']['stage_1']['result']:
        ner_res = [i for i in data['ner']['stage_1']['sql'].values() if i][0]
        ner_res = [i['result'] for i in ner_res if i['result']][0][0]
        ner_res = json.dumps(ner_res, ensure_ascii=False, indent=2)
        reg_p = re.compile('<Background Knowledge>')
        prompt = re.sub(reg_p, ner_res, prompt)
    else:
        reg_p = re.compile('<Background Knowledge>')
        prompt = re.sub(reg_p, '', prompt)
        reg_p = re.compile('## Background Knowledge')
        prompt = re.sub(reg_p, '', prompt)

    # replace query
    query = data['team'][0]['question']
    reg_p = re.compile('<Current Query>')
    prompt = re.sub(reg_p, query, prompt)

    return prompt

In [3]:
question_path = os.path.join(cwd, 'answer_tmp' + os.sep + 'stage_1-glm_4_plus-table_finder-v2.6.2.json')

questions = parse_data.read_json(question_path)

# GLM-4-Plus

In [4]:
model = 'glm_4_plus'

## Test

In [5]:
query = make_prompt(questions[1])

history = []

start_time = time.time()
message = chat.create_message(query, history=history, system_prompt=system_prompt, temperature=0.7, top_p=0.9, response_format='text')
end_time = time.time()

execution_time = end_time - start_time
usage = chat.get_token_usage(message, True)
content = chat.get_content(message, True)
history = chat.build_history(history, message=message)

{'prompt_tokens': 3926, 'completion_tokens': 226, 'total_tokens': 4152}
```json
{
    "query": "今天是2021年12月24日，创近半年新高的股票有几只？",
    "sql_cot_reasoning": "首先，我们需要确定今天的日期是2021年12月24日，因此我们需要在查询中使用这个日期。接着，我们要找到创近半年新高的股票，这意味着我们需要查看`IfHighestHPriceRMSix`字段是否为1。根据表结构，这个字段表示指定日期的最高价是否大于最近半年的最高价。最后，我们需要统计符合条件的股票数量，因此使用`COUNT`函数。综上所述，我们的查询将涉及选择`IfHighestHPriceRMSix`字段，并且过滤条件是`TradingDay`为2021年12月24日且`IfHighestHPriceRMSix`为1，最后使用`COUNT`函数统计数量。",
    "sql_query": "SELECT COUNT(*) FROM AStockMarketQuotesDB.CS_StockPatterns WHERE TradingDay LIKE '2021-12-24%' AND IfHighestHPriceRMSix = 1"
}
```


In [6]:
t = questions[1]
t['sql_generator'] = {}
t['sql_generator']['stage_1']= [json.loads(content.strip('`json'))]
t['token_usage'] = {}
t['token_usage']['sql_generator-stage_1'] = [usage]
t['time_usage'] = {}
t['time_usage']['sql_generator-stage_1'] = [f"{execution_time:.2f}s"]
t = [t]

saved_path = os.path.join(cwd, 'answer_tmp' + os.sep + f'stage_1-{model}-sql_generator-test-{version}.json')
parse_data.write_json(t, saved_path)

## ALL

In [7]:
answers = []

for question in tqdm(questions[:]):
    try:
        # the first question
        query = make_prompt(question)

        history = []

        start_time = time.time()
        message = chat.create_message(query, history=history, system_prompt=system_prompt, temperature=0.7, top_p=0.9, response_format='text')
        end_time = time.time()

        execution_time = end_time - start_time
        usage = chat.get_token_usage(message, False)
        content = chat.get_content(message, False)

        res = question
        res['sql_generator'] = {}
        res['sql_generator']['stage_1']= [json.loads(content.strip('`json'))]
        res['token_usage']['sql_generator-stage_1'] = [usage]
        res['time_usage']['sql_generator-stage_1'] = [f"{execution_time:.2f}s"]

        answers.append(res)
    except:
        print(question['tid'])

saved_path = os.path.join(cwd, 'answer_tmp' + os.sep + f'stage_1-{model}-sql_generator-{version}.json')
parse_data.write_json(answers, saved_path)

  0%|          | 0/101 [00:00<?, ?it/s]

 13%|█▎        | 13/101 [01:48<12:57,  8.84s/it]

tttt----13


 57%|█████▋    | 58/101 [07:42<05:46,  8.05s/it]

tttt----58


 58%|█████▊    | 59/101 [08:04<08:32, 12.20s/it]

tttt----59


 62%|██████▏   | 63/101 [08:50<07:34, 11.97s/it]

tttt----63


 63%|██████▎   | 64/101 [09:06<08:15, 13.39s/it]

tttt----64


 71%|███████▏  | 72/101 [10:21<04:33,  9.41s/it]

tttt----72


 72%|███████▏  | 73/101 [10:26<03:54,  8.37s/it]

tttt----74


 85%|████████▌ | 86/101 [12:37<02:30, 10.06s/it]

tttt----87


100%|██████████| 101/101 [15:07<00:00,  8.98s/it]


# Deepseek-chat

In [8]:
model = 'deepseek_v3'

In [9]:
deepseek_api = 'sk-ba0f5eed3bea4fa6be16eb33b139c684'

## Test

In [10]:
from openai import OpenAI

query = make_prompt(questions[1])

client = OpenAI(api_key= deepseek_api, base_url="https://api.deepseek.com")

start_time = time.time()
response = client.chat.completions.create(
    model="deepseek-chat",
    messages=[
        {"role": "user", "content": query},
    ],
    stream=False,
    top_p=0.7,
    temperature=0.9
)
end_time = time.time()

response = json.loads(response.to_json())
content = response['choices'][0]['message']['content']

content = content.strip('`json')
usage = response['usage']
execution_time = end_time - start_time

In [11]:
t = questions[1]
t['sql_generator'] = {}
t['sql_generator']['stage_1']= [json.loads(content.strip('`json'))]
t['token_usage'] = {}
t['token_usage']['sql_generator-stage_1'] = [usage]
t['time_usage'] = {}
t['time_usage']['sql_generator-stage_1'] = [f"{execution_time:.2f}s"]
t = [t]

saved_path = os.path.join(cwd, 'answer_tmp' + os.sep + f'stage_1-{model}-sql_generator-test-{version}.json')
parse_data.write_json(t, saved_path)

# ALL

In [12]:
answers = []

for question in tqdm(questions[:]):
    try:
        query = make_prompt(question)

        start_time = time.time()
        response = client.chat.completions.create(
            model="deepseek-chat",
            messages=[
                {"role": "user", "content": query},
            ],
            stream=False,
            top_p=0.7,
            temperature=0.9
        )
        end_time = time.time()

        response = json.loads(response.to_json())
        content = response['choices'][0]['message']['content']

        content = content.strip('`json')
        usage = response['usage']
        execution_time = end_time - start_time

        res = question
        res['sql_generator'] = {}
        res['sql_generator']['stage_1']= [json.loads(content.strip('`json'))]
        res['token_usage']['sql_generator-stage_1'] = [usage]
        res['time_usage']['sql_generator-stage_1'] = [f"{execution_time:.2f}s"]

        answers.append(res)
    except:
        print(question['tid'])
        
saved_path = os.path.join(cwd, 'answer_tmp' + os.sep + f'stage_1-{model}-sql_generator-{version}.json')
parse_data.write_json(answers, saved_path)

 30%|██▉       | 30/101 [03:13<08:15,  6.98s/it]

tttt----30


 59%|█████▉    | 60/101 [06:20<06:23,  9.35s/it]

tttt----60


 72%|███████▏  | 73/101 [07:46<03:09,  6.77s/it]

tttt----74


 85%|████████▌ | 86/101 [09:13<01:44,  6.95s/it]

tttt----87


100%|██████████| 101/101 [11:07<00:00,  6.60s/it]


# Obtain SQL Results

## GLM-4-Plus

In [17]:
model = 'glm_4_plus'

In [18]:
fname = f'stage_1-{model}-sql_generator-{version}.json'
fpath = os.path.join(cwd, 'answer_tmp' + os.sep + fname)
data = parse_data.read_json(fpath)

for i in tqdm(data[:]):
    tmp = sql.process_sql_generator_res(i['sql_generator']['stage_1'][0])

fname = f'stage_1-{model}-sql_generator-{version}-sql.json'
fpath = os.path.join(cwd, 'answer_tmp' + os.sep + fname)
data = parse_data.write_json(data, fpath)

  4%|▍         | 4/93 [00:03<01:06,  1.34it/s]

Request failed: 500 Server Error: Internal Server Error for url: https://comm.chatglm.cn/finglm2/api/query


 28%|██▊       | 26/93 [00:17<00:32,  2.08it/s]

Request failed: 500 Server Error: Internal Server Error for url: https://comm.chatglm.cn/finglm2/api/query
Request failed: 500 Server Error: Internal Server Error for url: https://comm.chatglm.cn/finglm2/api/query


 37%|███▋      | 34/93 [00:19<00:18,  3.27it/s]

Request failed: 500 Server Error: Internal Server Error for url: https://comm.chatglm.cn/finglm2/api/query


 41%|████      | 38/93 [00:21<00:18,  3.02it/s]

Request failed: 500 Server Error: Internal Server Error for url: https://comm.chatglm.cn/finglm2/api/query


 47%|████▋     | 44/93 [00:22<00:11,  4.26it/s]

Request failed: 500 Server Error: Internal Server Error for url: https://comm.chatglm.cn/finglm2/api/query
Request failed: 500 Server Error: Internal Server Error for url: https://comm.chatglm.cn/finglm2/api/query


 53%|█████▎    | 49/93 [00:23<00:09,  4.79it/s]

Request failed: 500 Server Error: Internal Server Error for url: https://comm.chatglm.cn/finglm2/api/query


 61%|██████▏   | 57/93 [01:26<10:58, 18.30s/it]

Request failed: 504 Server Error: Gateway Time-out for url: https://comm.chatglm.cn/finglm2/api/query


 74%|███████▍  | 69/93 [01:30<00:12,  1.94it/s]

Request failed: 500 Server Error: Internal Server Error for url: https://comm.chatglm.cn/finglm2/api/query


 75%|███████▌  | 70/93 [01:30<00:09,  2.36it/s]

Request failed: 500 Server Error: Internal Server Error for url: https://comm.chatglm.cn/finglm2/api/query


 82%|████████▏ | 76/93 [01:32<00:04,  3.96it/s]

Request failed: 500 Server Error: Internal Server Error for url: https://comm.chatglm.cn/finglm2/api/query
Request failed: 500 Server Error: Internal Server Error for url: https://comm.chatglm.cn/finglm2/api/query


 88%|████████▊ | 82/93 [01:33<00:02,  4.08it/s]

Request failed: 500 Server Error: Internal Server Error for url: https://comm.chatglm.cn/finglm2/api/query


 90%|█████████ | 84/93 [01:34<00:02,  4.00it/s]

Request failed: 500 Server Error: Internal Server Error for url: https://comm.chatglm.cn/finglm2/api/query


 92%|█████████▏| 86/93 [01:34<00:01,  4.36it/s]

Request failed: 500 Server Error: Internal Server Error for url: https://comm.chatglm.cn/finglm2/api/query


 97%|█████████▋| 90/93 [01:35<00:01,  2.96it/s]

Request failed: 500 Server Error: Internal Server Error for url: https://comm.chatglm.cn/finglm2/api/query


100%|██████████| 93/93 [01:36<00:00,  1.04s/it]


## Deepseek-v3

In [14]:
model = 'deepseek_v3'

In [16]:
fname = f'stage_1-{model}-sql_generator-{version}.json'
fpath = os.path.join(cwd, 'answer_tmp' + os.sep + fname)
data = parse_data.read_json(fpath)

for i in tqdm(data[:]):
    tmp = sql.process_sql_generator_res(i['sql_generator']['stage_1'][0])

fname = f'stage_1-{model}-sql_generator-{version}-sql.json'
fpath = os.path.join(cwd, 'answer_tmp' + os.sep + fname)
data = parse_data.write_json(data, fpath)

 28%|██▊       | 27/97 [00:06<00:15,  4.47it/s]

Request failed: 500 Server Error: Internal Server Error for url: https://comm.chatglm.cn/finglm2/api/query


 36%|███▌      | 35/97 [00:13<01:14,  1.20s/it]

Request failed: 500 Server Error: Internal Server Error for url: https://comm.chatglm.cn/finglm2/api/query


 44%|████▍     | 43/97 [00:23<00:47,  1.13it/s]

Request failed: 500 Server Error: Internal Server Error for url: https://comm.chatglm.cn/finglm2/api/query


 49%|████▉     | 48/97 [00:26<00:27,  1.76it/s]

Request failed: 500 Server Error: Internal Server Error for url: https://comm.chatglm.cn/finglm2/api/query


 63%|██████▎   | 61/97 [00:37<00:34,  1.05it/s]

Request failed: 500 Server Error: Internal Server Error for url: https://comm.chatglm.cn/finglm2/api/query


 64%|██████▍   | 62/97 [00:38<00:33,  1.06it/s]

Request failed: 500 Server Error: Internal Server Error for url: https://comm.chatglm.cn/finglm2/api/query


 72%|███████▏  | 70/97 [00:43<00:14,  1.91it/s]

Request failed: 500 Server Error: Internal Server Error for url: https://comm.chatglm.cn/finglm2/api/query


 91%|█████████ | 88/97 [00:54<00:04,  2.06it/s]

Request failed: 500 Server Error: Internal Server Error for url: https://comm.chatglm.cn/finglm2/api/query


 93%|█████████▎| 90/97 [00:55<00:03,  2.11it/s]

Request failed: 500 Server Error: Internal Server Error for url: https://comm.chatglm.cn/finglm2/api/query


 97%|█████████▋| 94/97 [00:59<00:02,  1.29it/s]

Request failed: 500 Server Error: Internal Server Error for url: https://comm.chatglm.cn/finglm2/api/query


100%|██████████| 97/97 [01:01<00:00,  1.59it/s]
