In [None]:
# !pip install openpyxl

# Data Open

In [None]:
import pandas as pd 
import openpyxl 

file_path1 = '특허Text_20230724_1.xlsx'
file_path2 = '특허Text_20230724_2.xlsx'
sheetList=[]

wb1 = openpyxl.load_workbook(file_path1)
wb2 = openpyxl.load_workbook(file_path2)
for i in wb1.get_sheet_names():
    sheetList.append(i)

for i in wb2.get_sheet_names():
    sheetList.append(i)

data1 = pd.ExcelFile(file_path1)
data2 = pd.ExcelFile(file_path2)

# for i in sheetList:
#     df = pd.read_excel(data,i)

df1 = pd.read_excel(data1,sheetList[0])
df2 = pd.read_excel(data2,sheetList[1])

df = df1.append(df2,sort=True,ignore_index=True)

len(df), df.columns
# Index(['특허/실용 구분', '발명의 명칭', '요약', '대표청구항', '출원번호', '출원일', '공개번호', '공개일',
#        '등록번호', '등록일', '출원인', '발명자', '우선권 번호', '우선권 국가', '우선권 주장일',
#        'Original IPC All'],
#       dtype='object')

In [None]:
qrels_data=[]
for id in range(len(df)):
    instance= df.iloc[id,:].to_dict()
    qrels_data.append({'_id':id, 'summary':instance['요약'], 'represent':instance['대표청구항'], 'title': instance['발명의 명칭'],'metadata':{'출원번호': instance['출원번호']}})

len(qrels_data),qrels_data[0]

In [None]:
from collections import defaultdict
import pandas as pd 

q2p = defaultdict(list)
for pair in qrels_data:
    q2p[pair['title']].append(pair['summary'])

len(q2p), pd.DataFrame([ len(v) for k,v in q2p.items()]).describe()

In [None]:
for id in range(len(df)):
    instance= df.iloc[id,:].to_dict()
    if instance['발명의 명칭']=='비콘 신호를 이용하여 도어 출입을 관리하기 위한 방법 및 시스템':
        print(instance['출원번호'])
        print(instance)

In [None]:
import random 
for i in random.sample(qrels_data,10):
    print("### 발명의 명칭:\n",i['title'])
    print("### 요약문:\n",i['summary'])
    print("### 대표청구항:\n",i['represent'])
    print()

# Corpus Set Generation

In [None]:
summary_corpus = []
for id in range(len(df)):
    instance= df.iloc[id,:].to_dict()
    summary_corpus.append({'_id':id, 'text':instance['요약'], 'title': instance['발명의 명칭'],'metadata':{'type':'요약문','출원번호': instance['출원번호']}})
    
len(summary_corpus)

In [None]:
summary_corpus[0]

In [None]:
from collections import defaultdict 
from tqdm import tqdm

all_q2id={}
all_query_dict = {}

summary_p2id={}

summary_corpus_dict = {}
summary_qrels = defaultdict(list)

    
for pair in tqdm(summary_corpus):
    query = pair['title']
    target = pair['text'] if type(pair['text'])!=float else '.'
#     target = pair.get('text','None') # to remove nan
    
    if query in all_q2id:
        qid = all_q2id[query]
    else:
        qid = f"Q{len(all_q2id)}"
        all_q2id[query]=qid
        all_query_dict[qid]=query 

    if target in summary_p2id:
        pid = summary_p2id[target]
    else:
        pid = f"C{len(summary_p2id)}"
        summary_p2id[target]=pid
        summary_corpus_dict[pid]={'_id':pid, 'text':target, 'title': query,'metadata':{'type':'요약문','출원번호':pair['metadata']['출원번호']}}
        

    summary_qrels[qid].append(pid)

len(all_q2id),len(summary_p2id),len(summary_corpus_dict),len(all_query_dict),len(summary_qrels) # (9357, 9524, 9524, 9357, 9357)

In [None]:
import os 
import json 

os.makedirs('new_summary_origin_20000',exist_ok=True)

with open('new_summary_origin_20000/corpus.jsonl','w',encoding='utf-8') as f:
    for pair in list(summary_corpus_dict.values()):
        f.write(json.dumps(pair,indent=4,ensure_ascii=False)+'\n')

# with open('summary_corpus.txt','w',encoding='utf-8') as f:
#     for k,v in summary_corpus_dict.items():
#         f.write(v['text']+'\n' if type(v['text'])!=float else '.\n')

print("done")

In [None]:
list(summary_corpus_dict.values())

In [None]:
# Generate queries
import json 
with open('new_summary_origin_20000/queries.jsonl','w',encoding='utf-8') as f:
    for k,pair in all_query_dict.items():
        f.write(json.dumps({'_id':k, 'text':pair},ensure_ascii=False)+'\n')


In [None]:
with open('new_summary_origin_20000/summary_corpus.tsv','w',encoding='utf-8') as f:
    f.write(f"id\ttitle\ttext\n")
    for k,pair in summary_corpus_dict.items():
        f.write(f"{pair['_id']}\t{pair['text']}\t{pair['title']}\n")


In [None]:
len(summary_qrels)

