# Init

In [13]:
import sys
import os
import json
import time
import re

from tqdm import tqdm

cwd = os.getcwd()
os.chdir(cwd)
sys.path.append('tools')

import chat
import parse_data
import embedding

In [14]:
version = 'v2.7.1'

In [39]:
system_prompt = ""

prompt_dir = os.path.join(cwd, 'prompt')
fname = f'table_finder-stage_1-{version}.md'
prompt_fpath = os.path.join(prompt_dir, fname)

with open(prompt_fpath, 'r') as f:
    prompt_template = ''.join(f.readlines())


def parse_database_and_table(query: str) -> dict:
    """
    Parse the given SQL query to return the database and table names in a dictionary.
    
    Args:
    - query (str): The SQL query to parse.
    
    Returns:
    - dict: A dictionary with 'database' and 'table' keys.
    """

    pattern = r'FROM\s+([a-zA-Z0-9_]+)\.([a-zA-Z0-9_]+)'
    match = re.search(pattern, query, re.IGNORECASE)
    
    if match:
        database = match.group(1)
        table = match.group(2)
        return {'database': database, 'table': table}
    
    return {}


def make_prompt(query: str, ner: dict) -> str:

    """
    ner_res: content from the stage_1
    """

    prompt = prompt_template + query

    # ner_result is None
    if not ner['result']:
        return prompt

    ner_content = {}
    # assume there is only one NER result
    ner_content.update(ner['result'][0])

    sql_res = ner['sql']

    for k, v in sql_res.items():
        for j in v:
            if not j['result']:
                continue

            sql_query = j['query']
            # add database and table
            ner_content.update(parse_database_and_table(sql_query))
            # get market
            if ner_content['table'] == 'US_SecuMain':
                ner_content['market'] = 'US'
            elif ner_content['table'] == 'HK_SecuMain':
                ner_content['market'] = 'HK'
            else:
                ner_content['market'] = 'CN'
            ner_content['data_from_table'] = j['result']

    # add NER result

    ner_str = f"\n\n### **Name Entity Recognition Result**\n```json\n{json.dumps(ner_content, ensure_ascii=False,indent=2)}\n```"

    prompt += ner_str
    
    return prompt

In [41]:
print(make_prompt(questions[72]['team'][0]['question'], questions[72]['ner']['stage_1']))

软通动力在2019年报酬总额和领取报酬的管理层人数是多少？

### **Name Entity Recognition Result**
```json
{
  "上市公司名称": "软通动力",
  "database": "ConstantDB",
  "table": "US_SecuMain",
  "market": "US",
  "data_from_table": [
    {
      "ID": 680237631934,
      "InnerCode": 7003343,
      "SecuCode": "ISS",
      "SecuAbbr": "软通动力",
      "ChiSpelling": "RTDL",
      "SecuCategory": 75,
      "SecuMarket": 78,
      "ListedSector": null,
      "ListedDate": "2010-12-14 12:00:00.000",
      "ListedState": 5,
      "ISIN": null,
      "CompanyCode": 102335737,
      "UpdateTime": "2022-12-30 03:54:01.817",
      "JSID": 694795441367,
      "DelistingDate": "2014-09-02 12:00:00.000",
      "InsertTime": "2021-09-07 01:36:57.043",
      "EngName": "iSoftStone Holdings Ltd. Sponsored ADR",
      "ChiName": "软通动力信息技术（集团）有限公司"
    }
  ]
}
```


In [16]:
question_path = os.path.join(cwd, 'answer_tmp' + os.sep + 'stage_1-glm_4_plus-ner-sql-HF.json')

questions = parse_data.read_json(question_path)

# GLM-4-Plus

In [4]:
model = 'glm_4_plus'

## Test

In [5]:
query = make_prompt(questions[12]['team'][0]['question'], questions[12]['ner']['stage_1'])

history = []

start_time = time.time()
message = chat.create_message(query, history=history, system_prompt=system_prompt, temperature=0.7, top_p=0.9, response_format='text')
end_time = time.time()

execution_time = end_time - start_time
usage = chat.get_token_usage(message, True)
content = chat.get_content(message, True)
history = chat.build_history(history, message=message)

