# Init

In [1]:
import sys
import os
import json
import time
import re

from tqdm import tqdm

cwd = os.getcwd()
os.chdir(cwd)
sys.path.append('tools')

import chat
import parse_data
import sql

In [2]:
question_path = os.path.join(cwd, 'answer_tmp', 'stage_3-glm_4_plus-table_finder-v2.7.3.json')
# question_path = os.path.join(cwd, 'answer_tmp', 'stage_1-glm_4_plus-ner-v2.0.0-sql-HF-Post.json')
# question_path = os.path.join(cwd, 'answer_tmp', 'stage_1-glm_4_plus-answer_generator-test-v1.0.0.json')

questions = parse_data.read_json(question_path)

# sort the questions by tid
questions = sorted(questions, key=lambda x: int(x['tid'].split('-')[-1]))

# Craft Prompt

In [3]:
prompt_dir = os.path.join(cwd, 'prompt')

sql_1_fname = f'sql_generator-stage_1-v2.0.0.md'
sql_2_fname = f'sql_generator-stage_2-v1.0.0.md'
ans_fname = f'answer_generator-v1.0.0.md'

with open(os.path.join(prompt_dir, sql_1_fname), 'r') as f:
    sql_1_prompt_template = ''.join(f.readlines())

with open(os.path.join(prompt_dir, sql_2_fname), 'r') as f:
    sql_2_prompt_template = ''.join(f.readlines())

with open(os.path.join(prompt_dir, ans_fname), 'r') as f:
    ans_prompt_template = ''.join(f.readlines())

In [4]:
# build sql prompt 

def make_prompt_sql_1(data: dict) -> str:
    
    prompt = sql_1_prompt_template

    # Database-Table Pair(s)
    table_finder_res = data['table_finder']['stage_1'][0]['data_source']
    tables = [i['table'] for i in table_finder_res]
    try:
        del table_finder_res['question']
    except:
        pass
    table_finder_res = json.dumps(table_finder_res, ensure_ascii=False, indent=2)
    reg_p = re.compile('<Database and Table>')
    prompt = re.sub(reg_p, table_finder_res, prompt)
    
    # Table Schema(s)
    table_schema = ''
    for table in tables:
        table_fname = f'{table}-with_table_name.md'
        table_dir = os.path.join(cwd, 'data' + os.sep + 'table-column')
        table_fpath = os.path.join(table_dir, table_fname)
        with open(table_fpath,'r') as f:
            table_schema += ''.join(f.readlines())
            table_schema += '\n\n'
    reg_p = re.compile('<Table-Column Schema>')
    prompt = re.sub(reg_p, table_schema, prompt)

    # NER Result
    if data['ner']['stage_1']['result']:
        ner_res = [i for i in data['ner']['stage_1']['sql'].values() if i][0]
        ner_res = [i['result'] for i in ner_res if i['result']][0][0]
        ner_res = json.dumps(ner_res, ensure_ascii=False, indent=2)
        reg_p = re.compile('<NER Result>')
        prompt = re.sub(reg_p, ner_res, prompt)
    else:
        reg_p = re.compile('\n<NER Result>\n')
        prompt = re.sub(reg_p, '', prompt)
        reg_p = re.compile('\n## NER Result\n')
        prompt = re.sub(reg_p, '', prompt)

    # replace query
    query = data['team'][0]['question']
    reg_p = re.compile('<Current Query>')
    prompt = re.sub(reg_p, query, prompt)

    return prompt

