## 1. Setup Environment

### 1.1 Install package

In [1]:
from IPython.display import clear_output
from tqdm import tqdm
!pip install sentence_transformers datasets
# if you have gpu, install faiss gpt
!pip install faiss-gpu
# else instlal faiss cpu
# !pip install faiss-cpu
clear_output()

### 1.2 Setup Retriever
here we use m3e-base as the chinese Retriever

In [24]:
from sentence_transformers import SentenceTransformer

model = SentenceTransformer('moka-ai/m3e-base')

#Our sentences we like to encode
sentences = [
    '* Moka 此文本嵌入模型由 MokaAI 训练并开源，训练脚本使用 uniem',
    '* Massive 此文本嵌入模型通过**千万级**的中文句对数据集进行训练',
    '* Mixed 此文本嵌入模型支持中英双语的同质文本相似度计算，异质文本检索等功能，未来还会支持代码检索，ALL in one'
]

#Sentences are encoded by calling model.encode()
embeddings = model.encode(sentences)

#Print the embeddings
for sentence, embedding in zip(sentences, embeddings):
    print("Sentence:", sentence)
    print("Embedding:", embedding.shape)
    print("")


Welcome to bitsandbytes. For bug reports, please run

python -m bitsandbytes

 and submit this information together with your error trace to: https://github.com/TimDettmers/bitsandbytes/issues
