In [1]:
import sys
sys.path.append("..")
from datasets import load_from_disk, Dataset
from transformers import AutoTokenizer
import numpy as np
from tqdm import tqdm, trange
import torch
import torch.nn.functional as F
from retrieval_model import BertEncoder
from torch import nn
import pandas as pd
from torch.utils.data import (DataLoader, RandomSampler, TensorDataset, SequentialSampler)
from tqdm import tqdm, trange
import pickle

In [2]:
p_encoder = BertEncoder.from_pretrained("/opt/ml/mrc-level2-nlp-08/retrieval/p_encoder")
q_encoder = BertEncoder.from_pretrained("/opt/ml/mrc-level2-nlp-08/retrieval/q_encoder")
tokenizer =  AutoTokenizer.from_pretrained("klue/bert-base")
if torch.cuda.is_available():
        p_encoder.cuda()
        q_encoder.cuda()

In [None]:
with open('/opt/ml/data/wiki_context_id_pair.bin','rb') as f:
    wiki_context_id = pickle.load(f)
with open('/opt/ml/data/wiki_id_context_pair.bin','rb') as f:
    wiki_id_context = pickle.load(f)

In [48]:
def get_relevant_doc(q_encoder, p_encoder, query, id_context, elastic_score_dict,k=100):
    with torch.no_grad():
        p_encoder.eval()
        q_encoder.eval()

        q_seqs_val = tokenizer([query], padding="max_length", truncation=True, return_tensors='pt').to('cuda')
        #q_emb = q_encoder(**q_seqs_val).to('cpu')  #(num_query, emb_dim)
        q_emb = q_encoder(**q_seqs_val)

        p_embs = []
        # for p in context:
        #     p = tokenizer(p, padding="max_length", truncation=True, return_tensors='pt').to('cuda')
        #     p_emb = p_encoder(**p).to('cpu').numpy()
        #     p_embs.append(p_emb)
        context = list(map(lambda x: x[1],id_context))
        
        p_seqs = tokenizer(context, padding="max_length", truncation=True, return_tensors='pt')
        
        dataset = TensorDataset(
            p_seqs["input_ids"],
            p_seqs["attention_mask"],
            p_seqs["token_type_ids"]
        )
        sampler = SequentialSampler(dataset)
        dataloader = DataLoader(
            dataset,
            sampler=sampler,
            batch_size=100
        )
        for _, batch in enumerate(dataloader):
            batch = tuple(t.cuda() for t in batch)

            p_inputs = {
                "input_ids": batch[0],
                "attention_mask": batch[1],
                "token_type_ids": batch[2]
            }
            
            #outputs = p_encoder(**p_inputs).to("cpu").numpy()
            outputs = p_encoder(**p_inputs)
            #print(outputs.size())
            p_embs.extend(outputs.unsqueeze(0))

    #p_embs = torch.Tensor(p_embs).squeeze()  # (num_passage, emb_dim)
   # p_embs = torch.Tensor(p_embs)
    p_embs = torch.cat(p_embs)
    #print(p_embs.size())
    dot_prod_scores = torch.matmul(q_emb, torch.transpose(p_embs, 0, 1))
    elastic_score = torch.Tensor([elastic_score_dict[query]]).cuda()

    
    #weight_score = weight * dot_prod_scores
    #print(weight_score[0][0])
    temp_score = elastic_score + dot_prod_scores
    # print(dot_prod_scores)
    # print(dot_prod_scores.size())
    # print(temp_score)
    
    #rank = torch.argsort(dot_prod_scores, dim=1, descending=True).squeeze()
    rank = torch.argsort(temp_score, dim=1, descending=True).squeeze()
    #print(rank)

    return_context = []
    for i in range(k):
        return_context.append(id_context[rank[i]])
    
    return dot_prod_scores.squeeze(), return_context #,rank[:k]

In [4]:
with open('/opt/ml/data/elastic_train_id_ctx.bin','rb') as f:
    elastic_valid = pickle.load(f)
with open('/opt/ml/data/elastic_train_id_ctx_score.bin','rb') as f:
    elastic_valid_score = pickle.load(f)

In [50]:
import re

In [51]:
def preprocess(text):
    text = re.sub(r'\n', ' ', text)
    text = re.sub(r'\\n', ' ', text) # remove newline character
    text = re.sub(r'\s+', ' ', text) # remove continuous spaces
    text = re.sub(r'#', ' ', text)

    return text

In [52]:
wiki_dataset = pd.read_json('/opt/ml/data/wikipedia_documents.json',orient='index')

In [53]:
wiki_dataset['text'] = wiki_dataset['text'].apply(preprocess)
wiki_dataset = wiki_dataset.drop_duplicates(['text','title'],ignore_index=True) 

In [6]:
querys = list(elastic_valid.keys())

In [54]:
with open('/opt/ml/data/hybrid_train_retrieval.bin','rb') as f:
    hy = pickle.load(f)

In [55]:
hybrid_train_id_ctx = {}
for i in tqdm(range(len(querys))):
    q = querys[i]
    ctxs = hy[q]
    top_n_id_text = []
    for ctx in ctxs:
        wiki_id = wiki_dataset['document_id'][wiki_dataset['text'] == ctx]
        top_n_id_text.append((wiki_id, ctx))
    hybrid_train_id_ctx[q] = top_n_id_text
    


100%|██████████| 3952/3952 [39:00<00:00,  1.69it/s]


In [None]:
with open('/opt/ml/data/hybrid_train_id_ctx.bin', "wb") as file:
    pickle.dump(hybrid_train_id_ctx, file)