## 1. Setup Environment

### 1.1 Install package

In [1]:
from IPython.display import clear_output
from tqdm import tqdm
!pip install sentence_transformers datasets
# if you have gpu, install faiss gpt
!pip install faiss-gpu
# else instlal faiss cpu
# !pip install faiss-cpu
clear_output()

### 1.2 Setup Retriever
here we use m3e-base as the chinese Retriever

In [1]:
from sentence_transformers import SentenceTransformer

model = SentenceTransformer('moka-ai/m3e-base')

#Our sentences we like to encode
sentences = [
    '* Moka 此文本嵌入模型由 MokaAI 训练并开源，训练脚本使用 uniem',
    '* Massive 此文本嵌入模型通过**千万级**的中文句对数据集进行训练',
    '* Mixed 此文本嵌入模型支持中英双语的同质文本相似度计算，异质文本检索等功能，未来还会支持代码检索，ALL in one'
]

#Sentences are encoded by calling model.encode()
embeddings = model.encode(sentences)

#Print the embeddings
for sentence, embedding in zip(sentences, embeddings):
    print("Sentence:", sentence)
    print("Embedding:", embedding.shape)
    print("")


Welcome to bitsandbytes. For bug reports, please run

python -m bitsandbytes

 and submit this information together with your error trace to: https://github.com/TimDettmers/bitsandbytes/issues
