# 第四章 交叉编码器重排序
本节课，将使用交叉编码器重排序的技术，对检索到的结果进行相关性分析。重排序是一种根据结果与特定查询的相关性来排序和评分的方法。


<div class="toc">
    <ul class="toc-item">
        <li><span><a href="#一底层原理" data-toc-modified-id="一、底层原理">一、底层原理</a></span></li>
        <li>
        <span><a href="#二实现过程" data-toc-modified-id="二、实现过程">二、实现过程</a></span></li><li>
        <ul class="toc-item">
            <li><span><a href="#21-导入辅助函数" data-toc-modified-id="2.1 导入辅助函数">2.1 导入辅助函数</a></span></li>
            <li><span><a href="#22-长尾部分的重排序" data-toc-modified-id="2.2 长尾部分的重排序">2.2 长尾部分的重排序</a></span></li>
            <li><span><a href="#23-结合查询扩展的重排序" data-toc-modified-id="2.3 结合查询扩展的重排序">2.3 结合查询扩展的重排序</a></span></li>
        </ul>
        </li>
    </ul>
</div>

## 一、底层原理
在重排序过程中，在得到特定查询检索到结果之后，需要将这些结果连同查询一起传递给一个重排序模型。这可以重新排列输出，使最相关的结果具有最高的排名。另一种思考方式是，重排序模型根据查询对每个结果打分，得分最高的就是最相关的结果。最后，以选择排名最高的结果作为与特定查询最相关的结果。
 ## 二、实现过程
### 2.1 导入辅助函数

In [1]:
# 导入辅助函数并把数据加载到Chroma中 
from helper_utils import load_chroma, word_wrap, project_embeddings
from chromadb.utils.embedding_functions import SentenceTransformerEmbeddingFunction
import numpy as np

In [2]:
# 使用代理可能出现网络问题，将以下端口号1080全部替换成自己的vpn的端口号
import os
os.environ['HTTPS_PROXY']='http://127.0.0.1:1080'
os.environ["HTTP_PROXY"]='http://127.0.0.1:1080'

embedding_function = SentenceTransformerEmbeddingFunction()

chroma_collection = load_chroma(filename='microsoft_annual_report_2022.pdf', collection_name='microsoft_annual_report_2022', embedding_function=embedding_function)
chroma_collection.count()

  return self.fget.__get__(instance, owner)()


506

### 2.2 长尾部分的重排序

In [3]:
# 之前一般设定返回5个结果，现在要求返回10个结果，加入了部分可能有用的的长尾结果
query = "What has been tchhe investment in research and development?"
results = chroma_collection.query(query_texts=query, n_results=10, include=['documents', 'embeddings'])

retrieved_documents = results['documents'][0]

for document in results['documents'][0]:
    print(word_wrap(document))
    print('')

48comprehensiveincomestatements ( inmillions ) yearendedjune30, 2022
2021 2020 netincome $ 72, 738 $ 61, 271 $ 44, 281
othercomprehensiveincome ( loss ), netoftax :
netchangerelatedtoderivatives 6 19 ( 38 ) netchangerelatedtoinvestments
( 5, 360 ) ( 2, 266 ) 3, 990 translationadjustmentsandother ( 1, 146 )
873 ( 426 ) othercomprehensiveincome ( loss ) ( 6, 500 ) ( 1, 374 ) 3,
526 comprehensiveincome $ 66, 238 $ 59, 897 $ 47, 807
refertoaccompanyingnotes.

2021acquisitions otherjune30, 2022 productivityandbusiness processes $
24, 190 $ 0 $ 127 $ 24, 317 $ 599 $ ( 105 ) $ 24, 811 intelligentcloud
12, 697 505 54 13, 256 16, 879 ( b ) 47 ( b ) 30, 182
morepersonalcomputing 6, 464 5, 556 ( a ) 118 ( a ) 12, 138 648 ( 255 )
12, 531

adjustednetincome ( non - gaap ) $ 69, 447 $ 60, 651 15 %
dilutedearningspershare $ 9. 65 $ 8. 05 20 %
netincometaxbenefitrelatedtotransferofintangibleproperties ( 0. 44 ) 0
*
netincometaxbenefitrelatedtoindiasupremecourtdecisiononwithholdingtaxes
0 ( 0. 08 ) * a

In [4]:
# BERT交叉编码器同时渠道查询和文档，通过一个分类器传递，获得一个得分
# 利用该得分作为检索结果的相关性或排名的得分
from sentence_transformers import CrossEncoder
cross_encoder = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2')

In [5]:
pairs = [[query, doc] for doc in retrieved_documents]
scores = cross_encoder.predict(pairs)
print("Scores:")
for score in scores:
    print(score)

Scores:
-10.778491
-11.12772
-11.024757
-10.903776
-9.883989
-10.559224
-10.512915
-11.216286
-11.215064
-10.801218


In [6]:
print("New Ordering:")
for o in np.argsort(scores)[::-1]:
    print(o+1)

New Ordering:
5
7
6
1
10
4
3
2
9
8


### 2.3 结合查询扩展的重排序

In [7]:
# 接下来把之前获得的结果排序前5名传递给LLM
original_query = "What were the most important factors that contributed to increases in revenue?"
generated_queries = [
    "What were the major drivers of revenue growth?",
    "Were there any new product launches that contributed to the increase in revenue?",
    "Did any changes in pricing or promotions impact the revenue growth?",
    "What were the key market trends that facilitated the increase in revenue?",
    "Did any acquisitions or partnerships contribute to the revenue growth?"
]

In [8]:
queries = [original_query] + generated_queries

results = chroma_collection.query(query_texts=queries, n_results=10, include=['documents', 'embeddings'])
retrieved_documents = results['documents']

In [9]:
# Deduplicate the retrieved documents
unique_documents = set()
for documents in retrieved_documents:
    for document in documents:
        unique_documents.add(document)

unique_documents = list(unique_documents)

In [10]:
pairs = []
for doc in unique_documents:
    pairs.append([original_query, doc])

In [11]:
scores = cross_encoder.predict(pairs)

In [12]:
print("Scores:")
for score in scores:
    print(score)

Scores:
-10.831182
-11.0566225
-11.184257
-10.7959385
-10.43167
-11.144808
-11.203177
-11.231071
-11.1831045
-11.09624
-11.086064
-10.929356
-11.127819
-9.105438
-11.190094
-11.283865
-10.869104
-11.024591
-11.043177
-10.629692
-11.120768
-11.068455
-10.873292
-11.136417
-10.426586
-10.566668
-10.276354
-10.340245
-11.226146
-10.291064
-11.10204


In [13]:
print("New Ordering:")
for o in np.argsort(scores)[::-1]:
    print(o)

New Ordering:
13
26
29
27
24
4
25
19
3
0
16
22
11
17
18
1
21
10
9
30
20
12
23
5
8
2
14
6
28
7
15
