In [None]:
import os
import numpy as np
import json
import pandas as pd
import torch
import pickle

from datasets import load_from_disk
from transformers import AutoTokenizer
from tqdm.notebook import tqdm
from torch.utils.data import DataLoader, TensorDataset, SequentialSampler
from input.code.dpr.trainer_DPR import BiEncoderTrainer # trainer_DPR 모듈 위치에 따라서, from을 수정해주세요
from input.code.dpr.cls_Encoder import BertEncoder, RoBertaEncoder # cls_Encoder 모듈 위치에 따라서, from을 수정해주세요
from input.code.bm25 import make_bm25_embedding # bm25 모듈 위치에 따라서, from을 수정해주세요

In [None]:
BASE_DIR = os.getcwd()
DATA_DIR = os.path.join(BASE_DIR, "input", "data")
P_ENCODER_DIR = os.path.join(BASE_DIR, "input", "code", "dpr", "roberta","p_encoder")
Q_ENCODER_DIR = os.path.join(BASE_DIR, "input", "code", "dpr", "roberta", "q_encoder")
datasets = load_from_disk(os.path.join(DATA_DIR, "train_dataset"))

In [None]:
with open(os.path.join(DATA_DIR, "wikipedia_documents.json"), "r") as f:
	wiki_corpus = json.load(f)
wiki_corpus = list(dict.fromkeys([v['text'] for v in wiki_corpus.values()]))

In [None]:
model_name = "klue/roberta-base"
tokenizer = AutoTokenizer.from_pretrained(model_name)

In [None]:
p_encoder = BertEncoder.from_pretrained(model_name).to("cuda:0")
q_encoder = BertEncoder.from_pretrained(model_name).to("cuda:0")

In [None]:
bitrainer = BiEncoderTrainer(p_encoder=p_encoder,
							 q_encoder=q_encoder,
							 tokenizer=tokenizer,
							 epochs=3,
							 batch_size=30,
							 neg_num=2,
							 lr=5e-5,
							 train_datasets=datasets['train'],
							 eval_datasets=datasets['validation'],
							 contexts_document=wiki_corpus,)

In [None]:
bitrainer.train()

In [None]:
sample_idx = np.random.choice(range(len(datasets['validation'])), 5)

In [None]:
passage_dataset = datasets['validation'][sample_idx]
query = datasets['validation'][sample_idx]['question']

In [None]:
score, pred_rank, pred_corpus = bitrainer.predict(passage_dataset, query)

In [None]:
for i in range(5):
	print("[Query] : ", query[i])
	print("[True Passage] \n", passage_dataset['context'][i])
	for k in range(5):
		print(f"Top-{k+1} Passage")
		print("[Score] : ", score[i][k])
		print("[Predicted Passage] \n", pred_corpus[i][k])
		print("=" * 15)

# BM25 TEST Dataset Rerank with DPR

In [None]:
wiki_corpus = np.array(wiki_corpus)
test_dataset = load_from_disk(os.path.join(DATA_DIR, "test_dataset"))
test_dataset = test_dataset['validation']

In [None]:
if os.path.exists(os.path.join(DATA_DIR, "test_wiki_bm25_embedding.bin")):
	with open(os.path.join(DATA_DIR, "test_wiki_bm25_embedding.bin" ), "rb") as f:
		bm25_embedding = pickle.load(f)
else:
	print("BM25 embedding file not exists")
	bm25_embedding = make_bm25_embedding(DATA_DIR=DATA_DIR, tokenizer=tokenizer, full_ds=test_dataset, context=wiki_corpus)

In [None]:
top_k_idx = np.argsort(bm25_embedding[:, :])[:, ::-1]

In [None]:
q_seqs = tokenizer(test_dataset['question'],
					padding="max_length",
					truncation=True,
				   max_length=50,
					return_tensors="pt").to("cuda:0")
q_test = TensorDataset(q_seqs['input_ids'],
					   q_seqs['attention_mask'],
					   q_seqs['token_type_ids'])
q_sampler = SequentialSampler(q_test)
q_dataloader = DataLoader(q_test, batch_size=1)

In [None]:
query_per_score = []
top_k = 50
with torch.no_grad():
	q_encoder.eval()
	p_encoder.eval()
	epoch_iter = tqdm(test_dataset['question'], desc=f"Top-{top_k} Question Per Wiki Passage")

	for i, q_seqs in enumerate(epoch_iter):
		q_inputs = tokenizer(q_seqs,
							 padding="max_length",
							 truncation=True,
							 max_length=50,
							 return_tensors="pt").to("cuda:0")
		p_inputs = tokenizer(wiki_corpus[top_k_idx[i][:top_k]].tolist(),
							 padding="max_length",
							 truncation=True,
							 max_length=500,
							 return_tensors="pt").to("cuda:0")


		q_outputs = q_encoder(**q_inputs).to('cpu')
		p_outputs = p_encoder(**p_inputs).to('cpu')
		score = torch.matmul(q_outputs, p_outputs.T).tolist()
		query_per_score.append(score)

In [None]:
rerank_corpus = [wiki_corpus[top_k_idx[i, np.argsort(query_per_score[i])[::-1]]][0].tolist() for i in range(len(query_per_score))]
with open(os.path.join(DATA_DIR, "rerank_corpus.bin" ), "wb") as f:
	pickle.dump(rerank_corpus, f)