def make_prompt_sql_2(data: dict, idx: int) -> str:
    
    prompt = sql_2_prompt_template

    # Database-Table Pair(s)
    table_finder_res = data['table_finder'][f'stage_{idx+1}'][0]['data_source']
    tables = [i['table'] for i in table_finder_res]
    try:
        del table_finder_res['question']
    except:
        pass
    table_finder_res = json.dumps(table_finder_res, ensure_ascii=False, indent=2)
    reg_p = re.compile('<Database and Table>')
    prompt = re.sub(reg_p, table_finder_res, prompt)
    
    # Table Schema(s)
    table_schema = ''
    for table in tables:
        table_fname = f'{table}-with_table_name.md'
        table_dir = os.path.join(cwd, 'data' + os.sep + 'table-column')
        table_fpath = os.path.join(table_dir, table_fname)
        with open(table_fpath,'r') as f:
            table_schema += ''.join(f.readlines())
            table_schema += '\n\n'
    reg_p = re.compile('<Table-Column Schema>')
    prompt = re.sub(reg_p, table_schema, prompt)

    # NER Result
    if data['ner']['stage_1']['result']:
        ner_res = [i for i in data['ner']['stage_1']['sql'].values() if i][0]
        ner_res = [i['result'] for i in ner_res if i['result']][0][0]
        ner_res = json.dumps(ner_res, ensure_ascii=False, indent=2)
        reg_p = re.compile('<NER Result>')
        prompt = re.sub(reg_p, ner_res, prompt)
    else:
        reg_p = re.compile('\n<NER Result>\n')
        prompt = re.sub(reg_p, '', prompt)
        reg_p = re.compile('\n## NER Result\n')
        prompt = re.sub(reg_p, '', prompt)

    # Query
    query = data['team'][idx]['question']
    reg_p = re.compile('<Current Query>')
    prompt = re.sub(reg_p, query, prompt)

    # Answers
    history = []
    answers = data['answer_generator']
    for i in range(len(answers)):
        ans = answers[i] # {'stage_n': ans}
        ans = list(ans.values())[0]
        query = data['team'][i]['question']
        history.append({'previous_query': query, "response": ans})
    history = json.dumps(history, ensure_ascii=False, indent=2)
    reg_p = re.compile('<Chat History>')
    prompt = re.sub(reg_p, history, prompt) 

    return prompt

def make_prompt_answer(data: dict, idx: int) -> str:

    prompt = ans_prompt_template

    # SQL Query
    if f'stage_{idx+1}' in data['sql_generator']:
        sql_query = data['sql_generator'][f'stage_{idx+1}'][0]['sql_query']
        reg_p = re.compile('<SQL Query>')
        prompt = re.sub(reg_p, sql_query, prompt)
    else:
        reg_p = re.compile('\n<SQL Query>\n')
        prompt = re.sub(reg_p, '', prompt)
        reg_p = re.compile('\n## SQL Query\n')
        prompt = re.sub(reg_p, '', prompt)

    # SQL Result
    if f'stage_{idx+1}' in data['sql_generator']:
        sql_res = str(data['sql_generator'][f'stage_{idx+1}'][0]['sql_res'])
        reg_p = re.compile('<SQL Result>')
        prompt = re.sub(reg_p, sql_res, prompt)
    else:
        reg_p = re.compile('\n<SQL Result>\n')
        prompt = re.sub(reg_p, '', prompt)
        reg_p = re.compile('\n## SQL Result\n')
        prompt = re.sub(reg_p, '', prompt)

    # NER Result
    if data['ner']['stage_1']['result']:
        ner_res = [i for i in data['ner']['stage_1']['sql'].values() if i][0]
        ner_res = [i['result'] for i in ner_res if i['result']][0][0]
        ner_res = json.dumps(ner_res, ensure_ascii=False, indent=2)
        reg_p = re.compile('<NER Result>')
        prompt = re.sub(reg_p, ner_res, prompt)
    else:
        reg_p = re.compile('\n<NER Result>\n')
        prompt = re.sub(reg_p, '', prompt)
        reg_p = re.compile('\n## NER Result\n')
        prompt = re.sub(reg_p, '', prompt)

    # replace query
    query = data['team'][idx]['question']
    reg_p = re.compile('<Current Query>')
    prompt = re.sub(reg_p, query, prompt)

    return prompt

# Process Query

## Init Deepseek

In [5]:
from openai import OpenAI

# deepseek_api = 'sk-ba0f5eed3bea4fa6be16eb33b139c684'

# client = OpenAI(api_key= deepseek_api, base_url="https://api.deepseek.com")

deepseek_api = 'db4a0fe1467d4456b3d83fe9bd413d84.shvUvvb2X9WkjRXW'

client = OpenAI(api_key= deepseek_api, base_url="https://open.bigmodel.cn/api/paas/v4/")

