## 1. Setup Environment

### 1.1 Install package

### 1.2 Setup Retriever
here we use m3e-base as the chinese Retriever

In [1]:
from sentence_transformers import SentenceTransformer

model = SentenceTransformer('moka-ai/m3e-base')

#Our sentences we like to encode
sentences = [
    '* Moka 此文本嵌入模型由 MokaAI 训练并开源，训练脚本使用 uniem',
    '* Massive 此文本嵌入模型通过**千万级**的中文句对数据集进行训练',
    '* Mixed 此文本嵌入模型支持中英双语的同质文本相似度计算，异质文本检索等功能，未来还会支持代码检索，ALL in one'
]

#Sentences are encoded by calling model.encode()
embeddings = model.encode(sentences)

#Print the embeddings
for sentence, embedding in zip(sentences, embeddings):
    print("Sentence:", sentence)
    print("Embedding:", embedding.shape)
    print("")

  from .autonotebook import tqdm as notebook_tqdm


Sentence: * Moka 此文本嵌入模型由 MokaAI 训练并开源，训练脚本使用 uniem
Embedding: (768,)

Sentence: * Massive 此文本嵌入模型通过**千万级**的中文句对数据集进行训练
Embedding: (768,)

Sentence: * Mixed 此文本嵌入模型支持中英双语的同质文本相似度计算，异质文本检索等功能，未来还会支持代码检索，ALL in one
Embedding: (768,)



## 2. Preprocess QA data
### 2.1 clean qa data

In [3]:
import json
import numpy as np
from pathlib import Path
from collections import defaultdict
T2Ranking_dir = Path("data/T2Ranking")
T2Ranking_dir.mkdir(exist_ok=True)
qa_input_path = T2Ranking_dir / "dev_qa_pairs.json"
document_path = T2Ranking_dir / "dev_document.json"
qa_data = json.load(open(qa_input_path))
documents = json.load(open(document_path))
print("qa pairs", len(qa_data), "documents", len(documents))

# 1. check documents
docid2doc = {doc["docid"]: doc["document"] for doc in documents}
print("docid2doc len:", len(docid2doc))

# 2. statistics and filter qa
print("qa pairs num:", len(qa_data))
ques_lens = [len(qa["question"]) for qa in qa_data]
min_ques_len = min(ques_lens)
max_ques_len = max(ques_lens)
mean_ques_len = np.mean(ques_lens)
print("max ques len:", max_ques_len)
print("min ques len:", min_ques_len)
print("mean ques len:", mean_ques_len)

# 3 filter doc
doc_lens = [len(docid2doc[docid]) for docid in docid2doc]
min_doc_len = min(doc_lens)
max_doc_len = max(doc_lens)
mean_doc_len = np.mean(doc_lens)
print("max doc len:", max_doc_len)
print("min doc len:", min_doc_len)
print("mean doc len:", mean_doc_len)

print("before filter, doc num:", len(docid2doc))
documents_keep = []
documents_drop = []
ans_len_99 = np.percentile(doc_lens, 99)
ans_len_1 = np.percentile(doc_lens, 1)
print("ans_len_99:", ans_len_99)
print("ans_len_1:", ans_len_1)
for item in documents:
    if (len(item["document"]) <= ans_len_99 and 
        len(item["document"]) >= ans_len_1):
        documents_keep.append(item)
    else:
        documents_drop.append(item)
print("after filter, doc num:", len(documents_keep))
print("drop ratio:", len(documents_drop) / len(documents))

docid2doc = {doc["docid"]: doc["document"] for doc in documents_keep}

# 4. filter qa
print("before filter, qa num:", len(qa_data))
qa_data_keep = []
for item in qa_data:
    if item["docid"] in docid2doc:
        qa_data_keep.append(item)
print("after filter, qa num:", len(qa_data_keep))

# ques2docids
ques2docids = defaultdict(list)
for qa in qa_data_keep:
    ques2docids[qa["question"]].append(qa["docid"])
print("ques2docids len:", len(ques2docids))

# 5. save
out_dir = Path("data/T2Ranking/")
qa_output_path = out_dir / "dev_qa_pairs_filter.json"
document_output_path = out_dir / "dev_document_filter.json"
json.dump(qa_data_keep, open(qa_output_path, "w"), indent=2, ensure_ascii=False)
json.dump(documents_keep, open(document_output_path, "w"), indent=2, ensure_ascii=False)

: 

: 

### 2.2 build faiss index for documents

In [9]:
from datasets import load_dataset, Dataset

# document_dataset = load_dataset("json", data_files=document_path)
document_dataset = Dataset.from_list(documents)
print(document_dataset)
print(json.dumps(document_dataset[0], indent=2, ensure_ascii=False))

