In [71]:
from datasets import load_from_disk

wiki_datasets = load_from_disk("/opt/ml/data/wiki_preprocessed_droped")
train_dataset = load_from_disk("/opt/ml/data/train_dataset")
generator_dataset = load_from_disk("/opt/ml/data/generator_dataset")
wiki_datasets.load_elasticsearch_index("text", host="localhost", port="9200", es_index_name="wikipedia_contexts")

In [72]:
generator_dataset

Dataset({
    features: ['context', 'question', 'id', 'title', 'document_id', 'answers'],
    num_rows: 35740
})

In [73]:
train_dataset

DatasetDict({
    train: Dataset({
        features: ['title', 'context', 'question', 'id', 'answers', 'document_id', '__index_level_0__', 'chunks'],
        num_rows: 3952
    })
    validation: Dataset({
        features: ['title', 'context', 'question', 'id', 'answers', 'document_id', '__index_level_0__', 'chunks'],
        num_rows: 240
    })
})

In [74]:
train_dataset['train']['answers']

[{'answer_start': [229], 'text': ['하원']},
 {'answer_start': [212], 'text': ['《경영의 실제》']},
 {'answer_start': [501], 'text': ['백성']},
 {'answer_start': [615], 'text': ['중국']},
 {'answer_start': [30], 'text': ['4개']},
 {'answer_start': [91], 'text': ['드래곤']},
 {'answer_start': [68], 'text': ['형양태수 왕식']},
 {'answer_start': [583], 'text': ['이탈리아군']},
 {'answer_start': [194], 'text': ['큰아들 유']},
 {'answer_start': [839], 'text': ['왕대마을']},
 {'answer_start': [808], 'text': ['음독자살']},
 {'answer_start': [411], 'text': ['출장 잦은 건축가']},
 {'answer_start': [82], 'text': ['반신화적인 인물인 우파']},
 {'answer_start': [26], 'text': ['1951년']},
 {'answer_start': [487], 'text': ['예수']},
 {'answer_start': [520], 'text': ["'초일기'"]},
 {'answer_start': [407], 'text': ['1916년']},
 {'answer_start': [13], 'text': ['레드삭스']},
 {'answer_start': [225], 'text': ['삼판동']},
 {'answer_start': [574], 'text': ['다산 정약용']},
 {'answer_start': [166], 'text': ['대나라']},
 {'answer_start': [142], 'text': ['10달러']},
 {'answer_start': [59], 

In [100]:
columns = ['context', 'question', 'id', 'title', 'document_id', 'answers']
train_datadict = {}
for col in columns:
  train_datadict[col] = []

In [101]:
train_datadict

{'context': [],
 'question': [],
 'id': [],
 'title': [],
 'document_id': [],
 'answers': []}

In [102]:
for data in train_dataset['train']:
  for column in columns:
    train_datadict[column].append(data[column])

In [103]:
len(train_datadict['title'])

3952

In [106]:
import random


def ds_data_function(data):
  query = data['question']
  negative_contexts = []
  _, retrieved_examples = wiki_datasets.get_nearest_examples("text", query, k=100)
  for index in range(4):
    if retrieved_examples['document_id'][index] == data['document_id']:
      continue
    negative_contexts.append(retrieved_examples['text'][index])
    if len(negative_contexts) == 2:
      break
  negative_contexts.extend([retrieved_examples['text'][98], retrieved_examples['text'][99]])
  random.shuffle(negative_contexts)
  index = random.randint(0, 4)
  negative_contexts.insert(index, data['context'])
  answer_index = data['answers']['answer_start'][0]
  for negative_index in range(index):
    answer_index += len(negative_contexts[negative_index])
  data['answers']['answer_start'][0] = answer_index + index
  data['context'] = " ".join(negative_contexts)
  return data

In [107]:
new_train_dataset = train_dataset['train'].map(ds_data_function, num_proc=4)

Loading cached processed dataset at /opt/ml/data/train_dataset/train/cache-01b0ec546f29837d.arrow
Loading cached processed dataset at /opt/ml/data/train_dataset/train/cache-f139d18c104a3d1a.arrow
Loading cached processed dataset at /opt/ml/data/train_dataset/train/cache-310cfd6c7bb3eb45.arrow
Loading cached processed dataset at /opt/ml/data/train_dataset/train/cache-ca6b0ca10ab30b2e.arrow


In [108]:
search_error = new_train_dataset.filter(lambda example: example['context'][example['answers']['answer_start'][0]:example['answers']['answer_start'][0]+len(example['answers']['text'][0])] != example['answers']['text'][0])

Loading cached processed dataset at /opt/ml/data/train_dataset/train/cache-ef310b10421ffd5f.arrow


In [109]:
search_error

Dataset({
    features: ['title', 'context', 'question', 'id', 'answers', 'document_id', '__index_level_0__', 'chunks'],
    num_rows: 0
})

In [110]:
for data in new_train_dataset:
  for column in columns:
    train_datadict[column].append(data[column])

In [111]:
from datasets import Dataset, DatasetDict

second_dataset = Dataset.from_dict(train_datadict)

In [112]:
second_dataset

Dataset({
    features: ['context', 'question', 'id', 'title', 'document_id', 'answers'],
    num_rows: 7904
})

In [93]:

data = {
  'train': second_dataset,
  'validation': train_dataset['validation']
}
train = DatasetDict(data)
train.save_to_disk('gen_ds_train_datasets') # 저장위치

In [115]:
wiki_datasets['document_id'].index(0)

0

In [116]:
a = 'asidjfsidjfisdjifsdf'

In [119]:
a.find('s')

1

In [121]:
def change_document_index(example):
  document_index = wiki_datasets['document_id'].index(example['document_id'])
  wiki = wiki_datasets[document_index]
  if wiki['text'].find(example['answers']['text'][0]) == -1:
    raise '에러!'
  example['context'] = wiki['text']
  example['answers']['answer_start'][0] = wiki['text'].find(example['answers']['text'][0])
  return example 
  

In [124]:
def id_to_str(example):
  example['id'] = f"{example['id']}"
  return example

In [125]:
new_gen_data = generator_dataset.map(id_to_str)

  0%|          | 0/35740 [00:00<?, ?ex/s]

In [128]:
len(train_datadict['context'])

7904

In [129]:
for data in new_gen_data:
  for column in columns:
    train_datadict[column].append(data[column])

In [130]:
third_dataset = Dataset.from_dict(train_datadict)

In [131]:
data = {
  'train': third_dataset,
  'validation': train_dataset['validation']
}
train = DatasetDict(data)
train.save_to_disk('gen_ds_train_datasets') # 저장위치

In [132]:
train

DatasetDict({
    train: Dataset({
        features: ['context', 'question', 'id', 'title', 'document_id', 'answers'],
        num_rows: 43644
    })
    validation: Dataset({
        features: ['title', 'context', 'question', 'id', 'answers', 'document_id', '__index_level_0__', 'chunks'],
        num_rows: 240
    })
})