# Dense Retriver 실험

In [1]:
import torch
from torch.utils.data import DataLoader

from dense_retrieval import DenseRetrieval

In [2]:
import pickle
import numpy as np
import pandas as pd
from tqdm import tqdm, trange
from typing import List, Tuple, NoReturn, Any, Optional, Union

import torch
import torch.nn.functional as F
from torch.utils.data import (DataLoader, RandomSampler, TensorDataset)

from transformers import AutoTokenizer, BertModel, BertPreTrainedModel, AdamW, TrainingArguments, get_linear_schedule_with_warmup
from datasets import Dataset, load_from_disk, concatenate_datasets

from retrieval import SparseRetrieval, timer

## 1. 데이터 로드

In [3]:
import os
import json 

data_path  = "../data/"
dataset_path = "../data/train_dataset"
context_path = "wikipedia_documents.json"
model_checkpoint = "klue/bert-base"

org_dataset = load_from_disk(dataset_path)
full_ds = concatenate_datasets([
        org_dataset["train"].flatten_indices(),
        org_dataset["validation"].flatten_indices(),
    ])

with open(os.path.join(data_path, context_path), "r", encoding="utf-8") as f:
    wiki = json.load(f)
contexts = list(dict.fromkeys([v["text"] for v in wiki.values()]))

Loading cached processed dataset at ../data/train_dataset/train/cache-fbc57aa6e699fb0c.arrow
Loading cached processed dataset at ../data/train_dataset/validation/cache-d2fba0c42123b1d6.arrow


In [4]:
df_train = pd.DataFrame(org_dataset['train'])
#df_train = df_train[['document_id','title','answers','question','context', 'id','__index_level_0__']]
df_train = df_train[['answers']]
df_train.to_csv('./data/train_answers.csv')

df_valid = pd.DataFrame(org_dataset['validation'])
#df_valid = df_valid[['document_id','title','answers','question','context', 'id','__index_level_0__']]
df_valid = df_valid[['answers']]
df_valid.to_csv('./data/valid_answers.csv')

In [5]:
len(contexts)

56737

[s for s in contexts if org_dataset["validation"][0]['context'][0:5] in s]

## 2. 추론

In [6]:
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint,use_fast=False,)

args = TrainingArguments(
    output_dir="dense_retireval",
    evaluation_strategy="epoch",
    learning_rate=1e-6,
    per_device_train_batch_size=4,
    per_device_eval_batch_size=4,
    num_train_epochs=3,
    weight_decay=0.01,
)

### 학습데이터 생성

In [7]:
dense_retriever = DenseRetrieval(tokenize_fn=tokenizer.tokenize, data_path = data_path, 
                                context_path = context_path, dataset_path=dataset_path, 
                                tokenizer=tokenizer, train_data=org_dataset["train"], is_bm25=True)
## 학습과정 ##
train_dataset = dense_retriever.make_train_data(tokenizer)
# dense_retriever.init_model(model_checkpoint)
# dense_retriever.train(args, train_dataset)

## 추론과정 ##
# dense_retriever.load_model(model_checkpoint, "outputs/p_encoder_3.pt", "outputs/q_encoder_3.pt")
# dense_retriever.get_dense_embedding()
# df = dense_retriever.retrieve(full_ds[0]['question'], topk=3)

Lengths of unique wiki contexts : 56737
Embedding bm25 pickle load.
make_train_data...


In [None]:
with open("./data/dense_embedding3.bin", "rb") as f:
    dense_retriever.dense_p_embedding = pickle.load(f)

import pickle

with open("./data/dense_embedding.bin", "wb") as f:
    pickle.dump(dense_retriever.dense_p_embedding, f)

In [None]:
dense_retriever.dense_p_embedding.shape

## 3. 실험 결과 확인

In [None]:
dense_retriever.get_relevant_doc(org_dataset["train"]["question"][0], k=1)

In [None]:
result_retriever = dense_retriever.retrieve(org_dataset["validation"], topk=5)
result_retriever

In [None]:
for i in range(10):
    df = dense_retriever.retrieve(org_dataset['validation'][i]['question'], topk=3)
    print(df)

In [None]:
topK_list = [1,10,20,50]
result = dense_retriever.topk_experiment(topK_list, org_dataset['validation'])
print(result)