In [17]:
import json
from tqdm import tqdm

def extract_dialog_info1(json_file, output_file):
    """
    从 JSON 文件中提取对话信息，并保存到新的 JSON 文件

    Args:
        json_file: 输入 JSON 文件路径
        output_file: 输出 JSON 文件路径
    """

    with open(json_file, 'r') as f:
        data = json.load(f)

    new_data = []
    for index, dialog in enumerate(data.values()):
        new_dialog = {
            "goal_question": "originCosql",
            "evidence": "nan",
            "db_name": dialog.get("db_id", "N/A"),
            "turns": [
                {
                    "isuser": turn["isUser"],
                    "text": turn["text"],
                    "type": turn.get("label",""), 
                    "query": turn.get("rawSql","")
                }
                for index, turn in enumerate(dialog["turns"])
            ]
        }
        new_data.append(new_dialog)

    with open(output_file, 'w') as f:
        json.dump(new_data, f, indent=4)

# 示例用法
input_file = "cosql_all_info_dialogs.json"
output_file = "extracted_data.json"
extract_dialog_info1(input_file, output_file)

print("finished1")

def extract_dialog_info2(json_file, output_file):
    """
    从 JSON 文件中提取对话信息，并保存到新的 JSON 文件

    Args:
        json_file: 输入 JSON 文件路径
        output_file: 输出 JSON 文件路径
    """

    with open(json_file, 'r') as f:
        data = json.load(f)

    new_data = []
    for dialog in data:
        new_dialog = {
            "goal_question": "originCosql",
            "evidence": "nan",
            "db_name": dialog.get("db_name", "N/A"),
            "turns": []
        }

        for i, turn in enumerate(dialog["turns"]):
            if turn["isuser"]:
                value = turn.get("type", "")
                first_char = value[0] if value else ""
                new_turn = {
                    "isuser": turn["isuser"],
                    "text": turn["text"],
                    "type": first_char
                }
                new_dialog["turns"].append(new_turn)
            else:
                if i > 0 and dialog["turns"][i-1]["isuser"]:
                    if i + 1 < len(dialog["turns"]):
                        text = dialog["turns"][i+1].get("text","")
                    else:
                        text = turn["text"]
                    new_turn = {
                    "isuser": turn["isuser"],
                    "text": text,
                    "query": turn.get("query", "")
                    }
                    new_dialog["turns"].append(new_turn)
                else:
                    # 处理异常情况，例如连续多个 isUser: false
                    pass

        new_data.append(new_dialog)

    with open(output_file, 'w') as f:
        json.dump(new_data, f, indent=4)

# 示例用法
input_file = "extracted_data.json"
output_file = "extracted_data.json"
extract_dialog_info2(input_file, output_file)
print("finished2")

def filter_by_average_length(data, output_file):
    """
    筛选平均文本长度大于指定阈值的问答对

    Args:
        data: 原始 JSON 数据
        threshold: 平均文本长度阈值

    Returns:
        list: 符合条件的问答对列表
    """
    with open(data, 'r') as f:
        data = json.load(f)
    
    results = []
    
    for dialog in tqdm(data):
        system_responses = [turn['text'] for turn in dialog['turns'] if not turn['isuser']]

        # 使用 split() 计算单词数
        total_words = sum(len(text.split()) for text in system_responses)
        avg_word_count = total_words / len(system_responses) if system_responses else 0

        if avg_word_count >= 11.5:
            results.append(dialog)

    with open(output_file, 'w') as f:
        json.dump(results, f, indent=4)

    return results

input_file = "extracted_data.json"
output_file = "extracted_data.json"
filter_by_average_length(input_file, output_file)

print("finished3")

finished1
finished2


100%|█████████████████████████████████████████████████████████████████████████████| 2458/2458 [00:00<00:00, 102426.13it/s]

finished3





In [18]:
import json

def merge_json_files(file1, file2, output_file):
    """合并两个 JSON 文件。

    Args:
        file1: 第一个 JSON 文件路径。
        file2: 第二个 JSON 文件路径。
        output_file: 输出文件路径。
    """

    with open(file1, 'r') as f:
        data1 = json.load(f)

    with open(file2, 'r') as f:
        data2 = json.load(f)

    # 这里根据实际情况选择合并方式
    combined_data = data1 + data2  # 示例：合并列表

    with open(output_file, 'w') as f:
        json.dump(combined_data, f, indent=4)