In [None]:
from sklearn.model_selection import train_test_split
from collections import defaultdict 

train_qrels = defaultdict(list)
dev_qrels = defaultdict(list)
test_qrels = defaultdict(list)

summary_all_pairs = [(qid,pid) for qid, pid_list in summary_qrels.items() for pid in pid_list]

others, test = train_test_split(list(summary_all_pairs), test_size=0.3) # 0.7 / 0.3
train, dev = train_test_split(list(others), test_size=0.2) # 0.7 / 0.3

print(len(summary_all_pairs), len(train),len(dev),len(test)) # 20000 11200 2800 6000

for pair in train: 
    qid, pid = pair
    train_qrels[qid].append(pid)

for pair in dev: 
    qid, pid = pair
    dev_qrels[qid].append(pid)

for pair in test: 
    qid, pid = pair
    test_qrels[qid].append(pid)

print("qrels:", len(train_qrels), len(dev_qrels),len(test_qrels)) # qrels: 10574 2728 5784

In [None]:
import os
os.makedirs('new_summary_origin_20000/qrels',exist_ok=True)


###############
#### Qrels for train, dev, test
###############
with open('new_summary_origin_20000/qrels/train.tsv','w', encoding='utf-8') as f:
    f.write(f"qid\tpid\score\n")
    for qid,pid_list in train_qrels.items():
        for pid in pid_list:
            f.write(f"{qid}\t{pid}\t{1}\n")
            
with open('new_summary_origin_20000/qrels/dev.tsv','w', encoding='utf-8') as f:
    f.write(f"qid\tpid\score\n")
    for qid,pid_list in dev_qrels.items():
        for pid in pid_list:
            f.write(f"{qid}\t{pid}\t{1}\n")
            
with open('new_summary_origin_20000/qrels/test.tsv','w', encoding='utf-8') as f:
    f.write(f"qid\tpid\score\n")
    for qid,pid_list in test_qrels.items():
        for pid in pid_list:
            f.write(f"{qid}\t{pid}\t{1}\n")

with open('new_summary_origin_20000/qrels/qrels.tsv','w', encoding='utf-8') as f:
    f.write(f"qid\tpid\score\n")
    for qid,pid_list in summary_qrels.items():
        for pid in pid_list:
            f.write(f"{qid}\t{pid}\t{1}\n")


# Reformulate LLM Generated Example to queries

In [None]:
!ls generated

In [None]:
import json

with open('generated/generated_text_summary_corpus.json','r',encoding='utf-8') as f:
    generated_query = json.load(f)#[0]
    
len(generated_query),generated_query[0]

In [None]:
generated_query[0].keys()

In [None]:
for pair in generated_query[:10]:
    print("summary :", pair['input'])
    print("generated_text :", pair['generated_text']['content'])
    print()

In [None]:
from collections import defaultdict 
from tqdm import tqdm
import re 

gen_all_q2id={}
gen_query_dict = {}

gen_summary_corpus_dict = {}
gen_summary_qrels = defaultdict(list)


# Post processing generated query / need to adapt 

for idx in tqdm(range(len(generated_query))):
    split_lines = generated_query[idx]['generated_text']['content'].split('\n')
    for position, raw_query in enumerate(split_lines):
        if ( 'generated_text' not in raw_query) and  ( 'summary:' not in raw_query) and ('키워드' not in raw_query) \
            and ('유의어 변경' not in raw_query) and ('# 출력' not in raw_query) and ('입력 문서:' not in raw_query) and ('입력문서:' not in raw_query):
            splitted_query = raw_query.split(' ',1)
            if len(splitted_query)<2 :
                continue

            if not re.findall('[-\d.]\s*',splitted_query[0]):
                # print(raw_query)
                continue
            
            query = splitted_query[1]
            target = generated_query[idx]['input']

        
            if query in gen_all_q2id:
                qid = gen_all_q2id[query]
            else:
                qid = f"Q{len(gen_all_q2id)}"
                gen_all_q2id[query]=qid
                gen_query_dict[qid]=query 

            pid = summary_p2id[target] # from summary_origin

            gen_summary_qrels[qid].append(pid)


len(gen_all_q2id),len(gen_query_dict),len(gen_summary_qrels) 

In [None]:
# Generate queries
import json 
import os 

os.makedirs('generated_summary_ver',exist_ok=True)

with open('generated_summary_ver/queries.jsonl','w',encoding='utf-8') as f:
    for k,pair in gen_query_dict.items():
        f.write(json.dumps({'_id':k, 'text':pair},ensure_ascii=False)+'\n')
        

In [None]:
len(gen_summary_qrels)

In [None]:
from sklearn.model_selection import train_test_split
from collections import defaultdict 

train_qrels = defaultdict(list)
dev_qrels = defaultdict(list)
test_qrels = defaultdict(list)

# all_pairs = [(qid,pid) for qid, pid_list in gen_summary_qrels.items() for pid in pid_list]
all_pairs = [(qid,pid) for qid, pid_list in gen_summary_qrels.items() for pid in pid_list]

others, test = train_test_split(list(all_pairs), test_size=0.2) # 0.7 / 0.3
train, dev = train_test_split(list(others), test_size=0.2) # 0.7 / 0.3

