## 1. Setup Environment

### 1.1 Install package

In [1]:
from IPython.display import clear_output
!pip install sentence_transformers datasets
!pip install faiss-cpu
!pip install rouge-chinese 
clear_output()

### 1.2 setup chatgpt

In [None]:
import openai

openai.api_key = "YOUR_API_KEY"

# setup chatgpt api and model
def get_chat_response(prompt, top_p=0.1):
    response_data = openai.ChatCompletion.create(
        model="gpt-3.5-turbo",
        top_p=top_p,
        messages=[
            {"role": "system", "content": "You are a helpful assistant."},
            {"role": "user", "content": prompt},
        ],
    )
    return response_data["choices"][0]["message"]["content"]

## 2. Get Ground Truth

Ground truth = correct doc + chatgpt
用正确的文档+chatgpt生成的回复作为ground truth, 应该是upper bound。

优于检索回来的多个文档 + chatgpt生成的回复

首先检索回来的文档可能不是正确的，其次检索回来的文档会有噪音，可能不是排在第一，可能要带上其他噪音文档一起输入给chatgpt，会影响chatgpt生成的回复
### 2.1 load test dataset

In [3]:
import json
from tqdm import tqdm
from datasets import load_from_disk

wikicn_dataset = load_from_disk('../data/wikicn/wikipedia-cn-20230720-dataset/')
print(wikicn_dataset)
print(wikicn_dataset['test'][0])
documents = json.load(open("../data/wikicn/wikipedia-cn-20230720-documents_all.json"))
docid2doc = {doc['docid']: doc for doc in documents}
sample = wikicn_dataset['test'][0]
docid = sample['docid']
doc = docid2doc[docid]
print(doc)

### 2.2 chatgpt generate response

In [4]:
# design a prompt here
prompt = "根据以下文本，回答问题：\n"
prompt += f"问题：{sample['question']}\n"
prompt += f"文本：{doc['document']}\n"
prompt += "答案："
print(prompt)

In [None]:
# get response from chatgpt
response = get_chat_response(prompt)
print(response)

## 4. Evaluate
### 4.1 Evaluate retrieval

In [17]:
def compute_topk_accuracy(predictions, true_labels, topk=5):
    assert len(predictions) == len(true_labels), "预测结果和真实标签的数量必须相同"
    
    num_correct = 0
    for pred, true_label in zip(predictions, true_labels):
        if true_label in pred[:topk]:
            num_correct += 1
    
    top1_accuracy = num_correct / len(predictions)
    return top1_accuracy

In [19]:
from tqdm import tqdm

questions, predictions, targets = [], [], []
for s in tqdm(qa_dataset_with_retrieval):
    questions.append(s["question"])
    predictions.append(s["retrieved_docids"])
    targets.append(s["docid"])

100%|██████████| 1761612/1761612 [06:57<00:00, 4214.95it/s]


In [20]:
for k in [1, 3, 5, 10, 20, 50, 100]:
    topk_acc = compute_topk_accuracy(predictions, targets, topk=k)
    print(f"top{k}:", topk_acc)

top1: 0.7622456023233266
top3: 0.8342580545545785
top5: 0.851998623987575
top10: 0.8704555827276381
top20: 0.8856388353394504
top50: 0.9032709813511716
top100: 0.9158253917434713


### 4.2 check badcase
因为我觉得top100都无法召回的问题，可能是低质量问题

In [21]:
from random import randrange
# check qa quality
index = [randrange(len(qa_dataset_with_retrieval)) for i in range(5)]
small_dataset = qa_dataset_with_retrieval.select(index).remove_columns(["retrieved_docids"])
print(json.dumps([(i["question"], i["answer"]) for i in small_dataset], ensure_ascii=False, indent=4))

[
    [
        "释道融的师父喜欢他的什么特点？",
        "释道融的师父喜欢他的精神、风采。"
    ],
    [
        "什么是水库诱发地震的震中？",
        "水库诱发地震的震中多在库底和水库边缘。"
    ],
    [
        "《如果云知道》的专辑文案中，哪首歌曲获得了金曲奖的最佳作词奖？",
        "《如果云知道》的专辑文案中，《如果云知道》获得了金曲奖的最佳作词奖。"
    ],
    [
        "Who is Gingerbread?",
        "Gingerbread is a band founded by Hong Kong singer Leung Wing Fook under the identity of \"Ming Fuk Yi\"."
    ],
    [
        "双标紫斑蝶的分布范围是哪些地区？",
        "双标紫斑蝶广泛分布于南亚、东南亚、澳洲、新几内亚等地。台湾地区于本岛中海拔地区可见，多以特有亚种归类。"
    ]
]


In [22]:
def check_badcase(questions, predictions, true_labels, topk=5):
    assert len(predictions) == len(true_labels), "预测结果和真实标签的数量必须相同"
    
    bad_cases = []
    for ques, pred, true_label in zip(questions, predictions, true_labels):
        if true_label not in pred[:topk]:
            bad_cases.append(ques)
    return bad_cases

In [23]:
top100_bad_case = check_badcase(questions, predictions, targets, topk=100)
print(len(top100_bad_case), top100_bad_case[:10])