{'prompt_tokens': 9364, 'completion_tokens': 544, 'total_tokens': 9908}
```json
{
    "raw_question": "今天是2020年10月27日，当日收盘价第3高的港股是？(以下都回答简称)",
    "data_source_reasoning": "现在进行逐步分析： 1. **解析输出格式要求**：要求返回公司简称。 2. **意图识别**： 问题关注的是特定日期的'收盘价第3高'的港股，并且要求'回答简称'。 3. **返回数据识别**： 收盘价、简称。 4. **关联信息**： 查询到收盘价第3高的港股后，需要获取该港股的简称。  5. **定位数据**： 查看Database-Table Schema，港股数据仅与'港股数据库' (HKStockDB) 中的'港股行情表现' (CS_HKStockPerformance) 表相关。CS_HKStockPerformance包含港股从最近一个交易日往前追溯一段时期的行情表现信息，包括收盘价，因此适合用于筛选收盘价第3高的港股。同时，还需要回答'简称'。查看Database-Table Schema，港股简称与'常量库'（ConstantDB）中的'港股证券主表'（HK_SecuMain）的强相关，和'港股数据库' (HKStockDB) 中的'港股公司概况' (HK_StockArchives)弱相关。根据给定的表描述，HK_SecuMain表记录港股单个证券品种的简称、中英文名称、上市交易所、上市状态等基础信息； 而HK_StockArchives表 收录港股上市公司的基础信息，包括名称、成立日期、注册地点、注册资本、公司业务、所属行业分类、主席、公司秘书、联系方式等信息，并没有明确提及简称。相较之下，HK_SecuMain更适合用于查询简称信息。 6. **结论**： 因此，要回答这个问题，我们需要先查询 HKStockDB 数据库中的 CS_HKStockPerformance 表，筛选出2020年10月27日的数据，并找出收盘价第3高的港股。然后，再查询 ConstantDB 数据库中的 HK_SecuMain 表，获取该港股的简称。",
    "data_source": [
        {"ques

In [6]:
t = questions[12]
t['table_finder'] = {}
t['table_finder']['stage_1']= [json.loads(content.strip('`json'))]
t['token_usage'] = {}
t['token_usage']['table_finder-stage_1'] = [usage]
t['time_usage'] = {}
t['time_usage']['table_finder-stage_1'] = [f"{execution_time:.2f}s"]
t = [t]

saved_path = os.path.join(cwd, 'answer_tmp' + os.sep + f'stage_1-{model}-table_finder-test-{version}.json')
parse_data.write_json(t, saved_path)

## ALL

In [7]:
answers = []

for question in tqdm(questions[:]):
    try:
        # the first question
        query = make_prompt(question['team'][0]['question'], question['ner']['stage_1'])

        history = []

        start_time = time.time()
        message = chat.create_message(query, history=history, system_prompt=system_prompt, temperature=0.7, top_p=0.9, response_format='text')
        end_time = time.time()

        execution_time = end_time - start_time
        usage = chat.get_token_usage(message, False)
        content = chat.get_content(message, False)

        res = question
        res['table_finder'] = {}
        res['table_finder']['stage_1']= [json.loads(content.strip('`json'))]
        res['token_usage']['table_finder-stage_1'] = [usage]
        res['time_usage']['table_finder-stage_1'] = [f"{execution_time:.2f}s"]

        answers.append(res)
    except:
        print(question['tid'])

saved_path = os.path.join(cwd, 'answer_tmp' + os.sep + f'stage_1-{model}-table_finder-{version}.json')
parse_data.write_json(answers, saved_path)

  0%|          | 0/101 [00:00<?, ?it/s]

100%|██████████| 101/101 [33:44<00:00, 20.05s/it]


# Deepseek-chat

In [8]:
model = 'deepseek_v3'

In [9]:
deepseek_api = 'sk-ba0f5eed3bea4fa6be16eb33b139c684'

## Test

In [10]:
from openai import OpenAI

query = make_prompt(questions[0]['team'][0]['question'], questions[0]['ner']['stage_1'])

