In [1]:
import json
import random
import numpy as np
import pandas as pd
from tqdm.auto import tqdm
from tqdm import trange

from pprint import pprint

from sklearn.feature_extraction.text import TfidfVectorizer

import torch
from torch.utils.data import DataLoader, TensorDataset
import torch.nn.functional as F

from datasets import load_dataset
from transformers import (
    AutoTokenizer,
    BertModel, RobertaModel,
    BertPreTrainedModel,
    AdamW, get_linear_schedule_with_warmup,
    TrainingArguments,
)
from datasets import (
    Dataset,
    load_from_disk,
    concatenate_datasets,
)

from typing import List

In [2]:
# 난수 고정
def set_seed(random_seed):
    torch.manual_seed(random_seed)
    torch.cuda.manual_seed(random_seed)
    torch.cuda.manual_seed_all(random_seed)  # if use multi-GPU
    random.seed(random_seed)
    np.random.seed(random_seed)
    
set_seed(42) # magic number :)

In [3]:
print ("PyTorch version:[%s]."%(torch.__version__))
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print ("device:[%s]."%(device))

PyTorch version:[1.7.1].
device:[cuda:0].


In [5]:
dataset = load_from_disk('/opt/ml/data/train_dataset')
train_dataset = dataset['train']

## Bi-Encoder

In [4]:
class BertEncoder(BertPreTrainedModel):
    def __init__(self, config):
        super(BertEncoder, self).__init__(config)

        self.bert = BertModel(config)
        self.init_weights()
      
    def forward(
            self,
            input_ids, 
            attention_mask=None,
            token_type_ids=None
        ): 

        outputs = self.bert(
            input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids
        )
        
        pooled_output = outputs[1]
        return pooled_output

In [6]:
with open('/opt/ml/data/wikipedia_documents.json', "r", encoding="utf-8") as f:
    wiki = json.load(f)

corpus = list(
    dict.fromkeys([v["text"] for v in wiki.values()])
)  # set 은 매번 순서가 바뀌므로

In [7]:
model_checkpoint = "klue/bert-base"
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)

In [8]:
import pickle
with open('/opt/ml/custom/passage_embedding_special_shuffle_10.bin', 'rb') as file :
    p_embs = pickle.load(file)
p_embs = p_embs

q_encoder = torch.load('/opt/ml/custom/q_encoder_special_shuffle_10.pt')

In [9]:
def get_relavant_doc(queries, q_encoder, p_embs, k=1) :
    with torch.no_grad() :
        q_encoder.eval()
        q_seqs_val = tokenizer(queries, padding='max_length',truncation=True,return_tensors='pt').to(device)
        q_emb = q_encoder(**q_seqs_val).to('cpu')
    dot_prod_scores = torch.mm(q_emb, p_embs.T)
    sort_result = torch.sort(dot_prod_scores, dim=1, descending=True)

    scores, ranks = sort_result[0], sort_result[1]

    result_scores = []
    result_indices = []
    for i in range(len(ranks)) :
        result_scores.append(scores[i].tolist()[:k])
        result_indices.append(ranks[i].tolist()[:k])
    
    return result_scores, result_indices

In [10]:
doc_scores, doc_indices = get_relavant_doc(dataset['validation']['question'], q_encoder, p_embs, k = 500)

In [15]:
# Bi-Encoder Retrieval 정확도 출력
a_total = []
for idx, example in enumerate(
        tqdm(dataset['validation'], desc="Dense retrieval: ")
    ):
        tmp = {
            # Query와 해당 id를 반환합니다.
            "question": example["question"],
            "id": example["id"],
            # Retrieve한 Passage의 id, context를 반환합니다.
            "context_id": doc_indices[idx],
            "context": " ".join(  # 기존에는 ' '.join()
                [corpus[pid] for pid in doc_indices[idx]]
            ),
        }
        if "context" in example.keys() and "answers" in example.keys():
            # validation 데이터를 사용하면 ground_truth context와 answer도 반환합니다.
            tmp["original_context"] = example["context"]
            tmp["answers"] = example["answers"]
        a_total.append(tmp)

b_cqas_50 = pd.DataFrame(a_total)
correct_length = []
for i in range(len(b_cqas_50)) :
    if b_cqas_50['original_context'][i] in b_cqas_50['context'][i] :
        correct_length.append(i)
