In [2]:
import sys
sys.path.append('/opt/ml/workspace/baseline')
import os
from tqdm.auto import tqdm
import numpy as np

import torch 
from torch.utils.data import DataLoader, TensorDataset
from transformers import AutoTokenizer, AutoConfig
from datasets import load_from_disk, load_dataset
from DPR_train import BertEncoder

In [3]:
tokenizer = AutoTokenizer.from_pretrained('bert-base-multilingual-cased')

dataset_KLUE = load_from_disk('/opt/ml/input/data/data/train_dataset')
validate_dataset = load_dataset('squad_kor_v1')['validation']# dataset_KLUE['validation']
q_seqs = tokenizer(validate_dataset['question'], 
                    padding='max_length', 
                    truncation=True, 
                    return_tensors='pt')
p_seqs = tokenizer(validate_dataset['context'], 
                    padding='max_length', 
                    truncation=True, 
                    return_tensors='pt')
valid_dataset = TensorDataset(p_seqs['input_ids'], p_seqs['token_type_ids'], p_seqs['attention_mask'], 
                             q_seqs['input_ids'], q_seqs['token_type_ids'], q_seqs['attention_mask'])

Reusing dataset squad_kor_v1 (/opt/ml/.cache/huggingface/datasets/squad_kor_v1/squad_kor_v1/1.0.0/92f88eedc7d67b3f38389e8682eabe68caa450442cc4f7370a27873dbc045fe4)


In [5]:
model_config = AutoConfig.from_pretrained('bert-base-multilingual-cased')
# p_model = BertEncoder.from_pretrained('/opt/ml/models/DPR_models/q_encoder')
p_model = BertEncoder(model_config)
state_dict = torch.load('/opt/ml/models/DPR/p_encoder.bin')
p_model.load_state_dict(state_dict)
# q_model = BertEncoder.from_pretrained('/opt/ml/models/DPR_models/p_encoder')
q_model = BertEncoder(model_config)
state_dict = torch.load('/opt/ml/models/DPR/q_encoder.bin')
q_model.load_state_dict(state_dict)

if torch.cuda.is_available():
    p_model.cuda()
    q_model.cuda()
    print('GPU enabled')

GPU enabled


In [6]:
valid_loader = DataLoader(valid_dataset, 40)

with torch.no_grad():
    # evaluation
    print('let\'s eval')

    p_model.eval()
    q_model.eval()

    p_outputs = []
    q_outputs = []

    for batch in tqdm(valid_loader):
        batch = tuple(t.cuda() for t in batch)

        p_inputs = {'input_ids' : batch[0],
                    'token_type_ids' : batch[1],
                    'attention_mask' : batch[2]
                }

        q_inputs = {'input_ids' : batch[3],
                    'token_type_ids' : batch[4],
                    'attention_mask' : batch[5]
                }

        p_outputs.append(p_model(**p_inputs).cpu().numpy())
        q_outputs.append(q_model(**q_inputs).cpu().numpy())

    len_vector = p_outputs[0].shape[-1]
    tmp_embedding = np.array(p_outputs[:-1]).reshape((-1, len_vector))
    p_outputs = np.concatenate((tmp_embedding, p_outputs[-1]), axis=0)

    len_vector = q_outputs[0].shape[-1]
    tmp_embedding = np.array(q_outputs[:-1]).reshape((-1, len_vector))
    q_outputs = np.concatenate((tmp_embedding, q_outputs[-1]), axis=0)

    sim_scores = np.dot(q_outputs, p_outputs.T)
    sorted_scores = np.argsort(sim_scores, axis=1)

    top_1_score, top_5_score, top_10_score, top_20_score = 0, 0, 0, 0

    for idx in tqdm(range(len(valid_dataset))):
        if idx in sorted_scores[idx][:-2:-1]: top_1_score += 1
        if idx in sorted_scores[idx][:-6:-1]: top_5_score += 1
        if idx in sorted_scores[idx][:-11:-1]: top_10_score += 1
        if idx in sorted_scores[idx][:-21:-1]: top_20_score += 1

    top_1_score, top_5_score, top_10_score, top_20_score = top_1_score / len(valid_dataset), \
                                                            top_5_score / len(valid_dataset), \
                                                            top_10_score / len(valid_dataset), \
                                                            top_20_score / len(valid_dataset) \

    print({'acc/top_1': top_1_score, 'acc/top_5': top_5_score, 
                'acc/top_10': top_10_score, 'acc/top_20': top_20_score},  len(valid_dataset))

let's eval


HBox(children=(FloatProgress(value=0.0, max=145.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=5774.0), HTML(value='')))


{'acc/top_1': 0.05767232421198476, 'acc/top_5': 0.2684447523380672, 'acc/top_10': 0.42570142015933493, 'acc/top_20': 0.5753377208174576} 5774