client = OpenAI(api_key= deepseek_api, base_url="https://api.deepseek.com")

start_time = time.time()
response = client.chat.completions.create(
    model="deepseek-chat",
    messages=[
        {"role": "user", "content": query},
    ],
    stream=False,
    top_p=0.7,
    temperature=0.9
)
end_time = time.time()

response = json.loads(response.to_json())
content = response['choices'][0]['message']['content']

content = content.strip('`json')
usage = response['usage']
execution_time = end_time - start_time

In [11]:
t = questions[0]
t['table_finder'] = {}
t['table_finder']['stage_1']= [json.loads(content.strip('`json'))]
t['token_usage'] = {}
t['token_usage']['table_finder-stage_1'] = [usage]
t['time_usage'] = {}
t['time_usage']['table_finder-stage_1'] = [f"{execution_time:.2f}s"]
t = [t]

saved_path = os.path.join(cwd, 'answer_tmp' + os.sep + f'stage_1-{model}-table_finder-test-{version}.json')
parse_data.write_json(t, saved_path)

## ALL

In [12]:
answers = []

for question in tqdm(questions[:]):
    
    query = make_prompt(question['team'][0]['question'], question['ner']['stage_1'])

    start_time = time.time()
    response = client.chat.completions.create(
        model="deepseek-chat",
        messages=[
            {"role": "user", "content": query},
        ],
        stream=False,
        top_p=0.7,
        temperature=0.9
    )
    end_time = time.time()

    response = json.loads(response.to_json())
    content = response['choices'][0]['message']['content']
    content = content.strip('`json')
    usage = response['usage']
    execution_time = end_time - start_time

    res = question
    res['table_finder'] = {}
    res['table_finder']['stage_1']= [json.loads(content.strip('`json'))]
    res['token_usage']['table_finder-stage_1'] = [usage]
    res['time_usage']['table_finder-stage_1'] = [f"{execution_time:.2f}s"]

    answers.append(res)

saved_path = os.path.join(cwd, 'answer_tmp' + os.sep + f'stage_1-{model}-table_finder-{version}.json')
parse_data.write_json(answers, saved_path)

  0%|          | 0/101 [00:00<?, ?it/s]

100%|██████████| 101/101 [15:05<00:00,  8.96s/it]


# Compare Results

Compare the stage 1 results => find the differences => get the correct answer.

In [37]:
import os
import json
import copy

# Path to the folder containing model answer files
dir_path = os.path.join(cwd, 'answer_tmp')

# List of model names
models = ['deepseek_v3', 'glm_4_plus']

# Create a dictionary of file paths for each model's JSON file
model_files = {model: os.path.join(dir_path, f"stage_1-{model}-table_finder-{version}.json") for model in models}

# Dictionary to store the data of each model
model_data = {}

# Read the JSON data for each model
for model, file_path in model_files.items():
    with open(file_path, 'r', encoding='utf-8') as f:
        model_data[model] = json.load(f)

# sort data
for key in model_data:
    model_data[key] = sorted(model_data[key], key=lambda x: x['tid'])

# Dictionary to store the data_source for each question id across different models
data_sources = {}

# Traverse through each model's data to extract the data_source for each question id
for model in models:
    for entry in copy.deepcopy(model_data[model]):
        tmp = entry['table_finder']['stage_1'][0]['data_source']
        data_sources.setdefault(entry['tid'].replace(' ', ''), {}).update({model: tmp})

# Compare the data_source for each question id across models
for question_id, sources in data_sources.items():
    # Check if the data_source is consistent across models
    # Convert each model's data_source to a JSON string (to handle the dictionary comparison)
    results = [v for k, v in sources.items()]
    if results[0] == results[1]:
        pass
    else:
        print(f"Question ID: {question_id}")

        q_id = int(question_id.split('-')[-1]) - 1
        
        # Calculate the maximum length of model names to align the output
        max_model_length = max(len(model) for model in models)
        # Print the data_source for each model, with aligned output
        for model in models:
            # Left-align model names with the calculated maximum length
            print(f"{model.ljust(max_model_length)}:")
            print('```')
            for d in model_data[model]:
                if d['tid'] == question_id:
                    print(json.dumps(d['table_finder']['stage_1'], indent=2, ensure_ascii=False))
            print('```')
        
        print("")