print(len(all_pairs), len(train),len(dev),len(test)) # (12757, 8164, 2041, 2552)

for pair in train: 
    qid, pid = pair
    train_qrels[qid].append(pid)

for pair in dev: 
    qid, pid = pair
    dev_qrels[qid].append(pid)

for pair in test: 
    qid, pid = pair
    test_qrels[qid].append(pid)

# for k,pairs in gen_summary_qrels.items():
print("qrels:", len(train_qrels), len(dev_qrels),len(test_qrels)) # qrels: 8162 2041 2552

In [None]:
# summary_corpus[0]

os.makedirs('generated_summary_ver/qrels',exist_ok=True)

with open('generated_summary_ver/qrels/qrels.tsv','w', encoding='utf-8') as f:
    f.write(f"qid\tpid\score\n")
    for qid,pid_list in gen_summary_qrels.items():
        for pid in pid_list:
            f.write(f"{qid}\t{pid}\t{1}\n")

###############
#### Qrels for train, dev, test
###############
# with open('generated_summary_ver/qrels/qrels/train.tsv','w', encoding='utf-8') as f:
#     f.write(f"qid\tpid\score\n")
#     for qid,pid_list in train_qrels.items():
#         for pid in pid_list:
#             f.write(f"{qid}\t{pid}\t{1}\n")
# with open('generated_summary_ver/qrels/dev.tsv','w', encoding='utf-8') as f:
#     f.write(f"qid\tpid\score\n")
#     for qid,pid_list in dev_qrels.items():
#         for pid in pid_list:
#             f.write(f"{qid}\t{pid}\t{1}\n")
# with open('generated_summary_ver/qrels/test.tsv','w', encoding='utf-8') as f:
#     f.write(f"qid\tpid\score\n")
#     for qid,pid_list in test_qrels.items():
#         for pid in pid_list:
#             f.write(f"{qid}\t{pid}\t{1}\n")

# with open('generated_summary_ver/qrels/qrels.tsv','w', encoding='utf-8') as f:
#     f.write(f"qid\tpid\score\n")
#     for qid,pid_list in gen_summary_qrels.items():
#         for pid in pid_list:
#             f.write(f"{qid}\t{pid}\t{1}\n")


In [None]:
!cp new_summary_origin_20000/corpus.jsonl generated_summary_ver

# Reformulate data for Contriever format

In [None]:
import json 
with open('generated_summary_ver/queries.jsonl','r',encoding='utf-8') as f:
    query_summary_llm_gen = [json.loads(l) for l in f]
    
with open('new_summary_origin_20000/queries.jsonl','r',encoding='utf-8') as f:
#     query_summary_origin = [json.loads(l) for l in f]
    query_summary_origin = {json.loads(l)['text']:json.loads(l) for l in f}

len(query_summary_llm_gen),len(query_summary_origin)

In [None]:
query_summary_llm_gen[0]

In [None]:

with open('generated_summary_ver/corpus.jsonl','r',encoding='utf-8') as f:
#     corpus_summary= [json.loads(l) for l in f]
    corpus_summary= {json.loads(l)['_id']:json.loads(l) for l in f}
    
len(corpus_summary)#,corpus_summary[0]

In [None]:
from collections import defaultdict

qrels_summary_llm_gen=defaultdict(list)
train_qrels_summary_llm_gen=defaultdict(list)
dev_qrels_summary_llm_gen=defaultdict(list)
test_qrels_summary_llm_gen=defaultdict(list)

with open('generated_summary_ver/qrels/qrels.tsv','r') as f:
    f.readline()
    for l in f:
        qid,pid,score = l.strip().split('\t')
        qrels_summary_llm_gen[qid].append(pid)

tr_qrels_summary_origin=defaultdict(list)
with open('new_summary_origin_20000/qrels/train.tsv','r') as f:
    f.readline()
    for l in f:
        qid,pid,score = l.strip().split('\t')
        tr_qrels_summary_origin[qid].append(pid)

len(tr_qrels_summary_origin)

dev_qrels_summary_origin=defaultdict(list)
with open('new_summary_origin_20000/qrels/dev.tsv','r') as f:
    f.readline()
    for l in f:
        qid,pid,score = l.strip().split('\t')
        dev_qrels_summary_origin[qid].append(pid)

len(dev_qrels_summary_origin)


qrels_summary_llm_gen=defaultdict(list)
train_qrels_summary_llm_gen=defaultdict(list)
dev_qrels_summary_llm_gen=defaultdict(list)
test_qrels_summary_llm_gen=defaultdict(list)


with open('generated_summary_ver/qrels/qrels.tsv','r') as f:
    f.readline()
    for l in f:
        qid,pid,score = l.strip().split('\t')
        qrels_summary_llm_gen[qid].append(pid)

#################
### Train, dev, test split
#################
with open('generated_summary_ver/qrels/train.tsv','r') as f:
    f.readline()
    for l in f:
        qid,pid,score = l.strip().split('\t')
        train_qrels_summary_llm_gen[qid].append(pid)

