In [5]:
#读取cosql真实数据
import json
from tqdm import tqdm

def determine_user_type(types):
    if 'INFORM_SQL' in types or 'INFER_SQL' in types:
        return 'INFORM_SQL'
    elif 'THANK_YOU' in types  or 'GOOD_BYE' in types:
        return 'IMPROPER'
    elif 'CANNOT_ANSWER' in types or 'NOT_RELATED' in types:
        return 'CANNOT_ANSWER'
    elif 'CANNOT_UNDERSTAND' in types or 'AMBIGUOUS' in types:
        return 'AMBIGUOUS'
    else:
        return types[0] if types else 'UNKNOWN'
        

def determine_nonuser_type(types):
    if 'CLARIFYL' in types:
        return 'CLARIFY'
    elif 'CONFIRM_SQL' in types:
        return 'CONFIRM_SQL'
    else:
        return types[0] if types else 'UNKNOWN'  # 如果types为空，返回'UNKNOWN'

# 读取原始JSON文件
with open('datasets/cosql_dataset/cosql_all_info_dialogs.json', 'r', encoding='utf-8') as file:
    data = json.load(file)

# 提取每一项的特定字段
extracted_data = []
for item in tqdm(data.values()):
    # 遍历turns数组
    turns = []
    temp_turn = {'isUser': False, 'text': '', 'query': '', 'result': [], 'type': ''}  # 用于合并非用户轮次的临时存储
    for turn in item.get('turns', []):
        if turn.get('isUser'):
            # 用户轮次
            user_type = determine_user_type(turn.get('label', []))  # 处理type字段
            # 如果当前轮次是用户轮次，先处理之前累积的非用户轮次
            if temp_turn['query'] or temp_turn['result'] or temp_turn['text']:
                # 确保至少有查询、结果或文本才添加
                turns.append(temp_turn)
                temp_turn = {'isUser': False, 'text': '', 'query': '', 'result': [], 'type': ''}  # 重置临时存储
            # 添加用户轮次
            turns.append({
                'isUser': True,
                'text': turn.get('text'),
                'type': user_type
            })
        else:
            # 非用户轮次
            nonuser_type = determine_nonuser_type(turn.get('label', []))  # 处理type字段
            if turn.get('isSql'):
                temp_turn['result'].append(turn.get('sql_result'))
                temp_turn['query'] = turn.get('rawSql')  # 假设每个非用户轮次只有一个查询
            elif 'label' in turn:  # 仅保存有type字段的非用户轮次
                temp_turn['text'] = turn.get('text')  # 保存isSql为false的text
                temp_turn['type'] = nonuser_type  # 保存处理后的type字段

    # 确保对话末尾的非用户轮次也被添加
    if temp_turn['query'] or temp_turn['result'] or (temp_turn['text'] and temp_turn['type']):
        turns.append(temp_turn)
    
    # 创建包含db_id和处理后的turns的新字典
    extracted_item = {
        'db_name': item.get('db_id'),
        'turns': turns
    }
    extracted_data.append(extracted_item)

# 将提取的数据保存到新的JSON文件
with open('cosql_real.json', 'w', encoding='utf-8') as new_file:
    json.dump(extracted_data, new_file, ensure_ascii=False, indent=4)


100%|███████████████████████████████████████████████████████████████████████████████████████████| 2458/2458 [00:00<00:00, 65894.56it/s]