Question ID: tttt----100
deepseek_v3:
```
[
  {
    "raw_question": "横店东磁在2019年1月份进行了多少次投资者关系活动？这些活动中参与机构最多的一次活动发生在哪天(xxxx-xx-xx的格式)，有多少家机构参与？",
    "data_source_reasoning": [
      {
        "step_1": "解析输出格式要求",
        "result": "问题要求回答投资者关系活动的次数、参与机构最多的一次活动的日期（格式为xxxx-xx-xx），以及参与机构的数量。"
      },
      {
        "step_2": "抽象化问题的信息意图",
        "cot_thinking": "用户的问题是询问某A股公司在特定时间段内进行的投资者关系活动的次数，以及这些活动中参与机构最多的一次活动的日期和参与机构数量。",
        "result": [
          "A股公司的投资者关系活动次数",
          "参与机构最多的一次活动的日期",
          "参与机构数量"
        ]
      },
      {
        "step_3": "定位所有的相关表格（需求1：投资者关系活动次数）",
        "cot_thinking": "查看 Database-Table Schema，A股公司的投资者关系活动数据与'投资者关系活动' (LC_InvestorRa) 表相关。该表记录各调研机构对上市公司调研的详情，包括调研日期、参与单位、调研人员、调研主要内容等信息。因此，LC_InvestorRa 表适合用于查询横店东磁在2019年1月份的投资者关系活动次数。"
      },
      {
        "step_4": "定位所有的相关表格（需求2：参与机构最多的一次活动的日期和参与机构数量）",
        "cot_thinking": "查看 Database-Table Schema，A股公司的投资者关系活动参与机构数据与'投资者关系活动调研明细' (LC_InvestorDetail) 表相关。该表记录参与上市公司调研活动的调研机构明细数据，包括

In [36]:
model_data['deepseek_v3']

[{'tid': 'tttt----1',
  'team': [{'id': 'tttt----1----1-1-1',
    'question': '600872的全称、A股简称、法人、法律顾问、会计师事务所及董秘是？'},
   {'id': 'tttt----1----1-1-2',
    'question': '该公司实控人是否发生改变？如果发生变化，什么时候变成了谁？是哪国人？是否有永久境外居留权？（回答时间用XXXX-XX-XX）'},
   {'id': 'tttt----1----1-1-3', 'question': '在实控人发生变化的当年股权发生了几次转让？'}],
  'ner': {'stage_1': {'reasoning_process_cot': "根据查询内容，'600872' 是一个股票代码，指向了一个上市公司，因此应该识别为一个代码。而'全称'、'A股简称'、'法人'、'法律顾问'、'会计师事务所'及'董秘'等词汇虽然出现在查询中，但它们并不构成独立的实体，而是与'600872'相关的属性或角色，因此不需要作为实体识别。所以，我们只需识别'600872'作为代码实体。",
    'result': [{'代码': '600872'}],
    'sql': {'代码:600872': [{'query': 'SELECT * FROM ConstantDB.SecuMain WHERE 600872 IN (InnerCode, CompanyCode, SecuCode, ISIN, JSID)',
       'result': [{'ID': 315934536696,
         'InnerCode': 2120,
         'CompanyCode': 1805,
         'SecuCode': '600872',
         'ChiName': '中炬高新技术实业(集团)股份有限公司',
         'ChiNameAbbr': '中炬高新',
         'EngName': 'Jonjee Hi-Tech Industrial And Commercial Holding Co.,Ltd',
         'EngNameAbbr': 'JONJ

In [33]:
idx = 'tttt----100'

if data_sources[idx]['deepseek_v3'] != data_sources[idx]['glm_4_plus']:
    print(data_sources[idx])

{'deepseek_v3': [{'database': 'AStockEventsDB', 'table': 'LC_InvestorRa'}, {'database': 'AStockEventsDB', 'table': 'LC_InvestorDetail'}], 'glm_4_plus': [{'database': 'AStockEventsDB', 'table': 'LC_InvestorRa'}]}