Dataset({
    features: ['docid', 'document'],
    num_rows: 2303643
})
{
  "docid": "3076cc10eea8",
  "document": "找寄件人改号码"
}


In [10]:
# if doc embedding has been processed, then load here
document_dataset.load_faiss_index('doc_embedding', '/ssd4T/T2Ranking/doc_embedding.faiss')
document_dataset_with_emb = document_dataset
print(document_dataset_with_emb.is_index_initialized("doc_embedding"))

True


### 2.3 encode questions

In [16]:
qa_dataset = Dataset.from_list(qa_data_keep)
print(qa_dataset)
print(qa_dataset[0])
# 
qa_dataset_with_emb = qa_dataset.map(
    lambda example: {'question_embedding': model.encode(example["question"])}, 
    batched=True
)

Dataset({
    features: ['question', 'answer', 'docid'],
    num_rows: 740960
})
{'question': '鹦鹉吃自己的小鱼吗', 'answer': '', 'docid': '3fef708246b3'}


Map:   0%|          | 0/740960 [00:00<?, ? examples/s]

In [25]:
q_dataset = Dataset.from_list([{"question":q} for q in ques2docids])
print(q_dataset)
print(q_dataset[0])
# 
q_dataset_with_emb = q_dataset.map(
    lambda example: {'question_embedding': model.encode(example["question"])}, 
    batched=True
)

Dataset({
    features: ['question'],
    num_rows: 199576
})
{'question': '鹦鹉吃自己的小鱼吗'}


Map:   0%|          | 0/199576 [00:00<?, ? examples/s]

## 3. Build dataset
### 3.1 retrieve relevant documents

In [26]:
import numpy as np

def retrieve_topk_documents(example):
    topk = 100
    ques_embedding = np.array(example["question_embedding"], dtype=np.float32)
    scores, retrieved_examples = document_dataset_with_emb.get_nearest_examples('doc_embedding', ques_embedding, k=topk)
    example["retrieved_docids"] = retrieved_examples["docid"]
    example["retrieved_doc_scores"] = scores.tolist()
    return example

In [31]:
# 根据硬件，调整num_proc
q_dataset_with_retrieval = q_dataset_with_emb.map(retrieve_topk_documents, 
                                                  num_proc=16,
                                                  remove_columns=["question_embedding"])

Map (num_proc=16):   0%|          | 0/199576 [00:00<?, ? examples/s]

In [32]:
q2retrieval = {}
for item in q_dataset_with_retrieval:
    ques = item['question']
    q2retrieval[ques] = dict(retrieved_docids=item['retrieved_docids'],
                             retrieved_doc_scores=item['retrieved_doc_scores'])  

In [35]:
import numpy as np

# 因为是一对多，这里为了训练调整成一对一
def prepare_for_training_dataset(example):
    ques = example["question"]
    ques_docids = ques2docids[ques]
    ques_docids = [d for d in ques_docids if d != example["docid"]]
    retrieved_docids = q2retrieval[ques]["retrieved_docids"]
    retrieved_doc_scores = q2retrieval[ques]["retrieved_doc_scores"]
    doc_scores = []
    docids = []
    for d, s in zip(retrieved_docids, retrieved_doc_scores): 
        if d in docid2doc and d not in ques_docids:
            doc_scores.append(s)
            docids.append(d)
    example["retrieved_docids"] = docids
    example["retrieved_doc_scores"] = doc_scores
    return example

In [36]:
# 根据硬件，调整num_proc
qa_dataset_with_retrieval = qa_dataset_with_emb.map(prepare_for_training_dataset, 
                                                    num_proc=1,
                                                    remove_columns=["question_embedding"])

Map:   0%|          | 0/740960 [00:00<?, ? examples/s]

In [51]:
print(len(qa_dataset_with_retrieval[0]["retrieved_docids"]))
print(len(qa_dataset_with_retrieval[0]["retrieved_doc_scores"]))
print(qa_dataset_with_retrieval[3]["docid"] in qa_dataset_with_retrieval[3]["retrieved_docids"])

87
87
True


## 4. Evaluate
### 4.1 Evaluate retrieval

In [52]:
def compute_topk_accuracy(predictions, true_labels, topk=5):
    assert len(predictions) == len(true_labels), "预测结果和真实标签的数量必须相同"
    
    num_correct = 0
    for pred, true_label in zip(predictions, true_labels):
        if true_label in pred[:topk]:
            num_correct += 1
    
    top1_accuracy = num_correct / len(predictions)
    return top1_accuracy

In [53]:
from tqdm import tqdm

