In [1]:
import json
import kss

In [4]:
def get_start_end_idx(answer):
    gold_text = answer['text']
    start_idx = answer['answer_start']
    end_idx = start_idx + len(gold_text)
    return start_idx, end_idx

def process_qa_text(context, question, answer):
    ans_gen_input = f"question: {question}  context: {context}"
    ans_gen_target = f"{answer}"
    return {"source_text": ans_gen_input, "target_text": ans_gen_target, "task": "qa"}

def process_qg_text(context, question, answer, qg_format):
    answer_text = answer['text'].strip()
    
    if qg_format == "prepend":
        que_gen_input = f"answer: {answer_text}  context: {context}"
    elif qg_format == "highlight":
        start_pos, end_pos = get_start_end_idx(answer)
        que_gen_input = f"generate question: {context[:start_pos]} {{hl_token}} {answer_text} {{hl_token}} {context[end_pos:]}"
    else:
        start_pos, end_pos = get_start_end_idx(answer)
        que_gen_input = f"answer: {answer_text} context: {context[:start_pos]} {{hl_token}} {answer_text} {{hl_token}} {context[end_pos:]}"
    
    que_gen_target = f"{question}"
    return {"source_text": que_gen_input, "target_text": que_gen_target, "task": "qg"}

def process_e2e_qg(paragraph):
    source_text = f"generate questions: {paragraph['context'].strip()}"
    questions = [qas['question'].strip() for qas in paragraph['qas']]
    target_text = " {sep_token} ".join(questions)
    target_text = f"{target_text} {{sep_token}}"
    return {"source_text": source_text, "target_text": target_text, "task": "e2e_qg"}

def process_ans_ext(paragraph):
    context = paragraph['context'].strip()

    # split into sentences
    sents = kss.split_sentences(context)
    
    # get positions of the sentences
    positions = []
    for i, sent in enumerate(sents):
        if i == 0:
            start, end = 0, len(sent)
        else:
            start, end = (prev_end + 1), (prev_end + len(sent) + 1)
        prev_end = end
        positions.append({'start': start, 'end': end})
    
    # get answers
    answers = [qa['answers'][0] for qa in paragraph['qas']]

    # get list of answers for each sentence
    sent_answers = []
    for pos, sent in zip(positions, sents):
        target_answers = []
        for ans in answers:
            if ans['answer_start'] in range(pos['start'], pos['end']):
                target_answers.append(ans['text'].strip())
        sent_answers.append(target_answers)

    # build inputs and targets
    examples = []
    for i, ans in enumerate(sent_answers):
        context = "extract answers:"
        if len(ans) == 0: continue
        ans = list(set(ans))
        for j, sent in enumerate(sents):
            if i == j:
                sent = "{hl_token} %s {hl_token}" % sent
            context = "%s %s" % (context, sent)
            context = context.strip()
        input_text = context
        target_text = " {sep_token} ".join(ans) + " {sep_token}"

        examples.append({'source_text': input_text, "target_text": target_text, "task": "ans_ext"})
    
    return examples

In [None]:
file_path = '/home/ubuntu/workspace/kaist.ir/qa/data/korquad/KorQuAD_v1.0_dev.json'
count = 0
tasks = ['qa', 'qg', 'ans_ext', 'e2e_qg']

with open(file_path) as f:
    squad = json.load(f)
    for doc in squad['data']:
        for paragraph in doc['paragraphs']:
            context = paragraph['context']

            if 'ans_ext' in tasks:
                ans_ext_examples = process_ans_ext(paragraph)
                for example in ans_ext_examples:
                    yield count, example
                    count += 1
                if 'e2e_qg' in tasks:
                    yield count, process_e2e_qg(paragraph)
                    count += 1

                for question_and_answers in paragraph['qas']:
                    question = question_and_answers['question']
                    answers = question_and_answers['answers']
                
                for task in tasks:
                    if task == 'qa':
                        yield count, process_qa_text(context, question, answers[0])
                        count += 1
                    if task == 'qg':
                        yield count, process_qg_text(context, question, answers[0])
                        count += 1