# Init

In [1]:
import sys
import os
import json
import time
import re

from tqdm import tqdm

cwd = os.getcwd()
os.chdir(cwd)
sys.path.append('tools')

import chat
import parse_data
import sql

In [2]:
question_path = os.path.join(cwd, 'data' + os.sep + 'question-b.json')

questions = parse_data.read_json(question_path)
# sort the questions by tid
questions = sorted(questions, key=lambda x: int(x['tid'].split('-')[-1]))

# NER

## TOOL

In [3]:
def sub_select_content(query: str) -> str:
    reg_p = re.compile('(?<=SELECT )\w+(?= FROM)')
    res = re.sub(reg_p, '*', query)

    return res

## Craft Prompt

In [4]:
system_prompt = ""

version = 'v2.2.1'
task = 'ner'

prompt_dir = os.path.join(cwd, 'prompt')
fname = f'{task}-stage_1-{version}.md'
ner_prompt_fpath = os.path.join(prompt_dir, fname)
ner_valid_prompt_fpath = os.path.join(prompt_dir, 'ner-validation-v1.0.0.md')

with open(ner_prompt_fpath, 'r') as f:
    ner_prompt_template = ''.join(f.readlines())

with open(ner_valid_prompt_fpath, 'r') as f:
    ner_valid_prompt_template = ''.join(f.readlines())

def make_ner_prompt(conversation_turn: dict) -> str:

    query = conversation_turn['team'][0]['question']

    prompt = ner_prompt_template + query

    return prompt

def make_ner_valid_prompt(question: dict) -> str:

    query = question['team'][0]['question']
    ner_result = question['ner']['stage_1']['result']
    ner_sql_result = question['ner']['stage_1']['sql']

    # use the number of potential ner sql queries to determine whether to include raw query
    l = max((len(v) for v in ner_sql_result.values()), default=1)

    if l == 3:
        tmp = {
            "query": query,
            "ner_result": ner_result,
            "ner_sql_result": ner_sql_result
        }
    else:
        tmp = {
            "ner_result": ner_result,
            "ner_sql_result": ner_sql_result
        } 

    prompt = ner_valid_prompt_template + json.dumps(tmp, ensure_ascii=False, indent=2)

    return prompt

## GLM

In [5]:
model = 'glm_4_plus'

### Test

In [19]:
query = make_ner_prompt(questions[51])

history = []

start_time = time.time()
message = chat.create_message(query, history=history, system_prompt=system_prompt, temperature=0.7, top_p=0.9, response_format='text')
end_time = time.time()

execution_time = end_time - start_time
usage = chat.get_token_usage(message, True)
content = chat.get_content(message, True)
history = chat.build_history(history, message=message)

t = questions[0].copy()
t['ner_result'] = {}
t['ner_result']['stage_1'] = json.loads(content.strip('`json'))
t['token_usage'] = {}
t['token_usage']['ner-stage_1'] = usage
t['time_usage'] = {}
t['time_usage']['ner-stage_1'] = f"{execution_time:.2f}s"
t = [t]

saved_path = os.path.join(cwd, 'answer_tmp' + os.sep + f'stage_1-{model}-{task}-test-{version}.json')
parse_data.write_json(t, saved_path)

{'prompt_tokens': 924, 'completion_tokens': 105, 'total_tokens': 1029}
```json
{
    "reasoning_process_cot": "从问题中可以看出，查询的是在特定日期（2021-12-21）A股市场中创出月度新高的公司及其证券代码。虽然问题中没有直接提及具体的公司名称和代码，但可以推断出需要识别的实体类型包括上市公司名称和代码。由于问题中没有提供具体的公司名称和代码，无法直接识别出具体的实体，但可以明确需要识别的实体类型。",
    "result": []
}
```


### ALL

Bad cases:

27:28 博时基金公司成立于？用XXXX年XX月XX日回复我

40:41 嘉实致元42个月定期债券基金的管理经理是谁？

52:53 南方亨元债券A在2019年的分红次数是多少？每次分红的派现比例是多少？

63:64 博时基金公司成立于？用XXXX年XX月XX日回复我

78:79 Huazhu Group Ltd.这家公司在美股英文名称是什么？

In [21]:
answers = []
ner_max_tried = 6

for question in tqdm(questions[:]):
    
    query = make_ner_prompt(question)

    history = []
    tried = 0

    while tried < ner_max_tried:
        start_time = time.time()
        message = chat.create_message(query, history=history, system_prompt=system_prompt, temperature=0.1, top_p=1, response_format='text')
        end_time = time.time()

        execution_time = end_time - start_time
        usage = chat.get_token_usage(message, False)
        content = chat.get_content(message, False)

        res = question.copy()
        res['ner'] = {}
        try:
            res['ner']['stage_1'] = json.loads(content.strip('`json'))
            break
        except:
            print(f"JSON parsing error: {query}")
            tried += 1

    res['token_usage'] = {}
    res['token_usage']['ner-stage_1'] = usage
    res['time_usage'] = {}
    res['time_usage']['ner-stage_1'] = f"{execution_time:.2f}s"

    # obtain sql results
    res['ner']['stage_1'] = sql.process_ner_res(res['ner']['stage_1'])

    # check results
    ner = res['ner']['stage_1'] # Extract the 'stage_1' data from the 'ner' key
    ner_result = res['ner']['stage_1']['result'] # Extract the 'result' from 'stage_1'

    # Check if all `result` fields in the `sql` data are empty
    all_results_empty = True  # Assume initially that all `result` fields are empty

    # Iterate through the `sql` dictionary in the `ner` object
    for key, queries in ner.get('sql', {}).items():
        for query_info in queries:  # Iterate through each query info in the list
            if query_info.get('result'):  # Check if the `result` field is not empty
                all_results_empty = False
                break  # Exit the loop immediately if a non-empty `result` is found
        if not all_results_empty:
            break  # Exit the outer loop if a non-empty `result` is found

    # If all `result` fields are empty and `ner_result` is not empty
    if all_results_empty and ner_result:
        tried = 0

        print('====NER Validation===')
        print(ner)

        while tried < ner_max_tried:
            query = make_ner_valid_prompt(res)
            history = []
            message = chat.create_message(query, history=history, system_prompt=system_prompt, temperature=0.7, top_p=0.9, response_format='text')
            content = chat.get_content(message, True)

            try:
                # JSON Format doesn't count a fail
                tmp_res = json.loads(content.strip('`json'))
            except:
                print("----JSON parsing error-----")
                continue

            sql_queries = [i['sql_query'] for i in tmp_res]
            
            for sql_query in sql_queries:
                sql_query = sub_select_content(sql_query)
                sql_res = sql.get_data_from_sql_query(sql_query)

                if sql_res:
                    break
            
            if sql_res:
                break
            else:
                tried += 1

        if sql_res:
            kv = list(ner['sql'].keys())[0]
            res['ner']['stage_1']['sql'][kv].append({
                "query": sql_query,
                "result": sql_res
            })
        else:
            print('====FAILED===')
            print('====NER Validation===')
            pass
        
    answers.append(res)

