# Init

In [117]:
import sys
import os
import json
import time
import re
import copy

import pandas as pd
from tqdm import tqdm
import numpy as np
from sentence_transformers import SentenceTransformer
from sentence_transformers import util
import ast
from openai import OpenAI

cwd = os.getcwd()
os.chdir(cwd)
sys.path.append('tools')

import chat
import parse_data
import sql

In [2]:
question_path = os.path.join(cwd, 'data' + os.sep + 'question-b.json')

questions = parse_data.read_json(question_path)
# sort the questions by tid
questions = sorted(questions, key=lambda x: int(x['tid'].split('-')[-1]))

# NER

## TOOL

In [3]:
def sub_select_content(query: str) -> str:
    reg_p = re.compile('(?<=SELECT )\w+(?= FROM)')
    res = re.sub(reg_p, '*', query)

    return res

## Craft Prompt

In [4]:
system_prompt = ""

version = 'v2.2.1'
task = 'ner'

prompt_dir = os.path.join(cwd, 'prompt')
fname = f'{task}-stage_1-{version}.md'
ner_prompt_fpath = os.path.join(prompt_dir, fname)
ner_valid_prompt_fpath = os.path.join(prompt_dir, 'ner-validation-v1.0.0.md')

with open(ner_prompt_fpath, 'r') as f:
    ner_prompt_template = ''.join(f.readlines())

with open(ner_valid_prompt_fpath, 'r') as f:
    ner_valid_prompt_template = ''.join(f.readlines())

def make_ner_prompt(conversation_turn: dict) -> str:

    query = conversation_turn['team'][0]['question']

    prompt = ner_prompt_template + query

    return prompt

def make_ner_valid_prompt(question: dict) -> str:

    query = question['team'][0]['question']
    ner_result = question['ner']['stage_1']['result']
    ner_sql_result = question['ner']['stage_1']['sql']

    # use the number of potential ner sql queries to determine whether to include raw query
    l = max((len(v) for v in ner_sql_result.values()), default=1)

    if l == 3:
        tmp = {
            "query": query,
            "ner_result": ner_result,
            "ner_sql_result": ner_sql_result
        }
    else:
        tmp = {
            "ner_result": ner_result,
            "ner_sql_result": ner_sql_result
        } 

    prompt = ner_valid_prompt_template + json.dumps(tmp, ensure_ascii=False, indent=2)

    return prompt

## GLM

In [5]:
model = 'glm_4_plus'

### Test

In [19]:
query = make_ner_prompt(questions[51])

history = []

start_time = time.time()
message = chat.create_message(query, history=history, system_prompt=system_prompt, temperature=0.7, top_p=0.9, response_format='text')
end_time = time.time()

execution_time = end_time - start_time
usage = chat.get_token_usage(message, True)
content = chat.get_content(message, True)
history = chat.build_history(history, message=message)

t = questions[0].copy()
t['ner_result'] = {}
t['ner_result']['stage_1'] = json.loads(content.strip('`json'))
t['token_usage'] = {}
t['token_usage']['ner-stage_1'] = usage
t['time_usage'] = {}
t['time_usage']['ner-stage_1'] = f"{execution_time:.2f}s"
t = [t]

saved_path = os.path.join(cwd, 'answer_tmp' + os.sep + f'stage_1-{model}-{task}-test-{version}.json')
parse_data.write_json(t, saved_path)

{'prompt_tokens': 924, 'completion_tokens': 105, 'total_tokens': 1029}
```json
{
    "reasoning_process_cot": "从问题中可以看出，查询的是在特定日期（2021-12-21）A股市场中创出月度新高的公司及其证券代码。虽然问题中没有直接提及具体的公司名称和代码，但可以推断出需要识别的实体类型包括上市公司名称和代码。由于问题中没有提供具体的公司名称和代码，无法直接识别出具体的实体，但可以明确需要识别的实体类型。",
    "result": []
}
```


### ALL

Bad cases:

27:28 博时基金公司成立于？用XXXX年XX月XX日回复我

40:41 嘉实致元42个月定期债券基金的管理经理是谁？

52:53 南方亨元债券A在2019年的分红次数是多少？每次分红的派现比例是多少？

63:64 博时基金公司成立于？用XXXX年XX月XX日回复我

78:79 Huazhu Group Ltd.这家公司在美股英文名称是什么？

In [21]:
answers = []
ner_max_tried = 6

for question in tqdm(questions[:]):
    
    query = make_ner_prompt(question)

    history = []
    tried = 0

    while tried < ner_max_tried:
        start_time = time.time()
        message = chat.create_message(query, history=history, system_prompt=system_prompt, temperature=0.1, top_p=1, response_format='text')
        end_time = time.time()

        execution_time = end_time - start_time
        usage = chat.get_token_usage(message, False)
        content = chat.get_content(message, False)

        res = question.copy()
        res['ner'] = {}
        try:
            res['ner']['stage_1'] = json.loads(content.strip('`json'))
            break
        except:
            print(f"JSON parsing error: {query}")
            tried += 1

    res['token_usage'] = {}
    res['token_usage']['ner-stage_1'] = usage
    res['time_usage'] = {}
    res['time_usage']['ner-stage_1'] = f"{execution_time:.2f}s"

    # obtain sql results
    res['ner']['stage_1'] = sql.process_ner_res(res['ner']['stage_1'])

    # check results
    ner = res['ner']['stage_1'] # Extract the 'stage_1' data from the 'ner' key
    ner_result = res['ner']['stage_1']['result'] # Extract the 'result' from 'stage_1'

    # Check if all `result` fields in the `sql` data are empty
    all_results_empty = True  # Assume initially that all `result` fields are empty

    # Iterate through the `sql` dictionary in the `ner` object
    for key, queries in ner.get('sql', {}).items():
        for query_info in queries:  # Iterate through each query info in the list
            if query_info.get('result'):  # Check if the `result` field is not empty
                all_results_empty = False
                break  # Exit the loop immediately if a non-empty `result` is found
        if not all_results_empty:
            break  # Exit the outer loop if a non-empty `result` is found

    # If all `result` fields are empty and `ner_result` is not empty
    if all_results_empty and ner_result:
        tried = 0

        print('====NER Validation===')
        print(ner)

        while tried < ner_max_tried:
            query = make_ner_valid_prompt(res)
            history = []
            message = chat.create_message(query, history=history, system_prompt=system_prompt, temperature=0.7, top_p=0.9, response_format='text')
            content = chat.get_content(message, True)

            try:
                # JSON Format doesn't count a fail
                tmp_res = json.loads(content.strip('`json'))
            except:
                print("----JSON parsing error-----")
                continue

            sql_queries = [i['sql_query'] for i in tmp_res]
            
            for sql_query in sql_queries:
                sql_query = sub_select_content(sql_query)
                sql_res = sql.get_data_from_sql_query(sql_query)

                if sql_res:
                    break
            
            if sql_res:
                break
            else:
                tried += 1

        if sql_res:
            kv = list(ner['sql'].keys())[0]
            res['ner']['stage_1']['sql'][kv].append({
                "query": sql_query,
                "result": sql_res
            })
        else:
            print('====FAILED===')
            print('====NER Validation===')
            pass
        
    answers.append(res)

saved_path = os.path.join(cwd, 'answer_tmp' + os.sep + f'stage_1-{model}-{task}-{version}.json')
parse_data.write_json(answers, saved_path)

 27%|██▋       | 27/100 [01:40<04:32,  3.73s/it]

====NER Validation===
{'reasoning_process_cot': "从问题中可以看出，'博时基金公司'是一个基金公司名称，因此需要识别为基金公司名称实体。问题中未提及其他实体，如上市公司名称、股票代码、基金名称或行业名称。", 'result': [{'基金公司名称': '博时基金公司'}], 'sql': {'基金公司名称:博时基金公司': [{'query': "SELECT * FROM InstitutionDB.LC_InstiArchive WHERE '博时基金公司' IN (ChiName, AbbrChiName, NameChiSpelling, EngName, AbbrEngName)", 'result': []}]}}
### **输出**

