In [111]:
import pyserini.search as pys
import transformers
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch
import numpy as np
import pandas as pd

In [163]:

searcher = pys.SimpleSearcher('indexes/sample_collection_jsonl')

f = open("data/queries.dev.tsv")
queries = []

for i in range(20):
    l = f.readline().split("\t")
    queries.append({"id": int(l[0]), "content": l[1].strip()})
    
print(queries[0])
f.close()

{'id': 1048578, 'content': 'cost of endless pools/swim spa'}


In [154]:

hits = searcher.search(queries[0]['content'], k=20)


for i in range(len(hits)):
    print(f'{i+1:2} {hits[i].docid:4} {hits[i].score:.5f}')

 1 7471198 21.35000
 2 7187236 20.33760
 3 5365326 19.68380
 4 7187234 19.53490
 5 7187242 18.88170
 6 2078221 18.84990
 7 7187241 18.66050
 8 6802210 18.12250
 9 6794083 17.85240
10 5365328 17.60240
11 6750054 17.18490
12 4332300 16.23630
13 6347088 16.22780
14 6347089 16.22780
15 6270168 16.07760
16 3982208 15.82690
17 7471199 15.75350
18 7313043 15.45380
19 8105762 15.30410
20 1139145 15.19000


In [153]:
def findDoc(id, mode="chunk"):
    if(mode == "linear"):
        res = ""
        f = open("data/collection.tsv",  encoding="utf8")
        for i in range(id+1):
            l = f.readline()
        print(l)
    elif mode == "chunk":
        res = id%10000
        nearest_n = id - res
        f = open(f"data/collection_chunks/{nearest_n}.txt",  encoding="utf8")
        for i in range(res):
            l = f.readline()
        return l


In [135]:
w = [2/3, 1/6, 1/6]
a = 0.5
# The tokenizer will automatically add any model specific separators (i.e. <CLS> and <SEP>) and tokens to
# the sequence, as well as compute the attention masks.

def getSim(query, doc):
    paraphrase = tokenizer(query, doc, return_tensors="pt")
    paraphrase_classification_logits = model(**paraphrase).logits
    paraphrase_results = torch.softmax(paraphrase_classification_logits, dim=1).tolist()[0]
    
    return paraphrase_results[1]
     
def calcBertScore(query, doc): 
    sentences = list(filter(lambda x: x != "", [d.strip() for d in doc.split(".")]))
    
    sims = [{"score": getSim(query, d), "sentence": d} for d in sentences]
    sims = sorted(sims, reverse = True, key = lambda x: x["score"])[:3]
    n = min(len(sims), 3)
    score = 0
    for i in range(n):
        score = score + w[i]*sims[i]["score"]
    
    return score

In [193]:
query = queries[0]["content"]


def bertRankQuery(query, k=20):
    hits = searcher.search(query, k=k)
    
    maxBM = 0
    maxBert = 0
    res = pd.DataFrame(np.zeros((len(hits), 3)), index = range(len(hits)), columns=["docid", "bm-25", "bert"])

    for i in range(len(hits)):
        hit = hits[i]
        docid = int(hit.docid)
        doc = findDoc(docid).split("\t")[1]

        score = calcBertScore(query, doc)
        print(f"Doc: {docid}, bm-25: {hit.score}, bert: {score}")
        res.at[i, "docid"] = docid
        res.at[i, "bm-25"] = hit.score
        res.at[i, "bert"] = score

        maxBM = max(maxBM, hit.score)
        maxBert = max(maxBert, score)
    res["final"] = a*res["bm-25"]/maxBM + (1-a)*res["bert"]/maxBert
    res = res.astype({'docid': 'int32'})
    res = res.sort_values("final", ascending = False)
        
    return res

res = bertRankQuery(query, 100)



Doc: 7471198, bm-25: 21.350000381469727, bert: 0.18122270703315735
Doc: 7187236, bm-25: 20.337600708007812, bert: 0.3372192780176798
Doc: 5365326, bm-25: 19.683799743652344, bert: 0.18112697328130403
Doc: 7187234, bm-25: 19.534900665283203, bert: 0.45602771391471225
Doc: 7187242, bm-25: 18.88170051574707, bert: 0.12109113112092017
Doc: 2078221, bm-25: 18.849899291992188, bert: 0.609403363118569
Doc: 7187241, bm-25: 18.660499572753906, bert: 0.048912956689794854
Doc: 6802210, bm-25: 18.122499465942383, bert: 0.04601731585959594
Doc: 6794083, bm-25: 17.852399826049805, bert: 0.35505015403032303
Doc: 5365328, bm-25: 17.602399826049805, bert: 0.08343984310825667
Doc: 6750054, bm-25: 17.184900283813477, bert: 0.1803257738550504
Doc: 4332300, bm-25: 16.236299514770508, bert: 0.055645694956183434
Doc: 6347088, bm-25: 16.227800369262695, bert: 0.11002308999498685
Doc: 6347089, bm-25: 16.227798461914062, bert: 0.0945375288526217
Doc: 6270168, bm-25: 16.077600479125977, bert: 0.06572579654554525

In [194]:
def exportRes(query, res, filename = "res.txt"):
    f = open(filename, "w")
    
    for i in range(len(res)):
#         print(res["docid"].iat[i])
        docid = res["docid"].iat[i]
        score = res["final"].iat[i]
        f.write(f"{query} Q0 {docid} {i + 1} {score} Bertserini \n")
        
exportRes(queries[0]["id"], res)
        
print(res)

      docid      bm-25      bert     final
5   2078221  18.849899  0.609403  0.873640
96  5788572  13.299000  0.705017  0.811452
3   7187234  19.534901  0.456028  0.780908
1   7187236  20.337601  0.337219  0.715447
8   6794083  17.852400  0.355050  0.669892
..      ...        ...       ...       ...
75  5488062  13.617000  0.047484  0.352575
84  4774742  13.457000  0.048387  0.349469
88  2720616  13.433400  0.048335  0.348879
89  7187240  13.402500  0.046401  0.346784
81  5800643  13.494600  0.037227  0.342434

[100 rows x 4 columns]


In [166]:
res.iloc[1]["docid"]

7187234.0