saved_path = os.path.join(cwd, 'answer_tmp' + os.sep + f'stage_1-{model}-{task}-{version}.json')
parse_data.write_json(answers, saved_path)

 27%|██▋       | 27/100 [01:40<04:32,  3.73s/it]

====NER Validation===
{'reasoning_process_cot': "从问题中可以看出，'博时基金公司'是一个基金公司名称，因此需要识别为基金公司名称实体。问题中未提及其他实体，如上市公司名称、股票代码、基金名称或行业名称。", 'result': [{'基金公司名称': '博时基金公司'}], 'sql': {'基金公司名称:博时基金公司': [{'query': "SELECT * FROM InstitutionDB.LC_InstiArchive WHERE '博时基金公司' IN (ChiName, AbbrChiName, NameChiSpelling, EngName, AbbrEngName)", 'result': []}]}}
### **输出**

```json
[
  {
    "potential_reason_cot_thinking": "原始 NER 结果中的 '博时基金公司' 可能存在名称不完整或拼写错误，导致在数据库中未能匹配到相关记录。考虑可能的名称变体或简称。",
    "sql_query": "SELECT * FROM InstitutionDB.LC_InstiArchive WHERE '博时基金' IN (ChiName, AbbrChiName, NameChiSpelling, EngName, AbbrEngName)",
    "sql_explanation": "调整查询条件，将 '博时基金公司' 改为 '博时基金'，以涵盖可能的简称或变体名称，从而提高匹配成功率。"
  },
  {
    "potential_reason_cot_thinking": "数据库中可能存在不同的名称格式，如全称、简称或英文名称。尝试使用模糊匹配来增加匹配范围。",
    "sql_query": "SELECT * FROM InstitutionDB.LC_InstiArchive WHERE ChiName LIKE '%博时基金%' OR AbbrChiName LIKE '%博时基金%' OR NameChiSpelling LIKE '%博时基金%' OR EngName LIKE '%Boshi Fund%' OR AbbrEngName LIKE '%Boshi

 33%|███▎      | 33/100 [03:39<09:47,  8.77s/it]

====NER Validation===
{'reasoning_process_cot': "分析当前查询内容，'在线教育' 是一个实体。接着，分析其可能的属性：'在线教育' 可以是一个行业名称，也可以是一个概念。结合查询内容，问题是在讨论 '在线教育' 属于科技概念的哪个分支，这表明 '在线教育' 在这里是作为一个行业名称被提及。因此，'在线教育' 应被识别为 '行业名称'。问题中还提到了 '科技概念' 的英文名称，但这不属于需要识别的五大类实体之一，所以无需识别。", 'result': [{'行业名称': '在线教育'}], 'sql': {'行业名称:在线教育': [{'query': "SELECT FirstIndustryCode AS 一级行业代码, SecondIndustryCode AS 二级行业代码, ThirdIndustryCode AS 三级行业代码, FourthIndustryCode AS 四级行业代码, FirstIndustryName AS 一级行业名称, SecondIndustryName AS 二级行业名称, ThirdIndustryName AS 三级行业名称, FourthIndustryName AS 四级行业名称 FROM AStockIndustryDB.LC_ExgIndustry WHERE '在线教育' IN (FirstIndustryName, SecondIndustryName, ThirdIndustryName, FourthIndustryName)", 'result': []}]}}
```json
[
  {
    "potential_reason_cot_thinking": "由于原始 NER 结果 '在线教育' 在数据库查询中未返回任何结果，可能的原因是该行业名称在数据库中使用了不同的表述或分类。需要调整 NER 结果以匹配数据库中的实际行业名称。",
    "sql_query": "SELECT FirstIndustryCode AS 一级行业代码, SecondIndustryCode AS 二级行业代码, ThirdIndustryCode AS 三级行业代码, FourthIndustryCode AS 四级行业代码, FirstIndustryName

 40%|████      | 40/100 [04:17<04:07,  4.13s/it]

====NER Validation===
{'reasoning_process_cot': '从当前查询中，可以看出涉及到一个基金名称‘嘉实致元42个月定期债券基金’，问题询问的是该基金的管理经理是谁。‘嘉实致元42个月定期债券基金’是一个基金名称，而查询中并未提及其他实体，如上市公司名称、股票代码、基金公司名称或行业名称。', 'result': [{'基金名称': '嘉实致元42个月定期债券基金'}], 'sql': {'基金名称:嘉实致元42个月定期债券基金': [{'query': "SELECT * FROM PublicFundDB.MF_FundProdName WHERE DisclName = '嘉实致元42个月定期债券基金' LIMIT 1", 'result': []}, {'query': "SELECT * FROM ConstantDB.SecuMain WHERE '嘉实致元42个月定期债券基金' IN (ChiName, ChiNameAbbr, EngName, EngNameAbbr, SecuAbbr, ChiSpelling) LIMIT 1", 'result': []}]}}
```json
[
  {
    "potential_reason_cot_thinking": "原始 NER 结果 '嘉实致元42个月定期债券基金' 在数据库中未能匹配到相应实体，可能是因为名称存在细微差异或数据库中未收录该基金。考虑可能的名称变体或缩写。",
    "sql_query": "SELECT * FROM PublicFundDB.MF_FundProdName WHERE DisclName LIKE '%嘉实致元42个月%' LIMIT 1",
    "sql_explanation": "使用 LIKE 语句模糊匹配基金名称，增加匹配灵活性，尝试找到包含 '嘉实致元42个月' 的基金名称。"
  },
  {
    "potential_reason_cot_thinking": "基金名称可能存在简称或不同表述，尝试在更广泛的证券主表中查找相似名称。",
    "sql_query": "SELECT * FROM ConstantDB.SecuMain WHERE ChiName LIKE '%嘉实致元4

 58%|█████▊    | 58/100 [05:35<02:50,  4.07s/it]

====NER Validation===
{'reasoning_process_cot': '从问题中可以看出，询问的是在2022年期间进行公司名称全称变更的公司及其代码。这里涉及到两个实体类型：上市公司名称和代码。虽然具体公司名称和代码未在问题中明确提及，但问题的意图是查找这些信息。因此，我们需要识别出问题中隐含的实体类型，即上市公司名称和代码。', 'result': [{'上市公司名称': '公司名称全称变更的公司'}, {'代码': '公司代码'}], 'sql': {'上市公司名称:公司名称全称变更的公司': [{'query': "SELECT * FROM ConstantDB.SecuMain WHERE '公司名称全称变更的公司' IN (ChiName, ChiNameAbbr, EngName, EngNameAbbr, SecuAbbr, ChiSpelling)", 'result': []}, {'query': "SELECT * FROM ConstantDB.HK_SecuMain WHERE '公司名称全称变更的公司' IN (ChiName, ChiNameAbbr, EngName, EngNameAbbr, SecuAbbr, ChiSpelling)", 'result': []}, {'query': "SELECT * FROM ConstantDB.US_SecuMain WHERE '公司名称全称变更的公司' IN (SecuCode, SecuAbbr, ChiSpelling, EngName, ChiName)", 'result': []}], '代码:公司代码': [{'query': None, 'result': None}]}}
```json
[
  {
    "potential_reason_cot_thinking": "原始 NER 结果中的 '公司名称全称变更的公司' 是一个模糊的描述，无法直接用于数据库查询。需要将其转换为具体的公司名称或相关字段。",
    "sql_query": "SELECT SecuCode, ChiName, ChiNameAbbr FROM ConstantDB.SecuMain WHERE SecuCode IN (SELECT SecuCode

 59%|█████▉    | 59/100 [07:19<23:16, 34.05s/it]

Request failed: 500 Server Error: Internal Server Error for url: https://comm.chatglm.cn/finglm2/api/query
====FAILED===
====NER Validation===


 63%|██████▎   | 63/100 [07:33<06:33, 10.63s/it]

====NER Validation===
{'reasoning_process_cot': "从问题中可以看出，'博时基金公司'是一个基金公司名称，因此需要识别为基金公司名称实体。问题中未提及其他实体，如上市公司名称、股票代码、基金名称或行业名称。", 'result': [{'基金公司名称': '博时基金公司'}], 'sql': {'基金公司名称:博时基金公司': [{'query': "SELECT * FROM InstitutionDB.LC_InstiArchive WHERE '博时基金公司' IN (ChiName, AbbrChiName, NameChiSpelling, EngName, AbbrEngName)", 'result': []}]}}
### **输出**

```json
[
  {
    "potential_reason_cot_thinking": "NER 结果中的 '博时基金公司' 可能存在名称不完整或简称未被识别的情况。在金融数据库中，机构名称可能以全称、简称、拼音或英文名称存在。因此，需要检查所有可能的名称形式。",
    "sql_query": "SELECT * FROM InstitutionDB.LC_InstiArchive WHERE '博时基金公司' IN (ChiName, AbbrChiName, NameChiSpelling, EngName, AbbrEngName) OR '博时' IN (ChiName, AbbrChiName, NameChiSpelling, EngName, AbbrEngName)",
    "sql_explanation": "此查询扩展了原始查询，增加了对 '博时' 这一简称的搜索，以覆盖更多可能的名称匹配情况。这样可以提高找到正确机构信息的概率。"
  },
  {
    "potential_reason_cot_thinking": "NER 结果可能存在名称拼写错误或格式不一致的情况。在金融数据库中，机构名称可能存在多种拼写或格式变体。因此，需要使用模糊匹配来提高检索准确性。",
    "sql_query": "SELECT * FROM InstitutionDB.LC_InstiArchive WHERE ChiName L

 64%|██████▍   | 64/100 [09:06<21:15, 35.42s/it]

```json
[
  {
    "potential_reason_cot_thinking": "原始 NER 结果中的 '博时基金公司' 在数据库中未找到匹配项，可能是因为名称存在差异或数据库中未收录该实体。需要调整名称以匹配数据库中的实体。",
    "sql_query": "SELECT * FROM InstitutionDB.LC_InstiArchive WHERE '博时基金管理有限公司' IN (ChiName, AbbrChiName, NameChiSpelling, EngName, AbbrEngName)",
    "sql_explanation": "将 '博时基金公司' 调整为 '博时基金管理有限公司'，因为数据库中可能使用全称。这样可以提高查询匹配的概率。"
  },
  {
    "potential_reason_cot_thinking": "原始 NER 结果中的 '博时基金公司' 可能存在简称或别称，数据库中可能使用不同的名称形式。",
    "sql_query": "SELECT * FROM InstitutionDB.LC_InstiArchive WHERE '博时基金' IN (ChiName, AbbrChiName, NameChiSpelling, EngName, AbbrEngName)",
    "sql_explanation": "尝试使用简称 '博时基金' 进行查询，以覆盖数据库中可能存在的简称或别称形式。"
  },
  {
    "potential_reason_cot_thinking": "原始 NER 结果中的 '博时基金公司' 可能存在拼音或英文简称，数据库中可能使用这些形式。",
    "sql_query": "SELECT * FROM InstitutionDB.LC_InstiArchive WHERE 'Boshi Fund' IN (ChiName, AbbrChiName, NameChiSpelling, EngName, AbbrEngName)",
    "sql_explanation": "尝试使用英文简称 'Boshi Fund' 进行查询，以覆盖数据库中可能存在的英文或拼音形式。"
  }
]
```


 78%|███████▊  | 78/100 [09:57<01:27,  3.99s/it]

====NER Validation===
{'reasoning_process_cot': "从问题中可以看出，'Huazhu Group Ltd.'是一个上市公司名称，因为它提到了'公司'并且询问的是该公司的美股英文名称。因此，需要将'Huazhu Group Ltd.'识别为上市公司名称实体。", 'result': [{'上市公司名称': 'Huazhu Group Ltd.'}], 'sql': {'上市公司名称:Huazhu Group Ltd.': [{'query': "SELECT * FROM ConstantDB.SecuMain WHERE 'Huazhu Group Ltd.' IN (ChiName, ChiNameAbbr, EngName, EngNameAbbr, SecuAbbr, ChiSpelling)", 'result': []}, {'query': "SELECT * FROM ConstantDB.HK_SecuMain WHERE 'Huazhu Group Ltd.' IN (ChiName, ChiNameAbbr, EngName, EngNameAbbr, SecuAbbr, ChiSpelling)", 'result': []}, {'query': "SELECT * FROM ConstantDB.US_SecuMain WHERE 'Huazhu Group Ltd.' IN (SecuCode, SecuAbbr, ChiSpelling, EngName, ChiName)", 'result': []}]}}
```json
[
  {
    "potential_reason_cot_thinking": "原始 NER 结果中的 'Huazhu Group Ltd.' 可能是中文名称或简称，而数据库中可能存储的是其英文名称或简称。因此，需要尝试不同的字段组合来匹配该实体。",
    "sql_query": "SELECT EngName FROM ConstantDB.US_SecuMain WHERE ChiName = 'Huazhu Group Ltd.' OR ChiNameAbbr = 'Huazhu Group Ltd.' OR EngNameAbbr = 'Huaz

100%|██████████| 100/100 [12:18<00:00,  7.38s/it]


## Check Results

In [25]:
answer_dir = os.path.join(cwd, 'answer_tmp')
fname = f'stage_1-{model}-{task}-{version}-HF.json'
fpath = os.path.join(answer_dir, fname)

data = parse_data.read_json(fpath)

for i in data[:]:  # Iterate through each element in the data list
    ner = i['ner']['stage_1']  # Extract the 'stage_1' data from the 'ner' key
    ner_result = i['ner']['stage_1']['result']  # Extract the 'result' from 'stage_1'

    # Check if all `result` fields in the `sql` data are empty
    all_results_empty = True  # Assume initially that all `result` fields are empty

    # Iterate through the `sql` dictionary in the `ner` object
    for key, queries in ner.get('sql', {}).items():
        for query_info in queries:  # Iterate through each query info in the list
            if query_info.get('result'):  # Check if the `result` field is not empty
                all_results_empty = False
                break  # Exit the loop immediately if a non-empty `result` is found
        if not all_results_empty:
            break  # Exit the outer loop if a non-empty `result` is found

    # If all `result` fields are empty and `ner_result` is not empty
    if all_results_empty and ner_result:
        print(i['team'][0])  # Print the first element of the 'team' list
        print(ner)  # Print the 'ner' object
        print()  # Print an empty line for separation

{'id': 'tttt----59----15-3-1', 'question': '2022年之间 哪些公司进行公司名称全称变更，公司代码是什么？'}
{'reasoning_process_cot': '从问题中可以看出，询问的是在2022年期间进行公司名称全称变更的公司及其代码。这里涉及到两个实体类型：上市公司名称和代码。虽然具体公司名称和代码未在问题中明确提及，但问题的意图是查找这些信息。因此，我们需要识别出问题中隐含的实体类型，即上市公司名称和代码。', 'result': [{'上市公司名称': '公司名称全称变更的公司'}, {'代码': '公司代码'}], 'sql': {'上市公司名称:公司名称全称变更的公司': [{'query': "SELECT * FROM ConstantDB.SecuMain WHERE '公司名称全称变更的公司' IN (ChiName, ChiNameAbbr, EngName, EngNameAbbr, SecuAbbr, ChiSpelling)", 'result': []}, {'query': "SELECT * FROM ConstantDB.HK_SecuMain WHERE '公司名称全称变更的公司' IN (ChiName, ChiNameAbbr, EngName, EngNameAbbr, SecuAbbr, ChiSpelling)", 'result': []}, {'query': "SELECT * FROM ConstantDB.US_SecuMain WHERE '公司名称全称变更的公司' IN (SecuCode, SecuAbbr, ChiSpelling, EngName, ChiName)", 'result': []}], '代码:公司代码': [{'query': None, 'result': None}]}}



# Market

In [78]:
question_path = os.path.join(cwd, 'answer_tmp' + os.sep + 'stage_1-glm_4_plus-ner-v2.2.1-HF.json')

questions = parse_data.read_json(question_path)
# sort the questions by tid
questions = sorted(questions, key=lambda x: int(x['tid'].split('-')[-1]))

## Tool

In [79]:
def parse_database_and_table(query: str) -> dict:
    """
    Parse the given SQL query to return the database and table names in a dictionary.
    
    Args:
    - query (str): The SQL query to parse.
    
    Returns:
    - dict: A dictionary with 'database' and 'table' keys.
    """

    pattern = r'FROM\s+([a-zA-Z0-9_]+)\.([a-zA-Z0-9_]+)'
    match = re.search(pattern, query, re.IGNORECASE)
    
    if match:
        database = match.group(1)
        table = match.group(2)
        return {'database': database, 'table': table}
    
    return {}

## Craft Prompt

In [119]:
system_prompt = ""

version = 'v1.0.0'
task = 'market_classifier'

prompt_dir = os.path.join(cwd, 'prompt')
market_fname = 'market_classifier-v1.0.0.md'
fpath = os.path.join(prompt_dir, market_fname)

with open(fpath, 'r') as f:
    market_prompt_template = ''.join(f.readlines())

def make_prompt_market(data: dict) -> str:

    prompt = market_prompt_template

    # replace query
    query = data['team'][0]['question']
    reg_p = re.compile('<Current Query>')
    prompt = re.sub(reg_p, query, prompt)

    # NER Result
    if data['ner']['stage_1']['result']:
        ner_res = data['ner']['stage_1']['sql']
        ner_res = json.dumps(ner_res, ensure_ascii=False, indent=2)
        reg_p = re.compile('<给出的信息>')
        prompt = re.sub(reg_p, ner_res, prompt)
    else:
        reg_p = re.compile('\n<给出的信息>\n')
        prompt = re.sub(reg_p, '', prompt)
        reg_p = re.compile('\n## \*\*给出的信息\*\*\n')
        prompt = re.sub(reg_p, '', prompt)
        pass

    return prompt

In [123]:
for i in questions:
    try:
        tmp = make_prompt_market(i)
    except:
        print(i)

## GLM

In [127]:
answers = []
ner_max_tried = 6

# remove usage
if 'usage' in locals():
    del locals()['usage']

for question in tqdm(questions[:]):
    
    query = make_prompt_market(question)

    ner_result = question['ner']['stage_1']['result']
    ner_sql_result = question['ner']['stage_1']['sql']

    res = question.copy()
    res['market'] = {}

    # if only one result, use that as market flag, others, run llm to judge
    cnt = 0

    for k, v in ner_sql_result.items():
        for i in v:
            if i.get('result', None):
                cnt += 1

    if cnt == 1:
        for k, v in ner_sql_result.items():
            for j in v:
                if not j['result']:
                    continue

                sql_query = j['query']
                # add database and table
                table = parse_database_and_table(sql_query)
        
                # get market
                if table['table'] == 'US_SecuMain':
                    res['market'] = {"query": question['team'][0]['question'],
                                    "cot_thinking": None,  "market": "US"}
                elif table['table'] == 'HK_SecuMain':
                    res['market'] = {"query": question['team'][0]['question'],
                                    "cot_thinking": None,  "market": "HK"}
                else:
                    res['market'] = {"query": question['team'][0]['question'],
                                    "cot_thinking": None,  "market": "CN"}
    else:
        print("---NOT ONE---")
        print(ner_sql_result)
        tried = 0

        while tried < ner_max_tried:
            history = []
            start_time = time.time()
            try:
                message = chat.create_message(query, history=history, system_prompt=system_prompt, temperature=0.1, top_p=1, response_format='text')
            except:
                print(f'API ERROR')
                tried += 1
            end_time = time.time()

            execution_time = f"{end_time - start_time:.2f}s"
            usage = chat.get_token_usage(message, False)
            content = chat.get_content(message, False)

            res = question.copy()
            res['market'] = {}
            try:
                res['market'] = json.loads(content.strip('`json'))
                break
            except:
                print(f"JSON parsing error: {query}")
                tried += 1

    res['token_usage']['market'] = locals().get('usage', 0)
    res['time_usage']['market'] = locals().get('execution_time', 0)
        
    answers.append(res)

saved_path = os.path.join(cwd, 'answer_tmp' + os.sep + f'{model}-{task}-{version}.json')
parse_data.write_json(answers, saved_path)

  0%|          | 0/100 [00:00<?, ?it/s]

---NOT ONE---
{}


  2%|▏         | 2/100 [00:06<05:04,  3.10s/it]

---NOT ONE---
{}


  7%|▋         | 7/100 [00:08<01:42,  1.10s/it]

---NOT ONE---
{}


  8%|▊         | 8/100 [00:13<02:32,  1.66s/it]

---NOT ONE---
{}


  9%|▉         | 9/100 [00:17<03:18,  2.18s/it]

---NOT ONE---
{'上市公司名称:中国人寿': [{'query': "SELECT * FROM ConstantDB.SecuMain WHERE '中国人寿' IN (ChiName, ChiNameAbbr, EngName, EngNameAbbr, SecuAbbr, ChiSpelling)", 'result': []}, {'query': "SELECT * FROM ConstantDB.HK_SecuMain WHERE '中国人寿' IN (ChiName, ChiNameAbbr, EngName, EngNameAbbr, SecuAbbr, ChiSpelling)", 'result': [{'ID': 16808232845300, 'InnerCode': 1000872, 'CompanyCode': 4057, 'SecuCode': '02628', 'ChiName': '中国人寿保险股份有限公司', 'ChiNameAbbr': None, 'EngName': 'China Life Insurance Company Limited', 'EngNameAbbr': 'CHINA LIFE', 'SecuAbbr': '中国人寿', 'ChiSpelling': 'ZGRS', 'SecuMarket': 72, 'SecuCategory': 3, 'ListedDate': '2003-12-18 12:00:00.000', 'ListedSector': 1, 'ListedState': 1, 'XGRQ': '2018-11-29 11:24:25.720', 'JSID': 596805865723, 'DelistingDate': None, 'ISIN': 'CNE1000002L3', 'FormerName': None, 'TradingUnit': 1000.0, 'TraCurrUnit': 1100, 'InsertTime': '2005-10-12 02:23:57.983'}]}, {'query': "SELECT * FROM ConstantDB.US_SecuMain WHERE '中国人寿' IN (SecuCode, SecuAbbr, ChiSpell

 16%|█▌        | 16/100 [00:22<01:39,  1.18s/it]

---NOT ONE---
{'上市公司名称:大唐国际发电股份有限公司': [{'query': "SELECT * FROM ConstantDB.SecuMain WHERE '大唐国际发电股份有限公司' IN (ChiName, ChiNameAbbr, EngName, EngNameAbbr, SecuAbbr, ChiSpelling)", 'result': [{'ID': 217614716484, 'InnerCode': 4874, 'CompanyCode': 3848, 'SecuCode': '601991', 'ChiName': '大唐国际发电股份有限公司', 'ChiNameAbbr': '大唐发电', 'EngName': 'Datang International Power Generation Co., Ltd.', 'EngNameAbbr': 'Datang Power', 'SecuAbbr': '大唐发电', 'ChiSpelling': 'DTFD', 'SecuMarket': 83, 'SecuCategory': 1, 'ListedDate': '2006-12-20 12:00:00.000', 'ListedSector': 1, 'ListedState': 1, 'XGRQ': '2017-03-16 04:30:01.650', 'JSID': 542953801666, 'ISIN': 'CNE000001Q02', 'ExtendedAbbr': None, 'ExtendedSpelling': None}]}, {'query': "SELECT * FROM ConstantDB.HK_SecuMain WHERE '大唐国际发电股份有限公司' IN (ChiName, ChiNameAbbr, EngName, EngNameAbbr, SecuAbbr, ChiSpelling)", 'result': []}, {'query': "SELECT * FROM ConstantDB.US_SecuMain WHERE '大唐国际发电股份有限公司' IN (SecuCode, SecuAbbr, ChiSpelling, EngName, ChiName)", 'result': [{

 20%|██        | 20/100 [00:29<01:56,  1.45s/it]

---NOT ONE---
{}


 21%|██        | 21/100 [00:33<02:17,  1.74s/it]

API ERROR
---NOT ONE---
{}


 22%|██▏       | 22/100 [00:39<02:53,  2.22s/it]

---NOT ONE---
{'上市公司名称:中国联通': [{'query': "SELECT * FROM ConstantDB.SecuMain WHERE '中国联通' IN (ChiName, ChiNameAbbr, EngName, EngNameAbbr, SecuAbbr, ChiSpelling)", 'result': []}, {'query': "SELECT * FROM ConstantDB.HK_SecuMain WHERE '中国联通' IN (ChiName, ChiNameAbbr, EngName, EngNameAbbr, SecuAbbr, ChiSpelling)", 'result': [{'ID': 16808232655600, 'InnerCode': 1000593, 'CompanyCode': 1000593, 'SecuCode': '00762', 'ChiName': '中国联合网络通信(香港)股份有限公司', 'ChiNameAbbr': None, 'EngName': 'China Unicom (Hong Kong) Limited', 'EngNameAbbr': 'CHINA UNICOM', 'SecuAbbr': '中国联通', 'ChiSpelling': 'ZGLT', 'SecuMarket': 72, 'SecuCategory': 53, 'ListedDate': '2000-06-22 12:00:00.000', 'ListedSector': 1, 'ListedState': 1, 'XGRQ': '2019-12-12 01:03:05.373', 'JSID': 629427788469, 'DelistingDate': None, 'ISIN': 'HK0000049939', 'FormerName': None, 'TradingUnit': 2000.0, 'TraCurrUnit': 1100, 'InsertTime': '2005-10-12 02:23:28.437'}]}, {'query': "SELECT * FROM ConstantDB.US_SecuMain WHERE '中国联通' IN (SecuCode, SecuAbbr, 

 33%|███▎      | 33/100 [00:44<01:11,  1.06s/it]

---NOT ONE---
{}


 34%|███▍      | 34/100 [00:49<01:26,  1.32s/it]

---NOT ONE---
{}


 39%|███▉      | 39/100 [00:52<01:04,  1.06s/it]

---NOT ONE---
{}


 40%|████      | 40/100 [00:55<01:17,  1.29s/it]

---NOT ONE---
{}


 44%|████▍     | 44/100 [01:01<01:14,  1.33s/it]

---NOT ONE---
{}


 46%|████▌     | 46/100 [01:05<01:19,  1.47s/it]

---NOT ONE---
{'基金名称:嘉实超短债债券A': [{'query': "SELECT * FROM PublicFundDB.MF_FundProdName WHERE DisclName = '嘉实超短债债券A' LIMIT 1", 'result': [{'ID': 677756078394, 'InnerCode': 381528, 'InfoPublDate': '2021-06-23 12:00:00.000', 'InfoSource': '产品资料概要', 'InfoType': 6, 'DisclName': '嘉实超短债债券A', 'EffectiveDate': '2021-06-23 12:00:00.000', 'ExpiryDate': None, 'IfEffected': 1, 'Remark': None, 'UpdateTime': '2022-06-24 03:33:10.413', 'JSID': 709366159233, 'ChiSpelling': 'JSCDZZQA', 'TransCode': 381528, 'InsertTime': '2021-06-23 09:34:16.393'}]}, {'query': "SELECT * FROM ConstantDB.SecuMain WHERE '嘉实超短债债券A' IN (ChiName, ChiNameAbbr, EngName, EngNameAbbr, SecuAbbr, ChiSpelling) LIMIT 1", 'result': [{'ID': 677756000000, 'InnerCode': 381528, 'CompanyCode': 643820, 'SecuCode': '12773', 'ChiName': '嘉实超短债证券投资基金A类', 'ChiNameAbbr': '嘉实超短债A', 'EngName': 'Harvest Ultra Short-term Bond Fund-A', 'EngNameAbbr': None, 'SecuAbbr': '嘉实超短债债券A', 'ChiSpelling': 'JSCDZZQA', 'SecuMarket': None, 'SecuCategory': 8, 'Listed

 47%|████▋     | 47/100 [01:11<01:46,  2.01s/it]

---NOT ONE---
{}


 52%|█████▏    | 52/100 [01:15<01:09,  1.46s/it]

---NOT ONE---
{'上市公司名称:中兴通讯': [{'query': "SELECT * FROM ConstantDB.SecuMain WHERE '中兴通讯' IN (ChiName, ChiNameAbbr, EngName, EngNameAbbr, SecuAbbr, ChiSpelling)", 'result': []}, {'query': "SELECT * FROM ConstantDB.HK_SecuMain WHERE '中兴通讯' IN (ChiName, ChiNameAbbr, EngName, EngNameAbbr, SecuAbbr, ChiSpelling)", 'result': [{'ID': 16808232656300, 'InnerCode': 1000594, 'CompanyCode': 79, 'SecuCode': '00763', 'ChiName': '中兴通讯股份有限公司', 'ChiNameAbbr': None, 'EngName': 'ZTE Corporation', 'EngNameAbbr': 'ZTE', 'SecuAbbr': '中兴通讯', 'ChiSpelling': 'ZXTX', 'SecuMarket': 72, 'SecuCategory': 3, 'ListedDate': '2004-12-09 12:00:00.000', 'ListedSector': 1, 'ListedState': 1, 'XGRQ': '2018-06-19 05:01:54.580', 'JSID': 582742914581, 'DelistingDate': None, 'ISIN': 'CNE1000004Y2', 'FormerName': None, 'TradingUnit': 200.0, 'TraCurrUnit': 1100, 'InsertTime': '2005-10-12 02:23:28.513'}]}, {'query': "SELECT * FROM ConstantDB.US_SecuMain WHERE '中兴通讯' IN (SecuCode, SecuAbbr, ChiSpelling, EngName, ChiName)", 'result'

 54%|█████▍    | 54/100 [01:18<01:07,  1.48s/it]

---NOT ONE---
{}


 57%|█████▋    | 57/100 [01:21<00:57,  1.33s/it]

---NOT ONE---
{'上市公司名称:中国长城': [{'query': "SELECT * FROM ConstantDB.SecuMain WHERE '中国长城' IN (ChiName, ChiNameAbbr, EngName, EngNameAbbr, SecuAbbr, ChiSpelling)", 'result': [{'ID': 316488637204, 'InnerCode': 107, 'CompanyCode': 81, 'SecuCode': '000066', 'ChiName': '中国长城科技集团股份有限公司', 'ChiNameAbbr': '中国长城', 'EngName': 'China Greatwall Technology Group Co., Ltd.', 'EngNameAbbr': 'CGT Group', 'SecuAbbr': '中国长城', 'ChiSpelling': 'ZGCC', 'SecuMarket': 90, 'SecuCategory': 1, 'ListedDate': '1997-06-26 12:00:00.000', 'ListedSector': 1, 'ListedState': 1, 'XGRQ': '2017-04-28 08:25:01.723', 'JSID': 546726301724, 'ISIN': 'CNE000000RL7', 'ExtendedAbbr': None, 'ExtendedSpelling': None}]}, {'query': "SELECT * FROM ConstantDB.HK_SecuMain WHERE '中国长城' IN (ChiName, ChiNameAbbr, EngName, EngNameAbbr, SecuAbbr, ChiSpelling)", 'result': []}, {'query': "SELECT * FROM ConstantDB.US_SecuMain WHERE '中国长城' IN (SecuCode, SecuAbbr, ChiSpelling, EngName, ChiName)", 'result': []}], '代码:000066': [{'query': 'SELECT * FRO

 58%|█████▊    | 58/100 [01:27<01:20,  1.93s/it]

---NOT ONE---
{'上市公司名称:公司名称全称变更的公司': [{'query': "SELECT * FROM ConstantDB.SecuMain WHERE '公司名称全称变更的公司' IN (ChiName, ChiNameAbbr, EngName, EngNameAbbr, SecuAbbr, ChiSpelling)", 'result': []}, {'query': "SELECT * FROM ConstantDB.HK_SecuMain WHERE '公司名称全称变更的公司' IN (ChiName, ChiNameAbbr, EngName, EngNameAbbr, SecuAbbr, ChiSpelling)", 'result': []}, {'query': "SELECT * FROM ConstantDB.US_SecuMain WHERE '公司名称全称变更的公司' IN (SecuCode, SecuAbbr, ChiSpelling, EngName, ChiName)", 'result': []}], '代码:公司代码': [{'query': None, 'result': None}]}


 59%|█████▉    | 59/100 [01:31<01:33,  2.28s/it]

---NOT ONE---
{}


 61%|██████    | 61/100 [01:36<01:30,  2.33s/it]

---NOT ONE---
{}


 62%|██████▏   | 62/100 [01:40<01:41,  2.66s/it]

---NOT ONE---
{}


 71%|███████   | 71/100 [01:44<00:31,  1.07s/it]

---NOT ONE---
{}


 73%|███████▎  | 73/100 [01:49<00:37,  1.37s/it]

---NOT ONE---
{}


 75%|███████▌  | 75/100 [01:55<00:42,  1.68s/it]

---NOT ONE---
{}


 84%|████████▍ | 84/100 [01:59<00:15,  1.02it/s]

---NOT ONE---
{}


 85%|████████▌ | 85/100 [02:06<00:22,  1.50s/it]

---NOT ONE---
{}


 96%|█████████▌| 96/100 [02:10<00:03,  1.21it/s]

---NOT ONE---
{}


 97%|█████████▋| 97/100 [02:13<00:02,  1.03it/s]

---NOT ONE---
{}


100%|██████████| 100/100 [02:16<00:00,  1.37s/it]


# WorkFlow

## Craft Prompt

In [3]:
prompt_dir = os.path.join(cwd, 'prompt')

sql_1_fname = f'sql_generator-stage_1-v2.0.0.md'
sql_2_fname = f'sql_generator-stage_2-v1.0.0.md'
ans_fname = f'answer_generator-v1.0.0.md'

with open(os.path.join(prompt_dir, sql_1_fname), 'r') as f:
    sql_1_prompt_template = ''.join(f.readlines())

with open(os.path.join(prompt_dir, sql_2_fname), 'r') as f:
    sql_2_prompt_template = ''.join(f.readlines())

with open(os.path.join(prompt_dir, ans_fname), 'r') as f:
    ans_prompt_template = ''.join(f.readlines())

In [4]:
# build sql prompt 

def make_prompt_sql_1(data: dict) -> str:
    
    prompt = sql_1_prompt_template

    # Database-Table Pair(s)
    table_finder_res = data['table_finder']['stage_1'][0]['data_source']
    tables = [i['table'] for i in table_finder_res]
    try:
        del table_finder_res['question']
    except:
        pass
    table_finder_res = json.dumps(table_finder_res, ensure_ascii=False, indent=2)
    reg_p = re.compile('<Database and Table>')
    prompt = re.sub(reg_p, table_finder_res, prompt)
    
    # Table Schema(s)
    table_schema = ''
    for table in tables:
        table_fname = f'{table}-with_table_name.md'
        table_dir = os.path.join(cwd, 'data' + os.sep + 'table-column')
        table_fpath = os.path.join(table_dir, table_fname)
        with open(table_fpath,'r') as f:
            table_schema += ''.join(f.readlines())
            table_schema += '\n\n'
    reg_p = re.compile('<Table-Column Schema>')
    prompt = re.sub(reg_p, table_schema, prompt)

    # NER Result
    if data['ner']['stage_1']['result']:
        ner_res = [i for i in data['ner']['stage_1']['sql'].values() if i][0]
        ner_res = [i['result'] for i in ner_res if i['result']][0][0]
        ner_res = json.dumps(ner_res, ensure_ascii=False, indent=2)
        reg_p = re.compile('<NER Result>')
        prompt = re.sub(reg_p, ner_res, prompt)
    else:
        reg_p = re.compile('\n<NER Result>\n')
        prompt = re.sub(reg_p, '', prompt)
        reg_p = re.compile('\n## NER Result\n')
        prompt = re.sub(reg_p, '', prompt)

    # replace query
    query = data['team'][0]['question']
    reg_p = re.compile('<Current Query>')
    prompt = re.sub(reg_p, query, prompt)

    return prompt

def make_prompt_sql_2(data: dict, idx: int) -> str:
    
    prompt = sql_2_prompt_template

    # Database-Table Pair(s)
    table_finder_res = data['table_finder'][f'stage_{idx+1}'][0]['data_source']
    tables = [i['table'] for i in table_finder_res]
    try:
        del table_finder_res['question']
    except:
        pass
    table_finder_res = json.dumps(table_finder_res, ensure_ascii=False, indent=2)
    reg_p = re.compile('<Database and Table>')
    prompt = re.sub(reg_p, table_finder_res, prompt)
    
    # Table Schema(s)
    table_schema = ''
    for table in tables:
        table_fname = f'{table}-with_table_name.md'
        table_dir = os.path.join(cwd, 'data' + os.sep + 'table-column')
        table_fpath = os.path.join(table_dir, table_fname)
        with open(table_fpath,'r') as f:
            table_schema += ''.join(f.readlines())
            table_schema += '\n\n'
    reg_p = re.compile('<Table-Column Schema>')
    prompt = re.sub(reg_p, table_schema, prompt)

    # NER Result
    if data['ner']['stage_1']['result']:
        ner_res = [i for i in data['ner']['stage_1']['sql'].values() if i][0]
        ner_res = [i['result'] for i in ner_res if i['result']][0][0]
        ner_res = json.dumps(ner_res, ensure_ascii=False, indent=2)
        reg_p = re.compile('<NER Result>')
        prompt = re.sub(reg_p, ner_res, prompt)
    else:
        reg_p = re.compile('\n<NER Result>\n')
        prompt = re.sub(reg_p, '', prompt)
        reg_p = re.compile('\n## NER Result\n')
        prompt = re.sub(reg_p, '', prompt)

    # Query
    query = data['team'][idx]['question']
    reg_p = re.compile('<Current Query>')
    prompt = re.sub(reg_p, query, prompt)

    # Answers
    history = []
    answers = data['answer_generator']
    for i in range(len(answers)):
        ans = answers[i] # {'stage_n': ans}
        ans = list(ans.values())[0]
        query = data['team'][i]['question']
        history.append({'previous_query': query, "response": ans})
    history = json.dumps(history, ensure_ascii=False, indent=2)
    reg_p = re.compile('<Chat History>')
    prompt = re.sub(reg_p, history, prompt) 

    return prompt

def make_prompt_answer(data: dict, idx: int) -> str:

    prompt = ans_prompt_template

    # SQL Query
    if f'stage_{idx+1}' in data['sql_generator']:
        sql_query = data['sql_generator'][f'stage_{idx+1}'][0]['sql_query']
        reg_p = re.compile('<SQL Query>')
        prompt = re.sub(reg_p, sql_query, prompt)
    else:
        reg_p = re.compile('\n<SQL Query>\n')
        prompt = re.sub(reg_p, '', prompt)
        reg_p = re.compile('\n## SQL Query\n')
        prompt = re.sub(reg_p, '', prompt)

    # SQL Result
    if f'stage_{idx+1}' in data['sql_generator']:
        sql_res = str(data['sql_generator'][f'stage_{idx+1}'][0]['sql_res'])
        reg_p = re.compile('<SQL Result>')
        prompt = re.sub(reg_p, sql_res, prompt)
    else:
        reg_p = re.compile('\n<SQL Result>\n')
        prompt = re.sub(reg_p, '', prompt)
        reg_p = re.compile('\n## SQL Result\n')
        prompt = re.sub(reg_p, '', prompt)

    # NER Result
    if data['ner']['stage_1']['result']:
        ner_res = [i for i in data['ner']['stage_1']['sql'].values() if i][0]
        ner_res = [i['result'] for i in ner_res if i['result']][0][0]
        ner_res = json.dumps(ner_res, ensure_ascii=False, indent=2)
        reg_p = re.compile('<NER Result>')
        prompt = re.sub(reg_p, ner_res, prompt)
    else:
        reg_p = re.compile('\n<NER Result>\n')
        prompt = re.sub(reg_p, '', prompt)
        reg_p = re.compile('\n## NER Result\n')
        prompt = re.sub(reg_p, '', prompt)

    # replace query
    query = data['team'][idx]['question']
    reg_p = re.compile('<Current Query>')
    prompt = re.sub(reg_p, query, prompt)

    return prompt

## Process Query

### Init Deepseek

In [5]:
from openai import OpenAI

# deepseek_api = 'sk-ba0f5eed3bea4fa6be16eb33b139c684'

# client = OpenAI(api_key= deepseek_api, base_url="https://api.deepseek.com")

deepseek_api = 'db4a0fe1467d4456b3d83fe9bd413d84.shvUvvb2X9WkjRXW'

client = OpenAI(api_key= deepseek_api, base_url="https://open.bigmodel.cn/api/paas/v4/")

## Tools

In [6]:
api_key = "6b90d15d9a234097bd56ac10c19f22fb"

import requests
import json

def fetch_data(data: dict):
    url = "https://comm.chatglm.cn/finglm2/api/query"
    headers = {
        "Content-Type": "application/json",
        "Authorization": f"Bearer {api_key}"
    }
    response = requests.post(url, headers=headers, json=data)

    return response.json()

In [7]:
t = {
  "sql": "SELECT sm.ChiName AS FullName, lsa.AShareAbbr AS AShareAbbreviation, lsa.LegalRepr AS LegalRepresentative, lsa.LegalConsultant AS LegalConsultant, lsa.AccountingFirm AS AccountingFirm, lsa.SecretaryBD AS BoardSecretary FROM ConstantDB.SecuMain sm JOIN AStockBasicInfoDB.LC_StockArchives lsa ON sm.CompanyCode = lsa.CompanyCode WHERE sm.SecuCode = '600872';",
  "limit": 100
}

t = fetch_data(t)
t['data']

[{'FullName': '中炬高新技术实业(集团)股份有限公司',
  'AShareAbbreviation': '中炬高新',
  'LegalRepresentative': '余健华',
  'LegalConsultant': '广东卓建(中山)律师事务所',
  'AccountingFirm': '天职国际会计师事务所（特殊普通合伙）',
  'BoardSecretary': '郭毅航'}]

## TEST

In [None]:
idx = 9

sql_generator_history, answer_generator_history = [], []

data = questions[idx].copy()
data['sql_generator'] = {}
data['answer_generator'] = []

max_retries = 5

retry_delay = 1  # 重试延迟时间（秒）

for i in tqdm(range(len(data['team']))):
    # sql generator
    retries = 0

    while retries < max_retries:
        should_retry = False  # 标志变量，标记是否需要重试

        if i == 0:
            llm_sql_query = make_prompt_sql_1(data)
        else:
            llm_sql_query = make_prompt_sql_2(data, i)

        tmp_sql_generator_history = sql_generator_history.copy()
        tmp_sql_generator_history.append({'role': 'user', 'content': llm_sql_query})

        start_time = time.time()

        # API 调用（增加重试机制）
        for attempt in range(max_retries):
            try:
                response = client.chat.completions.create(
                    # model="deepseek-chat",
                    model='glm-4-plus',
                    messages=tmp_sql_generator_history,
                    stream=False,
                    top_p=0.7,
                    temperature=0.9
                )
                break  # 如果成功，跳出重试循环
            except Exception as e:
                print(f"API call failed (attempt {attempt + 1}/{max_retries}): {e}")
                if attempt == max_retries - 1:
                    raise  # 如果达到最大重试次数，抛出异常
                time.sleep(retry_delay)  # 等待一段时间后重试

        end_time = time.time()


        response = json.loads(response.to_json())
        raw_content = response['choices'][0]['message']['content']

        # 第一个 try-except：处理 JSON 解析
        try:
            content = json.loads(raw_content.strip('`json'))
        except Exception as e:
            print(f"Error processing JSON {data['team'][i]['id']} (attempt {retries + 1}/{max_retries}): {e}")
            should_retry = True  # 标记需要重试

        # 如果 JSON 解析成功，继续执行后续逻辑
        if not should_retry:
            usage = response['usage']
            execution_time = end_time - start_time

            data['sql_generator'][f'stage_{i+1}'] = [content]
            data['token_usage'][f'sql_generator-stage_{i+1}'] = [usage]
            data['time_usage'][f'sql_generator-stage_{i+1}'] = [f"{execution_time:.2f}s"]

            post_data = {
                'sql': content['sql_query'],
                'limit': 1000
            }

            # 第二个 try-except：处理 fetch_data
            try:
                sql_res = fetch_data(post_data)['data']
                data['sql_generator'][f'stage_{i+1}'][0]['sql_res'] = sql_res
            except Exception as e:
                print(f"Error fetching SQL data {data['team'][i]['id']} (attempt {retries + 1}/{max_retries}): {e}")
                should_retry = True  # 标记需要重试

        # 根据 should_retry 决定是否重试
        if should_retry:
            retries += 1
            if retries == max_retries:
                print(f"Failed to process question {data['team'][i]['id']} after {max_retries} attempts.")
                data['sql_generator'][f'stage_{i+1}'][0]['sql_res'] = []  # 设置默认值
            continue  # 跳过本次循环，进入下一次重试

        # 如果成功，跳出重试循环
        break

    # 更新历史记录
    sql_generator_history.extend([
        {'role': 'user', 'content': llm_sql_query},
        {'role': 'assistant', 'content': raw_content}
    ])

    # answer generator（保持不变）
    llm_answer_query = make_prompt_answer(data, i)
    answer_generator_history.append({'role': 'user', 'content': llm_answer_query})

    start_time = time.time()
    response = client.chat.completions.create(
        # model="deepseek-chat",
        model='glm-4-plus',
        messages=answer_generator_history,
        stream=False,
        top_p=0.7,
        temperature=0.9
    )
    end_time = time.time()

    response = json.loads(response.to_json())
    content = response['choices'][0]['message']['content']

    data['answer_generator'].append({f'stage_{i+1}': content})
    data['token_usage'][f'answer_generator-stage_{i+1}'] = [usage]
    data['time_usage'][f'answer_generator-stage_{i+1}'] = [f"{execution_time:.2f}s"]

    answer_generator_history.append({'role': 'assistant', 'content': content})

    print("======llm_sql_query======")
    print(llm_sql_query)
    print("======llm_sql_query======")

    print("======llm_answer_query======")
    print(llm_answer_query)
    print("======llm_answer_query======")

In [14]:
print(sql_generator_history[1]['content'])

```json
{
  "query": "000958公司2021年主营业务产品有哪些？（合并报表调整后的，金额保留2位小数）",
  "sql_cot_reasoning": "To find the main business products of company 000958 in 2021, we need to select the 'Project' column from the 'LC_MainOperIncome' table in the 'AStockFinanceDB' database. We also need to ensure that the data is from the adjusted consolidated financial statements, so we will include conditions for 'IfMerged' and 'IfAdjusted'. Additionally, we need to filter the data for the year 2021, so we will use the 'EndDate' column with the 'LIKE' operator. Finally, we will round the 'MainOperIncome' to 2 decimal places using the 'ROUND' function.",
  "sql_query": "SELECT Project, ROUND(MainOperIncome, 2) AS MainOperIncome FROM AStockFinanceDB.LC_MainOperIncome WHERE CompanyCode = '000958' AND IfMerged = 1 AND IfAdjusted = 1 AND EndDate LIKE '2021%'",
  "sql_explanation": "This SQL query retrieves the main business products and their corresponding income for company 000958 in 2021 from the 'LC_MainOperIncome'

# Generate Final Answer 

In [193]:
src_fname = 'glm-answer_generator-0127.json'
saved_fname = 'submit-stage_1-250127.json'
template_fname = 'submit_example.json'

src_fpath = os.path.join('answer_tmp', src_fname)
saved_fpath = os.path.join('answer', saved_fname)
template_fpath = os.path.join('data', template_fname)

raw_data = parse_data.read_json(src_fpath)
answers = parse_data.read_json(template_fpath)

# raw_data = sorted(raw_data, key=lambda x: int(x['tid'].split('-')[-1]))

# raw_data[1]

In [199]:
raw_data[1]['answer_generator']

{'stage_1': '美亚光电在2021年的减持计划中，最大可减持股份数量与最小可减持股份数量的差距是0。',
 'stage_2': '美亚光电在2021年的减持计划中涉及了1名股东。',
 'stage_3': '美亚光电在2021年的减持计划中，股东张建军的最大减持比例最高，为0.007027%。'}

In [201]:
for i in raw_data[:]:
    tid = i['tid']
    try:
        for num in range(len(i['team'])):
            answer = i['answer_generator'][f'stage_{num+1}']
            
            for j in answers:
                if j['tid'] == tid:
                    j['team'][num]['answer'] = answer
    except:
        print(tid)

parse_data.write_json(answers, saved_fpath)

tttt----78
tttt----84
