In [71]:
import json
import re
import string
import pandas as pd

from datasets import load_from_disk, Dataset, DatasetDict

train_dataset = load_from_disk("./data/train_dataset/")
train_df = pd.DataFrame(train_dataset['train'])
dev_df = pd.DataFrame(train_dataset['validation']) 

with open("./data/wikipedia_documents.json", 'r', encoding='utf-8') as f:
    wiki_json = json.load(f)
wiki_df = pd.DataFrame(list(wiki_json.values()))

with open("./data/KorQuAD_v1.0_train.json", 'r', encoding='utf-8') as f:
    train_json = json.load(f)

with open("./data/KorQuAD_v1.0_dev.json", 'r', encoding='utf-8') as f:
    dev_json = json.load(f)

In [72]:
def addDF(data, isTrain):
    title = data['title']

    for paragraph in data['paragraphs']:
        context = paragraph['context']
        qas = paragraph['qas']

        for it in qas:
            question = it['question']
            id = it['id']
            answers = {'answer_start': [it['answers'][0]['answer_start']], 'text': [it['answers'][0]['text']]}

            df = pd.DataFrame({'title': title, 'context': context, 'question': question, 'id': id, 'answers': [answers], 'document_id': 0, '__index_level_0__': 0})
            
            global train_df, dev_df
            if isTrain:
                train_df = pd.concat([train_df, df], ignore_index=True)
            else:
                dev_df = pd.concat([dev_df, df], ignore_index=True)

In [73]:
print(train_df.shape, dev_df.shape)

for row in train_json['data']:
    addDF(row, True)
    
for row in dev_json['data']:
    addDF(row, False)

print(train_df.shape, dev_df.shape)

(3952, 7) (240, 7)
(64359, 7) (6014, 7)


In [74]:
def flattenList(nested_list):
    flat_list = []
    for sublist in nested_list:
        for item in sublist:
            flat_list.append(item)
    return flat_list

special_characters = re.escape(string.punctuation)

def getOtherCharacters(text):
    text = re.sub(r"[가-힣ㄱ-ㅎㅏ-ㅣA-Za-z0-9一-龥ぁ-ゔァ-ヴー々〆〤 ]", "", text)
    text = re.sub(f'[{special_characters}]', "", text)
    text = re.sub(r"[≪≫《》〈〉＜＞「」『』‘’“”・·°∧­ćä]", "", text)
    text = re.sub(r"[\*\#]+", "",text) # 특수기호 '*', '#'
    return text

other_characters = [getOtherCharacters(it) for it in wiki_df['text']]
other_characters = set(flattenList(other_characters))
other_characters = (''.join(other_characters))

In [75]:
def removeOtherCharacters(row, isTrain):
    question = train_df['question'][row]
    start = train_df['answers'][row]['answer_start'][0]
    answer = train_df['answers'][row]['text'][0]
    text = train_df['context'][row]

    th = 0
    idx = 0
    while idx > -1:
        idx = text.find(answer, idx)
        if idx > -1:
            th += 1
            if idx == start:
                break
            idx += 1

    text = re.sub(f'[{other_characters}]', "", text)
    text = text.replace('\\n', ' ').replace('\n', ' ')
    text = ' '.join(text.split())
    
    question = re.sub(f'[{other_characters}]', "", question)
    question = question.replace('\\n', ' ').replace('\n', ' ')
    question = ' '.join(question.split())

    i = 0
    idx = 0
    while idx > -1:
        idx = text.find(answer, idx)
        if idx > -1:
            i += 1
            if (i == th):
                break
            idx += 1

    answer = train_df['answers'][row]
    answer['answer_start'][0] = idx
    
    if isTrain:
        train_df.loc[row, 'question'] = question
        train_df.loc[row, 'context'] = text

    else:
        dev_df.loc[row, 'question'] = question
        dev_df.loc[row, 'context'] = text
    
for row in train_df.itertuples():
    i = row.Index
    removeOtherCharacters(i, True)

for row in dev_df.itertuples():
    i = row.Index
    removeOtherCharacters(i, False)

In [77]:
train_dataset = Dataset.from_dict(train_df)
dev_dataset = Dataset.from_dict(dev_df)
dataset = DatasetDict({"train":train_dataset, "validation":dev_dataset})

dataset.save_to_disk("./data/train_preprocessed/")

Saving the dataset (1/1 shards): 100%|██████████| 64359/64359 [00:00<00:00, 113778.07 examples/s]
Saving the dataset (1/1 shards): 100%|██████████| 6014/6014 [00:00<00:00, 113671.69 examples/s]