bin /home/howard/miniconda3/envs/torch1.13/lib/python3.10/site-packages/bitsandbytes/libbitsandbytes_cpu.so
ERROR: /home/howard/miniconda3/envs/torch1.13/bin/python: undefined symbol: cudaRuntimeGetVersion
CUDA SETUP: libcudart.so path is None
CUDA SETUP: Is seems that your cuda installation is not in your path. See https://github.com/TimDettmers/bitsandbytes/issues/85 for more information.
CUDA SETUP: CUDA version lower than 11 are currently not supported for LLM.int8(). You will be only to use 8-bit optimizers and quantization routines!!
CUDA SETUP: Highest compute capability among GPUs detected: 8.6
CUDA SETUP: Detected CUDA version 00
CUDA SETUP: Loading binary /home/howard/miniconda3/envs/torch1.13/lib/python3.10/site-packages/bitsandbytes/libbitsandbytes_cpu.so...


  warn("The installed version of bitsandbytes was compiled without GPU support. "
  warn(msg)
  warn(msg)
  warn(msg)
  warn(msg)
  warn(msg)
  warn(msg)
  warn(msg)
  warn(msg)


Sentence: * Moka 此文本嵌入模型由 MokaAI 训练并开源，训练脚本使用 uniem
Embedding: (768,)

Sentence: * Massive 此文本嵌入模型通过**千万级**的中文句对数据集进行训练
Embedding: (768,)

Sentence: * Mixed 此文本嵌入模型支持中英双语的同质文本相似度计算，异质文本检索等功能，未来还会支持代码检索，ALL in one
Embedding: (768,)



## 2. Preprocess QA data
### 2.1 clean qa data

In [3]:
import re

def replace_special_char(pattern, text, src="\n\n", tgt="\n"):
    while True:
        i = re.search(pattern, text)
        if i is None:
            break
        s, e = i.span()
        # split text into 3 parts
        q, a, b = text[:s], text[s:e], text[e:]
        a = a.replace(src, tgt)
        # merge 3 parts
        text = q + a + b
    return text

def split_text(text):
    # 处理特殊情况：Q和A分隔符也是\n\n, 替换为\n
    text = replace_special_char("Q: .*?\n\nA:", text)
    text = replace_special_char("Q[0-9]+: .*?\n\nA[0-9]+:", text)
    text = replace_special_char("[0-9]+\. .*?\n\n答：", text)
    text = replace_special_char("[0-9]+\. .*?\n\n", text)
    # 处理特殊情况：QA和QA分隔符是\n, 替换为\n\n
    text = replace_special_char('(?<!\n)\n\d+\.', text, src="\n", tgt="\n\n")
    qa_pairs = [qa for qa in text.split('\n\n') if '\n' in qa]
    return qa_pairs

def clean_qa(q, a):
    # Remove "1." at the beginning of the sentence
    cleaned_q = re.sub(r'^\d+\.\s+', '', q, flags=re.MULTILINE)
    cleaned_q = re.sub(r'^Q\d+:\s+', '', cleaned_q, flags=re.MULTILINE)
    cleaned_q = re.sub(r'^Q:\s+', '', cleaned_q, flags=re.MULTILINE)
    cleaned_q = re.sub(r'^> \d+\. ', '', cleaned_q)
    
    # Remove "答：" at the beginning of the sentence
    cleaned_a = re.sub(r'^答：', '', a, flags=re.MULTILINE)
    cleaned_a = re.sub(r'^A\d+:\s+', '', cleaned_a, flags=re.MULTILINE)
    cleaned_a = re.sub(r'^A:\s+', '', cleaned_a, flags=re.MULTILINE)
    cleaned_a = re.sub(r'^-\s+', '', cleaned_a)
    cleaned_a = re.sub(r'^回答：', '', cleaned_a)
    
    return cleaned_q.strip(), cleaned_a.strip()

In [6]:
import hashlib

def hash_to_12_length(input_string):
    # Create an MD5 hash object
    md5_hash = hashlib.md5()
    
    # Convert the input string to bytes and update the hash object
    md5_hash.update(input_string.encode('utf-8'))
    
    # Get the hexadecimal representation of the hash
    hashed_string = md5_hash.hexdigest()
    
    # Truncate the hash to 12 characters
    truncated_hash = hashed_string[:12]
    
    return truncated_hash

In [7]:
from collections import Counter

def count_ngram(text, n=2):
    # 创建 n-gram 列表
    ngrams = ["".join(text[i:i+n]) for i in range(len(text) - n + 1)]
    # 统计 n-gram 出现次数
    ngram_counter = Counter(ngrams)
    # 计算总单词数
    total_words = len(ngrams)
    # 计算每个 n-gram 的频率
    ngram_freq = {ngram: ngram_counter[ngram] / total_words for ngram in ngram_counter}
    return ngram_freq

def is_repeated_text(text, threshold=0.1, n=2):
    # 计算词组频率
    ngram_freq = count_ngram(text, n)
    # 如果 max(ngram_freq.values()) 大于阈值，则返回False
    return max(ngram_freq.values()) > threshold


In [10]:
import json
import numpy as np
from pathlib import Path
dureader_dir = Path("data/dureader")
dureader_dir.mkdir(exist_ok=True)
qa_input_path = "data/raw_data/train_dureader_cleaned.jsonl"
qa_ouput_path = dureader_dir / "dureader-qapairs_all.json"
document_path = dureader_dir / "dureader-documents_all.json"

# 1. build documents
docid2doc = {}
for line in open(qa_input_path):
    doc = json.loads(line)["materials"]
    docid = hash_to_12_length(doc)
    docid2doc[docid] = doc
documents = []
for docid, doc in docid2doc.items():
    item = dict(docid=docid, document=doc, source="dureader")
    documents.append(item)
json.dump(documents, open(document_path, "w"), ensure_ascii=False, indent=2)


In [22]:

# 2. build qa pairs
qa_pairs_with_docid = []
ans_lens = []
for line in open(qa_input_path):
    j = json.loads(line)
    qa_pairs = j["conversation"]
    doc = j["materials"].strip()
    docid = hash_to_12_length(doc)
    for i, qa_pair in enumerate(qa_pairs):
        q = qa_pair.get("QUES") if "QUES" in qa_pair else qa_pair.get("QUS")
        a = qa_pair.get("ANS") if "ANS" in qa_pair else qa_pair.get("AAS")
        if q and a:
            ans_lens.append(len(a))
            item = dict(question=q, answer=a, docid=docid)
            qa_pairs_with_docid.append(item)
        
# 3. statistics and filter
print("before filter, qa pairs num:", len(qa_pairs_with_docid))
qa_pairs_with_docid_keep = []
qa_pairs_with_docid_drop = []
ans_len_percentile99 = np.percentile(ans_lens, 99)
ans_len_percentile1 = np.percentile(ans_lens, 1)
ans_len_percentile95 = np.percentile(ans_lens, 95)
for item in qa_pairs_with_docid:
    if (len(item["answer"]) >= ans_len_percentile1 and 
        len(item["answer"]) <= ans_len_percentile99):

        if (len(item["answer"]) > ans_len_percentile95 and 
            is_repeated_text(item["answer"], threshold=0.10)):

            qa_pairs_with_docid_drop.append(item)
        else:
            qa_pairs_with_docid_keep.append(item)
    else:
        qa_pairs_with_docid_drop.append(item)
print("after filter, qa pairs num:", len(qa_pairs_with_docid_keep))
print("drop qa pairs num ratio:", len(qa_pairs_with_docid_drop) / len(qa_pairs_with_docid))
print("max ans len in keep:", max([len(item["answer"]) for item in qa_pairs_with_docid_keep]))
print("min ans len in keep:", min([len(item["answer"]) for item in qa_pairs_with_docid_keep]))

json.dump(qa_pairs_with_docid_keep, open(qa_ouput_path, "w"), ensure_ascii=False, indent=2)


before filter, qa pairs num: 162661
after filter, qa pairs num: 159893
drop qa pairs num ratio: 0.017016986247471735
max ans len in keep: 141
min ans len in keep: 10


### 2.2 build faiss index for documents

In [23]:
from datasets import load_dataset, Dataset

# document_dataset = load_dataset("json", data_files=document_path)
document_dataset = Dataset.from_list(documents)
print(document_dataset)
print(json.dumps(document_dataset[0], indent=2, ensure_ascii=False))

Dataset({
    features: ['docid', 'document', 'source'],
    num_rows: 11454
})
{
  "docid": "538aa9eccd44",
  "document": "选择燃气热水器时，一定要关注这几个问题：1、出水稳定性要好，不能出现忽热忽冷的现象2、快速到达设定的需求水温3、操作要智能、方便4、安全性要好，要装有安全报警装置 市场上燃气热水器品牌众多，购买时还需多加对比和仔细鉴别。方太今年主打的磁化恒温热水器在使用体验方面做了全面升级：9秒速热，可快速进入洗浴模式；水温持久稳定，不会出现忽热忽冷的现象，并通过水量伺服技术将出水温度精确控制在±0.5℃，可满足家里宝贝敏感肌肤洗护需求；配备CO和CH4双气体报警装置更安全（市场上一般多为CO单气体报警）。另外，这款热水器还有智能WIFI互联功能，只需下载个手机APP即可用手机远程操作热水器，实现精准调节水温，满足家人多样化的洗浴需求。当然方太的磁化恒温系列主要的是增加磁化功能，可以有效吸附水中的铁锈、铁屑等微小杂质，防止细菌滋生，使沐浴水质更洁净，长期使用磁化水沐浴更利于身体健康。",
  "source": "dureader"
}


In [26]:
# 
document_dataset_with_emb = document_dataset.map(
    lambda example: {'doc_embedding': model.encode(example["document"])}, 
    batched=True
)
# 模型index在cpu，设置device可以把index放到对应gpu，#device is the index of the GPU to use !!!但有非常慢
document_dataset_with_emb.add_faiss_index(column='doc_embedding')
document_dataset_with_emb.save_faiss_index('doc_embedding', 'data/raw_data/dureader_doc_embedding.faiss')

Map:   0%|          | 0/11454 [00:00<?, ? examples/s]

  0%|          | 0/12 [00:00<?, ?it/s]

### 2.3 encode questions

In [27]:
# qa_dataset = load_dataset("json", data_files="wikipedia-cn-20230720-qapairs_10k.json")
qa_dataset = Dataset.from_list(qa_pairs_with_docid_keep)
print(qa_dataset)
print(qa_dataset[0])
# 
qa_dataset_with_emb = qa_dataset.map(
    lambda example: {'question_embedding': model.encode(example["question"])}, 
    batched=True
)

Dataset({
    features: ['question', 'answer', 'docid'],
    num_rows: 159893
})
{'question': '选择燃气热水器时，需要关注哪些问题？', 'answer': '选购燃气热水器时，需要关注以下几个问题：1、出水稳定性好，不能出现忽热忽冷的现象。2、快速到达设定的需求水温。3、操作智能方便。4、安全性要好，要装有安全报警装置。', 'docid': '538aa9eccd44'}


Map:   0%|          | 0/159893 [00:00<?, ? examples/s]

## 3. Build dataset
### 3.1 retrieve relevant documents

In [28]:
import numpy as np

def retrieve_topk_documents(example):
    topk = 100
    ques_embedding = np.array(example["question_embedding"], dtype=np.float32)
    scores, retrieved_examples = document_dataset_with_emb.get_nearest_examples('doc_embedding', ques_embedding, k=topk)
    example["retrieved_docids"] = retrieved_examples["docid"]
    example["retrieved_doc_scores"] = scores.tolist()
    return example

In [29]:
# 根据硬件，调整num_proc
qa_dataset_with_retrieval = qa_dataset_with_emb.map(retrieve_topk_documents, 
                                                    num_proc=16,
                                                    remove_columns=["question_embedding"])

Map (num_proc=16):   0%|          | 0/159893 [00:00<?, ? examples/s]

## 4. Evaluate
### 4.1 Evaluate retrieval

In [30]:
def compute_topk_accuracy(predictions, true_labels, topk=5):
    assert len(predictions) == len(true_labels), "预测结果和真实标签的数量必须相同"
    
    num_correct = 0
    for pred, true_label in zip(predictions, true_labels):
        if true_label in pred[:topk]:
            num_correct += 1
    
    top1_accuracy = num_correct / len(predictions)
    return top1_accuracy

In [33]:
from tqdm import tqdm

questions, predictions, targets = [], [], []
for s in tqdm(qa_dataset_with_retrieval):
    questions.append(s["question"])
    predictions.append(s["retrieved_docids"])
    targets.append(s["docid"])

100%|██████████| 159893/159893 [00:29<00:00, 5388.76it/s]


In [34]:
for k in [1, 3, 5, 10, 20, 50, 100]:
    topk_acc = compute_topk_accuracy(predictions, targets, topk=k)
    print(f"top{k}:", topk_acc)

top1: 0.7172484098741033
top3: 0.854408885942474
top5: 0.8917088302802499
top10: 0.9261193423101699
top20: 0.9500103194010995
top50: 0.9712933023959773
top100: 0.9820504962693801


### 4.2 check badcase
因为我觉得top100都无法召回的问题，可能是低质量问题

In [35]:
from random import randrange
# check qa quality
index = [randrange(len(qa_dataset_with_retrieval)) for i in range(5)]
small_dataset = qa_dataset_with_retrieval.select(index).remove_columns(["retrieved_docids"])
print(json.dumps([(i["question"], i["answer"]) for i in small_dataset], ensure_ascii=False, indent=4))

[
    [
        "志高的服务电话是多少？",
        "志高的服务电话是0571-8513-5103。"
    ],
    [
        "排球场地有哪些规格类型可选？",
        "排球场地可选择全塑(QS)型或混合(HH)型，颜色可以是铁红、草绿或根据用户需求定制，厚度一般为7-10mm或按用户要求定制。"
    ],
    [
        "什么是怀孕最显著也是最早的信号？",
        "月经停止是怀孕最显著也是最早的一个信号，如果在无避孕措施下进行了性生活而出现月经停止的话，很可能就是怀孕了。"
    ],
    [
        "无创DNA检测的价格大概在多少范围内？",
        "无创DNA检测的价格大约在2000到2800之间。"
    ],
    [
        "《择天记》的首播时间是什么时候？",
        "《择天记》于2017年4月17日登陆湖南卫视“青春进行时”时段开始播出。"
    ]
]


In [36]:
def check_badcase(questions, predictions, true_labels, topk=5):
    assert len(predictions) == len(true_labels), "预测结果和真实标签的数量必须相同"
    
    bad_cases = []
    for ques, pred, true_label in zip(questions, predictions, true_labels):
        if true_label not in pred[:topk]:
            bad_cases.append(ques)
    return bad_cases

In [37]:
top100_bad_case = check_badcase(questions, predictions, targets, topk=100)
print(len(top100_bad_case), top100_bad_case[:10])

2870 ['新闻来源提到了哪个娱乐频道网站？', '请问您的手机是什么型号的呢?', '售后服务中心的地址和联系电话在哪里能查询到？', '请问您的手机是什么型号的呢?', '什么是四维检查？', '为什么需要进行四维检查？', '为什么选医院要考虑服务态度和正规性？', '为什么要近水楼台先得月？', '如果术前有炎症，是否需要先治好才能做手术？', '这个剧情是根据哪部作品改编的？']


In [38]:
top50_bad_case = check_badcase(questions, predictions, targets, topk=50)
top50_bad_case_diff = [q for q in top50_bad_case if q not in top100_bad_case]
print(len(top50_bad_case_diff), top50_bad_case_diff[:10])

1704 ['电视剧《斗破苍穹》预计何时上映播出？', '电视剧《斗破苍穹》预计何时上映播出？', '电视剧《斗破苍穹》何时上映播出？', '请问您的手机是什么型号的呢？', '什么是阻击模式，如何进入该模式？', '需要提前购买回程的车票吗？', '驾车路线中需要转弯几次？', 'mAh和mA之间有什么关系？', '沿着哪些地点可以到达长江路淮南街站？', '还有其他类似的软件吗？']


In [39]:
top20_bad_case = check_badcase(questions, predictions, targets, topk=20)
top20_bad_case_diff = [q for q in top20_bad_case if q not in top50_bad_case]
top20_bad_case_diff_sampled = [top20_bad_case_diff[randrange(len(top20_bad_case_diff))] for i in range(10)]
print(len(top20_bad_case), len(top20_bad_case_diff), top20_bad_case_diff_sampled)

7993 3331 ['这首歌的演唱者是谁？', '可以从一个半岛入境再从另一个半岛出境吗？', '如何在《龙魂》中获得大礼包？', '2017年3月8日至2017年3月14日期间有什么活动？', '22寸行李箱的尺寸是多少？', '更换空调器压缩机时需要注意什么？', '一升车预计何时上市？', '有什么建议可以查看三本院校的招生简章吗？', '生化危机6还有其他的游戏模式吗？', '题记中描述了什么样的情感？']


##### 观察： top20未召回的问题，确实不少是低质量问题，也有一些是比较难的问题

## 5. Filtering
保留top20内能召回正确文档的问题, 过滤重复问题

In [44]:
import pandas as pd

qa_df = pd.DataFrame(qa_dataset_with_retrieval)
qa_df.drop_duplicates(subset=['question'], inplace=True)
qa_dataset_dedup = Dataset.from_pandas(qa_df).remove_columns(['__index_level_0__'])
qa_dataset_dedup

Dataset({
    features: ['question', 'answer', 'docid', 'retrieved_docids', 'retrieved_doc_scores'],
    num_rows: 135333
})

In [45]:
qa_dataset_filtered = qa_dataset_dedup.filter(lambda x: x["question"] not in top20_bad_case)
print(qa_dataset_filtered)

Filter:   0%|          | 0/135333 [00:00<?, ? examples/s]

Dataset({
    features: ['question', 'answer', 'docid', 'retrieved_docids', 'retrieved_doc_scores'],
    num_rows: 128224
})


## 6. Train and test split

In [46]:
qa_dataset_with_retrieval_output_path = "data/dureader/dureader_dataset"
qa_dataset_train_test = qa_dataset_filtered.train_test_split(test_size=0.01, shuffle=True)
print(qa_dataset_train_test)
qa_dataset_train_test.save_to_disk(qa_dataset_with_retrieval_output_path)

DatasetDict({
    train: Dataset({
        features: ['question', 'answer', 'docid', 'retrieved_docids', 'retrieved_doc_scores'],
        num_rows: 126941
    })
    test: Dataset({
        features: ['question', 'answer', 'docid', 'retrieved_docids', 'retrieved_doc_scores'],
        num_rows: 1283
    })
})


Saving the dataset (0/1 shards):   0%|          | 0/126941 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/1283 [00:00<?, ? examples/s]

## 7. evaluate retrieval model on testset

In [1]:
from datasets import load_from_disk

test_ds = load_from_disk("data/dureader/dureader_dataset/")["test"]
print(test_ds)

Dataset({
    features: ['question', 'answer', 'docid', 'retrieved_docids', 'retrieved_doc_scores'],
    num_rows: 1283
})


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
import sys
sys.path.append("../")
from metric import RetrievalMetrics
from misc import recursive_round

def evaluate_retriever(dataset):
    retrieval_metrics = RetrievalMetrics()
    for example in dataset:
        # ranking metric
        rank_preds = [1 / (1+p) for p in example["retrieved_doc_scores"]]
        rank_targets = [False] * len(rank_preds)
        pos_idx =  example["retrieved_docids"].index(example["docid"])
        rank_targets[pos_idx] = True
        retrieval_metrics.update([rank_preds], [rank_targets])
    return recursive_round(retrieval_metrics.compute())

In [3]:
import json

eval_results = evaluate_retriever(test_ds)
print(json.dumps(eval_results, indent=2))

{
  "HitRate@1": 0.761,
  "HitRate@5": 0.933,
  "HitRate@10": 0.971,
  "MRR": 0.838,
  "MAP@1": 0.761,
  "MAP@5": 0.83,
  "MAP@10": 0.836,
  "NDCG@1": 0.761,
  "NDCG@5": 0.856,
  "NDCG@10": 0.869,
  "Recall@1": 0.761,
  "Recall@5": 0.933,
  "Recall@10": 0.971
}


In [35]:
def compute_topk_accuracy(predictions, true_labels, topk=5):
    assert len(predictions) == len(true_labels), "预测结果和真实标签的数量必须相同"
    
    num_correct = 0
    for pred, true_label in zip(predictions, true_labels):
        if true_label in pred[:topk]:
            num_correct += 1
    
    top1_accuracy = num_correct / len(predictions)
    return top1_accuracy

In [36]:
from tqdm import tqdm

questions, predictions, targets = [], [], []
for s in tqdm(test_ds):
    questions.append(s["question"])
    predictions.append(s["retrieved_docids"])
    targets.append(s["docid"])

100%|██████████| 1283/1283 [00:00<00:00, 13866.34it/s]


In [38]:
for k in [1, 3, 5, 10]:
    topk_acc = compute_topk_accuracy(predictions, targets, topk=k)
    print(f"Recall@{k}:", topk_acc)

Recall@1: 0.7614964925954794
Recall@3: 0.8916601714731099
Recall@5: 0.9329696024941543
Recall@10: 0.9711613406079501