with open('generated_summary_ver/qrels/dev.tsv','r') as f:
    f.readline()
    for l in f:
        qid,pid,score = l.strip().split('\t')
        dev_qrels_summary_llm_gen[qid].append(pid)

with open('generated_summary_ver/qrels/test.tsv','r') as f:
    f.readline()
    for l in f:
        qid,pid,score = l.strip().split('\t')
        test_qrels_summary_llm_gen[qid].append(pid)

qrels_summary_origin=defaultdict(list)
with open('new_summary_origin_20000/qrels/qrels.tsv','r') as f:
    f.readline()
    for l in f:
        qid,pid,score = l.strip().split('\t')
        qrels_summary_origin[qid].append(pid)

len(train_qrels_summary_llm_gen),len(dev_qrels_summary_llm_gen),len(test_qrels_summary_llm_gen),len(qrels_summary_origin) 

{"question":"What is the most popular operating system?","positive_ctxs":[{"text": "Windows is the most popular operating system."}],"negative_ctxs":[{"text": "Windows is the most popular programming language."}],"hard_negative_ctxs":[{"text": "Windows is the most popular game console."}],"title":"Windows","text":"Windows is the most popular operating system."}

In [None]:
tr_data_for_contriever_finetune_summary_origin=[]
dev_data_for_contriever_finetune_summary_origin=[]

for k,v in tr_qrels_summary_origin.items():
    query_info = query_summary_origin[int(k[1:])]
    pos_info = [ {'title': corpus_summary[int(pid[1:])]['title'] , 'text': corpus_summary[int(pid[1:])]['text'] } for pid in v] 

    tr_data_for_contriever_finetune_summary_origin.append({
        "question": query_info['text'],
        "positive_ctxs": pos_info,
        # "negative_ctxs": [],
        # "hard_negative_ctxs": [] 
    })

    

for k,v in dev_qrels_summary_origin.items():
    query_info = query_summary_origin[int(k[1:])]
    pos_info = [ {'title': corpus_summary[int(pid[1:])]['title'] , 'text': corpus_summary[int(pid[1:])]['text'] } for pid in v] 

    dev_data_for_contriever_finetune_summary_origin.append({
        "question": query_info['text'],
        "positive_ctxs": pos_info,
        # "negative_ctxs": [],
        # "hard_negative_ctxs": [] 
    })



len(tr_data_for_contriever_finetune_summary_origin),len(dev_data_for_contriever_finetune_summary_origin)

In [None]:
with open('new_summary_origin_20000/train.data.for_contriever.jsonl','w',encoding='utf-8') as f:
    for pair in tr_data_for_contriever_finetune_summary_origin:
        f.write(json.dumps(pair, ensure_ascii=False)+'\n')
        
with open('new_summary_origin_20000/dev.data.for_contriever.jsonl','w',encoding='utf-8') as f:
    for pair in dev_data_for_contriever_finetune_summary_origin:
        f.write(json.dumps(pair, ensure_ascii=False)+'\n')


# Making data for new summary origin 20000

In [None]:
import json 
    
with open('new_summary_origin_20000/queries.jsonl','r',encoding='utf-8') as f:
#     query_summary_origin = [json.loads(l) for l in f]
    query_summary_origin = {json.loads(l)['text']:json.loads(l) for l in f}

print(len(query_summary_llm_gen),len(query_summary_origin))

with open('new_summary_origin_20000/corpus.jsonl','r',encoding='utf-8') as f:
#     corpus_summary= [json.loads(l) for l in f]
    corpus_summary= {json.loads(l)['_id']:json.loads(l) for l in f}
    
print(len(corpus_summary)) #,corpus_summary[0]

In [None]:
with open('new_summary_origin_20000/train.w_negative.data.for_contriever.jsonl','r',encoding='utf-8') as f:
#     origin_contriever_data=[json.loads(l) for l in f]
    origin_contriever_data={json.loads(l)['question']:json.loads(l) for l in f}

len(origin_contriever_data)#,origin_contriever_data[0].keys()

In [None]:
with open('new_summary_origin_20000/train.data.for_contriever.jsonl','r',encoding='utf-8') as f:
    tr_data_for_contriever_finetune_summary_origin = [json.loads(l) for l in f]
        
with open('new_summary_origin_20000/dev.data.for_contriever.jsonl','r',encoding='utf-8') as f:
    dev_data_for_contriever_finetune_summary_origin = [json.loads(l) for l in f]

len(tr_data_for_contriever_finetune_summary_origin),len(dev_data_for_contriever_finetune_summary_origin)

In [None]:
for pair in tr_data_for_contriever_finetune_summary_origin:
    query = pair['question']
    info = origin_contriever_data[query]
    pair['negative_ctxs']= info['negative_ctxs']

for pair in dev_data_for_contriever_finetune_summary_origin:
    query = pair['question']
    info = origin_contriever_data[query]
    pair['negative_ctxs']= info['negative_ctxs']

len(tr_data_for_contriever_finetune_summary_origin),len(dev_data_for_contriever_finetune_summary_origin)