148283 ['林肯是如何成为律师和议员的？', '与上一张商贩单曲《This Silence Is Mine/你与SciencE》相隔多久？', '歌词网站「歌Net」于作品发行前先行揭露了哪首歌曲的歌词？', '官方网站公开了哪首歌曲的音乐录像带无声制作花絮？', '谁是埃塞俄比亚运动员德斯塔·阿斯杰多姆？', '谁是苏联跳高运动员谢尔希·谢纽科夫？', '谁是波兰政治家彼得·雅罗谢维奇？', '谁是美国高尔夫球手奇克哈伯特？', '谁是沙特阿拉伯叛教者巴巴拉·麦克林托克？', '谁是美国遗传学家约翰尼·莫蒂默？']


In [24]:
top50_bad_case = check_badcase(questions, predictions, targets, topk=50)
top50_bad_case_diff = [q for q in top50_bad_case if q not in top100_bad_case]
print(len(top50_bad_case_diff), top50_bad_case_diff[:10])

20931 ['节目的收官演唱会叫什么名字？', '如果投手因为球队的守备失误而失分的话，是否列入责失分当中？', '该电影获得了哪些奖项和提名？', 'What did Roger Taylor study at university?', '该公约的目的是什么？', '该公约的条款有哪些？', '勇者斗恶龙系列在全球出货量是多少？', '苏高利条约是否有效？', '宁德时代的发展历史是怎样的？', '宁德时代被哪个国家选为“国家名片”？']


In [25]:
top20_bad_case = check_badcase(questions, predictions, targets, topk=20)
top20_bad_case_diff = [q for q in top20_bad_case if q not in top50_bad_case]
top20_bad_case_diff_sampled = [top20_bad_case_diff[randrange(len(top20_bad_case_diff))] for i in range(10)]
print(len(top20_bad_case), len(top20_bad_case_diff), top20_bad_case_diff_sampled)

201460 30053 ['E组的比赛是在哪一天举行的？', 'What happened during the 1997-1998 civil war in the Republic of Congo?', '什么是飓风？', '松岛的陆地面积是多少？', "What is Au Kin Yee's affiliation with the Hong Kong film industry?", 'Who is the current principal of YY3?', '什么是最佳化问题？', '少女的判决是什么？', '该集的首播前，HBO在网络上播出了什么？', '2002年2月17日中国冬奥会历史上的第一枚金牌是谁获得的？']


##### 观察： top20未召回的问题，确实不少是低质量问题，也有一些是比较难的问题。也是召回系统的问题，一些实体无法正确召回

## 5. Filtering
保留top20内能召回正确文档的问题

In [26]:
qa_dataset_filtered = qa_dataset_with_retrieval.filter(lambda x: x["question"] not in top20_bad_case)
print(qa_dataset_filtered)

Filter:   0%|          | 0/1761612 [00:00<?, ? examples/s]

Dataset({
    features: ['question', 'answer', 'docid', 'retrieved_docids', 'retrieved_doc_scores'],
    num_rows: 1558310
})


过滤重复问题

In [30]:
from collections import Counter
c = Counter(qa_dataset_filtered["question"])
c.most_common(10)

[('祁汉是哪里人？', 21),
 ('好小子是什么游戏？', 19),
 ('小串成重的父亲是谁？', 18),
 ('陈横是谁的部将？', 18),
 ('常宁市有哪些文物保护单位？', 16),
 ('剑南道的治所在哪个县？', 16),
 ('陈志钊在哪个足球俱乐部效力过？', 16),
 ('李琰的女儿适谁？', 15),
 ('洋流的分类有哪些？', 15),
 ('吕壹为什么被处死？', 14)]

In [29]:
import pandas as pd

qa_df = pd.DataFrame(qa_dataset_filtered)
qa_df.drop_duplicates(subset=['question'], inplace=True)
qa_dataset_dedup = Dataset.from_pandas(qa_df).remove_columns(['__index_level_0__'])
qa_dataset_dedup

Dataset({
    features: ['question', 'answer', 'docid', 'retrieved_docids', 'retrieved_doc_scores'],
    num_rows: 1524340
})

## 6. Train and test split

In [31]:
qa_dataset_with_retrieval_output_path = "data/wikiqa/wikipedia-cn-20230720-dataset"
qa_dataset_train_test = qa_dataset_dedup.train_test_split(test_size=0.01, shuffle=True)
print(qa_dataset_train_test)
qa_dataset_train_test.save_to_disk(qa_dataset_with_retrieval_output_path)

DatasetDict({
    train: Dataset({
        features: ['question', 'answer', 'docid', 'retrieved_docids', 'retrieved_doc_scores'],
        num_rows: 1509096
    })
    test: Dataset({
        features: ['question', 'answer', 'docid', 'retrieved_docids', 'retrieved_doc_scores'],
        num_rows: 15244
    })
})


Saving the dataset (0/8 shards):   0%|          | 0/1509096 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/15244 [00:00<?, ? examples/s]

## 7. evaluate retrieval on testset

In [None]:
dataset_path = "data/"