In [7]:
from Retrieval.retrieval import SparseRetrieval
from transformers import AutoTokenizer
from datasets import load_from_disk, Dataset
from tqdm import tqdm
import pandas as pd
import re
from clean_dataset import preprocess
import pickle
import os

In [2]:
tokenizer = AutoTokenizer.from_pretrained("kykim/bert-kor-base") # 토크나이저
sparse_retrieval = SparseRetrieval(tokenizer=tokenizer) 

Token indices sequence length is longer than the specified maximum sequence length for this model (969 > 512). Running this sequence through the model will result in indexing errors


In [8]:
data_path="/opt/ml/mrc-level2-nlp-08/Retrieval/"
caching_path = "caching/"
wiki_dataset = pd.read_json("/opt/ml/data/preprocess_wiki.json", orient="index") # 전처리된 위키 데이터
train_dataset = load_from_disk("/opt/ml/data/new_train_dataset/train").to_pandas() # 전처리된 train data
origin_train_dataset = load_from_disk("/opt/ml/data/train_dataset/train").to_pandas() # 전처리 되지 않은 train data

In [10]:
caching_context_id_path = data_path + caching_path + "wiki_context_id_pair.bin"
caching_id_context_path = data_path + caching_path + "wiki_id_context_pair.bin"
if (os.path.isfile(caching_context_id_path) and os.path.isfile(caching_id_context_path)):
    with open(caching_context_id_path, "rb") as f:
        wiki_context_id_dict = pickle.load(f)
    with open(caching_id_context_path, "rb") as f:
        wiki_id_context_dict = pickle.load(f)
else:
    wiki_text = wiki_dataset["text"]
    wiki_id = wiki_dataset["document_id"]
    wiki_context_id_dict = {k: v for k, v in zip(wiki_text, wiki_id)}
    wiki_id_context_dict = {k: v for k, v in zip(wiki_id, wiki_text)}
    with open(caching_context_id_path, "wb") as file:
        pickle.dump(wiki_context_id_dict, file)
    with open(caching_id_context_path, "wb") as file:
        pickle.dump(wiki_id_context_dict, file)


# caching 된 dict가 없을 경우 만들어서 caching을 진행합니다.
# wiki_context_id : key : context, value: wiki id
# wiki_id_context : key : wiki id, value: context
# -> retrieval로 가지고온 id를 context로 변환하는 역할을 합니다.



In [16]:
retrieval_ids,retrieval_scores = sparse_retrieval.get_topk_doc_id_and_score_for_querys(train_dataset['question'].to_list(),top_k=200)

100%|██████████| 3952/3952 [04:53<00:00, 13.48it/s]


In [17]:
new_context = []

for i in tqdm(range(len(train_dataset))):

    train_context = train_dataset['context'][i] # ground truth
    
    query = train_dataset['question'][i] # query
    ctx_wiki_ids = retrieval_ids[query] # sparse_retrieval[query] = doc_ids
    answer = train_dataset['answers'][i]['text'][0]
    
    cnt = 4 # 추가할 갯수
    train_concat_list = [origin_train_dataset['context'][i]]

    pre_ground = preprocess(train_context)

    idx = 0

    while cnt != 0:
        concat_cxt = wiki_id_context_dict[ctx_wiki_ids[idx]] # id 를 cxt로 변환

        if pre_ground != concat_cxt and not (answer in concat_cxt):
            # 비슷한 context를 추가하되 정답을 포함하지 않는 문장을 추가한다.
            train_concat_list.append(concat_cxt)
            cnt -= 1
        idx += 1
        if idx == 200: # index를 넘어가면 break
            break
    add_sim_context = ' '.join(train_concat_list)
    new_context.append(add_sim_context)

100%|██████████| 3952/3952 [00:00<00:00, 22261.29it/s]


In [18]:
save_path = '/opt/ml/data/new_train_dataset/'
save_name = 'train_concat_dataset'
train_df = load_from_disk(save_path + "train").to_pandas()
train_df['context'] = new_context
concat_train_dataset = Dataset.from_pandas(train_df)
concat_train_dataset.save_to_disk(save_path + save_name)


In [19]:
import json
with open(save_path + 'dataset_dict.json') as f:
    dataset_dict = json.load(f)

In [21]:
print(dataset_dict)

{'splits': ['train', 'validation', 'train_concat_no_duplication', 'SEP_train', 'train_pre_es_no_dup_wiki', 'train_ori_es_no_dup_wiki', 'train_concat_es_no_dup_wiki', 'train_concat_es_no_dup_more_wiki']}


In [22]:
dataset_dict['splits'].append(save_name)
print(dataset_dict)

{'splits': ['train', 'validation', 'train_concat_no_duplication', 'SEP_train', 'train_pre_es_no_dup_wiki', 'train_ori_es_no_dup_wiki', 'train_concat_es_no_dup_wiki', 'train_concat_es_no_dup_more_wiki', 'train_concat_dataset']}


In [23]:
with open(save_path + 'dataset_dict.json', "w", encoding="utf-8") as make_file:
    json.dump(dataset_dict, make_file, indent="\t", ensure_ascii=False)