In [None]:
with open('new_summary_origin_20000/train.w_negative.data.for_contriever.jsonl','w',encoding='utf-8') as f:
    for pair in tr_data_for_contriever_finetune_summary_origin:
        f.write(json.dumps(pair, ensure_ascii=False)+'\n')
        
with open('new_summary_origin_20000/dev.w_negative.data.for_contriever.jsonl','w',encoding='utf-8') as f:
    for pair in dev_data_for_contriever_finetune_summary_origin:
        f.write(json.dumps(pair, ensure_ascii=False)+'\n')

# Hard negative mining

In [None]:
!ls processed/corpus2.size20000.summary_llm_gen.1012_ver1.4 

In [None]:
import json
with open('processed/corpus2.size20000.summary_llm_gen.1012_ver1.4/test.filtered.data.for_contriever.jsonl','r') as f:
    test= [json.loads(l) for l in f]
len(test),test[0]

In [None]:
for i in test:
    print("Question:",i['question'])
#     print("Target:",i['positive_ctxs'][0])
    print()

In [None]:
import pprint 
import random 
from pyserini.search.lucene import LuceneSearcher

from beir.datasets.data_loader import GenericDataLoader
from beir.retrieval.evaluation import EvaluateRetrieval
from beir.retrieval.search.dense import DenseRetrievalExactSearch

from collections import defaultdict

import torch
import numpy as np
import json
import random
import faiss
from tqdm import tqdm
import pickle
from datasets import load_dataset
from collections import defaultdict
import time
# from easydict import EasyDict
# from transformers import AutoTokenizer, AutoModel, AutoModelForSequenceClassification

import glob
import sys 
if not '/mnt/sda/hanseok/projects/nable_kaist/baselines/contriever' in sys.path:
    sys.path.append('/mnt/sda/hanseok/projects/nable_kaist/baselines/contriever')
    print(sys.path)
import os 

# from contriever.src import contriever as contriever


# Evaluator
retriever = EvaluateRetrieval(None, score_function=None)
metrics = defaultdict(list)  # store final results
print("### Custom data mode")
# data_path = 'processed/corpus2.size20000.summary_llm_gen.1012_ver1.4'
# data_path = 'processed/summary_origin2'
# data_path = 'processed/corpus2.subset.summary_llm_gen.1021_ver1.5'
data_path = 'processed/new_summary_origin_20000/'

corpus, queries, qrels = GenericDataLoader(data_folder=data_path,qrels_file = os.path.join(data_path,'qrels/train.tsv')).load_custom()
# corpus, queries, qrels = GenericDataLoader(data_folder=data_path,qrels_file = os.path.join(data_path,'qrels/qrels.tsv')).load_custom()

# BM25
bm25_searcher = LuceneSearcher('../baselines/bm25/indexes/summary_origin2_ko')
bm25_searcher.set_language('ko')


In [None]:
from transformers import AutoTokenizer, AutoModelForSequenceClassification

tokenizer = AutoTokenizer.from_pretrained("amberoad/bert-multilingual-passage-reranking-msmarco")
reranker = AutoModelForSequenceClassification.from_pretrained("amberoad/bert-multilingual-passage-reranking-msmarco")

In [None]:
# with open('processed/corpus2.subset.summary_llm_gen.1021_ver1.5/data.for_contriever.jsonl','r',encoding='utf-8') as f:
# with open('processed/corpus2.size20000.summary_llm_gen.1012_ver1.4/train.data.for_contriever.jsonl','r',encoding='utf-8') as f:
with open('processed/summary_origin2/train.data.for_contriever.jsonl','r',encoding='utf-8') as f:
    contriever_data=[json.loads(l) for l in f]

len(contriever_data),contriever_data[0].keys()

# data_for_contriever_finetune_summary_origin.append({
#         "question": query_info['text'],
#         "positive_ctxs": pos_info,
#         # "negative_ctxs": [],
#         # "hard_negative_ctxs": [] 
#     })



In [None]:
import random
for i in random.sample(contriever_data,10):
    print(i)

### train dataset filtering

In [None]:
import pprint 
import random 
from pyserini.search.lucene import LuceneSearcher
from tqdm import tqdm 
import torch 

train_filtered_qrels = defaultdict(dict)
rank_info=[]

reranker.cuda()
reranker.eval()

cnt=0
for idx,(qid, pid_dict) in tqdm(enumerate(qrels.items())): # train_qrels mode 
    query = queries[qid]
    gold_pid = list(pid_dict.keys())[0]
    p_info = corpus[gold_pid]
    
    reranker_inputs= tokenizer(query,p_info['text'],return_tensors='pt',max_length=512).to('cuda')
    logits = reranker(**reranker_inputs).logits
    true_prob = torch.softmax(logits, dim=1).tolist()[0][1]
    if true_prob < 0.9: 
        # print("pass")
        cnt+=1
        continue 
    
    ######################
    #### BM25 
    ######################
    hits = bm25_searcher.search(query)
    
    negative_passages = [] 
    flag=True 
    for i in range(len(hits)):
        # print(f'{i+1:2} {hits[i].docid:4} {hits[i].score:.5f}')
        pred_docid = hits[i].docid
        p_info = corpus[pred_docid]
        
        # if (pred_docid !=gold_pid):
        #     flag=False

        if (pred_docid !=gold_pid):# and (i==0):
            rank_info.append(i+1)
            flag=False 
            # reranker_inputs= tokenizer(query,p_info['text'],return_tensors='pt',max_length=512).to('cuda')
            # logits = reranker(**reranker_inputs).logits
            # true_prob = torch.softmax(logits, dim=1).tolist()[0][1]
            # # if true_prob < 0.9:
            # if true_prob <=1.0: # use all annotated negative
            #     p_info.update({'ce_score':true_prob})