# 调用函数
merge_json_files('extracted_data.json', '全部增强数据.json', 'combined_data.json')


In [24]:
import random

def filter_by_average_length(data, output_file):
    """
    筛选平均文本长度大于指定阈值的问答对

    Args:
        data: 原始 JSON 数据
        threshold: 平均文本长度阈值

    Returns:
        list: 符合条件的问答对列表
    """
    with open(data, 'r') as f:
        data = json.load(f)
    
    results = []
    
    for dialog in tqdm(data):
        system_responses = [turn['text'] for turn in dialog['turns'] if not turn.get('isuser',"")==True]
        
        # 使用 split() 计算单词数
        total_words = sum(len(text.split()) for text in system_responses)
        avg_word_count = total_words / len(system_responses) if system_responses else 0

        if avg_word_count > 15:
            results.append(dialog)
        elif avg_word_count >= 5:
            # 如果平均单词数大于等于12且小于13，则随机决定是否保留
            if random.random() < 0.5:  # 随机生成0到1之间的数，小于0.5则保留
                results.append(dialog)

    with open(output_file, 'w') as f:
        json.dump(results, f, indent=4)

    return results

input_file = "combined_data.json"
output_file = "combined_data1.json"
filter_by_average_length(input_file, output_file)

print("finished4")

100%|████████████████████████████████████████████████████████████████████████████| 11553/11553 [00:00<00:00, 65332.60it/s]


finished4


In [29]:
import json

def classify_type(old_type):
    """
    根据旧的 type 值，返回新的分类。

    Args:
        old_type: 旧的 type 值。

    Returns:
        新的 type 值。
    """

    type_mapping = {
        'INFORM_SQL': 'answerable',
        'INFER_SQL AFFIRM': 'answerable',
        'AMBIGUOUS': 'ambiguous',
        'CANNOT_ANSWER': 'unanswerable',
        'CANNOT_UNDERSTAND': 'unanswerable',
        'NOT_RELATED': 'unanswerable',
    }
    
    # print(old_type,type_mapping.get(old_type, 'improper'))
    
    return type_mapping.get(old_type, 'improper')

def filter_turns(data_file, output_file):
    """
    过滤 turns 中的字段。

    Args:
        data_file: 输入 JSON 文件路径。
        output_file: 输出 JSON 文件路径。
    """

    with open(data_file, 'r') as f:
        data = json.load(f)
        
    for dialog in data:
        for i, turn in enumerate(dialog['turns']):
            if turn['isuser']:
                # 用户提问，保留 text 和 type
                dialog['turns'][i] = {'isuser':True, 'text': turn['text'], 'type': classify_type(turn['type'])}
            else:
                # 系统回答，保留 text 和 query
                dialog['turns'][i] = {'isuser':False, 'text': turn['text'], 'query': turn.get('query', "")}

    with open(output_file, 'w') as f:
        json.dump(data, f, indent=4)


# 示例用法
filter_turns('combined_data1.json', 'combined_data2.json')
print("finished5")

finished5


In [6]:
# Count statistics for the dataset
import json

with open('datasets/MMSQL_train.json', 'r', encoding='utf-8') as file:
    data = json.load(file)

# Initialize counters
total_turns = 0
user_turns = 0
total_user_text_length = 0

# Traverse the dataset
for item in data:
    total_turns += 1
    for turn in item['turns']:
        if turn['isuser']:
            user_turns += 1
            total_user_text_length += len(turn['text'].split())

# Calculate average question length
average_question_length = total_user_text_length / user_turns if user_turns else 0

print(f"Total turns: {total_turns}")
print(f"Q&A: {user_turns}")
print(f"Q&A/Turns: {(user_turns/total_turns):.2f}")
print(f"Average question length: {average_question_length:.2f} words")


Total turns: 6493
Q&A: 38666
Q&A/Turns: 5.96
Average question length: 11.42 words