print(len(correct_length) / len(dataset['validation']))
# 출력결과 0.916666

HBox(children=(FloatProgress(value=0.0, description='Dense retrieval: ', max=240.0, style=ProgressStyle(descri…


0.9166666666666666


## Cross_Encoder

In [12]:
class BertEncoder(BertPreTrainedModel):
    def __init__(self, config):
        super(BertEncoder, self).__init__(config)

        self.bert = BertModel(config)
        self.init_weights()
        classifier_dropout=(
            config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
        )
        self.dropout = torch.nn.Dropout(classifier_dropout)
        self.linear = torch.nn.Linear(config.hidden_size, 1)
      
    def forward(
            self,
            input_ids, 
            attention_mask=None,
            token_type_ids=None
        ): 

        outputs = self.bert(
            input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids
        )
        
        pooled_output = outputs[1]
        pooled_output = self.dropout(pooled_output)
        output = self.linear(pooled_output)
        return output

In [13]:
c_encoder = torch.load('/opt/ml/custom/c_encoder_e5.pt')

In [None]:
# Validation Data의 경우 약 40분 소요
question_data = dataset['validation']['question']
with torch.no_grad() : 
    c_encoder.eval()

    result_scores = []
    result_indices = []
    for i in tqdm(range(len(question_data))) :
        question = question_data[i]
        question_score = []
        for indice in tqdm(doc_indices[i]) :
            passage = corpus[indice]
            tokenized_examples = tokenizer(
                question,
                passage,
                truncation="only_second",
                max_length=512,
                stride=128,
                return_overflowing_tokens=True,
                return_offsets_mapping=True,
                #return_token_type_ids=False,  # roberta모델을 사용할 경우 False, bert를 사용할 경우 True로 표기해야합니다.
                padding="max_length",
                return_tensors='pt'
            )
            score = 0
            for i in range(len(tokenized_examples['input_ids'])) :
                c_input = {
                    'input_ids' : torch.tensor(tokenized_examples['input_ids'][i].unsqueeze(dim=0)).to('cuda'),
                    'attention_mask' : torch.tensor(tokenized_examples['attention_mask'][i].unsqueeze(dim=0)).to('cuda'),
                    'token_type_ids' : torch.tensor(tokenized_examples['token_type_ids'][i].unsqueeze(dim=0)).to('cuda')
                }
                tmp_score = c_encoder(**c_input).to('cpu')
                score += tmp_score
            score = score / len(tokenized_examples['input_ids'])
            question_score.append(score)
        sort_result = torch.sort(torch.tensor(question_score), descending=True)
        scores, index_list = sort_result[0], sort_result[1]

        result_scores.append(scores.tolist())
        result_indices.append(index_list.tolist())        

In [None]:
final_indices = []
for i in range(len(doc_indices)) :
    t_list = [doc_indices[i][result_indices[i][k]] for k in range(50)]
    final_indices.append(t_list)

In [None]:
# Cross Encoder 정확도 출력
total = []
for idx, example in enumerate(
        tqdm(dataset['validation'], desc="Dense retrieval: ")
    ):
        tmp = {
            # Query와 해당 id를 반환합니다.
            "question": example["question"],
            "id": example["id"],
            # Retrieve한 Passage의 id, context를 반환합니다.
            "context_id": final_indices[idx],
            "context": " ".join(  # 기존에는 ' '.join()
                [corpus[pid] for pid in final_indices[idx]]
            ),
        }
        if "context" in example.keys() and "answers" in example.keys():
            # validation 데이터를 사용하면 ground_truth context와 answer도 반환합니다.
            tmp["original_context"] = example["context"]
            tmp["answers"] = example["answers"]
        total.append(tmp)

cqas_50 = pd.DataFrame(total)

In [None]:
# 출력결과 0.8708
correct_length = []
for i in range(len(cqas_50)) :
    if cqas_50['original_context'][i] in cqas_50['context'][i] :
        correct_length.append(i)
print(len(correct_length) / len(dataset['validation']))

In [None]:
cqas_50.to_csv('b16_special_shuffle_e10_t500_ce_t50.csv', index = False)