## Tools

In [6]:
api_key = "6b90d15d9a234097bd56ac10c19f22fb"

import requests
import json

def fetch_data(data: dict):
    url = "https://comm.chatglm.cn/finglm2/api/query"
    headers = {
        "Content-Type": "application/json",
        "Authorization": f"Bearer {api_key}"
    }
    response = requests.post(url, headers=headers, json=data)

    return response.json()

In [7]:
t = {
  "sql": "SELECT sm.ChiName AS FullName, lsa.AShareAbbr AS AShareAbbreviation, lsa.LegalRepr AS LegalRepresentative, lsa.LegalConsultant AS LegalConsultant, lsa.AccountingFirm AS AccountingFirm, lsa.SecretaryBD AS BoardSecretary FROM ConstantDB.SecuMain sm JOIN AStockBasicInfoDB.LC_StockArchives lsa ON sm.CompanyCode = lsa.CompanyCode WHERE sm.SecuCode = '600872';",
  "limit": 100
}

t = fetch_data(t)
t['data']

[{'FullName': '中炬高新技术实业(集团)股份有限公司',
  'AShareAbbreviation': '中炬高新',
  'LegalRepresentative': '余健华',
  'LegalConsultant': '广东卓建(中山)律师事务所',
  'AccountingFirm': '天职国际会计师事务所（特殊普通合伙）',
  'BoardSecretary': '郭毅航'}]

## TEST

In [None]:
idx = 9

sql_generator_history, answer_generator_history = [], []

data = questions[idx].copy()
data['sql_generator'] = {}
data['answer_generator'] = []

max_retries = 5

retry_delay = 1  # 重试延迟时间（秒）

for i in tqdm(range(len(data['team']))):
    # sql generator
    retries = 0

    while retries < max_retries:
        should_retry = False  # 标志变量，标记是否需要重试

        if i == 0:
            llm_sql_query = make_prompt_sql_1(data)
        else:
            llm_sql_query = make_prompt_sql_2(data, i)

        tmp_sql_generator_history = sql_generator_history.copy()
        tmp_sql_generator_history.append({'role': 'user', 'content': llm_sql_query})

        start_time = time.time()

        # API 调用（增加重试机制）
        for attempt in range(max_retries):
            try:
                response = client.chat.completions.create(
                    # model="deepseek-chat",
                    model='glm-4-plus',
                    messages=tmp_sql_generator_history,
                    stream=False,
                    top_p=0.7,
                    temperature=0.9
                )
                break  # 如果成功，跳出重试循环
            except Exception as e:
                print(f"API call failed (attempt {attempt + 1}/{max_retries}): {e}")
                if attempt == max_retries - 1:
                    raise  # 如果达到最大重试次数，抛出异常
                time.sleep(retry_delay)  # 等待一段时间后重试

        end_time = time.time()


        response = json.loads(response.to_json())
        raw_content = response['choices'][0]['message']['content']

        # 第一个 try-except：处理 JSON 解析
        try:
            content = json.loads(raw_content.strip('`json'))
        except Exception as e:
            print(f"Error processing JSON {data['team'][i]['id']} (attempt {retries + 1}/{max_retries}): {e}")
            should_retry = True  # 标记需要重试

        # 如果 JSON 解析成功，继续执行后续逻辑
        if not should_retry:
            usage = response['usage']
            execution_time = end_time - start_time

            data['sql_generator'][f'stage_{i+1}'] = [content]
            data['token_usage'][f'sql_generator-stage_{i+1}'] = [usage]
            data['time_usage'][f'sql_generator-stage_{i+1}'] = [f"{execution_time:.2f}s"]

            post_data = {
                'sql': content['sql_query'],
                'limit': 1000
            }

            # 第二个 try-except：处理 fetch_data
            try:
                sql_res = fetch_data(post_data)['data']
                data['sql_generator'][f'stage_{i+1}'][0]['sql_res'] = sql_res
            except Exception as e:
                print(f"Error fetching SQL data {data['team'][i]['id']} (attempt {retries + 1}/{max_retries}): {e}")
                should_retry = True  # 标记需要重试

        # 根据 should_retry 决定是否重试
        if should_retry:
            retries += 1
            if retries == max_retries:
                print(f"Failed to process question {data['team'][i]['id']} after {max_retries} attempts.")
                data['sql_generator'][f'stage_{i+1}'][0]['sql_res'] = []  # 设置默认值
            continue  # 跳过本次循环，进入下一次重试

        # 如果成功，跳出重试循环
        break

    # 更新历史记录
    sql_generator_history.extend([
        {'role': 'user', 'content': llm_sql_query},
        {'role': 'assistant', 'content': raw_content}
    ])

    # answer generator（保持不变）
    llm_answer_query = make_prompt_answer(data, i)
    answer_generator_history.append({'role': 'user', 'content': llm_answer_query})

    start_time = time.time()
    response = client.chat.completions.create(
        # model="deepseek-chat",
        model='glm-4-plus',
        messages=answer_generator_history,
        stream=False,
        top_p=0.7,
        temperature=0.9
    )
    end_time = time.time()

    response = json.loads(response.to_json())
    content = response['choices'][0]['message']['content']

    data['answer_generator'].append({f'stage_{i+1}': content})
    data['token_usage'][f'answer_generator-stage_{i+1}'] = [usage]
    data['time_usage'][f'answer_generator-stage_{i+1}'] = [f"{execution_time:.2f}s"]

    answer_generator_history.append({'role': 'assistant', 'content': content})

    print("======llm_sql_query======")
    print(llm_sql_query)
    print("======llm_sql_query======")

    print("======llm_answer_query======")
    print(llm_answer_query)
    print("======llm_answer_query======")