bin /home/howard/miniconda3/envs/torch1.13/lib/python3.10/site-packages/bitsandbytes/libbitsandbytes_cpu.so
ERROR: /home/howard/miniconda3/envs/torch1.13/bin/python: undefined symbol: cudaRuntimeGetVersion
CUDA SETUP: libcudart.so path is None
CUDA SETUP: Is seems that your cuda installation is not in your path. See https://github.com/TimDettmers/bitsandbytes/issues/85 for more information.
CUDA SETUP: CUDA version lower than 11 are currently not supported for LLM.int8(). You will be only to use 8-bit optimizers and quantization routines!!
CUDA SETUP: Highest compute capability among GPUs detected: 8.6
CUDA SETUP: Detected CUDA version 00
CUDA SETUP: Loading binary /home/howard/miniconda3/envs/torch1.13/lib/python3.10/site-packages/bitsandbytes/libbitsandbytes_cpu.so...


  warn("The installed version of bitsandbytes was compiled without GPU support. "
  warn(msg)
  warn(msg)
  warn(msg)
  warn(msg)
  warn(msg)
  warn(msg)
  warn(msg)
  warn(msg)


Sentence: * Moka 此文本嵌入模型由 MokaAI 训练并开源，训练脚本使用 uniem
Embedding: (768,)

Sentence: * Massive 此文本嵌入模型通过**千万级**的中文句对数据集进行训练
Embedding: (768,)

Sentence: * Mixed 此文本嵌入模型支持中英双语的同质文本相似度计算，异质文本检索等功能，未来还会支持代码检索，ALL in one
Embedding: (768,)



## 2. Preprocess QA data
### 2.1 clean qa data

In [2]:
import re

def replace_special_char(pattern, text, src="\n\n", tgt="\n"):
    while True:
        i = re.search(pattern, text)
        if i is None:
            break
        s, e = i.span()
        # split text into 3 parts
        q, a, b = text[:s], text[s:e], text[e:]
        a = a.replace(src, tgt)
        # merge 3 parts
        text = q + a + b
    return text

def split_text(text):
    # 处理特殊情况：Q和A分隔符也是\n\n, 替换为\n
    text = replace_special_char("Q: .*?\n\nA:", text)
    text = replace_special_char("Q[0-9]+: .*?\n\nA[0-9]+:", text)
    text = replace_special_char("[0-9]+\. .*?\n\n答：", text)
    text = replace_special_char("[0-9]+\. .*?\n\n", text)
    # 处理特殊情况：QA和QA分隔符是\n, 替换为\n\n
    text = replace_special_char('(?<!\n)\n\d+\.', text, src="\n", tgt="\n\n")
    qa_pairs = [qa for qa in text.split('\n\n') if '\n' in qa]
    return qa_pairs

def clean_qa(q, a):
    # Remove "1." at the beginning of the sentence
    cleaned_q = re.sub(r'^\d+\.\s+', '', q, flags=re.MULTILINE)
    cleaned_q = re.sub(r'^Q\d+:\s+', '', cleaned_q, flags=re.MULTILINE)
    cleaned_q = re.sub(r'^Q:\s+', '', cleaned_q, flags=re.MULTILINE)
    cleaned_q = re.sub(r'^> \d+\. ', '', cleaned_q)
    
    # Remove "答：" at the beginning of the sentence
    cleaned_a = re.sub(r'^答：', '', a, flags=re.MULTILINE)
    cleaned_a = re.sub(r'^A\d+:\s+', '', cleaned_a, flags=re.MULTILINE)
    cleaned_a = re.sub(r'^A:\s+', '', cleaned_a, flags=re.MULTILINE)
    cleaned_a = re.sub(r'^-\s+', '', cleaned_a)
    cleaned_a = re.sub(r'^回答：', '', cleaned_a)
    
    return cleaned_q.strip(), cleaned_a.strip()

In [3]:
import hashlib

def hash_to_12_length(input_string):
    # Create an MD5 hash object
    md5_hash = hashlib.md5()
    
    # Convert the input string to bytes and update the hash object
    md5_hash.update(input_string.encode('utf-8'))
    
    # Get the hexadecimal representation of the hash
    hashed_string = md5_hash.hexdigest()
    
    # Truncate the hash to 12 characters
    truncated_hash = hashed_string[:12]
    
    return truncated_hash

In [4]:
from collections import Counter

def count_ngram(text, n=2):
    # 创建 n-gram 列表
    ngrams = ["".join(text[i:i+n]) for i in range(len(text) - n + 1)]
    # 统计 n-gram 出现次数
    ngram_counter = Counter(ngrams)
    # 计算总单词数
    total_words = len(ngrams)
    # 计算每个 n-gram 的频率
    ngram_freq = {ngram: ngram_counter[ngram] / total_words for ngram in ngram_counter}
    return ngram_freq

def is_repeated_text(text, threshold=0.1, n=2):
    # 计算词组频率
    ngram_freq = count_ngram(text, n)
    # 如果 max(ngram_freq.values()) 大于阈值，则返回False
    return max(ngram_freq.values()) > threshold


In [5]:
import json
import numpy as np
from pathlib import Path
wikiqa_dir = Path("data/wikiqa")
wikiqa_dir.mkdir(exist_ok=True)
qa_input_path = "data/raw_data/wikipedia-cn-20230720-ref2qa_all.json"
qa_ouput_path = "data/wikiqa/wikipedia-cn-20230720-qapairs_all.json"
document_path = "data/wikiqa/wikipedia-cn-20230720-documents_all.json"
qa_data = json.load(open(qa_input_path))

# 1. build documents
documents = []
for line in qa_data:
    doc = line["input"].strip()
    docid = hash_to_12_length(doc)
    docid2doc = dict(docid=docid, document=doc, source=line["source"])
    documents.append(docid2doc)
json.dump(documents, open(document_path, "w"), ensure_ascii=False)

# 2. build qa pairs
qa_pairs_with_docid = []
ans_lens = []
for line in qa_data:
    qa_pairs = split_text(line["output"])
    doc = line["input"].strip()
    docid = hash_to_12_length(doc)
    for i, qa_pair in enumerate(qa_pairs):
        qa_pair = qa_pair.strip().split("\n")
        q, a = qa_pair[0], "\n".join(qa_pair[1:])
        q, a = clean_qa(q, a)
        if not q or not a:
            # 最后一个qa_pair可能是空的
            if i == len(qa_pairs) - 1:
                continue
            else:
                print(q, a, qa_pair) # make sure q and a are not empty
        ans_lens.append(len(a))
        item = dict(question=q, answer=a, docid=docid)
        qa_pairs_with_docid.append(item)
        
# 3. statistics and filter
print("before filter, qa pairs num:", len(qa_pairs_with_docid))
qa_pairs_with_docid_keep = []
qa_pairs_with_docid_drop = []
ans_len_percentile99 = np.percentile(ans_lens, 99)
ans_len_percentile1 = np.percentile(ans_lens, 1)
ans_len_percentile95 = np.percentile(ans_lens, 95)
for item in qa_pairs_with_docid:
    if (len(item["answer"]) >= ans_len_percentile1 and 
        len(item["answer"]) <= ans_len_percentile99):

        if (len(item["answer"]) > ans_len_percentile95 and 
            is_repeated_text(item["answer"], threshold=0.10)):

            qa_pairs_with_docid_drop.append(item)
        else:
            qa_pairs_with_docid_keep.append(item)
    else:
        qa_pairs_with_docid_drop.append(item)
print("after filter, qa pairs num:", len(qa_pairs_with_docid_keep))
print("drop qa pairs num ratio:", len(qa_pairs_with_docid_drop) / len(qa_pairs_with_docid))
print("max ans len in keep:", max([len(item["answer"]) for item in qa_pairs_with_docid_keep]))
print("min ans len in keep:", min([len(item["answer"]) for item in qa_pairs_with_docid_keep]))

json.dump(qa_pairs_with_docid_keep, open(qa_ouput_path, "w"), ensure_ascii=False, indent=2)


before filter, qa pairs num: 1794483
after filter, qa pairs num: 1761612
drop qa pairs num ratio: 0.018317810756635754
max ans len in keep: 283
min ans len in keep: 9


### 2.2 build faiss index for documents

In [6]:
from datasets import load_dataset, Dataset

# document_dataset = load_dataset("json", data_files=document_path)
document_dataset = Dataset.from_list(documents)
print(document_dataset)
print(json.dumps(document_dataset[0], indent=2, ensure_ascii=False))

Dataset({
    features: ['docid', 'document', 'source'],
    num_rows: 254338
})
{
  "docid": "1e13ffb6871c",
  "document": "路易斯·萨恩斯·佩尼亚·达比拉（Luis Sáenz Peña Dávila，12月4日) ，律师和阿根廷总统（1890年—1892年）。\n他从布宜诺斯艾利斯大学法律系毕业，参加1860年宪法汇编。他是全国代理和参议员之一。1882年他占有在布宜诺斯艾利斯省最高法院的一个位子。后他被雇用作为省银行的主席，法律学院的主任和在教育委员会一个成员。\n1892年任总统，1895年1月23日他对国会提出辞职并被接受，政府转到何塞·埃瓦里斯托·乌里武鲁将军手中，他1898年完成任期。",
  "source": "wikipedia.zh2307"
}


In [7]:
# if doc embedding has been processed, then load here
document_dataset.load_faiss_index('doc_embedding', 'data/raw_data/wiki_doc_embedding.faiss')
document_dataset_with_emb = document_dataset
print(document_dataset_with_emb.is_index_initialized("doc_embedding"))

True


In [8]:
# 
document_dataset_with_emb = document_dataset.map(
    lambda example: {'doc_embedding': model.encode(example["document"])}, 
    batched=True
)
# 模型index在cpu，设置device可以把index放到对应gpu，#device is the index of the GPU to use !!!但有非常慢
document_dataset_with_emb.add_faiss_index(column='doc_embedding')
document_dataset_with_emb.save_faiss_index('doc_embedding', 
                                           'data/raw_data/wiki_doc_embedding.faiss')

  0%|          | 0/255 [00:00<?, ?ba/s]

  0%|          | 0/255 [00:00<?, ?it/s]

### 2.3 encode questions

In [14]:
# qa_dataset = load_dataset("json", data_files="wikipedia-cn-20230720-qapairs_10k.json")
qa_dataset = Dataset.from_list(qa_pairs_with_docid_keep)
print(qa_dataset)
print(qa_dataset[0])
# 
qa_dataset_with_emb = qa_dataset.map(
    lambda example: {'question_embedding': model.encode(example["question"])}, 
    batched=True
)

Dataset({
    features: ['question', 'answer', 'docid'],
    num_rows: 1761612
})
{'question': '路易斯·萨恩斯·佩尼亚·达比拉是谁？', 'answer': '路易斯·萨恩斯·佩尼亚·达比拉是阿根廷总统和律师，出生于1862年。', 'docid': '1e13ffb6871c'}


Map:   0%|          | 0/1761612 [00:00<?, ? examples/s]

## 3. Build dataset
### 3.1 retrieve relevant documents

In [15]:
import numpy as np

def retrieve_topk_documents(example):
    topk = 100
    ques_embedding = np.array(example["question_embedding"], dtype=np.float32)
    scores, retrieved_examples = document_dataset_with_emb.get_nearest_examples('doc_embedding', ques_embedding, k=topk)
    example["retrieved_docids"] = retrieved_examples["docid"]
    example["retrieved_doc_scores"] = scores.tolist()
    return example

In [16]:
# 根据硬件，调整num_proc
qa_dataset_with_retrieval = qa_dataset_with_emb.map(retrieve_topk_documents, 
                                                    num_proc=64,
                                                    remove_columns=["question_embedding"])

Map (num_proc=64):   0%|          | 0/1761612 [00:00<?, ? examples/s]

## 4. Evaluate
### 4.1 Evaluate retrieval

In [17]:
def compute_topk_accuracy(predictions, true_labels, topk=5):
    assert len(predictions) == len(true_labels), "预测结果和真实标签的数量必须相同"
    
    num_correct = 0
    for pred, true_label in zip(predictions, true_labels):
        if true_label in pred[:topk]:
            num_correct += 1
    
    top1_accuracy = num_correct / len(predictions)
    return top1_accuracy

In [19]:
from tqdm import tqdm

questions, predictions, targets = [], [], []
for s in tqdm(qa_dataset_with_retrieval):
    questions.append(s["question"])
    predictions.append(s["retrieved_docids"])
    targets.append(s["docid"])

100%|██████████| 1761612/1761612 [06:57<00:00, 4214.95it/s]


In [20]:
for k in [1, 3, 5, 10, 20, 50, 100]:
    topk_acc = compute_topk_accuracy(predictions, targets, topk=k)
    print(f"top{k}:", topk_acc)

top1: 0.7622456023233266
top3: 0.8342580545545785
top5: 0.851998623987575
top10: 0.8704555827276381
top20: 0.8856388353394504
top50: 0.9032709813511716
top100: 0.9158253917434713


### 4.2 check badcase
因为我觉得top100都无法召回的问题，可能是低质量问题

In [21]:
from random import randrange
# check qa quality
index = [randrange(len(qa_dataset_with_retrieval)) for i in range(5)]
small_dataset = qa_dataset_with_retrieval.select(index).remove_columns(["retrieved_docids"])
print(json.dumps([(i["question"], i["answer"]) for i in small_dataset], ensure_ascii=False, indent=4))

[
    [
        "释道融的师父喜欢他的什么特点？",
        "释道融的师父喜欢他的精神、风采。"
    ],
    [
        "什么是水库诱发地震的震中？",
        "水库诱发地震的震中多在库底和水库边缘。"
    ],
    [
        "《如果云知道》的专辑文案中，哪首歌曲获得了金曲奖的最佳作词奖？",
        "《如果云知道》的专辑文案中，《如果云知道》获得了金曲奖的最佳作词奖。"
    ],
    [
        "Who is Gingerbread?",
        "Gingerbread is a band founded by Hong Kong singer Leung Wing Fook under the identity of \"Ming Fuk Yi\"."
    ],
    [
        "双标紫斑蝶的分布范围是哪些地区？",
        "双标紫斑蝶广泛分布于南亚、东南亚、澳洲、新几内亚等地。台湾地区于本岛中海拔地区可见，多以特有亚种归类。"
    ]
]


In [22]:
def check_badcase(questions, predictions, true_labels, topk=5):
    assert len(predictions) == len(true_labels), "预测结果和真实标签的数量必须相同"
    
    bad_cases = []
    for ques, pred, true_label in zip(questions, predictions, true_labels):
        if true_label not in pred[:topk]:
            bad_cases.append(ques)
    return bad_cases

In [23]:
top100_bad_case = check_badcase(questions, predictions, targets, topk=100)
print(len(top100_bad_case), top100_bad_case[:10])

148283 ['林肯是如何成为律师和议员的？', '与上一张商贩单曲《This Silence Is Mine/你与SciencE》相隔多久？', '歌词网站「歌Net」于作品发行前先行揭露了哪首歌曲的歌词？', '官方网站公开了哪首歌曲的音乐录像带无声制作花絮？', '谁是埃塞俄比亚运动员德斯塔·阿斯杰多姆？', '谁是苏联跳高运动员谢尔希·谢纽科夫？', '谁是波兰政治家彼得·雅罗谢维奇？', '谁是美国高尔夫球手奇克哈伯特？', '谁是沙特阿拉伯叛教者巴巴拉·麦克林托克？', '谁是美国遗传学家约翰尼·莫蒂默？']


In [24]:
top50_bad_case = check_badcase(questions, predictions, targets, topk=50)
top50_bad_case_diff = [q for q in top50_bad_case if q not in top100_bad_case]
print(len(top50_bad_case_diff), top50_bad_case_diff[:10])

20931 ['节目的收官演唱会叫什么名字？', '如果投手因为球队的守备失误而失分的话，是否列入责失分当中？', '该电影获得了哪些奖项和提名？', 'What did Roger Taylor study at university?', '该公约的目的是什么？', '该公约的条款有哪些？', '勇者斗恶龙系列在全球出货量是多少？', '苏高利条约是否有效？', '宁德时代的发展历史是怎样的？', '宁德时代被哪个国家选为“国家名片”？']


In [25]:
top20_bad_case = check_badcase(questions, predictions, targets, topk=20)
top20_bad_case_diff = [q for q in top20_bad_case if q not in top50_bad_case]
top20_bad_case_diff_sampled = [top20_bad_case_diff[randrange(len(top20_bad_case_diff))] for i in range(10)]
print(len(top20_bad_case), len(top20_bad_case_diff), top20_bad_case_diff_sampled)

201460 30053 ['E组的比赛是在哪一天举行的？', 'What happened during the 1997-1998 civil war in the Republic of Congo?', '什么是飓风？', '松岛的陆地面积是多少？', "What is Au Kin Yee's affiliation with the Hong Kong film industry?", 'Who is the current principal of YY3?', '什么是最佳化问题？', '少女的判决是什么？', '该集的首播前，HBO在网络上播出了什么？', '2002年2月17日中国冬奥会历史上的第一枚金牌是谁获得的？']


##### 观察： top20未召回的问题，确实不少是低质量问题，也有一些是比较难的问题。也是召回系统的问题，一些实体无法正确召回

## 5. Filtering
保留top20内能召回正确文档的问题

In [26]:
qa_dataset_filtered = qa_dataset_with_retrieval.filter(lambda x: x["question"] not in top20_bad_case)
print(qa_dataset_filtered)

Filter:   0%|          | 0/1761612 [00:00<?, ? examples/s]

Dataset({
    features: ['question', 'answer', 'docid', 'retrieved_docids', 'retrieved_doc_scores'],
    num_rows: 1558310
})


过滤重复问题

In [30]:
from collections import Counter
c = Counter(qa_dataset_filtered["question"])
c.most_common(10)

[('祁汉是哪里人？', 21),
 ('好小子是什么游戏？', 19),
 ('小串成重的父亲是谁？', 18),
 ('陈横是谁的部将？', 18),
 ('常宁市有哪些文物保护单位？', 16),
 ('剑南道的治所在哪个县？', 16),
 ('陈志钊在哪个足球俱乐部效力过？', 16),
 ('李琰的女儿适谁？', 15),
 ('洋流的分类有哪些？', 15),
 ('吕壹为什么被处死？', 14)]

In [29]:
import pandas as pd

qa_df = pd.DataFrame(qa_dataset_filtered)
qa_df.drop_duplicates(subset=['question'], inplace=True)
qa_dataset_dedup = Dataset.from_pandas(qa_df).remove_columns(['__index_level_0__'])
qa_dataset_dedup

Dataset({
    features: ['question', 'answer', 'docid', 'retrieved_docids', 'retrieved_doc_scores'],
    num_rows: 1524340
})

## 6. Train and test split

In [31]:
qa_dataset_with_retrieval_output_path = "data/wikiqa/wikipedia-cn-20230720-dataset"
qa_dataset_train_test = qa_dataset_dedup.train_test_split(test_size=0.01, shuffle=True)
print(qa_dataset_train_test)
qa_dataset_train_test.save_to_disk(qa_dataset_with_retrieval_output_path)

DatasetDict({
    train: Dataset({
        features: ['question', 'answer', 'docid', 'retrieved_docids', 'retrieved_doc_scores'],
        num_rows: 1509096
    })
    test: Dataset({
        features: ['question', 'answer', 'docid', 'retrieved_docids', 'retrieved_doc_scores'],
        num_rows: 15244
    })
})


Saving the dataset (0/8 shards):   0%|          | 0/1509096 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/15244 [00:00<?, ? examples/s]

## 7. evaluate retrieval on testset

In [None]:
dataset_path = "data/"