In [1]:
import sys
sys.path.append("..")
from datasets import load_from_disk
from transformers import AutoTokenizer
import numpy as np
from tqdm import tqdm
import torch
import torch.nn.functional as F
from retrieval_model import BertEncoder,RobertaEncoder
from torch import nn
import pandas as pd
from torch.utils.data import (DataLoader,TensorDataset, SequentialSampler)
import pickle
from utils_mrc import seed_everything

In [2]:
seed_everything(42)

In [3]:
# DPR model load
p_encoder = BertEncoder.from_pretrained("/opt/ml/mrc-level2-nlp-08/retrieval/p_encoder")
q_encoder = BertEncoder.from_pretrained("/opt/ml/mrc-level2-nlp-08/retrieval/q_encoder")
tokenizer =  AutoTokenizer.from_pretrained("kykim/bert-kor-base")
#"kykim/bert-kor-base"
# "monologg/kobigbird-bert-base"

In [4]:
#train_dataset = load_from_disk("/opt/ml/data/train_dataset/train/")
#query_dataset = load_from_disk('/opt/ml/data/test_dataset/validation') # test query
#train_dataset = load_from_disk("/opt/ml/data/train_dataset/new_validation/")
#origin_valid = load_from_disk("/opt/ml/data/train_dataset/validation/")

In [5]:
#wiki_dataset = pd.read_csv('/opt/ml/data/preprocess_wiki_doc.csv')
query_dataset = load_from_disk("/opt/ml/data/train_dataset/validation/")

In [6]:
with open("/opt/ml/data/wiki_id_context_pair.bin", "rb") as f:
    wiki_id_context = pickle.load(f)
# context - doc_id
with open("/opt/ml/data/wiki_context_id_pair.bin", "rb") as f:
    wiki_context_id = pickle.load(f)
with open("/opt/ml/data/wiki_id_title_pair.bin", "rb") as f:
    wiki_id_title = pickle.load(f)

In [7]:
wiki_corpus = list(wiki_context_id.keys())
wiki_title_corpus = []
for i in range(len(wiki_corpus)):
    wiki_title_corpus.append(wiki_id_title[wiki_context_id[wiki_corpus[i]]])

In [8]:
print(len(wiki_corpus))
print(len(wiki_title_corpus))

55963
55963


In [9]:
query = query_dataset['question']

In [10]:
eval_batch_size = 32
def to_cuda(batch):
  return tuple(t.cuda() for t in batch)
if torch.cuda.is_available():
    p_encoder.cuda()
    q_encoder.cuda()

# Construt dataloader
#train_p_seqs = tokenizer(wiki_title_corpus,wiki_corpus, max_length=512, padding="max_length", truncation=True, return_tensors='pt') # add title
train_p_seqs = tokenizer(wiki_corpus, max_length=512, padding="max_length", truncation=True, return_tensors='pt') # no title

valid_dataset = TensorDataset(
    train_p_seqs["input_ids"],
    train_p_seqs["attention_mask"],
    train_p_seqs["token_type_ids"]
)
valid_sampler = SequentialSampler(valid_dataset)
valid_dataloader = DataLoader(
    valid_dataset,
    sampler=valid_sampler,
    batch_size=eval_batch_size
)

# Inference using the passage encoder to get dense embeddeings
p_embs = []

with torch.no_grad():

    epoch_iterator = tqdm(
        valid_dataloader,
        desc="Iteration",
        position=0,
        leave=True
    )
    p_encoder.eval()

    for _, batch in enumerate(epoch_iterator):
        batch = tuple(t.cuda() for t in batch)

        p_inputs = {
            "input_ids": batch[0],
            "attention_mask": batch[1],
            "token_type_ids": batch[2]
        }
        
        outputs = p_encoder(**p_inputs).to("cpu").numpy()
        p_embs.extend(outputs)

torch.cuda.empty_cache()

Iteration: 100%|██████████| 1749/1749 [09:44<00:00,  2.99it/s]


In [17]:
train_q_seqs = tokenizer(
    query,
    max_length=64,
    padding="max_length",
    truncation=True,
    return_tensors="pt"
)

query_dataset = TensorDataset(
    train_q_seqs["input_ids"],
    train_q_seqs["attention_mask"],
    train_q_seqs["token_type_ids"]
)

query_sampler = SequentialSampler(query_dataset)
query_dataloader = DataLoader(
    query_dataset,
    sampler=query_sampler,
    batch_size=eval_batch_size
)

q_embs = []

with torch.no_grad():

    epoch_iterator = tqdm(
        query_dataloader,
        desc="Iteration",
        position=0,
        leave=True
    )
    q_encoder.eval()

    for _, batch in enumerate(epoch_iterator):
        batch = tuple(t.cuda() for t in batch)

        q_inputs = {
            "input_ids": batch[0],
            "attention_mask": batch[1],
            "token_type_ids": batch[2]
        }
        
        outputs = q_encoder(**q_inputs).to("cpu").numpy()
        q_embs.extend(outputs)

torch.cuda.empty_cache()
print('done')

Iteration: 100%|██████████| 8/8 [00:00<00:00, 25.41it/s]

done





