In [2]:
import random

from tqdm import tqdm
from datasets import load_from_disk, DatasetDict

In [3]:
wiki_datasets = load_from_disk("/opt/ml/data/wiki_preprocessed_droped")
train_dataset = load_from_disk("/opt/ml/data/train_dataset")
wiki_datasets.load_elasticsearch_index("text", host="localhost", port="9200", es_index_name="wikipedia_contexts")

In [4]:
dicts = []
for wiki in tqdm(wiki_datasets):
  wiki_dick = {}
  wiki_dick['content'] = wiki['text']
  wiki_dick['meta'] = {
    'title': wiki['title'],
    'document_id': wiki['document_id']
  }
  dicts.append(wiki_dick)

100%|██████████| 55962/55962 [00:09<00:00, 6079.98it/s]


In [5]:
def generate_dpr_dataset(target_dataset, dataset_name):
  dpr_train_datas = []
  def change_score(x):
    x['score'] = 0
    return x
  for data in tqdm(target_dataset):
    train_dict = {}
    train_dict['dataset'] = dataset_name
    train_dict['question'] = data['question']
    train_dict['answers'] = data['answers']['text']
    train_dict['positive_ctxs'] = [{
      'title': data['title'],
      'text': data['context'],
      'score': 1000,
      'title_score': 1,
      'passage_id': data['document_id']
    }]
    negatives = []
    query = data['question']
    scores, retrieved_examples = wiki_datasets.get_nearest_examples("text", query, k=100)
    for index in range(100):
      if retrieved_examples['document_id'][index] == data['document_id']:
        continue
      negative_dict = {
        'title': retrieved_examples['title'][index],
        'text': retrieved_examples['text'][index],
        'score': scores[index],
        'title_score': 0,
        'passage_id': retrieved_examples['document_id'][index]
      }
      negatives.append(negative_dict)
    train_dict['hard_negative_ctxs'] = random.sample(negatives[:15], 5)
    train_dict['negative_ctxs'] = list(map(change_score, random.sample(negatives[50:], 10)))
    dpr_train_datas.append(train_dict)
  return dpr_train_datas

In [6]:
dpr_train_datas = generate_dpr_dataset(train_dataset['train'], 'original_train')
dpr_valid_datas = generate_dpr_dataset(train_dataset['validation'], 'original_valid')

100%|██████████| 3952/3952 [05:01<00:00, 13.09it/s]
100%|██████████| 240/240 [00:18<00:00, 13.11it/s]


In [7]:
import json

with open('train.json', 'w', encoding='UTF-8') as file:
  file.write(json.dumps(dpr_train_datas, ensure_ascii=False))
with open('valid.json', 'w', encoding='UTF-8') as file:
  file.write(json.dumps(dpr_valid_datas, ensure_ascii=False))