In [14]:
print(sql_generator_history[1]['content'])

```json
{
  "query": "000958公司2021年主营业务产品有哪些？（合并报表调整后的，金额保留2位小数）",
  "sql_cot_reasoning": "To find the main business products of company 000958 in 2021, we need to select the 'Project' column from the 'LC_MainOperIncome' table in the 'AStockFinanceDB' database. We also need to ensure that the data is from the adjusted consolidated financial statements, so we will include conditions for 'IfMerged' and 'IfAdjusted'. Additionally, we need to filter the data for the year 2021, so we will use the 'EndDate' column with the 'LIKE' operator. Finally, we will round the 'MainOperIncome' to 2 decimal places using the 'ROUND' function.",
  "sql_query": "SELECT Project, ROUND(MainOperIncome, 2) AS MainOperIncome FROM AStockFinanceDB.LC_MainOperIncome WHERE CompanyCode = '000958' AND IfMerged = 1 AND IfAdjusted = 1 AND EndDate LIKE '2021%'",
  "sql_explanation": "This SQL query retrieves the main business products and their corresponding income for company 000958 in 2021 from the 'LC_MainOperIncome'

# Generate Final Answer 

In [193]:
src_fname = 'glm-answer_generator-0127.json'
saved_fname = 'submit-stage_1-250127.json'
template_fname = 'submit_example.json'

src_fpath = os.path.join('answer_tmp', src_fname)
saved_fpath = os.path.join('answer', saved_fname)
template_fpath = os.path.join('data', template_fname)

raw_data = parse_data.read_json(src_fpath)
answers = parse_data.read_json(template_fpath)

# raw_data = sorted(raw_data, key=lambda x: int(x['tid'].split('-')[-1]))

# raw_data[1]

In [199]:
raw_data[1]['answer_generator']

{'stage_1': '美亚光电在2021年的减持计划中，最大可减持股份数量与最小可减持股份数量的差距是0。',
 'stage_2': '美亚光电在2021年的减持计划中涉及了1名股东。',
 'stage_3': '美亚光电在2021年的减持计划中，股东张建军的最大减持比例最高，为0.007027%。'}

In [201]:
for i in raw_data[:]:
    tid = i['tid']
    try:
        for num in range(len(i['team'])):
            answer = i['answer_generator'][f'stage_{num+1}']
            
            for j in answers:
                if j['tid'] == tid:
                    j['team'][num]['answer'] = answer
    except:
        print(tid)

parse_data.write_json(answers, saved_fpath)

tttt----78
tttt----84