#             negative_passages.append(p_info)
            
    # test_filtered_qrels[qid]={gold_pid:1}
    if flag:
#         filtered_contriever_data.append(contriever_data[idx])
        train_filtered_qrels[qid]={gold_pid:1}
        
#     contriever_data[idx]['negative_ctxs']= negative_passages

len(train_filtered_qrels), cnt, len(rank_info)
# len(contriever_data), cnt

In [None]:
contriever_data[0]

In [None]:
filtered_query_list = [queries[qid] for qid, pid_dict in train_filtered_qrels.items() if pid_dict]
len(filtered_query_list),len(set(filtered_query_list))

# filtered_test_contriever_data = []
# for pair in test_contriever_data:
#     if pair['question'] in filtered_query_list:
#         filtered_test_contriever_data.append(pair)

# len(filtered_test_contriever_data),filtered_test_contriever_data[0]

In [None]:
# with open('processed/corpus2.size20000.summary_llm_gen.1012_ver1.4/train.w_negative.data.for_contriever.jsonl','w',encoding='utf-8') as f:
with open('processed/summary_origin2/train.w_negative.data.for_contriever.jsonl','w',encoding='utf-8') as f:
    for l in contriever_data:
        f.write(json.dumps(l,ensure_ascii=False)+'\n')

### test set filtering

In [None]:
# Evaluator
retriever = EvaluateRetrieval(None, score_function=None)
metrics = defaultdict(list)  # store final results
print("### Custom data mode")
# data_path = 'processed/corpus2.size20000.summary_llm_gen.1012_ver1.4'
# data_path = 'processed/corpus2.subset.summary_llm_gen.1021_ver1.5/'
data_path = 'processed/summary_origin2'

# corpus, queries, qrels = GenericDataLoader(data_folder=data_path,qrels_file = os.path.join(data_path,'qrels/test.tsv')).load_custom()
corpus, queries, qrels = GenericDataLoader(data_folder=data_path,qrels_file = os.path.join(data_path,'qrels/qrels.tsv')).load_custom()

# BM25
bm25_searcher = LuceneSearcher('../baselines/bm25/indexes/summary_origin2_ko')
bm25_searcher.set_language('ko')


In [None]:
# with open('processed/corpus2.size20000.summary_llm_gen.1012_ver1.4/test.data.for_contriever.jsonl','r',encoding='utf-8') as f:
with open('processed/summary_origin2/train.w_negative.data.for_contriever.jsonl','r',encoding='utf-8') as f:
    test_contriever_data=[json.loads(l) for l in f]

len(test_contriever_data),test_contriever_data[0].keys()

# data_for_contriever_finetune_summary_origin.append({
#         "question": query_info['text'],
#         "positive_ctxs": pos_info,
#         # "negative_ctxs": [],
#         # "hard_negative_ctxs": [] 
#     })



In [None]:
len(qrels)

In [None]:
import pprint 
import random 
from pyserini.search.lucene import LuceneSearcher
from tqdm import tqdm 
import torch 

# for qid, pid_dict in random.sample(qrels.items(),10):

test_filtered_qrels = defaultdict(list)
high_lexical_test_filtered_qrels = defaultdict(list)

reranker.cuda()
reranker.eval()
cnt=0
rank_info=[]

for idx,(qid, pid_dict) in tqdm(enumerate(qrels.items())): # train_qrels mode 
    query = queries[qid]
    # pprint.pprint(f"# Q_info: {qid} / {query}")

    gold_pid = list(pid_dict.keys())[0]
    p_info = corpus[gold_pid]
    # print(f"# Gold Passage_info: {gold_pid} / \n",p_info)

    reranker_inputs= tokenizer(query,p_info['text'],return_tensors='pt',max_length=512).to('cuda')
    logits = reranker(**reranker_inputs).logits
    true_prob = torch.softmax(logits, dim=1).tolist()[0][1]
#     if true_prob < 0.9: 
#         cnt+=1
#         continue 
    
    ######################
    #### BM25 
    ######################
    hits = bm25_searcher.search(query)
    
    # print("# Prediction - BM25")
    # negative_passages = [] 
    flag=True
    for i in range(len(hits)):
        # print(f'{i+1:2} {hits[i].docid:4} {hits[i].score:.5f}')
        pred_docid = hits[i].docid
        p_info = corpus[pred_docid]
        # pprint.pprint(f"pred - {i}:\n{p_info}")

        if (pred_docid ==gold_pid):
            flag=False
            rank_info.append(i+1)
            
            high_lexical_test_filtered_qrels[qid]={gold_pid:1}
            break 

        if (pred_docid !=gold_pid) and (i==0):
            reranker_inputs= tokenizer(query,p_info['text'],return_tensors='pt',max_length=512)
            logits = reranker(**reranker_inputs).logits
            true_prob = torch.softmax(logits, dim=1).tolist()[0][1]
            if true_prob < 0.9:
                negative_passages.append(p_info)

        #         test_filtered_qrels[qid].append({gold_pid:1})
                break
        else:
            break 

            # print("negative psgs:",negative_passages)
    contriever_data[idx]['hard_negative_ctxs']= negative_passages

