In [17]:
import pyserini.search as pys
import transformers
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from transformers import TrainingArguments
from transformers import Trainer
from datasets import load_metric
import torch
import numpy as np
import pandas as pd
import time
import math as m

In [34]:

searcher = pys.SimpleSearcher('indexes/sample_collection_jsonl')
tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")
model = AutoModelForSequenceClassification.from_pretrained("distilbert-base-uncased")
device = "cuda:0" if torch.cuda.is_available() else "cpu"
model     = model.to(device)

loading configuration file https://huggingface.co/distilbert-base-uncased/resolve/main/config.json from cache at C:\Users\marij/.cache\huggingface\transformers\23454919702d26495337f3da04d1655c7ee010d5ec9d77bdb9e399e00302c0a1.91b885ab15d631bf9cee9dc9d25ece0afd932f2f5130eba28f2055b2220c0333
Model config DistilBertConfig {
  "_name_or_path": "distilbert-base-uncased",
  "activation": "gelu",
  "architectures": [
    "DistilBertForMaskedLM"
  ],
  "attention_dropout": 0.1,
  "dim": 768,
  "dropout": 0.1,
  "hidden_dim": 3072,
  "initializer_range": 0.02,
  "max_position_embeddings": 512,
  "model_type": "distilbert",
  "n_heads": 12,
  "n_layers": 6,
  "pad_token_id": 0,
  "qa_dropout": 0.1,
  "seq_classif_dropout": 0.2,
  "sinusoidal_pos_embds": false,
  "tie_weights_": true,
  "transformers_version": "4.16.2",
  "vocab_size": 30522
}

loading file https://huggingface.co/distilbert-base-uncased/resolve/main/vocab.txt from cache at C:\Users\marij/.cache\huggingface\transformers\0e1bbfda7f6

In [35]:
def findDoc(id, mode="chunk"):
    if(mode == "linear"):
        res = ""
        f = open("data/collection.tsv",  encoding="utf8")
        for i in range(id+1):
            l = f.readline()
        return l.strip().split('\t')
    elif(mode == "chunk"):
        res = id%10000
        nearest_n = id - res
        f = open(f"data/collection_chunks/{nearest_n}.txt",  encoding="utf8")
        l = None
        if res == 0:
            f = open(f"data/collection_chunks/{id-10000}.txt",  encoding="utf8")
            lines = f.readlines()
            l = lines[len(lines)-1]
        
        for i in range(res):
            l = f.readline()
        
        if l == None:
            print(f"Document not found/read, id: {id}, res id: {res}")
            
        return l
    
def loadQueries(filename = "data/queries/queries.train.tsv"):
    f = open(filename,  encoding="utf8")
    lines = f.readlines()
    print(lines[0])
    f.close()
    return {int(l.strip().split('\t')[0]):l.strip().split('\t')[1] for l in lines}
    
def findQuery(id, cached=True):
    if(cached):
        content = queries[id]
        return [id, content]
    t = time.time()
    f = open("data/queries/queries.train.tsv",  encoding="utf8")
    while True:
        l = f.readline().strip().split('\t')
        l[0] = int(l[0])
        if(l[0] == id):
            print(time.time() - t)
            return l
        if not l:
            break
    print(time.time() - t)
    return l  

def getTrip(f):
    l = f.readline().strip().split('\t')
    if l[0] == '':
        return None
    l = [int(x) for x in l]
    l[0] = findQuery(l[0])
    l[1] = findDoc(l[1]).strip().split('\t')
    l[2] = findDoc(l[2]).strip().split('\t')
    return l

In [None]:
queries = loadQueries()

In [None]:
def savePosSamples():
    triples = open("data/qidpidtriples.train.full.2.tsv")
    fw = open("data/train.pos.txt", "w")
    currQ = ""
    currPos = ""
    for i in range(4000000):
        l = triples.readline()
        triple = l.strip().split("\t")
        if(currQ != triple[0] or currPos != triple[1]):
            currQ = triple[0]
            currPos = triple[1]
            fw.write(l)
    triples.close()
    fw.close()