```json
[
  {
    "potential_reason_cot_thinking": "原始 NER 结果中的 '博时基金公司' 可能存在名称不完整或拼写错误，导致在数据库中未能匹配到相关记录。考虑可能的名称变体或简称。",
    "sql_query": "SELECT * FROM InstitutionDB.LC_InstiArchive WHERE '博时基金' IN (ChiName, AbbrChiName, NameChiSpelling, EngName, AbbrEngName)",
    "sql_explanation": "调整查询条件，将 '博时基金公司' 改为 '博时基金'，以涵盖可能的简称或变体名称，从而提高匹配成功率。"
  },
  {
    "potential_reason_cot_thinking": "数据库中可能存在不同的名称格式，如全称、简称或英文名称。尝试使用模糊匹配来增加匹配范围。",
    "sql_query": "SELECT * FROM InstitutionDB.LC_InstiArchive WHERE ChiName LIKE '%博时基金%' OR AbbrChiName LIKE '%博时基金%' OR NameChiSpelling LIKE '%博时基金%' OR EngName LIKE '%Boshi Fund%' OR AbbrEngName LIKE '%Boshi

 33%|███▎      | 33/100 [03:39<09:47,  8.77s/it]

====NER Validation===
{'reasoning_process_cot': "分析当前查询内容，'在线教育' 是一个实体。接着，分析其可能的属性：'在线教育' 可以是一个行业名称，也可以是一个概念。结合查询内容，问题是在讨论 '在线教育' 属于科技概念的哪个分支，这表明 '在线教育' 在这里是作为一个行业名称被提及。因此，'在线教育' 应被识别为 '行业名称'。问题中还提到了 '科技概念' 的英文名称，但这不属于需要识别的五大类实体之一，所以无需识别。", 'result': [{'行业名称': '在线教育'}], 'sql': {'行业名称:在线教育': [{'query': "SELECT FirstIndustryCode AS 一级行业代码, SecondIndustryCode AS 二级行业代码, ThirdIndustryCode AS 三级行业代码, FourthIndustryCode AS 四级行业代码, FirstIndustryName AS 一级行业名称, SecondIndustryName AS 二级行业名称, ThirdIndustryName AS 三级行业名称, FourthIndustryName AS 四级行业名称 FROM AStockIndustryDB.LC_ExgIndustry WHERE '在线教育' IN (FirstIndustryName, SecondIndustryName, ThirdIndustryName, FourthIndustryName)", 'result': []}]}}
```json
[
  {
    "potential_reason_cot_thinking": "由于原始 NER 结果 '在线教育' 在数据库查询中未返回任何结果，可能的原因是该行业名称在数据库中使用了不同的表述或分类。需要调整 NER 结果以匹配数据库中的实际行业名称。",
    "sql_query": "SELECT FirstIndustryCode AS 一级行业代码, SecondIndustryCode AS 二级行业代码, ThirdIndustryCode AS 三级行业代码, FourthIndustryCode AS 四级行业代码, FirstIndustryName

 40%|████      | 40/100 [04:17<04:07,  4.13s/it]

====NER Validation===
{'reasoning_process_cot': '从当前查询中，可以看出涉及到一个基金名称‘嘉实致元42个月定期债券基金’，问题询问的是该基金的管理经理是谁。‘嘉实致元42个月定期债券基金’是一个基金名称，而查询中并未提及其他实体，如上市公司名称、股票代码、基金公司名称或行业名称。', 'result': [{'基金名称': '嘉实致元42个月定期债券基金'}], 'sql': {'基金名称:嘉实致元42个月定期债券基金': [{'query': "SELECT * FROM PublicFundDB.MF_FundProdName WHERE DisclName = '嘉实致元42个月定期债券基金' LIMIT 1", 'result': []}, {'query': "SELECT * FROM ConstantDB.SecuMain WHERE '嘉实致元42个月定期债券基金' IN (ChiName, ChiNameAbbr, EngName, EngNameAbbr, SecuAbbr, ChiSpelling) LIMIT 1", 'result': []}]}}
```json
[
  {
    "potential_reason_cot_thinking": "原始 NER 结果 '嘉实致元42个月定期债券基金' 在数据库中未能匹配到相应实体，可能是因为名称存在细微差异或数据库中未收录该基金。考虑可能的名称变体或缩写。",
    "sql_query": "SELECT * FROM PublicFundDB.MF_FundProdName WHERE DisclName LIKE '%嘉实致元42个月%' LIMIT 1",
    "sql_explanation": "使用 LIKE 语句模糊匹配基金名称，增加匹配灵活性，尝试找到包含 '嘉实致元42个月' 的基金名称。"
  },
  {
    "potential_reason_cot_thinking": "基金名称可能存在简称或不同表述，尝试在更广泛的证券主表中查找相似名称。",
    "sql_query": "SELECT * FROM ConstantDB.SecuMain WHERE ChiName LIKE '%嘉实致元4

 58%|█████▊    | 58/100 [05:35<02:50,  4.07s/it]

====NER Validation===
{'reasoning_process_cot': '从问题中可以看出，询问的是在2022年期间进行公司名称全称变更的公司及其代码。这里涉及到两个实体类型：上市公司名称和代码。虽然具体公司名称和代码未在问题中明确提及，但问题的意图是查找这些信息。因此，我们需要识别出问题中隐含的实体类型，即上市公司名称和代码。', 'result': [{'上市公司名称': '公司名称全称变更的公司'}, {'代码': '公司代码'}], 'sql': {'上市公司名称:公司名称全称变更的公司': [{'query': "SELECT * FROM ConstantDB.SecuMain WHERE '公司名称全称变更的公司' IN (ChiName, ChiNameAbbr, EngName, EngNameAbbr, SecuAbbr, ChiSpelling)", 'result': []}, {'query': "SELECT * FROM ConstantDB.HK_SecuMain WHERE '公司名称全称变更的公司' IN (ChiName, ChiNameAbbr, EngName, EngNameAbbr, SecuAbbr, ChiSpelling)", 'result': []}, {'query': "SELECT * FROM ConstantDB.US_SecuMain WHERE '公司名称全称变更的公司' IN (SecuCode, SecuAbbr, ChiSpelling, EngName, ChiName)", 'result': []}], '代码:公司代码': [{'query': None, 'result': None}]}}
```json
[
  {
    "potential_reason_cot_thinking": "原始 NER 结果中的 '公司名称全称变更的公司' 是一个模糊的描述，无法直接用于数据库查询。需要将其转换为具体的公司名称或相关字段。",
    "sql_query": "SELECT SecuCode, ChiName, ChiNameAbbr FROM ConstantDB.SecuMain WHERE SecuCode IN (SELECT SecuCode

 59%|█████▉    | 59/100 [07:19<23:16, 34.05s/it]

Request failed: 500 Server Error: Internal Server Error for url: https://comm.chatglm.cn/finglm2/api/query
====FAILED===
====NER Validation===


 63%|██████▎   | 63/100 [07:33<06:33, 10.63s/it]

====NER Validation===
{'reasoning_process_cot': "从问题中可以看出，'博时基金公司'是一个基金公司名称，因此需要识别为基金公司名称实体。问题中未提及其他实体，如上市公司名称、股票代码、基金名称或行业名称。", 'result': [{'基金公司名称': '博时基金公司'}], 'sql': {'基金公司名称:博时基金公司': [{'query': "SELECT * FROM InstitutionDB.LC_InstiArchive WHERE '博时基金公司' IN (ChiName, AbbrChiName, NameChiSpelling, EngName, AbbrEngName)", 'result': []}]}}
### **输出**

```json
[
  {
    "potential_reason_cot_thinking": "NER 结果中的 '博时基金公司' 可能存在名称不完整或简称未被识别的情况。在金融数据库中，机构名称可能以全称、简称、拼音或英文名称存在。因此，需要检查所有可能的名称形式。",
    "sql_query": "SELECT * FROM InstitutionDB.LC_InstiArchive WHERE '博时基金公司' IN (ChiName, AbbrChiName, NameChiSpelling, EngName, AbbrEngName) OR '博时' IN (ChiName, AbbrChiName, NameChiSpelling, EngName, AbbrEngName)",
    "sql_explanation": "此查询扩展了原始查询，增加了对 '博时' 这一简称的搜索，以覆盖更多可能的名称匹配情况。这样可以提高找到正确机构信息的概率。"
  },
  {
    "potential_reason_cot_thinking": "NER 结果可能存在名称拼写错误或格式不一致的情况。在金融数据库中，机构名称可能存在多种拼写或格式变体。因此，需要使用模糊匹配来提高检索准确性。",
    "sql_query": "SELECT * FROM InstitutionDB.LC_InstiArchive WHERE ChiName L

 64%|██████▍   | 64/100 [09:06<21:15, 35.42s/it]

```json
[
  {
    "potential_reason_cot_thinking": "原始 NER 结果中的 '博时基金公司' 在数据库中未找到匹配项，可能是因为名称存在差异或数据库中未收录该实体。需要调整名称以匹配数据库中的实体。",
    "sql_query": "SELECT * FROM InstitutionDB.LC_InstiArchive WHERE '博时基金管理有限公司' IN (ChiName, AbbrChiName, NameChiSpelling, EngName, AbbrEngName)",
    "sql_explanation": "将 '博时基金公司' 调整为 '博时基金管理有限公司'，因为数据库中可能使用全称。这样可以提高查询匹配的概率。"
  },
  {
    "potential_reason_cot_thinking": "原始 NER 结果中的 '博时基金公司' 可能存在简称或别称，数据库中可能使用不同的名称形式。",
    "sql_query": "SELECT * FROM InstitutionDB.LC_InstiArchive WHERE '博时基金' IN (ChiName, AbbrChiName, NameChiSpelling, EngName, AbbrEngName)",
    "sql_explanation": "尝试使用简称 '博时基金' 进行查询，以覆盖数据库中可能存在的简称或别称形式。"
  },
  {
    "potential_reason_cot_thinking": "原始 NER 结果中的 '博时基金公司' 可能存在拼音或英文简称，数据库中可能使用这些形式。",
    "sql_query": "SELECT * FROM InstitutionDB.LC_InstiArchive WHERE 'Boshi Fund' IN (ChiName, AbbrChiName, NameChiSpelling, EngName, AbbrEngName)",
    "sql_explanation": "尝试使用英文简称 'Boshi Fund' 进行查询，以覆盖数据库中可能存在的英文或拼音形式。"
  }
]
```


 78%|███████▊  | 78/100 [09:57<01:27,  3.99s/it]

====NER Validation===
{'reasoning_process_cot': "从问题中可以看出，'Huazhu Group Ltd.'是一个上市公司名称，因为它提到了'公司'并且询问的是该公司的美股英文名称。因此，需要将'Huazhu Group Ltd.'识别为上市公司名称实体。", 'result': [{'上市公司名称': 'Huazhu Group Ltd.'}], 'sql': {'上市公司名称:Huazhu Group Ltd.': [{'query': "SELECT * FROM ConstantDB.SecuMain WHERE 'Huazhu Group Ltd.' IN (ChiName, ChiNameAbbr, EngName, EngNameAbbr, SecuAbbr, ChiSpelling)", 'result': []}, {'query': "SELECT * FROM ConstantDB.HK_SecuMain WHERE 'Huazhu Group Ltd.' IN (ChiName, ChiNameAbbr, EngName, EngNameAbbr, SecuAbbr, ChiSpelling)", 'result': []}, {'query': "SELECT * FROM ConstantDB.US_SecuMain WHERE 'Huazhu Group Ltd.' IN (SecuCode, SecuAbbr, ChiSpelling, EngName, ChiName)", 'result': []}]}}
```json
[
  {
    "potential_reason_cot_thinking": "原始 NER 结果中的 'Huazhu Group Ltd.' 可能是中文名称或简称，而数据库中可能存储的是其英文名称或简称。因此，需要尝试不同的字段组合来匹配该实体。",
    "sql_query": "SELECT EngName FROM ConstantDB.US_SecuMain WHERE ChiName = 'Huazhu Group Ltd.' OR ChiNameAbbr = 'Huazhu Group Ltd.' OR EngNameAbbr = 'Huaz

100%|██████████| 100/100 [12:18<00:00,  7.38s/it]


## Check Results

In [25]:
answer_dir = os.path.join(cwd, 'answer_tmp')
fname = f'stage_1-{model}-{task}-{version}-HF.json'
fpath = os.path.join(answer_dir, fname)

data = parse_data.read_json(fpath)

for i in data[:]:  # Iterate through each element in the data list
    ner = i['ner']['stage_1']  # Extract the 'stage_1' data from the 'ner' key
    ner_result = i['ner']['stage_1']['result']  # Extract the 'result' from 'stage_1'

    # Check if all `result` fields in the `sql` data are empty
    all_results_empty = True  # Assume initially that all `result` fields are empty

    # Iterate through the `sql` dictionary in the `ner` object
    for key, queries in ner.get('sql', {}).items():
        for query_info in queries:  # Iterate through each query info in the list
            if query_info.get('result'):  # Check if the `result` field is not empty
                all_results_empty = False
                break  # Exit the loop immediately if a non-empty `result` is found
        if not all_results_empty:
            break  # Exit the outer loop if a non-empty `result` is found

    # If all `result` fields are empty and `ner_result` is not empty
    if all_results_empty and ner_result:
        print(i['team'][0])  # Print the first element of the 'team' list
        print(ner)  # Print the 'ner' object
        print()  # Print an empty line for separation

{'id': 'tttt----59----15-3-1', 'question': '2022年之间 哪些公司进行公司名称全称变更，公司代码是什么？'}
{'reasoning_process_cot': '从问题中可以看出，询问的是在2022年期间进行公司名称全称变更的公司及其代码。这里涉及到两个实体类型：上市公司名称和代码。虽然具体公司名称和代码未在问题中明确提及，但问题的意图是查找这些信息。因此，我们需要识别出问题中隐含的实体类型，即上市公司名称和代码。', 'result': [{'上市公司名称': '公司名称全称变更的公司'}, {'代码': '公司代码'}], 'sql': {'上市公司名称:公司名称全称变更的公司': [{'query': "SELECT * FROM ConstantDB.SecuMain WHERE '公司名称全称变更的公司' IN (ChiName, ChiNameAbbr, EngName, EngNameAbbr, SecuAbbr, ChiSpelling)", 'result': []}, {'query': "SELECT * FROM ConstantDB.HK_SecuMain WHERE '公司名称全称变更的公司' IN (ChiName, ChiNameAbbr, EngName, EngNameAbbr, SecuAbbr, ChiSpelling)", 'result': []}, {'query': "SELECT * FROM ConstantDB.US_SecuMain WHERE '公司名称全称变更的公司' IN (SecuCode, SecuAbbr, ChiSpelling, EngName, ChiName)", 'result': []}], '代码:公司代码': [{'query': None, 'result': None}]}}



# Market

In [78]:
question_path = os.path.join(cwd, 'answer_tmp' + os.sep + 'stage_1-glm_4_plus-ner-v2.2.1-HF.json')

questions = parse_data.read_json(question_path)
# sort the questions by tid
questions = sorted(questions, key=lambda x: int(x['tid'].split('-')[-1]))

## Tool

In [79]:
def parse_database_and_table(query: str) -> dict:
    """
    Parse the given SQL query to return the database and table names in a dictionary.
    
    Args:
    - query (str): The SQL query to parse.
    
    Returns:
    - dict: A dictionary with 'database' and 'table' keys.
    """

    pattern = r'FROM\s+([a-zA-Z0-9_]+)\.([a-zA-Z0-9_]+)'
    match = re.search(pattern, query, re.IGNORECASE)
    
    if match:
        database = match.group(1)
        table = match.group(2)
        return {'database': database, 'table': table}
    
    return {}

## Craft Prompt

In [119]:
system_prompt = ""

version = 'v1.0.0'
task = 'market_classifier'

prompt_dir = os.path.join(cwd, 'prompt')
market_fname = 'market_classifier-v1.0.0.md'
fpath = os.path.join(prompt_dir, market_fname)

with open(fpath, 'r') as f:
    market_prompt_template = ''.join(f.readlines())

def make_prompt_market(data: dict) -> str:

    prompt = market_prompt_template

    # replace query
    query = data['team'][0]['question']
    reg_p = re.compile('<Current Query>')
    prompt = re.sub(reg_p, query, prompt)

    # NER Result
    if data['ner']['stage_1']['result']:
        ner_res = data['ner']['stage_1']['sql']
        ner_res = json.dumps(ner_res, ensure_ascii=False, indent=2)
        reg_p = re.compile('<给出的信息>')
        prompt = re.sub(reg_p, ner_res, prompt)
    else:
        reg_p = re.compile('\n<给出的信息>\n')
        prompt = re.sub(reg_p, '', prompt)
        reg_p = re.compile('\n## \*\*给出的信息\*\*\n')
        prompt = re.sub(reg_p, '', prompt)
        pass

    return prompt

In [123]:
for i in questions:
    try:
        tmp = make_prompt_market(i)
    except:
        print(i)

## GLM

In [127]:
answers = []
ner_max_tried = 6

# remove usage
if 'usage' in locals():
    del locals()['usage']

for question in tqdm(questions[:]):
    
    query = make_prompt_market(question)

    ner_result = question['ner']['stage_1']['result']
    ner_sql_result = question['ner']['stage_1']['sql']

    res = question.copy()
    res['market'] = {}

    # if only one result, use that as market flag, others, run llm to judge
    cnt = 0

    for k, v in ner_sql_result.items():
        for i in v:
            if i.get('result', None):
                cnt += 1

    if cnt == 1:
        for k, v in ner_sql_result.items():
            for j in v:
                if not j['result']:
                    continue

                sql_query = j['query']
                # add database and table
                table = parse_database_and_table(sql_query)
        
                # get market
                if table['table'] == 'US_SecuMain':
                    res['market'] = {"query": question['team'][0]['question'],
                                    "cot_thinking": None,  "market": "US"}
                elif table['table'] == 'HK_SecuMain':
                    res['market'] = {"query": question['team'][0]['question'],
                                    "cot_thinking": None,  "market": "HK"}
                else:
                    res['market'] = {"query": question['team'][0]['question'],
                                    "cot_thinking": None,  "market": "CN"}
    else:
        print("---NOT ONE---")
        print(ner_sql_result)
        tried = 0

        while tried < ner_max_tried:
            history = []
            start_time = time.time()
            try:
                message = chat.create_message(query, history=history, system_prompt=system_prompt, temperature=0.1, top_p=1, response_format='text')
            except:
                print(f'API ERROR')
                tried += 1
            end_time = time.time()

            execution_time = f"{end_time - start_time:.2f}s"
            usage = chat.get_token_usage(message, False)
            content = chat.get_content(message, False)

            res = question.copy()
            res['market'] = {}
            try:
                res['market'] = json.loads(content.strip('`json'))
                break
            except:
                print(f"JSON parsing error: {query}")
                tried += 1

    res['token_usage']['market'] = locals().get('usage', 0)
    res['time_usage']['market'] = locals().get('execution_time', 0)
        
    answers.append(res)

saved_path = os.path.join(cwd, 'answer_tmp' + os.sep + f'{model}-{task}-{version}.json')
parse_data.write_json(answers, saved_path)

  0%|          | 0/100 [00:00<?, ?it/s]

---NOT ONE---
{}


  2%|▏         | 2/100 [00:06<05:04,  3.10s/it]

---NOT ONE---
{}


  7%|▋         | 7/100 [00:08<01:42,  1.10s/it]

---NOT ONE---
{}


  8%|▊         | 8/100 [00:13<02:32,  1.66s/it]

---NOT ONE---
{}


  9%|▉         | 9/100 [00:17<03:18,  2.18s/it]

---NOT ONE---
{'上市公司名称:中国人寿': [{'query': "SELECT * FROM ConstantDB.SecuMain WHERE '中国人寿' IN (ChiName, ChiNameAbbr, EngName, EngNameAbbr, SecuAbbr, ChiSpelling)", 'result': []}, {'query': "SELECT * FROM ConstantDB.HK_SecuMain WHERE '中国人寿' IN (ChiName, ChiNameAbbr, EngName, EngNameAbbr, SecuAbbr, ChiSpelling)", 'result': [{'ID': 16808232845300, 'InnerCode': 1000872, 'CompanyCode': 4057, 'SecuCode': '02628', 'ChiName': '中国人寿保险股份有限公司', 'ChiNameAbbr': None, 'EngName': 'China Life Insurance Company Limited', 'EngNameAbbr': 'CHINA LIFE', 'SecuAbbr': '中国人寿', 'ChiSpelling': 'ZGRS', 'SecuMarket': 72, 'SecuCategory': 3, 'ListedDate': '2003-12-18 12:00:00.000', 'ListedSector': 1, 'ListedState': 1, 'XGRQ': '2018-11-29 11:24:25.720', 'JSID': 596805865723, 'DelistingDate': None, 'ISIN': 'CNE1000002L3', 'FormerName': None, 'TradingUnit': 1000.0, 'TraCurrUnit': 1100, 'InsertTime': '2005-10-12 02:23:57.983'}]}, {'query': "SELECT * FROM ConstantDB.US_SecuMain WHERE '中国人寿' IN (SecuCode, SecuAbbr, ChiSpell

 16%|█▌        | 16/100 [00:22<01:39,  1.18s/it]

---NOT ONE---
{'上市公司名称:大唐国际发电股份有限公司': [{'query': "SELECT * FROM ConstantDB.SecuMain WHERE '大唐国际发电股份有限公司' IN (ChiName, ChiNameAbbr, EngName, EngNameAbbr, SecuAbbr, ChiSpelling)", 'result': [{'ID': 217614716484, 'InnerCode': 4874, 'CompanyCode': 3848, 'SecuCode': '601991', 'ChiName': '大唐国际发电股份有限公司', 'ChiNameAbbr': '大唐发电', 'EngName': 'Datang International Power Generation Co., Ltd.', 'EngNameAbbr': 'Datang Power', 'SecuAbbr': '大唐发电', 'ChiSpelling': 'DTFD', 'SecuMarket': 83, 'SecuCategory': 1, 'ListedDate': '2006-12-20 12:00:00.000', 'ListedSector': 1, 'ListedState': 1, 'XGRQ': '2017-03-16 04:30:01.650', 'JSID': 542953801666, 'ISIN': 'CNE000001Q02', 'ExtendedAbbr': None, 'ExtendedSpelling': None}]}, {'query': "SELECT * FROM ConstantDB.HK_SecuMain WHERE '大唐国际发电股份有限公司' IN (ChiName, ChiNameAbbr, EngName, EngNameAbbr, SecuAbbr, ChiSpelling)", 'result': []}, {'query': "SELECT * FROM ConstantDB.US_SecuMain WHERE '大唐国际发电股份有限公司' IN (SecuCode, SecuAbbr, ChiSpelling, EngName, ChiName)", 'result': [{

 20%|██        | 20/100 [00:29<01:56,  1.45s/it]

---NOT ONE---
{}


 21%|██        | 21/100 [00:33<02:17,  1.74s/it]

API ERROR
---NOT ONE---
{}


 22%|██▏       | 22/100 [00:39<02:53,  2.22s/it]

---NOT ONE---
{'上市公司名称:中国联通': [{'query': "SELECT * FROM ConstantDB.SecuMain WHERE '中国联通' IN (ChiName, ChiNameAbbr, EngName, EngNameAbbr, SecuAbbr, ChiSpelling)", 'result': []}, {'query': "SELECT * FROM ConstantDB.HK_SecuMain WHERE '中国联通' IN (ChiName, ChiNameAbbr, EngName, EngNameAbbr, SecuAbbr, ChiSpelling)", 'result': [{'ID': 16808232655600, 'InnerCode': 1000593, 'CompanyCode': 1000593, 'SecuCode': '00762', 'ChiName': '中国联合网络通信(香港)股份有限公司', 'ChiNameAbbr': None, 'EngName': 'China Unicom (Hong Kong) Limited', 'EngNameAbbr': 'CHINA UNICOM', 'SecuAbbr': '中国联通', 'ChiSpelling': 'ZGLT', 'SecuMarket': 72, 'SecuCategory': 53, 'ListedDate': '2000-06-22 12:00:00.000', 'ListedSector': 1, 'ListedState': 1, 'XGRQ': '2019-12-12 01:03:05.373', 'JSID': 629427788469, 'DelistingDate': None, 'ISIN': 'HK0000049939', 'FormerName': None, 'TradingUnit': 2000.0, 'TraCurrUnit': 1100, 'InsertTime': '2005-10-12 02:23:28.437'}]}, {'query': "SELECT * FROM ConstantDB.US_SecuMain WHERE '中国联通' IN (SecuCode, SecuAbbr, 

 33%|███▎      | 33/100 [00:44<01:11,  1.06s/it]

---NOT ONE---
{}


 34%|███▍      | 34/100 [00:49<01:26,  1.32s/it]

---NOT ONE---
{}


 39%|███▉      | 39/100 [00:52<01:04,  1.06s/it]

---NOT ONE---
{}


 40%|████      | 40/100 [00:55<01:17,  1.29s/it]

---NOT ONE---
{}


 44%|████▍     | 44/100 [01:01<01:14,  1.33s/it]

---NOT ONE---
{}


 46%|████▌     | 46/100 [01:05<01:19,  1.47s/it]

---NOT ONE---
{'基金名称:嘉实超短债债券A': [{'query': "SELECT * FROM PublicFundDB.MF_FundProdName WHERE DisclName = '嘉实超短债债券A' LIMIT 1", 'result': [{'ID': 677756078394, 'InnerCode': 381528, 'InfoPublDate': '2021-06-23 12:00:00.000', 'InfoSource': '产品资料概要', 'InfoType': 6, 'DisclName': '嘉实超短债债券A', 'EffectiveDate': '2021-06-23 12:00:00.000', 'ExpiryDate': None, 'IfEffected': 1, 'Remark': None, 'UpdateTime': '2022-06-24 03:33:10.413', 'JSID': 709366159233, 'ChiSpelling': 'JSCDZZQA', 'TransCode': 381528, 'InsertTime': '2021-06-23 09:34:16.393'}]}, {'query': "SELECT * FROM ConstantDB.SecuMain WHERE '嘉实超短债债券A' IN (ChiName, ChiNameAbbr, EngName, EngNameAbbr, SecuAbbr, ChiSpelling) LIMIT 1", 'result': [{'ID': 677756000000, 'InnerCode': 381528, 'CompanyCode': 643820, 'SecuCode': '12773', 'ChiName': '嘉实超短债证券投资基金A类', 'ChiNameAbbr': '嘉实超短债A', 'EngName': 'Harvest Ultra Short-term Bond Fund-A', 'EngNameAbbr': None, 'SecuAbbr': '嘉实超短债债券A', 'ChiSpelling': 'JSCDZZQA', 'SecuMarket': None, 'SecuCategory': 8, 'Listed

 47%|████▋     | 47/100 [01:11<01:46,  2.01s/it]

---NOT ONE---
{}


 52%|█████▏    | 52/100 [01:15<01:09,  1.46s/it]

---NOT ONE---
{'上市公司名称:中兴通讯': [{'query': "SELECT * FROM ConstantDB.SecuMain WHERE '中兴通讯' IN (ChiName, ChiNameAbbr, EngName, EngNameAbbr, SecuAbbr, ChiSpelling)", 'result': []}, {'query': "SELECT * FROM ConstantDB.HK_SecuMain WHERE '中兴通讯' IN (ChiName, ChiNameAbbr, EngName, EngNameAbbr, SecuAbbr, ChiSpelling)", 'result': [{'ID': 16808232656300, 'InnerCode': 1000594, 'CompanyCode': 79, 'SecuCode': '00763', 'ChiName': '中兴通讯股份有限公司', 'ChiNameAbbr': None, 'EngName': 'ZTE Corporation', 'EngNameAbbr': 'ZTE', 'SecuAbbr': '中兴通讯', 'ChiSpelling': 'ZXTX', 'SecuMarket': 72, 'SecuCategory': 3, 'ListedDate': '2004-12-09 12:00:00.000', 'ListedSector': 1, 'ListedState': 1, 'XGRQ': '2018-06-19 05:01:54.580', 'JSID': 582742914581, 'DelistingDate': None, 'ISIN': 'CNE1000004Y2', 'FormerName': None, 'TradingUnit': 200.0, 'TraCurrUnit': 1100, 'InsertTime': '2005-10-12 02:23:28.513'}]}, {'query': "SELECT * FROM ConstantDB.US_SecuMain WHERE '中兴通讯' IN (SecuCode, SecuAbbr, ChiSpelling, EngName, ChiName)", 'result'

 54%|█████▍    | 54/100 [01:18<01:07,  1.48s/it]

---NOT ONE---
{}


 57%|█████▋    | 57/100 [01:21<00:57,  1.33s/it]

---NOT ONE---
{'上市公司名称:中国长城': [{'query': "SELECT * FROM ConstantDB.SecuMain WHERE '中国长城' IN (ChiName, ChiNameAbbr, EngName, EngNameAbbr, SecuAbbr, ChiSpelling)", 'result': [{'ID': 316488637204, 'InnerCode': 107, 'CompanyCode': 81, 'SecuCode': '000066', 'ChiName': '中国长城科技集团股份有限公司', 'ChiNameAbbr': '中国长城', 'EngName': 'China Greatwall Technology Group Co., Ltd.', 'EngNameAbbr': 'CGT Group', 'SecuAbbr': '中国长城', 'ChiSpelling': 'ZGCC', 'SecuMarket': 90, 'SecuCategory': 1, 'ListedDate': '1997-06-26 12:00:00.000', 'ListedSector': 1, 'ListedState': 1, 'XGRQ': '2017-04-28 08:25:01.723', 'JSID': 546726301724, 'ISIN': 'CNE000000RL7', 'ExtendedAbbr': None, 'ExtendedSpelling': None}]}, {'query': "SELECT * FROM ConstantDB.HK_SecuMain WHERE '中国长城' IN (ChiName, ChiNameAbbr, EngName, EngNameAbbr, SecuAbbr, ChiSpelling)", 'result': []}, {'query': "SELECT * FROM ConstantDB.US_SecuMain WHERE '中国长城' IN (SecuCode, SecuAbbr, ChiSpelling, EngName, ChiName)", 'result': []}], '代码:000066': [{'query': 'SELECT * FRO

 58%|█████▊    | 58/100 [01:27<01:20,  1.93s/it]

---NOT ONE---
{'上市公司名称:公司名称全称变更的公司': [{'query': "SELECT * FROM ConstantDB.SecuMain WHERE '公司名称全称变更的公司' IN (ChiName, ChiNameAbbr, EngName, EngNameAbbr, SecuAbbr, ChiSpelling)", 'result': []}, {'query': "SELECT * FROM ConstantDB.HK_SecuMain WHERE '公司名称全称变更的公司' IN (ChiName, ChiNameAbbr, EngName, EngNameAbbr, SecuAbbr, ChiSpelling)", 'result': []}, {'query': "SELECT * FROM ConstantDB.US_SecuMain WHERE '公司名称全称变更的公司' IN (SecuCode, SecuAbbr, ChiSpelling, EngName, ChiName)", 'result': []}], '代码:公司代码': [{'query': None, 'result': None}]}


 59%|█████▉    | 59/100 [01:31<01:33,  2.28s/it]

---NOT ONE---
{}


 61%|██████    | 61/100 [01:36<01:30,  2.33s/it]

---NOT ONE---
{}


 62%|██████▏   | 62/100 [01:40<01:41,  2.66s/it]

---NOT ONE---
{}


 71%|███████   | 71/100 [01:44<00:31,  1.07s/it]

---NOT ONE---
{}


 73%|███████▎  | 73/100 [01:49<00:37,  1.37s/it]

---NOT ONE---
{}


 75%|███████▌  | 75/100 [01:55<00:42,  1.68s/it]

---NOT ONE---
{}


 84%|████████▍ | 84/100 [01:59<00:15,  1.02it/s]

---NOT ONE---
{}


 85%|████████▌ | 85/100 [02:06<00:22,  1.50s/it]

---NOT ONE---
{}


 96%|█████████▌| 96/100 [02:10<00:03,  1.21it/s]

---NOT ONE---
{}


 97%|█████████▋| 97/100 [02:13<00:02,  1.03it/s]

---NOT ONE---
{}


100%|██████████| 100/100 [02:16<00:00,  1.37s/it]


# Column Embedding

## Init

In [253]:
llm_api = '1430f2573b273bebdf21c8d68c91d3d6.71llzxJFa9x2Z6ex'

client = OpenAI(api_key= llm_api, base_url="https://open.bigmodel.cn/api/paas/v4/")

def compute_emb(text: str):

    response = client.embeddings.create(
    model="embedding-3", #填写需要调用的模型编码
    input=[
        text,
    ],
    )
    res = json.loads(response.to_json())['data'][0]['embedding']

    return res

In [None]:
import json
import pandas as pd
from sentence_transformers import SentenceTransformer
import numpy as np
import ast

llm_api = '1430f2573b273bebdf21c8d68c91d3d6.71llzxJFa9x2Z6ex'

client = OpenAI(api_key= llm_api, base_url="https://open.bigmodel.cn/api/paas/v4/")

# emb_model = SentenceTransformer('moka-ai/m3e-base')

def compute_emb(text: str):

    response = client.embeddings.create(
    model="embedding-3", #填写需要调用的模型编码
    input=[
        text,
    ],
    )
    res = json.loads(response.to_json())['data'][0]['embedding']

    return res

data_dict_df = pd.read_csv('data/data_dictionary.csv')
db_description_json = json.load(open('data/database-table/database-with_description.json'))

table_info_dict = {}
for db in db_description_json:
    key = db['table_name_en']
    value = {
        'database_name_zh': db['database_name_zh'],
        'database_name_en': db['database_name_en'],
        'table_name_en': db['table_name_en'],
        'table_name_zh': db['table_name_zh'],
        'table_description' : db['table_description'],
    }
    table_info_dict[key] = value 


from tqdm import tqdm


columns_info = []

for i, row in tqdm(data_dict_df.iterrows(), total=data_dict_df.shape[0]):
    table_name_en = row['table_name']
    column_name_en = row['column_name']
    db_name = table_info_dict[table_name_en]['database_name_en']
    column_description = (row['column_description'] if isinstance(row['column_description'], str) else '')
    annotation = (row['注释'] if isinstance(row['注释'], str) else '')
    table_description = table_info_dict[table_name_en]['table_description']
    table_name_zh = table_info_dict[table_name_en]['table_name_zh']

    columns_info.append({
        'db_name': table_info_dict[table_name_en]['database_name_en'],
        'table_name_en': table_name_en,
        'column_name': column_name_en,
        'column_description': column_description,
        'table_name_zh': table_name_zh,
        'table_name_zh_emb': compute_emb(table_name_zh),
        'table_desc_emb': compute_emb(table_description),
        'col_emb': compute_emb(column_description + ' ' + annotation),
        'all_emb': compute_emb(table_name_zh + ': ' + table_description + ' ' + column_description + ' ' + annotation),
    })

##
# 保存embedding为文件
##

def convert_embeddings(obj):
    if isinstance(obj, np.ndarray):
        return obj.tolist()
    if isinstance(obj, dict):
        return {key: convert_embeddings(value) for key, value in obj.items()}
    if isinstance(obj, list):
        return [convert_embeddings(item) for item in obj]
    return obj

# Convert dictionaries with embeddings to JSON serializable format
columns_info_serializable = convert_embeddings(columns_info)

with open('data/database-table/columns_emb.json', 'w', encoding='utf-8') as f:
    json.dump(columns_info_serializable, f, ensure_ascii=False, indent=4)

##
# 读取保存的文件
##

def parse_embedding(embedding_str):
    """
    将存储在 JSON 中的嵌入字符串转换为 numpy 数组。
    假设嵌入存储为形如 "[0.1, 0.2, ...]" 的字符串。
    """
    return np.array(ast.literal_eval(embedding_str))

 25%|██▌       | 888/3489 [10:56<36:12,  1.20it/s]  

In [4]:
emb_model = SentenceTransformer('moka-ai/m3e-base')

# 读取 JSON 文件并转换回字典
def load_json_to_dict(file_path):
    with open(file_path, 'r', encoding='utf-8') as f:
        data = json.load(f)
    
    # 递归解析嵌入数据
    def convert_embeddings(obj):
        if isinstance(obj, list) and all(isinstance(x, (int, float)) for x in obj):
            return np.array(obj)
        if isinstance(obj, dict):
            return {key: convert_embeddings(value) for key, value in obj.items()}
        if isinstance(obj, list):
            return [convert_embeddings(item) for item in obj]
        return obj
    
    return convert_embeddings(data)

# 加载 JSON 文件到字典
columns_info = load_json_to_dict('data/database-table/columns_emb.json')

##
# 计算 embedding
##

def cosine_similarity(vec_a, vec_b):
    return np.dot(vec_a, vec_b) / (np.linalg.norm(vec_a) * np.linalg.norm(vec_b))

def find_top_similar_columns(word, columns_info, emb_model, top_k=20):
    """
    Find the top `top_k` most similar columns to the given word based on cosine similarity.

    Args:
    - word (str): The query word.
    - columns_info (list): list containing column details with embeddings.
    - emb_model (SentenceTransformer): The embedding model.
    - top_k (int): Number of top similar columns to return.

    Returns:
    - List of tuples (column_name, similarity_score, column_description)
    """
    # Encode the input word
    word_embedding = compute_emb(word)

    table_similarities = []
    column_similarities = []
    scores1 = []
    scores2 = []

    for col_info in columns_info:
        # Get column name and description
        column_name = col_info['column_name']
        column_description = col_info.get('column_description', '')
        db_name = col_info.get('db_name', '')
        table_name_en = col_info.get('table_name_en', '')
        table_name_zh = col_info.get('table_name_zh', '')


        # Get embeddings (ensure they exist)
        table_name_zh_emb = col_info.get('table_name_zh_emb')
        table_desc_emb = col_info.get('table_desc_emb')
        col_emb = col_info.get('col_emb')
        all_emb = col_info.get('all_emb')

        table_name_similarity = cosine_similarity(word_embedding, table_name_zh_emb) if table_name_zh_emb is not None else 0
        table_desc_similarity = cosine_similarity(word_embedding, table_desc_emb) if table_desc_emb is not None else 0
        col_similarity = cosine_similarity(word_embedding, col_emb) if col_emb is not None else 0
        all_similarity = cosine_similarity(word_embedding, all_emb) if all_emb is not None else 0

        # 四种similarity的计算方法

        table_similarity = 0.5 * table_name_similarity + 0.5 * table_desc_similarity 
        if table_name_en not in [x[2] for x in table_similarities]:
            table_similarities.append((table_similarity, db_name, table_name_en, table_name_zh))

        column_similarities.append((col_similarity, db_name, table_name_en, column_name, column_description))

        score1 = all_similarity
        scores1.append((score1, db_name, table_name_en, column_name, column_description))

        score2 = 0.3 * table_similarity + 0.7 * col_similarity
        scores2.append((score2, db_name, table_name_en, column_name, column_description))


    # Sort by similarity score
    table_similarities = sorted(table_similarities, key=lambda x: x[0], reverse=True)
    column_similarities = sorted(column_similarities, key=lambda x: x[0], reverse=True)
    scores1 = sorted(scores1, key=lambda x: x[0], reverse=True)
    scores2 = sorted(scores2, key=lambda x: x[0], reverse=True)


    # Return top `top_k` results
    return table_similarities[:top_k], column_similarities[:top_k], scores1[:top_k], scores2[:top_k]

## Test

In [223]:
question_path = os.path.join(cwd, 'answer_tmp' + os.sep + 'glm_4_plus-market_classifier-v1.0.0.json')

questions = parse_data.read_json(question_path)
# sort the questions by tid
questions = sorted(questions, key=lambda x: int(x['tid'].split('-')[-1]))

In [11]:
entities = [i['team'][0]['question'] for i in questions][2:3]

available_tables = set()

for entity in entities:
    print(f"{entity}: ")
    results = find_top_similar_columns(entity, columns_info, emb_model, top_k=10)
    print("- By Table Only:")
    for result in results[0]:
        print(f'  - {result[1]}.{result[2]} ({result[3]})')
        available_tables.add(f'{result[1]}.{result[2]}')
    print("- By Column Only:")
    for result in results[1]:
        print(f'  - {result[1]}.{result[2]}.{result[3]} ({result[4]})')
        available_tables.add(f'{result[1]}.{result[2]}')
    print("- By Total Scores 1")
    for result in results[2]:
        print(f'  - {result[1]}.{result[2]}.{result[3]} ({result[4]})')
        available_tables.add(f'{result[1]}.{result[2]}')
    print("- By Total Scores 2")
    for result in results[2]:
        print(f'  - {result[1]}.{result[2]}.{result[3]} ({result[4]})')
        available_tables.add(f'{result[1]}.{result[2]}')

    print()

中国宝安集团在2020-2021年间实施了几次分红方案？每次分红的合计派现金额（四舍五入保留两位小数，单位元）和实施公告日期（xxxx-xx-xx的格式）分别是多少？: 
- By Table Only:
  - AStockFinanceDB.LC_Dividend (公司分红)
  - PublicFundDB.MF_Dividend (公募基金分红)
  - AStockShareholderDB.LC_ESOPSummary (员工持股计划概况)
  - AStockShareholderDB.LC_ESOP (员工持股计划)
  - AStockOperationsDB.LC_RewardStat (公司管理层报酬统计)
  - AStockFinanceDB.LC_CapitalInvest (资金投向说明)
  - AStockFinanceDB.LC_IncomeStatementAll (利润分配表_新会计准则)
  - AStockFinanceDB.LC_ASharePlacement (A股配股)
  - AStockShareholderDB.LC_TransferPlan (股东增减持计划表)
  - AStockShareholderDB.LC_SHNumber (股东户数)
- By Column Only:
  - PublicFundDB.MF_Dividend.DividendTimesYTD (本年累计分红次数(次))
  - PublicFundDB.MF_Dividend.DividendSumYTD (本年累计分红总额(元))
  - PublicFundDB.MF_Dividend.DiviTimesSinceIncepion (历史累计分红次数(次))
  - AStockFinanceDB.LC_Dividend.DividendImplementDate (分红实施公告日)
  - PublicFundDB.MF_Dividend.DividendImplementDate (分红实施公告日)
  - PublicFundDB.MF_Dividend.DiviSumSinceInception (历史累计分红总额(元))
  - AStockFinanceDB.LC_Dividend.DiviObjectNew 

In [181]:
df

Unnamed: 0,表英文,表中文,表描述,数据范围,信息来源
10,AStockIndustryDB.LC_ExgIndChange,公司行业变更表,本表记录上市公司从上市至今，由于主营业务变更导致的行业变化情况，采用同一行业分类标准。,,公司公告、聚源整理
11,AStockIndustryDB.LC_IndustryValuation,行业估值指标,本表记录不同行业标准下的的衍生指标，包括市值、市盈率、市销率、市净率、股息率等指标。,2014-01-01 ~ 至今,聚源计算
12,AStockIndustryDB.LC_IndFinIndicators,行业财务指标表,本表存储行业衍生指标相关数据，反映不同行业分类标准下，各行业的成长能力、偿债能力、盈利能力和...,2014年 ~ 至今,公告披露，聚源计算
15,AStockOperationsDB.LC_SuppCustDetail,公司供应商与客户,收录A股上市公司的主要供应商、客户清单，以及交易标的、交易金额等信息。,2015年 ~ 至今,招股说明书、定报
16,AStockOperationsDB.LC_Staff,公司职工构成,从技术职称、专业、文化程度、年龄等几个方面介绍公司职工构成情况。,1999-12-31 ~ 至今,定期报告、招股说明书等
17,AStockOperationsDB.LC_RewardStat,公司管理层报酬统计,按报告期统计管理层的报酬情况，包括报酬总额、前三名董事报酬、前三名高管报酬、报酬区间统计分析等。,2001-12-31 ~ 至今,定期报告、招股说明书等
19,AStockShareholderDB.LC_MainSHListNew,主要股东名单(新),收录公司主要股东构成及持股数量比例、持股性质、股东类型、股东排行等明细资料，包括发行前和上...,1992-06-30 ~ 至今,招股说明书、上市公告书、定报、临时公告等
20,AStockShareholderDB.LC_SHNumber,股东户数,1. 反映公司全体股东、A股股东、B股东、H股东、CDR股东的持股情况及其历史变动情况等。...,1991年 ~ 至今,招股说明书、上市公告书、定报、临时公告、深交所互动易、上证e互动等
21,AStockShareholderDB.LC_Mshareholder,大股东介绍,收录上市公司及发债企业大股东的基本资料，包括直接持股和间接持股，以及持股比例、背景介绍等内容。,2004-12-31 ~ 至今,募集说明书、招股说明书、定报、临时公告等
22,AStockShareholderDB.LC_ActualController,公司实际控制人,1. 收录根据上市公司在招投说明书、定期报告、及临时公告中披露的实际控制人结构图判断的上市...,2004-12-31 ~ 至今,招股说明书、上市公告书、定报、临时公告等


In [183]:
table_fpath = 'data/database-table/database_v4.md'

# read database table
df = pd.read_table(table_fpath, sep="|", skiprows=[1], engine='python')
df = df.iloc[:, 1:-1]  # 去掉多余的边界列
df.columns = [col.strip() for col in df.columns]  # 去掉列名的空格
df = df.fillna('')

# keep the constant databases
new_df = df[:6].copy()

for index, row in df.iterrows():
    rname = row['表英文']
    rname= re.sub('\s', '', rname)
    if rname in available_tables:
        new_df = pd.concat([new_df, pd.DataFrame([row])], ignore_index=True)

# 转换成 Markdown 格式
markdown_table = new_df.to_markdown(index=False)

# 去除多余的空格和横线
markdown_table = '\n'.join(re.sub('  ', '', line) for line in markdown_table.splitlines() if line.strip())
# 去除多余的 --
markdown_table = re.sub('\|:-+\|:-+\|:-+\|:-+\|:-+\|', '|---|---|---|---|---|', markdown_table)
# remove nan
# markdown_table = re.sub('nan', '', markdown_table)

print(markdown_table)

| 表英文| 表中文| 表描述 | 数据范围| 信息来源|
|---|---|---|---|---|
| ConstantDB.HK_SecuMain| 港股证券主表| 记录港股单个证券品种的简称、中英文名、上市交易、上市状态所等基础信息。 | | |
| ConstantDB.US_SecuMain| 美股证券主表| 记录美国等境外市场单个证券品种的简称、中英文名、上市交易所、上市状态等基础信息。 | | |
| ConstantDB.SecuMain | 证券主表| 记录A股单个证券品种（股票、基金、债券）的代码、简称、中英文名、上市交易所、上市板块、上市状态等基础信息。| | |
| ConstantDB.CT_SystemConst | 系统常量表| 本表收录数据库中各种常量值的具体分类和常量名称描述。 | | |
| ConstantDB.LC_AreaCode| 国家城市代码表| 本表收录世界所有国家层面的数据信息和我国不同层级行政区域的划分信息。 | | |
| ConstantDB.QT_TradingDayNew | 交易日表(新)| 本表收录各个市场的交易日信息，包括给定日期是否是交易日，是否周、月、季、年最后一个交易日。 | | |
| AStockOperationsDB.LC_RewardStat| 公司管理层报酬统计| 按报告期统计管理层的报酬情况，包括报酬总额、前三名董事报酬、前三名高管报酬、报酬区间统计分析等。 | 2001-12-31 ~ 至今 | 定期报告、招股说明书等|
| AStockShareholderDB.LC_SHNumber | 股东户数| 1. 反映公司全体股东、A股股东、B股东、H股东、CDR股东的持股情况及其历史变动情况等。<br>2.指标计算公式：<br>\t1)户均持股比例＝((股本/股东总户数)/股本)*100%（公式中分子分母描述同一股票类型）。<br>\t2)相对上一期报告期户均持股比例变化＝本报告期户均持股比例-上一报告期户均持股比例。<br>\t3)户均持股数季度增长率＝(本季度户均持股数量/上一季度户均持股数量-1)*100%。<br>\t4)户均持股比例季度增长率=(本季度户均持股比例/上一季度户均持股比例-1)*100%。<br>\t5)户均持股数半年增长率=(本报告期户均持股数量/前推两

In [186]:
def rag_find_tables(query: str) -> str:
    """
    input: query
    output: md format table schema
    """

    available_tables = set()
    rag_results = find_top_similar_columns(query, columns_info, emb_model, top_k=10)
    for i in [0, 1, 2, 3]:
        for result in rag_results[i]:
            # get database.table
            available_tables.add(f'{result[1]}.{result[2]}')

    # read database table
    table_fpath = os.path.join(cwd, 'data', 'database-table', 'database_v4.md')
    df = pd.read_table(table_fpath, sep="|", skiprows=[1], engine='python')
    df = df.iloc[:, 1:-1]  # 去掉多余的边界列
    df = df.fillna('')
    df.columns = [col.strip() for col in df.columns]  # 去掉列名的空格

    # keep the constant databases
    new_df = df[:6].copy()

    # add the available_tables
    for index, row in df.iterrows():
        rname = row['表英文']
        rname= re.sub('\s', '', rname)
        if rname in available_tables:
            new_df = pd.concat([new_df, pd.DataFrame([row])], ignore_index=True)

    # 转换成 Markdown 格式
    markdown_table = new_df.to_markdown(index=False)

    # 去除多余的空格和横线
    markdown_table = '\n'.join(re.sub('  ', '', line) for line in markdown_table.splitlines() if line.strip())
    # 去除多余的 --
    markdown_table = re.sub('\|:-+\|:-+\|:-+\|:-+\|:-+\|', '|---|---|---|---|---|', markdown_table)
    # remove nan
    # markdown_table = re.sub('nan', '', markdown_table)

    return markdown_table

In [237]:
print(rag_find_tables('天士力在2020年的最大担保金额是多少？'))

| 表英文 | 表中文 | 表描述| 数据范围| 信息来源 |
|---|---|---|---|---|
| ConstantDB.HK_SecuMain | 港股证券主表 | 记录港股单个证券品种的简称、中英文名、上市交易、上市状态所等基础信息。| ||
| ConstantDB.US_SecuMain | 美股证券主表 | 记录美国等境外市场单个证券品种的简称、中英文名、上市交易所、上市状态等基础信息。| ||
| ConstantDB.SecuMain| 证券主表 | 记录A股单个证券品种（股票、基金、债券）的代码、简称、中英文名、上市交易所、上市板块、上市状态等基础信息。 | ||
| ConstantDB.CT_SystemConst| 系统常量表 | 本表收录数据库中各种常量值的具体分类和常量名称描述。| ||
| ConstantDB.LC_AreaCode | 国家城市代码表 | 本表收录世界所有国家层面的数据信息和我国不同层级行政区域的划分信息。| ||
| ConstantDB.QT_TradingDayNew| 交易日表(新) | 本表收录各个市场的交易日信息，包括给定日期是否是交易日，是否周、月、季、年最后一个交易日。| ||
| AStockOperationsDB.LC_RewardStat | 公司管理层报酬统计 | 按报告期统计管理层的报酬情况，包括报酬总额、前三名董事报酬、前三名高管报酬、报酬区间统计分析等。| 2001-12-31 ~ 至今 | 定期报告、招股说明书等 |
| AStockShareholderDB.LC_ShareFP | 股东股权冻结和质押 | 收录股东股权的被冻结和质押及进展情况，包括被冻结质押股东、被接受股权质押方、涉及股数以及冻结质押期限起始和截止日等内容。| 1999-09-30 ~ 至今 | 股权质押公告、股权冻结公告、解除质押冻结公告等 |
| AStockShareholderDB.LC_ShareFPSta| 股东股权冻结和质押统计 | 1. 收录股东股权的质押冻结统计数据，包括股东股权累计冻结质押股数、累计占冻结质押方持股数比例和累计占总股本比例等情况。<br>2. 指标计算公式：<br>\t1)累计占冻结质押方持股数比例=股东累计冻结质押股数(股)/股东持股数。<br>\t2)累计占总股本

# WorkFlow

## Craft Prompt

In [131]:
prompt_dir = os.path.join(cwd, 'prompt')

table_finder_fname = 'table_finder-stage_1-v4.0.0.md'
rewriter_fname = 'rewrite_query-v1.0.0.md'
sql_1_fname = f'sql_generator-stage_1-v2.0.0.md'
sql_2_fname = f'sql_generator-stage_2-v1.0.0.md'
ans_fname = f'answer_generator-v2.0.0.md'

with open(os.path.join(prompt_dir, table_finder_fname), 'r') as f:
    table_finder_prompt_template = ''.join(f.readlines())

with open(os.path.join(prompt_dir, rewriter_fname), 'r') as f:
    rewriter_prompt_template = ''.join(f.readlines())

with open(os.path.join(prompt_dir, sql_1_fname), 'r') as f:
    sql_1_prompt_template = ''.join(f.readlines())

with open(os.path.join(prompt_dir, sql_2_fname), 'r') as f:
    sql_2_prompt_template = ''.join(f.readlines())

with open(os.path.join(prompt_dir, ans_fname), 'r') as f:
    ans_prompt_template = ''.join(f.readlines())

In [243]:
# build sql prompt 

def make_prompt_table_finder(query: str, market: str, ner: dict) -> str:
    """
    ner_res: content from the stage_1
    """

    def parse_database_and_table(query: str) -> dict:
        """
        Parse the given SQL query to return the database and table names in a dictionary.
        
        Args:
        - query (str): The SQL query to parse.
        
        Returns:
        - dict: A dictionary with 'database' and 'table' keys.
        """

        pattern = r'FROM\s+([a-zA-Z0-9_]+)\.([a-zA-Z0-9_]+)'
        match = re.search(pattern, query, re.IGNORECASE)
        
        if match:
            database = match.group(1)
            table = match.group(2)
            return {'database': database, 'table': table}
        
        return {}

    prompt = table_finder_prompt_template

    prompt = prompt + query


    # replace table schema
    tables = rag_find_tables(query)
    reg_p = re.compile('<Database-Table Schema>')
    prompt = re.sub(reg_p, tables, prompt)

    # ner_result is None
    if not ner['result']:
        return prompt

    ner_content = {}
    # assume there is only one NER result
    ner_content.update(ner['result'][0])
    ner_content['market'] = market

    sql_res = ner['sql']

    cnt = 0
    for k, v in sql_res.items():
        for j in v:
            if not j['result']:
                continue
            else:
                cnt += 1

    if cnt == 1:
        for k, v in sql_res.items():
            for j in v:
                if not j['result']:
                    continue
            
                sql_query = j['query']
                # select database and table via market
                # add database and table
                # tmp_table = parse_database_and_table(sql_query)
                ner_content.update(parse_database_and_table(sql_query))
                # # get market
                # if market in tmp_table['table']: # HK or US
                ner_content['entity_information'] = j['result']
                # else:
    elif cnt == 2:
        for k, v in sql_res.items():
            for j in v:
                if not j['result']:
                    continue
                
                sql_query = j['query']
                tmp_table = parse_database_and_table(sql_query)
                if market in tmp_table['table']: # HK or US
                    ner_content['entity_information'] = j['result']
                    ner_content.update(tmp_table)
                if market == 'cn' and market not in tmp_table['table']:
                    ner_content['entity_information'] = j['result']
                    ner_content.update(tmp_table)      
    else:
        pass              

    # add NER result

    ner_str = f"\n\n### **Name Entity Recognition Result**\n```json\n{json.dumps(ner_content, ensure_ascii=False,indent=2)}\n```"

    prompt += ner_str
    
    return prompt

def make_prompt_sql(query: str, ner: dict, table_finder: dict) -> str:
    
    prompt = sql_1_prompt_template

    # Database-Table Pair(s)
    table_finder_res = table_finder['data_source']
    tables = [i['table'] for i in table_finder_res]
    try:
        del table_finder_res['question']
    except:
        pass
    table_finder_res = json.dumps(table_finder_res, ensure_ascii=False, indent=2)
    reg_p = re.compile('<Database and Table>')
    prompt = re.sub(reg_p, table_finder_res, prompt)
    
    # Table Schema(s)
    table_schema = ''
    for table in tables:
        table_fname = f'{table}-with_instance.md'
        table_dir = os.path.join(cwd, 'data' + os.sep + 'table-column')
        table_fpath = os.path.join(table_dir, table_fname)
        try:
            with open(table_fpath,'r') as f:
                table_schema += ''.join(f.readlines())
                table_schema += '\n\n'
        except:
            print(f"Can't find {table_fname}")
            table_schema = ''
    reg_p = re.compile('<Table-Column Schema>')
    prompt = re.sub(reg_p, table_schema, prompt)

    # NER Result
    if ner['result']:
        ner_res = [i for i in ner['sql'].values() if i][0]
        ner_res = [i['result'] for i in ner_res if i['result']][0][0]
        ner_res = json.dumps(ner_res, ensure_ascii=False, indent=2)
        reg_p = re.compile('<NER Result>')
        prompt = re.sub(reg_p, ner_res, prompt)
    else:
        reg_p = re.compile('\n<NER Result>\n')
        prompt = re.sub(reg_p, '', prompt)
        reg_p = re.compile('\n## NER Result\n')
        prompt = re.sub(reg_p, '', prompt)

    # replace query
    reg_p = re.compile('<Current Query>')
    prompt = re.sub(reg_p, query, prompt)

    # add shots
    prompt += sql_shots(query)

    return prompt

def make_prompt_rewrite(query: str, history: list):

    prompt = rewriter_prompt_template

    # replace query
    reg_p = re.compile('<Current Query>')
    prompt = re.sub(reg_p, query, prompt)

    # replace History
    reg_p = re.compile('<Chat History>')
    history = json.dumps(history, ensure_ascii=False, indent=2)
    prompt = re.sub(reg_p, history, prompt)

    return prompt

def make_prompt_answer(query: str, ner: dict, sql_res: list) -> str:

    prompt = ans_prompt_template

    # SQL Result
    sql_res = json.dumps(sql_res, ensure_ascii=False, indent=2)
    reg_p = re.compile('<SQL Result>')
    prompt = re.sub(reg_p, sql_res, prompt)

    # NER Result
    if ner['result']:
        ner_res = [i for i in ner['sql'].values() if i][0]
        ner_res = [i['result'] for i in ner_res if i['result']][0][0]
        ner_res = json.dumps(ner_res, ensure_ascii=False, indent=2)
        reg_p = re.compile('<NER Result>')
        prompt = re.sub(reg_p, ner_res, prompt)
    else:
        reg_p = re.compile('\n<NER Result>\n')
        prompt = re.sub(reg_p, '', prompt)
        reg_p = re.compile('\n## NER Result\n')
        prompt = re.sub(reg_p, '', prompt)

    # replace query
    reg_p = re.compile('<Current Query>')
    prompt = re.sub(reg_p, query, prompt)

    return prompt


def sql_shots(query: str) -> str:

    way_string_2 = "## **查询参考示例**\n"
    
    if "近一个月最高价" in query:
        way_string_2 += "查询近一个月最高价,你写的sql语句可以优先考虑表中已有字段HighPriceRM  近一月最高价(元)  "
    if "近一个月最低价" in query:
        way_string_2 += "查询近一月最低价(元),你写的sql语句直接调用已有字段LowPriceRM"
    if "行业" in query and ('多少只' in query or '几个' in query or '多少个' in query):
        way_string_2 += """查询某行业某年数量 示例sql语句:SELECT count(*) as 风电零部件_2021
            FROM AStockIndustryDB.LC_ExgIndustry
            where ThirdIndustryName like '%风电零部件%' and year(InfoPublDate)=2021 and IfPerformed = 1;"""
    if '年度报告' in query:
        way_string_2 += """特别重要一定注意，查询最新更新XXXX年年度报告，参考sql条件语句，
                            WHERE date(EndDate) = 'XXXX-12-31'"""

    if '新高' in query:
        way_string_2 += """新高 要用AStockMarketQuotesDB.CS_StockPatterns现有字段
        
        查询今天是2021年01月01日，创近半年新高的股票有几只示。示例sql语句:SELECT count(*)  FROM AStockMarketQuotesDB.CS_StockPatterns
                where  IfHighestHPriceRMSix=1 and date(TradingDay)='2021-01-01;
                判断某日 YY-MM-DD  InnerCode XXXXXX 是否创近一周的新高，查询结果1代表是,IfHighestHPriceRW字段可以根据情况灵活调整  SELECT   InnerCode,TradingDay,IfHighestHPriceRW  FROM AStockMarketQuotesDB.CS_StockPatterns
where  date(TradingDay)='2021-12-20' and InnerCode = '311490'
                
                """
    if '成交额' in query and '平均' in query:
        way_string_2 += """查询这家公司5日内平均成交额是多少。示例sql语句:SELECT count(*)  FROM AStockMarketQuotesDB.CS_StockPatterns
                where  IfHighestHPriceRMSix=1 and date(TradingDay)='2021-01-01"""
    if '半年度报告' in query:
        way_string_2 += """查询XXXX年半年度报告的条件为：year(EndDate) = XXXX and InfoSource='半年度报告'"""

    if '新高' in query:
        way_string_2 += """查询今天是2021年01月01日，创近半年新高的股票有几只示。示例sql语句:SELECT count(*)  FROM AStockMarketQuotesDB.CS_StockPatterns
                where  IfHighestHPriceRMSix=1 and date(TradingDay)='2021-01-01"""
    if '成交额' in query and '平均' in query:
        way_string_2 += """查询这家公司5日内平均成交额是多少。示例sql语句:SELECT count(*)  FROM AStockMarketQuotesDB.CS_StockPatterns
                where  IfHighestHPriceRMSix=1 and date(TradingDay)='2021-01-01"""
        
    if '基金' in query:
        way_string_2 += """如果需要返回基金名，参考以下 sql 语句访问特定的基金名。 SELECT DisclName FROM PublicFundDB.MF_FundProdName
                where  InnerCode=XXX"""
        
    if '调整后' in query:
        way_string_2 += """不要使用 `IfAdjusted` 词条。"""

    return way_string_2

## Process Query

In [217]:
llm_api = '1430f2573b273bebdf21c8d68c91d3d6.71llzxJFa9x2Z6ex'

client = OpenAI(api_key= deepseek_api, base_url="https://open.bigmodel.cn/api/paas/v4/")

## Tools

In [61]:
api_key = "d989596b9e61478bb368eb14e536db69"

import requests
import json

def fetch_data(data: dict):
    url = "https://comm.chatglm.cn/finglm2/api/query"
    headers = {
        "Content-Type": "application/json",
        "Authorization": f"Bearer {api_key}"
    }
    response = requests.post(url, headers=headers, json=data)

    return response.json()

In [62]:
t = {
  "sql": "SELECT sm.ChiName AS FullName, lsa.AShareAbbr AS AShareAbbreviation, lsa.LegalRepr AS LegalRepresentative, lsa.LegalConsultant AS LegalConsultant, lsa.AccountingFirm AS AccountingFirm, lsa.SecretaryBD AS BoardSecretary FROM ConstantDB.SecuMain sm JOIN AStockBasicInfoDB.LC_StockArchives lsa ON sm.CompanyCode = lsa.CompanyCode WHERE sm.SecuCode = '600872';",
  "limit": 100
}

t = fetch_data(t)
t['data']

[{'FullName': '中炬高新技术实业(集团)股份有限公司',
  'AShareAbbreviation': '中炬高新',
  'LegalRepresentative': '余健华',
  'LegalConsultant': '广东卓建(中山)律师事务所',
  'AccountingFirm': '天职国际会计师事务所（特殊普通合伙）',
  'BoardSecretary': '郭毅航'}]

## Build Agent

In [221]:
def table_finder_agent(query: str, market: str, ner: dict):

    print("====TABLE FINDER AGENT UP====")
    retry_delay = 1
    max_retries = 6
    llm_query = make_prompt_table_finder(query, market, ner)

    history = [{"role": "user", "content": llm_query}]

    retries = 0

    while retries < max_retries:
        should_retry = False
        start_time = time.time()

        # API 调用（增加重试机制）
        for attempt in range(max_retries):
            try:
                response = client.chat.completions.create(
                    # model="deepseek-chat",
                    model='glm-4-plus',
                    messages=history,
                    stream=False,
                    top_p=0.7,
                    temperature=0.9
                )
                break  # 如果成功，跳出重试循环
            except Exception as e:
                print(f"API call failed (attempt {attempt + 1}/{max_retries}): {e}")
                if attempt == max_retries - 1:
                    raise  # 如果达到最大重试次数，抛出异常
                time.sleep(retry_delay)  # 等待一段时间后重试

        end_time = time.time()    

        response = json.loads(response.to_json())
        raw_content = response['choices'][0]['message']['content']
        usage = response['usage']
        execution_time = end_time - start_time

        # 第一个 try-except：处理 JSON 解析
        try:
            content = json.loads(raw_content.strip('`json'))
        except Exception as e:
            print(f"Error processing JSON {query} (attempt {retries + 1}/{max_retries}): {e}")
            should_retry = True  # 标记需要重试

        # 如果 JSON 解析成功，继续执行后续逻辑
        if not should_retry:
            print("====TABLE FINDER AGENT COMPLETED====")
            history.append({'role': 'assistant', 'content': content})
            return history
        else:
            retries += 1
            if retries == max_retries:
                print(f"Failed to process question ==={query}=== after {max_retries} attempts.")
                print("====TABLE FINDER AGENT DOWN====")
                return []


def sql_generator_agent(query: str, ner: dict, table_finder: dict):

    print("====SQL GENERATOR AGENT UP====")
    retry_delay = 1
    max_retries = 4
    llm_query = make_prompt_sql(query, ner, table_finder)
    chat_history = [{"role": "user", "content": llm_query}]

    retries = 0

    while retries < max_retries:
        should_retry = False
        start_time = time.time()

        # API 调用（增加重试机制）
        for attempt in range(max_retries):
            try:
                response = client.chat.completions.create(
                    # model="deepseek-chat",
                    model='glm-4-plus',
                    messages=chat_history,
                    stream=False,
                    top_p=0.7,
                    temperature=0.9
                )
                break  # 如果成功，跳出重试循环
            except Exception as e:
                print(f"API call failed (attempt {attempt + 1}/{max_retries}): {e}")
                if attempt == max_retries - 1:
                    raise  # 如果达到最大重试次数，抛出异常
                time.sleep(retry_delay)  # 等待一段时间后重试

        end_time = time.time()    

        response = json.loads(response.to_json())
        raw_content = response['choices'][0]['message']['content']
        usage = response['usage']
        execution_time = end_time - start_time

        # 第一个 try-except：处理 JSON 解析
        try:
            content = json.loads(raw_content.strip('`json'))
            chat_history.append({'role': 'assistant', 'content': str(content)})
        except Exception as e:
            print(f"Error processing JSON {query} (attempt {retries + 1}/{max_retries}): {e}")
            should_retry = True  # 标记需要重试

        # 如果 JSON 解析成功，继续执行后续逻辑
        if not should_retry:
            print("====SQL GENERATOR AGENT COMPLETED====")

            print("====SQL QUERYING UP====")

            post_data = {
                'sql': content['sql_query'],
                'limit': 1000
            }

            print(content['sql_query'])

            # 第二个 try-except：处理 fetch_data
            try:
                sql_res = fetch_data(post_data)['data']
                print("====SQL QUERYING COMPLETED====")
            except Exception as e:
                print("====SQL QUERYING DOWN====")
                print(f"Error fetching SQL data {query} (attempt {retries + 1}/{max_retries}): {e}")
                should_retry = True  # 标记需要重试
        
        if should_retry:
            retries += 1
            if retries == max_retries:
                print(f"Failed to process question ==={query}=== after {max_retries} attempts.")
                print("====SQL GENERATOR AGENT DOWN====")
            continue
            
        break

    if not locals().get('sql_res', []):

        sql_res = sql_generator_reflection_agent(chat_history, ner, table_finder)

    return locals().get('sql_res', [])

def rewrite_query_agent(query: str, history: list):
    print("====REWRITER AGENT UP====")

    retry_delay = 1
    max_retries = 6
    llm_query = make_prompt_rewrite(query, history)

    start_time = time.time()

    # API 调用（增加重试机制）
    for attempt in range(max_retries):
        try:
            response = client.chat.completions.create(
                # model="deepseek-chat",
                model='glm-4-plus',
                messages=[{"role": "user", "content": llm_query}],
                stream=False,
                top_p=0.7,
                temperature=0.9
            )
            break  # 如果成功，跳出重试循环
        except Exception as e:
            print(f"API call failed (attempt {attempt + 1}/{max_retries}): {e}")
            if attempt == max_retries - 1:
                raise  # 如果达到最大重试次数，抛出异常
            time.sleep(retry_delay)  # 等待一段时间后重试

    end_time = time.time()    

    response = json.loads(response.to_json())
    raw_content = response['choices'][0]['message']['content']
    usage = response['usage']
    execution_time = end_time - start_time

    print("====REWRITER AGENT COMPLETED====")

    return raw_content

def answer_generator_agent(query: str, ner: dict, sql_res: list):
    print("====ANSWER GENERATOR UP====")

    retry_delay = 1
    max_retries = 6
    llm_query = make_prompt_answer(query, ner, sql_res)

    start_time = time.time()

    # API 调用（增加重试机制）
    for attempt in range(max_retries):
        try:
            response = client.chat.completions.create(
                # model="deepseek-chat",
                model='glm-4-plus',
                messages=[{"role": "user", "content": llm_query}],
                stream=False,
                top_p=0.7,
                temperature=0.9
            )
            break  # 如果成功，跳出重试循环
        except Exception as e:
            print(f"API call failed (attempt {attempt + 1}/{max_retries}): {e}")
            if attempt == max_retries - 1:
                raise  # 如果达到最大重试次数，抛出异常
            time.sleep(retry_delay)  # 等待一段时间后重试

    end_time = time.time()    

    response = json.loads(response.to_json())
    raw_content = response['choices'][0]['message']['content']
    usage = response['usage']
    execution_time = end_time - start_time

    print("====ANSWER GENERATOR COMPLETED====")

    return raw_content

def table_finder_reflection_agent(chat_history: list, market: str, ner: dict):
    print("====TABLE REFLECTION AGENT UP====")
    retry_delay = 1
    max_retries = 6
    query = "结果不正确。可能的原因：表找的不全面；表找错了；忽略了表之间的联系。在原本的思考路径中增加一步「反思与修正」（reflection and correct）。"

    history = copy.deepcopy(chat_history)
    history.append({"role": "user", "content": query})

    for i in history:
        i['content'] = str(i['content'])

    retries = 0

    while retries < max_retries:
        should_retry = False
        start_time = time.time()

        # API 调用（增加重试机制）
        for attempt in range(max_retries):
            try:
                response = client.chat.completions.create(
                    # model="deepseek-chat",
                    model='glm-4-plus',
                    messages=history,
                    stream=False,
                    top_p=0.7,
                    temperature=0.9
                )
                break  # 如果成功，跳出重试循环
            except Exception as e:
                print(f"API call failed (attempt {attempt + 1}/{max_retries}): {e}")
                if attempt == max_retries - 1:
                    raise  # 如果达到最大重试次数，抛出异常
                time.sleep(retry_delay)  # 等待一段时间后重试

        end_time = time.time()    

        response = json.loads(response.to_json())
        raw_content = response['choices'][0]['message']['content']
        usage = response['usage']
        execution_time = end_time - start_time

        # 第一个 try-except：处理 JSON 解析
        try:
            content = json.loads(raw_content.strip('`json'))
        except Exception as e:
            print(f"Error processing JSON {query} (attempt {retries + 1}/{max_retries}): {e}")
            should_retry = True  # 标记需要重试

        # 如果 JSON 解析成功，继续执行后续逻辑
        if not should_retry:
            print("====TABLE REFLECTION AGENT COMPLETED====")
            history.append({'role': 'assistant', 'content': content})
            return history
        else:
            retries += 1
            if retries == max_retries:
                print(f"Failed to process question ==={query}=== after {max_retries} attempts.")
                print("====TABLE REFLECTION AGENT DOWN====")
                return chat_history
            
def sql_generator_reflection_agent(chat_history: list, ner: dict, table_finder: dict):
    print("====SQL REFLECTION AGENT UP====")

    retry_delay = 1
    max_retries = 4

    query = "结果为空。在原本的思考路径中增加一步「反思与修正」（reflection and correct）。"

    history = copy.deepcopy(chat_history)
    history.append({"role": "user", "content": query})

    retries = 0

    while retries < max_retries:
        should_retry = False
        start_time = time.time()

        if len(history) % 2 == 0:
            history.append({"role": "user", "content": query})

        # API 调用（增加重试机制）
        for attempt in range(max_retries):
            try:
                response = client.chat.completions.create(
                    # model="deepseek-chat",
                    model='glm-4-plus',
                    messages=history,
                    stream=False,
                    top_p=0.7,
                    temperature=0.9
                )
                break  # 如果成功，跳出重试循环
            except Exception as e:
                print(f"API call failed (attempt {attempt + 1}/{max_retries}): {e}")
                if attempt == max_retries - 1:
                    raise  # 如果达到最大重试次数，抛出异常
                time.sleep(retry_delay)  # 等待一段时间后重试

        end_time = time.time()    

        response = json.loads(response.to_json())
        raw_content = response['choices'][0]['message']['content']
        usage = response['usage']
        execution_time = end_time - start_time

        # 第一个 try-except：处理 JSON 解析
        try:
            content = json.loads(raw_content.strip('`json'))
            history.append({'role': 'assistant', 'content': str(content)})
        except Exception as e:
            print(f"Error processing JSON {query} (attempt {retries + 1}/{max_retries}): {e}")
            should_retry = True  # 标记需要重试

        # 如果 JSON 解析成功，继续执行后续逻辑
        if not should_retry:
            print("====SQL REFLECTION AGENT COMPLETED====")

            print("====SQL QUERYING UP====")

            post_data = {
                'sql': content['sql_query'],
                'limit': 1000
            }

            print(content['sql_query'])

            # 第二个 try-except：处理 fetch_data
            try:
                sql_res = fetch_data(post_data)['data']
                print("====SQL QUERYING COMPLETED====")
                if sql_res:
                    break
            except Exception as e:
                print("====SQL QUERYING DOWN====")
                print(f"Error fetching SQL data {query} (attempt {retries + 1}/{max_retries}): {e}")
                should_retry = True  # 标记需要重试
        
        if should_retry:
            retries += 1
            if retries == max_retries:
                print(f"Failed to process question ==={query}=== after {max_retries} attempts.")
                print("====SQL GENERATOR AGENT DOWN====")
            continue
            
        break

    return locals().get('sql_res', [])

In [239]:
from zhipuai import ZhipuAI
client = ZhipuAI(api_key="1430f2573b273bebdf21c8d68c91d3d6.71llzxJFa9x2Z6ex")  # 请填写您自己的APIKey

def table_finder_agent(query: str, market: str, ner: dict):

    print("====TABLE FINDER AGENT UP====")
    retry_delay = 1
    max_retries = 6
    llm_query = make_prompt_table_finder(query, market, ner)

    history = [{"role": "user", "content": llm_query}]

    retries = 0

    while retries < max_retries:
        should_retry = False
        start_time = time.time()

        # API 调用（增加重试机制）
        for attempt in range(max_retries):
            try:
                response = client.chat.completions.create(
                    model="glm-4-plus",
                    messages=history,
                    response_format = {
                        'type': 'json_object'
                    },
                    stream=False,
                    top_p=0.7,
                    temperature=0.9,
                    max_tokens=4000
                )
                break  # 如果成功，跳出重试循环
            except Exception as e:
                print(f"API call failed (attempt {attempt + 1}/{max_retries}): {e}")
                if attempt == max_retries - 1:
                    raise  # 如果达到最大重试次数，抛出异常
                time.sleep(retry_delay)  # 等待一段时间后重试

        end_time = time.time()    

        response = json.loads(response.to_json())
        raw_content = response['choices'][0]['message']['content']
        usage = response['usage']
        execution_time = end_time - start_time

        # 第一个 try-except：处理 JSON 解析
        try:
            content = json.loads(raw_content.strip('`json'))
        except Exception as e:
            print(f"Error processing JSON {query} (attempt {retries + 1}/{max_retries}): {e}")
            should_retry = True  # 标记需要重试

        # 如果 JSON 解析成功，继续执行后续逻辑
        if not should_retry:
            print("====TABLE FINDER AGENT COMPLETED====")
            history.append({'role': 'assistant', 'content': content})
            return history
        else:
            retries += 1
            if retries == max_retries:
                print(f"Failed to process question ==={query}=== after {max_retries} attempts.")
                print("====TABLE FINDER AGENT DOWN====")
                return []


def sql_generator_agent(query: str, ner: dict, table_finder: dict):

    print("====SQL GENERATOR AGENT UP====")
    retry_delay = 1
    max_retries = 4
    llm_query = make_prompt_sql(query, ner, table_finder)
    chat_history = [{"role": "user", "content": llm_query}]

    retries = 0

    while retries < max_retries:
        should_retry = False
        start_time = time.time()

        # API 调用（增加重试机制）
        for attempt in range(max_retries):
            try:
                response = client.chat.completions.create(
                    model="glm-4-plus",
                    messages=chat_history,
                    response_format = {
                        'type': 'json_object'
                    },
                    stream=False,
                    top_p=0.7,
                    temperature=0.9,
                    max_tokens=4000
                )
                break  # 如果成功，跳出重试循环
            except Exception as e:
                print(f"API call failed (attempt {attempt + 1}/{max_retries}): {e}")
                if attempt == max_retries - 1:
                    raise  # 如果达到最大重试次数，抛出异常
                time.sleep(retry_delay)  # 等待一段时间后重试

        end_time = time.time()    

        response = json.loads(response.to_json())
        raw_content = response['choices'][0]['message']['content']
        usage = response['usage']
        execution_time = end_time - start_time

        # 第一个 try-except：处理 JSON 解析
        try:
            content = json.loads(raw_content.strip('`json'))
            chat_history.append({'role': 'assistant', 'content': str(content)})
        except Exception as e:
            print(f"Error processing JSON {query} (attempt {retries + 1}/{max_retries}): {e}")
            should_retry = True  # 标记需要重试

        # 如果 JSON 解析成功，继续执行后续逻辑
        if not should_retry:
            print("====SQL GENERATOR AGENT COMPLETED====")

            print("====SQL QUERYING UP====")

            post_data = {
                'sql': content['sql_query'],
                'limit': 1000
            }

            print(content['sql_query'])

            # 第二个 try-except：处理 fetch_data
            try:
                sql_res = fetch_data(post_data)['data']
                print("====SQL QUERYING COMPLETED====")
            except Exception as e:
                print("====SQL QUERYING DOWN====")
                print(f"Error fetching SQL data {query} (attempt {retries + 1}/{max_retries}): {e}")
                should_retry = True  # 标记需要重试
        
        if should_retry:
            retries += 1
            if retries == max_retries:
                print(f"Failed to process question ==={query}=== after {max_retries} attempts.")
                print("====SQL GENERATOR AGENT DOWN====")
            continue
            
        break

    if not locals().get('sql_res', []):

        sql_res = sql_generator_reflection_agent(chat_history, ner, table_finder)

    return locals().get('sql_res', [])

def rewrite_query_agent(query: str, history: list):
    print("====REWRITER AGENT UP====")

    retry_delay = 1
    max_retries = 6
    llm_query = make_prompt_rewrite(query, history)

    start_time = time.time()

    # API 调用（增加重试机制）
    for attempt in range(max_retries):
        try:
            response = client.chat.completions.create(
                # model="deepseek-chat",
                model='glm-4-plus',
                messages=[{"role": "user", "content": llm_query}],
                stream=False,
                top_p=0.7,
                temperature=0.9
            )
            break  # 如果成功，跳出重试循环
        except Exception as e:
            print(f"API call failed (attempt {attempt + 1}/{max_retries}): {e}")
            if attempt == max_retries - 1:
                raise  # 如果达到最大重试次数，抛出异常
            time.sleep(retry_delay)  # 等待一段时间后重试

    end_time = time.time()    

    response = json.loads(response.to_json())
    raw_content = response['choices'][0]['message']['content']
    usage = response['usage']
    execution_time = end_time - start_time

    print("====REWRITER AGENT COMPLETED====")

    return raw_content

def answer_generator_agent(query: str, ner: dict, sql_res: list):
    print("====ANSWER GENERATOR UP====")

    retry_delay = 1
    max_retries = 6
    llm_query = make_prompt_answer(query, ner, sql_res)

    start_time = time.time()

    # API 调用（增加重试机制）
    for attempt in range(max_retries):
        try:
            response = client.chat.completions.create(
                # model="deepseek-chat",
                model='glm-4-plus',
                messages=[{"role": "user", "content": llm_query}],
                stream=False,
                top_p=0.7,
                temperature=0.9
            )
            break  # 如果成功，跳出重试循环
        except Exception as e:
            print(f"API call failed (attempt {attempt + 1}/{max_retries}): {e}")
            if attempt == max_retries - 1:
                raise  # 如果达到最大重试次数，抛出异常
            time.sleep(retry_delay)  # 等待一段时间后重试

    end_time = time.time()    

    response = json.loads(response.to_json())
    raw_content = response['choices'][0]['message']['content']
    usage = response['usage']
    execution_time = end_time - start_time

    print("====ANSWER GENERATOR COMPLETED====")

    return raw_content

def table_finder_reflection_agent(chat_history: list, market: str, ner: dict):
    print("====TABLE REFLECTION AGENT UP====")
    retry_delay = 1
    max_retries = 6
    query = "结果不正确。可能的原因：表找的不全面；表找错了；忽略了表之间的联系。在原本的思考路径中增加一步「反思与修正」（reflection and correct）。"

    history = copy.deepcopy(chat_history)
    history.append({"role": "user", "content": query})

    for i in history:
        i['content'] = str(i['content'])

    retries = 0

    while retries < max_retries:
        should_retry = False
        start_time = time.time()

        # API 调用（增加重试机制）
        for attempt in range(max_retries):
            try:
                response = client.chat.completions.create(
                    model="glm-4-plus",
                    messages=history,
                    response_format = {
                        'type': 'json_object'
                    },
                    stream=False,
                    top_p=0.7,
                    temperature=0.9,
                    max_tokens=4000
                )
                break  # 如果成功，跳出重试循环
            except Exception as e:
                print(f"API call failed (attempt {attempt + 1}/{max_retries}): {e}")
                if attempt == max_retries - 1:
                    raise  # 如果达到最大重试次数，抛出异常
                time.sleep(retry_delay)  # 等待一段时间后重试

        end_time = time.time()    

        response = json.loads(response.to_json())
        raw_content = response['choices'][0]['message']['content']
        usage = response['usage']
        execution_time = end_time - start_time

        # 第一个 try-except：处理 JSON 解析
        try:
            content = json.loads(raw_content.strip('`json'))
        except Exception as e:
            print(f"Error processing JSON {query} (attempt {retries + 1}/{max_retries}): {e}")
            should_retry = True  # 标记需要重试

        # 如果 JSON 解析成功，继续执行后续逻辑
        if not should_retry:
            print("====TABLE REFLECTION AGENT COMPLETED====")
            history.append({'role': 'assistant', 'content': content})
            return history
        else:
            retries += 1
            if retries == max_retries:
                print(f"Failed to process question ==={query}=== after {max_retries} attempts.")
                print("====TABLE REFLECTION AGENT DOWN====")
                return chat_history
            
def sql_generator_reflection_agent(chat_history: list, ner: dict, table_finder: dict):
    print("====SQL REFLECTION AGENT UP====")

    retry_delay = 1
    max_retries = 4

    query = "结果为空。在原本的思考路径中增加一步「反思与修正」（reflection and correct）。"

    history = copy.deepcopy(chat_history)
    history.append({"role": "user", "content": query})

    retries = 0

    while retries < max_retries:
        should_retry = False
        start_time = time.time()

        if len(history) % 2 == 0:
            history.append({"role": "user", "content": query})

        # API 调用（增加重试机制）
        for attempt in range(max_retries):
            try:
                response = client.chat.completions.create(
                    model="glm-4-plus",
                    messages=history,
                    response_format = {
                        'type': 'json_object'
                    },
                    stream=False,
                    top_p=0.7,
                    temperature=0.9,
                    max_tokens=4000
                )
                break  # 如果成功，跳出重试循环
            except Exception as e:
                print(f"API call failed (attempt {attempt + 1}/{max_retries}): {e}")
                if attempt == max_retries - 1:
                    raise  # 如果达到最大重试次数，抛出异常
                time.sleep(retry_delay)  # 等待一段时间后重试

        end_time = time.time()    

        response = json.loads(response.to_json())
        raw_content = response['choices'][0]['message']['content']
        usage = response['usage']
        execution_time = end_time - start_time

        # 第一个 try-except：处理 JSON 解析
        try:
            content = json.loads(raw_content.strip('`json'))
            history.append({'role': 'assistant', 'content': str(content)})
        except Exception as e:
            print(f"Error processing JSON {query} (attempt {retries + 1}/{max_retries}): {e}")
            should_retry = True  # 标记需要重试

        # 如果 JSON 解析成功，继续执行后续逻辑
        if not should_retry:
            print("====SQL REFLECTION AGENT COMPLETED====")

            print("====SQL QUERYING UP====")

            post_data = {
                'sql': content['sql_query'],
                'limit': 1000
            }

            print(content['sql_query'])

            # 第二个 try-except：处理 fetch_data
            try:
                sql_res = fetch_data(post_data)['data']
                print("====SQL QUERYING COMPLETED====")
                if sql_res:
                    break
            except Exception as e:
                print("====SQL QUERYING DOWN====")
                print(f"Error fetching SQL data {query} (attempt {retries + 1}/{max_retries}): {e}")
                should_retry = True  # 标记需要重试
        
        if should_retry:
            retries += 1
            if retries == max_retries:
                print(f"Failed to process question ==={query}=== after {max_retries} attempts.")
                print("====SQL GENERATOR AGENT DOWN====")
            continue
            
        break

    return locals().get('sql_res', [])

In [240]:
def main_workflow(data: dict):
    """
    only for one question team
    """

    market = data['market']['market']
    ner = data['ner']['stage_1']
    chat_history = []
    max_retries = 3
    result = {}
    result['tid'] = data['tid']
    result['team'] = data['team']

    for i in range(len(data['team'])):

        retries = 0

        query = data['team'][i]['question']

        if i >= 1:
            query = rewrite_query_agent(query, chat_history)
            print(query)

        table_finder_history = table_finder_agent(query, market, ner)
        print(table_finder_history[-1]['content'])
        sql_res = sql_generator_agent(query, ner, table_finder_history[-1]['content'])
        print(sql_res)
        answer = answer_generator_agent(query, ner, sql_res)
        print(answer)

        if "信息不足" in answer:
            
            while retries < max_retries:
                table_finder_history = table_finder_reflection_agent(table_finder_history, market, ner)
                print(table_finder_history[-1]['content'])
                sql_res = sql_generator_agent(query, ner, table_finder_history[-1]['content'])
                print(sql_res)
                answer = answer_generator_agent(query, ner, sql_res)
                print(answer)

                retries += 1
                if "信息不足" not in answer:
                        break
            
        chat_history.append({"user": query, "assistant": answer})
        result['team'][i]['answer'] = answer
    
    return result

In [248]:
res = main_workflow(questions[39-1])
res

====TABLE FINDER AGENT UP====
====TABLE FINDER AGENT COMPLETED====
{'raw_question': '2019-05-31，当日收盘价最高的是？(以下都回答简称)', 'data_source_reasoning': [{'step_1': '提取问题的信息意图', 'restate_question_abstractly': '特定日期下，收盘价最高的股票及其简称。', 'information_intention': ['特定日期的收盘价最高的股票', '特定日期的收盘价最高的股票的简称']}, {'step_2': '定位其相关的所有表格（需求1：特定日期的收盘价排名）', 'list_all_related_tables': "查看 Database-Table Schema，A股收盘价数据与'日行情表' (AStockMarketQuotesDB.QT_DailyQuote) 表相关。该表收录A股当日行情数据，包括昨收盘、今开盘、最高价、最低价、收盘价、成交量、成交金额、成交笔数这些行情指标，因此适合用于查询特定日期的收盘价排名。", 'cot_thinking': 'QT_DailyQuote 表专注于当日数据，包含收盘价字段，适合用于筛选特定日期的股价。'}, {'step_3': '定位其相关的所有表格（需求2：A股股票简称）', 'list_all_related_tables': "查看 Database-Table Schema，A股股票简称与 '证券主表' (ConstantDB.SecuMain) 表相关。该表记录A股单个证券品种的简称、中英文名称、上市交易所、上市状态等基础信息，因此适合用于查询A股简称。"}, {'step_4': '结论', 'conclusion': '因此，要同时满足两个需求，首先需要查询 AStockMarketQuotesDB.QT_DailyQuote 表，筛选出 2019 年 5 月 31 日的A股收盘价数据，并找出收盘价最高的A股。然后，再查询 ConstantDB.SecuMain 表，获取该A股的简称。'}], 'data_source': [{'database': 'AStockMarketQuotesDB', 'table': 

KeyboardInterrupt: 

In [224]:
tmp = []

for i in [66, 78, 85, 88, 90, 97]:
    res = main_workflow(questions[i-1])
    tmp.append(res)

====TABLE FINDER AGENT UP====
====TABLE FINDER AGENT COMPLETED====
{'raw_question': '2022年之间 哪些公司进行公司名称全称变更，公司代码是什么？', 'data_source_reasoning': [{'step_1': '提取问题的信息意图', 'restate_question_abstractly': '在特定年份内，哪些公司进行了名称全称变更，以及这些公司的代码。', 'information_intention': ['特定年份内公司名称全称变更的公司', '这些公司的代码']}, {'step_2': '定位其相关的所有表格（需求1：公司名称全称变更）', 'list_all_related_tables': "查看 Database-Table Schema，与公司名称变更相关的表格是 '公司名称更改状况' (AStockBasicInfoDB.LC_NameChange) 表。该表收录公司名称历次变更情况，包括中英文名称、中英文缩写名称、更改日期等内容，因此适用于查询特定年份内公司名称全称变更的情况。", 'cot_thinking': 'LC_NameChange 表详细记录了公司名称的变更情况，包括变更日期，非常适合用于筛选特定年份内的名称变更记录。'}, {'step_3': '定位其相关的所有表格（需求2：公司代码）', 'list_all_related_tables': "查看 Database-Table Schema，公司代码信息与 '证券主表' (ConstantDB.SecuMain) 表相关。该表记录A股单个证券品种的代码、简称、中英文名、上市交易所、上市状态等基础信息，因此适合用于查询公司代码。", 'cot_thinking': '虽然 LC_NameChange 表中可能包含公司代码信息，但为了确保数据的准确性和完整性，最好结合 SecuMain 表来获取公司代码。'}, {'step_4': '结论', 'conclusion': '综上所述，要回答这个问题，首先需要查询 AStockBasicInfoDB.LC_NameChange 表，筛选出2022年期间进行名称全称变更的公司记录。然后，再查询 ConstantDB.SecuM

{'tid': 'tttt----59',
 'team': [{'id': 'tttt----59----15-3-1',
   'question': '2022年之间 哪些公司进行公司名称全称变更，公司代码是什么？',
   'answer': '2022年之间进行公司名称全称变更的公司及其公司代码如下：\n\n1. 广东宏大控股集团股份有限公司，公司代码：75253\n2. 安井食品集团股份有限公司，公司代码：187102\n\n请注意，这些信息是基于已知数据提供的，具体变更时间和详细情况可能需要进一步核实。'},
  {'id': 'tttt----59----15-3-2',
   'question': '这些公司A股证券代码分别是什么？',
   'answer': '广东宏大控股集团股份有限公司的A股证券代码是002683，安井食品集团股份有限公司的A股证券代码是603345。'},
  {'id': 'tttt----59----15-3-3',
   'question': '这些公司的证券内部编码是多少？',
   'answer': '广东宏大控股集团股份有限公司的A股证券代码是002683，安井食品集团股份有限公司的A股证券代码是603345。'}]}

In [238]:
missed = [66, 78, 85, 88, 90, 97]
unsolved = []

for q in questions:
    for i in missed:
        if str(i) == q['tid'].split('-')[-1]:
            unsolved.append(q)

unsolved

[{'tid': 'tttt----66',
  'team': [{'id': 'tttt----66----9-3-1',
    'question': '上海家化2020年年度研发投入合计是多少元？研发人员数量为多少人？（合并报表调整后的，金额保留2位小数）'},
   {'id': 'tttt----66----9-3-2', 'question': '费用化研发占比为？资本化研发占比为？'},
   {'id': 'tttt----66----9-3-3', 'question': '上一年度研发投入最高的三家公司是？（回答中文简称）'}],
  'ner': {'stage_1': {'reasoning_process_cot': "分析当前查询内容，'上海家化' 是一个上市公司名称，因为问题中提到了该公司的年度研发投入和研发人员数量，这通常是针对上市公司的财务和运营数据。查询中没有提到股票代码、基金名称、基金公司名称或行业名称，因此只需识别'上海家化'为上市公司名称。",
    'result': [{'上市公司名称': '上海家化'}],
    'sql': {'上市公司名称:上海家化': [{'query': "SELECT * FROM ConstantDB.SecuMain WHERE '上海家化' IN (ChiName, ChiNameAbbr, EngName, EngNameAbbr, SecuAbbr, ChiSpelling)",
       'result': [{'ID': 1768742157516,
         'InnerCode': 1437,
         'CompanyCode': 1303,
         'SecuCode': '600315',
         'ChiName': '上海家化联合股份有限公司',
         'ChiNameAbbr': '上海家化',
         'EngName': 'Shanghai Jahwa United Co., Ltd.',
         'EngNameAbbr': 'Shanghai Jahwa',
         'SecuAbbr': '上海家化',
         'ChiSpelling': 'SHJH'