len(test_filtered_qrels),len(high_lexical_test_filtered_qrels), cnt, len(rank_info)

In [None]:
import pprint 
import random 
from pyserini.search.lucene import LuceneSearcher
from tqdm import tqdm 
import torch 

# for qid, pid_dict in random.sample(qrels.items(),10):

test_filtered_qrels = defaultdict(list)
high_lexical_test_filtered_qrels = defaultdict(list)

reranker.cuda()
reranker.eval()
cnt=0
rank_info=[]

for idx,(qid, pid_dict) in tqdm(enumerate(qrels.items())): # train_qrels mode 
    query = queries[qid]
    # pprint.pprint(f"# Q_info: {qid} / {query}")

    gold_pid = list(pid_dict.keys())[0]
    p_info = corpus[gold_pid]
    # print(f"# Gold Passage_info: {gold_pid} / \n",p_info)

    reranker_inputs= tokenizer(query,p_info['text'],return_tensors='pt',max_length=512).to('cuda')
    logits = reranker(**reranker_inputs).logits
    true_prob = torch.softmax(logits, dim=1).tolist()[0][1]
    if true_prob < 0.9: 
#         print("pass")
        cnt+=1
        continue 
    
    ######################
    #### BM25 
    ######################
    hits = bm25_searcher.search(query)
    
    # print("# Prediction - BM25")
    # negative_passages = [] 
    flag=True
    for i in range(len(hits)):
        # print(f'{i+1:2} {hits[i].docid:4} {hits[i].score:.5f}')
        pred_docid = hits[i].docid
        p_info = corpus[pred_docid]
        # pprint.pprint(f"pred - {i}:\n{p_info}")

        if (pred_docid ==gold_pid):
            flag=False
            rank_info.append(i+1)
            
            high_lexical_test_filtered_qrels[qid]={gold_pid:1}
            break 

        # if (pred_docid !=gold_pid) and (i==0):
        #     reranker_inputs= tokenizer(query,p_info['text'],return_tensors='pt',max_length=512)
        #     logits = reranker(**reranker_inputs).logits
        #     true_prob = torch.softmax(logits, dim=1).tolist()[0][1]
        #     if true_prob < 0.9:
        #         # negative_passages.append(p_info)

        #         test_filtered_qrels[qid].append({gold_pid:1})
        #         break
        # else:
        #     break 

            # print("negative psgs:",negative_passages)
    if flag:
        test_filtered_qrels[qid]={gold_pid:1}

    # contriever_data[idx]['hard_negative_ctxs']= negative_passages

len(test_filtered_qrels),len(high_lexical_test_filtered_qrels), cnt, len(rank_info)

In [None]:
import random
for qid, pid_dict in random.sample(qrels.items(),10):
# for qid, pid_dict in qrels.items():
#     if qid in test_filtered_qrels:
    if not (qid in test_filtered_qrels):
        query = queries[qid]
        pprint.pprint(f"# Q_info: {qid} / {query}")
    
        gold_pid = list(pid_dict.keys())[0]
        p_info = corpus[gold_pid]
        print(f"# Gold Passage_info: {gold_pid} \n",p_info)
        hits = bm25_searcher.search(query,5)

        print("# Prediction - BM25")
        for i in range(len(hits)):
            print(f'{i+1:2} {hits[i].docid:4} {hits[i].score:.5f}')
            pred_docid = hits[i].docid
            p_info = corpus[pred_docid]
            print(f"pred - {i}:\n{p_info}")
        print()
#         input()


In [None]:
from collections import Counter 
import pandas as pd 
pd.DataFrame(rank_info).describe()

In [None]:
# sorted_dict= sorted(test_filtered_qrels.items(), key = lambda item: int(item[0][1:]),reverse=False)
sorted_dict= sorted(high_lexical_test_filtered_qrels.items(), key = lambda item: int(item[0][1:]),reverse=False)

In [None]:
type(sorted_dict),sorted_dict[1],len(sorted_dict)

In [None]:
queries['Q42'],corpus['C41']

In [None]:
# with open('../baselines/bm25/data/query/filtered.query_llm.query_llm.1012_ver1.4.tsv','w') as f:
# with open('../baselines/bm25/data/query/high_lexical.filtered.query_llm.query_llm.1012_ver1.4.tsv','w') as f:
with open('../baselines/bm25/data/query/high_lexical.filtered.summary_origin2.tsv','w') as f:
# with open('../baselines/bm25/data/query/high_semantic.filtered.summary_origin2.tsv','w') as f:
    for pair in sorted_dict:
        k,v = pair
        if not v:
            continue
        f.write(f"{k}\t{queries[k]}\n")