def trainBert():
    triples = open("data/train.pos.txt")
    
    trainingData = []
    
    for i in range(4000):
        triple = getTrip(triples)
        if(triple == None):
            break
        trainingData.append({"s1": triple[0][1], "s2": triple[1][1]})
        if(i%100 == 0):
            print(i)
        
    triples.close()
    print("What took so long?")
    return trainingData 
        
trainingData = trainBert()
# savePosSamples()

In [None]:
def tokenize_function(sample):
    return tokenizer(sample["s1"], sample["s2"], padding="max_length", truncation=True)


tokens = list(map(tokenize_function, trainingData))

tval = 0.8
trainsize = int(len(trainingData)*tval)
train_tokens = tokens[:trainsize]
val_tokens = tokens[trainsize:]

In [None]:
training_args = TrainingArguments(output_dir="test_trainer",
    do_train=True,
    do_eval=False,
#     evaluation_strategy="steps",
#     eval_steps=1000,
#     logging_dir="exp/bart/logs",
#     num_train_epochs=1,
    per_device_train_batch_size=2,
#     per_device_eval_batch_size=32,
#     gradient_accumulation_steps=2,
    eval_accumulation_steps=1,)
metric = load_metric("accuracy")


def compute_metrics(eval_pred):
    logits, labels = eval_pred
    predictions = np.argmax(logits, axis=-1)
    return metric.compute(predictions=predictions, references=labels)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_tokens,
    eval_dataset=val_tokens,
    compute_metrics=None,
)

In [None]:
trainer.train()

In [37]:
# f = open("data/queries/queries.dev.tsv")
# f = open("data/msmarco-test2019-queries.tsv")
f = open("data/hard_queries.txt")
queries = []

for i in range(200):
    l = f.readline().split("\t")
    if(l[0] != ''):
        queries.append({"id": int(l[0]), "content": l[1].strip()})
    
print(queries[0])
f.close()

{'id': 1000000, 'content': 'where does real insulin come from'}


In [48]:
def loadHardQrels():
    f = open("data/hard_qrels.txt")
    lines = [l.strip().split(' ') for l in f.readlines()]
    return {int(l[0]): int(l[1]) for l in lines}
qrels = loadHardQrels()

fw = open("bert_hard.txt", "w")
resses = []
for query in queries[1:]:
    if(query["id"] in qrels):
        print(f"Checking for query: {query['content']}")

#         hits = searcher.search(query["content"], k=1000)
#         res = []
#         for hit in hits:
#             res.append(int(hit.docid))
#         print(qrels[query["id"]] in res)
        
        res = bertRankQuery(query["content"], k=1000)
        res["rank"] = range(1, len(res)+1)
        res = res.sort_values("bm-25", ascending = False)
        res["bm25-rank"] = range(1, len(res)+1)
        res = res.sort_values("final", ascending = False)
        print(res[res["docid"] == qrels[query["id"]]])
        exportRes(query, res, fw)
        resses.append(res)
fw.close()


# res["rank"] = range(1, len(res)+1)

# print(res[res["docid"] == 7264269])

# print(res[res["docid"] == 7264269]["bert"]/res["bert"].max())
# res

Checking for query: where does parrot live
7.461893796920776
31.696105003356934
0.04100227355957031
       docid                                           sentence     bm-25  \
792  7264269  Parrots live in just about all of the tropical...  6.799999   

         bert     final  rank  bm25-rank  
792  0.489199  0.781416   301        306  
Checking for query: where does most of the iron ore come from
7.417783975601196
36.86902379989624
0.04298114776611328
       docid                                           sentence    bm-25  \
872  7264253  Earth's most important iron ore deposits are f...  10.8964   

        bert     final  rank  bm25-rank  
872  0.49968  0.844691   112        124  
Checking for query: where is weston, fl
9.344017744064331
37.93731474876404
0.04096412658691406
       docid         sentence     bm-25      bert     final  rank  bm25-rank
894  7942175  Weston, Florida  6.500095  0.506244  0.740717   143        220
Checking for query: where does most coal form
7.413602

Unnamed: 0,docid,sentence,bm-25,bert,final,rank


