In [1]:
import os
import json
import pandas as pd

def csv2json(data_path):
    data = pd.read_csv(data_path, encoding="utf-8")
    data = data.to_dict(orient='records')
    return data

data_paths = ["train.csv", "val.csv"]
for data_path in data_paths:
    data_save_path = ".".join(data_path.split('.')[:-1]) + ".json"
    data = csv2json(data_path)
    with open(data_save_path, 'w', encoding="utf-8") as f:
        json.dump(data, f, indent=4, ensure_ascii=False)

# 检索【维基百科信息（../index）】和【训练集例题（../index_ref）】构造数据

In [2]:
import os
import json
import pandas as pd
from pyserini.search.lucene import LuceneSearcher
from jinja2 import Template

with open("kb_search_template.j2") as f:
    kb_template = Template(f.read())

with open("ref_search_template.j2") as f:
    ref_template = Template(f.read())

def render(template, item):
    if "question" in item:
        return template.render(
            question=item['question'],
            options=[option[2:] for option in item["options"].split("\n")], # 移除开头的字母和空格
        )
    else:
        # 大小写问题
        return template.render(
            question=item['Question'],
            options=[option[2:] for option in item["Options"].split("\n")], # 移除开头的字母和空格
        )


def select_kb(data_path, kb_path, kb_num, ref_path, ref_num, data_save_path):
    # Load the data
    with open(data_path, 'r') as f:
        if data_path.endswith('.json'):
            data = json.load(f)
        else:
            raise ValueError('Unsupported data format')
        # 这里应该不用再考虑 CSV 了，而且这一段有问题，应该不存在 explaination
        # elif data_path.endswith('.csv'):
        #     data = pd.read_csv(f)
        #     data_json = []
        #     for i, row in data.iterrows():
        #         data_json.append({
        #             'question': row['Question'],
        #             'options': row['Options'],
        #             'answer': row['Answer'],
        #             'explanation': row['Explanation']
        #         })
        #     data = data_json
    # Load the indexes
    # kb_searcher = LuceneSearcher(kb_path)
    # kb_searcher.set_language('zh')
    ref_searcher = LuceneSearcher(ref_path)
    ref_searcher.set_language('zh')

    # Select the kb
    new_data = []
    for i, item in enumerate(data):
        # 搜 kb
        # kb_query = render(kb_template, item)
        # kb_hits = kb_searcher.search(kb_query, k=kb_num)

        # 搜 ref
        ref_query = render(ref_template, item)
        ref_hits = ref_searcher.search(ref_query, k=ref_num)

        new_item = item.copy()
        # new_item['kb'] = [hit_item.lucene_document.get("raw") for hit_item in kb_hits]
        new_item['ref'] = [hit_item.lucene_document.get("raw") for hit_item in ref_hits]

        # 这个地方应该是不需要 if 的，不然有可能导致漏题
        # if len(hits) > 0:
        new_data.append(new_item)

        if i % 1000 == 0:
            with open(data_save_path, 'w') as f:
                json.dump(new_data, f, ensure_ascii=False, indent=4)
            print(f'Processing {i}...')

    # Save the new data
    with open(data_save_path, 'w') as f:
        json.dump(new_data, f, ensure_ascii=False, indent=4)

data_path = "val_kb4.json"
kb_index = "../index"
ref_index = "../index_ref"
kb_num = 4
ref_num = 8
data_save_path = ".".join(data_path.split('.')[:-1]) + f"_ref{ref_num}.json"
select_kb(data_path, kb_index, kb_num, ref_index, ref_num, data_save_path)

Processing 0...
Processing 1000...
Processing 2000...
Processing 3000...
Processing 4000...
Processing 5000...
Processing 6000...