In [None]:
# with open('processed/corpus2.size20000.summary_llm_gen.1012_ver1.4/qrels/filtered.test.tsv','w') as f:
# with open('processed/corpus2.size20000.summary_llm_gen.1012_ver1.4/qrels/high_lexical.filtered.test.tsv','w') as f:
# with open('processed/summary_origin2/qrels/high_semantic.filtered.qrels.tsv','w') as f:
with open('processed/summary_origin2/qrels/high_lexical.filtered.qrels.tsv','w') as f:
    f.write("qid\tpid\tscore\n")
    # for k,v in test_filtered_qrels.items():
    for pair in sorted_dict:
        k,v = pair
        if not v:
            continue
        f.write(f"{k}\t{list(v.keys())[0]}\t{1}\n")

In [None]:
test_contriever_data[0].keys()

In [None]:
# filtered_query_list = [queries[qid] for qid, pid_dict in test_filtered_qrels.items() if pid_dict]
# len(filtered_query_list),len(set(filtered_query_list))

high_lexical_filtered_query_list = [queries[qid] for qid, pid_dict in high_lexical_test_filtered_qrels.items() if pid_dict]
len(high_lexical_filtered_query_list),len(set(high_lexical_filtered_query_list))

In [None]:
# filtered_test_contriever_data = []
high_lexical_filtered_test_contriever_data = []
for pair in test_contriever_data:
#     if pair['question'] in filtered_query_list:
    if pair['question'] in high_lexical_filtered_query_list:
#         filtered_test_contriever_data.append(pair)
        high_lexical_filtered_test_contriever_data.append(pair)

# len(filtered_test_contriever_data),filtered_test_contriever_data[0]
len(high_lexical_filtered_test_contriever_data),high_lexical_filtered_test_contriever_data[0]

In [None]:
with open('processed/summary_origin2/lexical.filtered.data.for_contriever.jsonl','w') as f:
# with open('processed/summary_origin2/semantic.filtered.data.for_contriever.jsonl','w') as f:
# with open('processed/corpus2.size20000.summary_llm_gen.1012_ver1.4/test.filtered.data.for_contriever.jsonl','w') as f:
# with open('processed/corpus2.size20000.summary_llm_gen.1012_ver1.4/test.high_lexical.filtered.data.for_contriever.jsonl','w') as f:
#     for l in filtered_test_contriever_data:
    for l in high_lexical_filtered_test_contriever_data:
        f.write(json.dumps(l, ensure_ascii=False)+'\n')

In [None]:
import pandas as pd
pid2qid = defaultdict(list)
# for qid, pid_dict in test_filtered_qrels.items():
for pair in sorted_dict:
    qid,pid_dict = pair
    if not pid_dict:
        continue
    query = queries[qid]
    gold_pid = list(pid_dict.keys())[0]
    pid2qid[gold_pid].append(qid)

len(pid2qid), pd.DataFrame([len(v) for k,v in pid2qid.items()]).describe()

In [None]:
# for k,v in pid2qid.items():
for qid, pid_dict in test_filtered_qrels.items():
# for pair in sorted_dict:
    # qid,pid_dict = pair
    
    if not pid_dict:
        continue

    query = queries[qid]
    pprint.pprint(f"# Q_info: {qid} / {query}")
    
    gold_pid = list(pid_dict.keys())[0]    
    p_info = corpus[gold_pid]
    pprint.pprint(f"# Gold Passage_info: {p_info} ")
    
    # if len(v)>=5:
    # print()
    # for qid in v:
    #     query = queries[qid]
    #     pprint.pprint(f"# Q_info: {qid} / {query}")
    
    metrics = defaultdict(list)  # store final results
    temp_result = defaultdict(dict)

    hits = bm25_searcher.search(query)
    
    # print("# Prediction - BM25")
    for i in range(len(hits)):
        # print(f'{i+1:2} {hits[i].docid:4} {hits[i].score:.5f}')
        pred_docid = hits[i].docid

        temp_result[qid][pred_docid]=float(hits[i].score)
        
        p_info = corpus[pred_docid]
        if i==0:
            pprint.pprint(f"pred - {i}:\n{p_info}")

        # reranker_inputs= tokenizer(query,p_info['text'],return_tensors='pt',max_length=512)
        # logits = reranker(**reranker_inputs).logits
        # true_prob = torch.softmax(logits, dim=1).tolist()[0][1]
        
    # ndcg, _map, recall, precision = retriever.evaluate(qrels, temp_result, [1,10,100])#retriever.k_values)
    ndcg, _map, recall, precision = retriever.evaluate(qrels, temp_result, [1,10])#retriever.k_values)
    for metric in (ndcg, _map, recall, precision, "mrr", "recall_cap", "hole"):
        if isinstance(metric, str):
            metric = retriever.evaluate_custom(test_filtered_qrels, temp_result, retriever.k_values, metric=metric)
        for key, value in metric.items():
            metrics[key].append(value)

    pprint.pprint(f"NDCG@10:{metrics['NDCG@10']}")
    print("\n\n")
    