In [38]:
# The tokenizer will automatically add any model specific separators (i.e. <CLS> and <SEP>) and tokens to
# the sequence, as well as compute the attention masks.

def getSim(query, doc):
    paraphrase = tokenizer(query, doc, return_tensors="pt")
    paraphrase_classification_logits = model(**paraphrase).logits
    paraphrase_results = torch.softmax(paraphrase_classification_logits, dim=1).tolist()[0]
    
    return paraphrase_results[1]
     
def calcBertScore(query, doc): 
    sentences = list(filter(lambda x: x != "", [d.strip() for d in doc.split(".")]))
    
    sims = [{"score": getSim(query, d), "sentence": d} for d in sentences]
    sims = sorted(sims, reverse = True, key = lambda x: x["score"])[:3]
    n = min(len(sims), 3)
    score = 0
    for i in range(n):
        score = score + w[i]*sims[i]["score"]
    
    return score



In [39]:
def combine3(x):
    sims = sorted(x.to_numpy(), reverse=True)
    n = min(len(sims), 3)
    score = 0
    for i in range(n):
        score = score + w[i]*sims[i]
        
    return score

def calcBertScores(query, docs): 
    n = len(docs)
    batch_size = 4
    n_batches = m.ceil(n/batch_size)
    scores = []
    for i in range(n_batches):
        batch_sentences = [query]*(min((i + 1)*batch_size, n) - i*batch_size)
        batch_of_second_sentences = docs["sentence"].to_numpy().tolist()[i*batch_size:min((i + 1)*batch_size, n)]

        t =time.time()
        
        encoded_inputs = tokenizer(batch_sentences, batch_of_second_sentences, padding=True, return_tensors="pt").to(device)
        classification_logits = model(**encoded_inputs).logits
        results = torch.softmax(classification_logits, dim=1).tolist()
        scores = scores + [x[1] for x in results]
    
    return scores


def expandSentences(query, hits):
    res_dict = {"docid": [], "sentence" : [], "bm-25": []}
    
    for hit in hits:
        docid = int(hit.docid)
        doc = findDoc(docid).split("\t")[1]
        sentences = list(filter(lambda x: x != "", [d.strip() for d in doc.split(".")]))
        for s in sentences:
            res_dict["docid"].append(docid)
            res_dict["sentence"].append(s)
            res_dict["bm-25"].append(hit.score)
    
    res = pd.DataFrame(res_dict)
    return res

    


def bertRankQuery(query, k=20):
    t = time.time()
    hits = searcher.search(query, k=k)
    
    maxBM = 0
    maxBert = 0
    
    res = expandSentences(query, hits)
    print(time.time()-t)
    t = time.time()

    res["bert"]  = calcBertScores(query, res)
    print(time.time()-t)
    t = time.time()
    agr = res.groupby("docid", as_index=False).first()
    agr["bert"] = res.groupby("docid")["bert"].aggregate(combine3).to_numpy()
    res = agr
    
    maxBM = res["bm-25"].max()
    maxBert = res["bert"].max()
    res["final"] = a*res["bm-25"]/maxBM + (1-a)*res["bert"]/maxBert
    res = res.astype({'docid': 'int32'})
    res = res.sort_values("final", ascending = False)
    print(time.time()-t)
    t = time.time()

        
    return res

w = [2/3, 1/6, 1/6]
a = 0.5

query = queries[0]["content"]
res = []
its = 99

fw = open("bert_hard.txt", "w")

for query in queries:
    its = its + 1
    print(its)
    res = bertRankQuery(query["content"], 500)
    print(res)
    exportRes(query, res, fw)
    print(its)
    
fw.close()





100
4.391444444656372


KeyboardInterrupt: 

In [42]:
def exportRes(query, res, f):
    query = query["id"]
    for i in range(len(res)):
        docid = res["docid"].iat[i]
        score = res["final"].iat[i]
        f.write(f"{query} Q0 {docid} {i + 1} {score} Bertserini \n")
                
print(res)

[]


In [None]:
res["bert"] = scores
    
agr = res.groupby("docid").first()
agr["bert"] = res.groupby("docid")["bert"].aggregate(combine3)
agr