questions, predictions, targets = [], [], []
for s in tqdm(qa_dataset_with_retrieval):
    questions.append(s["question"])
    predictions.append(s["retrieved_docids"])
    targets.append(s["docid"])

100%|██████████| 740960/740960 [02:11<00:00, 5623.69it/s]


In [54]:
for k in [1, 3, 5, 10, 20, 50, 100]:
    topk_acc = compute_topk_accuracy(predictions, targets, topk=k)
    print(f"top{k}:", topk_acc)

top1: 0.1833931656229756
top3: 0.3276033793996977
top5: 0.39888388037141004
top10: 0.4942102137767221
top20: 0.5847265709350032
top50: 0.6923882530770892
top100: 0.7542728352407687


### 4.2 check badcase
因为我觉得top100都无法召回的问题，可能是低质量问题

In [55]:
from random import randrange
# check qa quality
index = [randrange(len(qa_dataset_with_retrieval)) for i in range(5)]
small_dataset = qa_dataset_with_retrieval.select(index).remove_columns(["retrieved_docids"])
print(json.dumps([(i["question"], i["answer"]) for i in small_dataset], ensure_ascii=False, indent=4))

[
    [
        "发动机内部清洗油有必要吗",
        ""
    ],
    [
        "太阳穴针扎一样疼",
        ""
    ],
    [
        "王者战士制裁的用处",
        ""
    ],
    [
        "高阳膝盖不好使疼痛甚是是什么引起的",
        ""
    ],
    [
        "什么情况会被拉入黑名单买不了车票",
        ""
    ]
]


In [56]:
def check_badcase(questions, predictions, true_labels, topk=5):
    assert len(predictions) == len(true_labels), "预测结果和真实标签的数量必须相同"
    
    bad_cases = []
    for ques, pred, true_label in zip(questions, predictions, true_labels):
        if true_label not in pred[:topk]:
            bad_cases.append(ques)
    return bad_cases

In [57]:
top100_bad_case = check_badcase(questions, predictions, targets, topk=100)
print(len(top100_bad_case), top100_bad_case[:10])

182074 ['鹦鹉的生活习性是什么', '鹦鹉的生活习性是什么', '鹦鹉鱼头洞传染吗', '鹰潭市景点排行榜', '鹰潭市景点排行榜', '鹰嘴骨折手术后能正常生活吗', '鹰嘴骨折手术后能正常生活吗', '鹰嘴骨折手术后能正常生活吗', '鹰嘴骨折手术后能正常生活吗', '盈利的盈组词']


In [58]:
top50_bad_case = check_badcase(questions, predictions, targets, topk=50)
top100_badcase_set = set(top100_bad_case)
top50_bad_case_diff = [q for q in top50_bad_case if q not in top100_badcase_set]
print(len(top50_bad_case_diff), top50_bad_case_diff[:10])

17707 ['鹦鹉吃自己的小鱼吗', '荧光pcr基因 内参关系', '硬路肩也铺沥青的吗', '用不完的尿不湿的妙用', '用电饼档爆栗子怎么样', '用电饼档爆栗子怎么样', '用两个抢票软件可以吗', '用硫磺皂洗澡后有气味', '用人单位还是派遣单位工资', '用珊瑚癣净泡脚非常疼']


In [61]:
top20_bad_case = check_badcase(questions, predictions, targets, topk=20)
top50_badcase_set = set(top50_bad_case)
top20_bad_case_diff = [q for q in top20_bad_case if q not in top50_badcase_set]
top20_bad_case_diff_sampled = [top20_bad_case_diff[randrange(len(top20_bad_case_diff))] for i in range(10)]
print(len(top20_bad_case), len(top20_bad_case_diff), top20_bad_case_diff_sampled)

307701 27614 ['电动车电池测试好坏', '网购哪个平台能买到正品衣服', '咸阳晚上哪里好玩的景点', '打孩子多久子宫就好了', '单株是什么意思', '儿童便秘应该挂什么科', '生育难免险流产可以报销多少', '汽车尿素液起什么作用', '邮政快递为什么这么慢', '非洲旅游哪里安全']


##### 观察： top20未召回的问题，主要是召回模型的问题，感觉可以不用过滤

## 5. Filtering
T2Ranking这个数据集，就不过滤了

不过滤重复问题，因为数据集天然就是1对多的

## 6. save the dataset

In [63]:
# 不train_test_split
qa_dataset_with_retrieval_output_path = "data/T2Ranking/T2Ranking_train_dataset"
qa_dataset_with_retrieval.save_to_disk(qa_dataset_with_retrieval_output_path)

Saving the dataset (0/4 shards):   0%|          | 0/740960 [00:00<?, ? examples/s]