In [18]:
p_embs = np.array(p_embs)
q_embs = np.array(q_embs)
print(p_embs.shape)
print(q_embs.shape)
'''
valid-query1 - [{id : score}....]
valid-query1 - [{id : score}....]
valid-query1 - [{id : score}....]

'''

(55963, 768)
(240, 768)


'\nvalid-query1 - [{id : score}....]\nvalid-query1 - [{id : score}....]\nvalid-query1 - [{id : score}....]\n\n'

In [20]:
with open("/opt/ml/data/dense_embedding.bin", "wb") as file:
    pickle.dump(p_embs,file)

In [21]:
if torch.cuda.is_available():
    p_embs_cuda = torch.Tensor(p_embs).to('cuda')
    q_embs_cuda = torch.Tensor(q_embs).to('cuda')

dot_prod_scores = torch.matmul(q_embs_cuda, torch.transpose(p_embs_cuda, 0, 1))
rank = torch.argsort(dot_prod_scores, dim=1, descending=True).squeeze()

In [22]:
dense_p_retrieval_result = {}
idx = 0
for i in tqdm(range(len(query))):
    p_list = []
    q = query[i]
    for j in range(100):
        p_list.append(wiki_context_id[wiki_corpus[rank[idx][j]]])
    dense_p_retrieval_result[q] = p_list
    idx += 1

100%|██████████| 240/240 [00:00<00:00, 481.24it/s]


In [23]:
with open("/opt/ml/data/dense_valid_retrieval.bin", "wb") as file:
    pickle.dump(dense_p_retrieval_result,file)

# dense_n_retrieval_result = {}
# idx = 0
# for i in tqdm(range(len(query))):
#     p_list = []
#     q = query[i]
#     for j in range(10000,10004):
#         p_list.append(wiki_corpus[rank[idx][j]])
#     dense_n_retrieval_result[q] = p_list
#     idx += 1

In [24]:
valid_dataset = load_from_disk("/opt/ml/data/new_train_dataset/validation")
with open('/opt/ml/data/wiki_context_id_pair.bin','rb') as f:
    wiki_context_id = pickle.load(f)
with open('/opt/ml/data/wiki_id_context_pair.bin','rb') as f:
    wiki_id_context = pickle.load(f)
with open('/opt/ml/data/elastic_valid_1000.bin','rb') as f:
    elastic_valid = pickle.load(f)

dense_valid = dense_p_retrieval_result
query = valid_dataset['question']
context = valid_dataset['context']

top_k_list = [1,5,10,15,20,25,30]

for top_k in top_k_list:
    elastic_acc = 0
    dense_acc = 0
    for i in range(len(query)):
        q = query[i]
        ground_truth = context[i]
        dense_top_k = []
        for j in range(top_k):
            dense_top_k.append(wiki_id_context[dense_valid[q][j]])

        elastic_top_k = []
        for j in range(top_k):
            elastic_top_k.append(wiki_id_context[elastic_valid[q][j]])
            
        if ground_truth in elastic_top_k:
            elastic_acc += 1
        if ground_truth in dense_top_k:
            dense_acc += 1

    print('score_top_k : ', top_k)
    print('elastic ACC : ', elastic_acc / len(query))
    print('Dense ACC : ', dense_acc / len(query))
    print()


score_top_k :  1
elastic ACC :  0.7166666666666667
Dense ACC :  0.375

score_top_k :  5
elastic ACC :  0.8625
Dense ACC :  0.6208333333333333

score_top_k :  10
elastic ACC :  0.9125
Dense ACC :  0.6541666666666667

score_top_k :  15
elastic ACC :  0.925
Dense ACC :  0.6958333333333333

score_top_k :  20
elastic ACC :  0.9375
Dense ACC :  0.7041666666666667

score_top_k :  25
elastic ACC :  0.9416666666666667
Dense ACC :  0.725

score_top_k :  30
elastic ACC :  0.95
Dense ACC :  0.7333333333333333



In [32]:
'''
score_top_k :  1
elastic ACC :  0.7166666666666667
Dense ACC :  0.375

score_top_k :  5
elastic ACC :  0.8625
Dense ACC :  0.5875

score_top_k :  10
elastic ACC :  0.9125
Dense ACC :  0.6541666666666667

score_top_k :  15
elastic ACC :  0.925
Dense ACC :  0.7

score_top_k :  20
elastic ACC :  0.9375
Dense ACC :  0.7166666666666667

score_top_k :  25
elastic ACC :  0.9416666666666667
Dense ACC :  0.7375

score_top_k :  30
elastic ACC :  0.95
Dense ACC :  0.7625
'''
'''
score_top_k :  1
elastic ACC :  0.7166666666666667
Dense ACC :  0.375

score_top_k :  5
elastic ACC :  0.8625
Dense ACC :  0.5875

score_top_k :  10
elastic ACC :  0.9125
Dense ACC :  0.6541666666666667

score_top_k :  15
elastic ACC :  0.925
Dense ACC :  0.7

score_top_k :  20
elastic ACC :  0.9375
Dense ACC :  0.7166666666666667

score_top_k :  25
elastic ACC :  0.9416666666666667
Dense ACC :  0.7375

score_top_k :  30
elastic ACC :  0.95
Dense ACC :  0.7625
'''