In [22]:
from transformers import AutoTokenizer
from datasets import load_from_disk

tokenizer = AutoTokenizer.from_pretrained('klue/roberta-base')
pad_on_right = tokenizer.padding_side == "right"
train_dataset = load_from_disk("../data/train_dataset")['train']
column_names = train_dataset.column_names


In [23]:
def prepare_train_features(examples):
        # truncation과 padding(length가 짧을때만)을 통해 toknization을 진행하며, stride를 이용하여 overflow를 유지합니다.
        # 각 example들은 이전의 context와 조금씩 겹치게됩니다.
        tokenized_examples = tokenizer(
            examples['question'],
            examples['context'],
            truncation="only_second",
            max_length=384,
            stride=128,
            return_overflowing_tokens=True,
            return_offsets_mapping=True,
            #return_token_type_ids=False, # roberta모델을 사용할 경우 False, bert를 사용할 경우 True로 표기해야합니다.
            padding="max_length",
        )

        # 길이가 긴 context가 등장할 경우 truncate를 진행해야하므로, 해당 데이터셋을 찾을 수 있도록 mapping 가능한 값이 필요합니다.
        sample_mapping = tokenized_examples.pop("overflow_to_sample_mapping")
        # token의 캐릭터 단위 position를 찾을 수 있도록 offset mapping을 사용합니다.
        # start_positions과 end_positions을 찾는데 도움을 줄 수 있습니다.
        offset_mapping = tokenized_examples.pop("offset_mapping")

        # 데이터셋에 "start position", "enc position" label을 부여합니다.
        tokenized_examples["start_positions"] = []
        tokenized_examples["end_positions"] = []

        for i, offsets in enumerate(offset_mapping):
            input_ids = tokenized_examples["input_ids"][i]
            cls_index = input_ids.index(tokenizer.cls_token_id)  # cls index

            # sequence id를 설정합니다 (to know what is the context and what is the question).
            sequence_ids = tokenized_examples.sequence_ids(i)

            # 하나의 example이 여러개의 span을 가질 수 있습니다.
            sample_index = sample_mapping[i]
            answers = examples['answers'][sample_index]

            # answer가 없을 경우 cls_index를 answer로 설정합니다(== example에서 정답이 없는 경우 존재할 수 있음).
            if len(answers["answer_start"]) == 0:
                tokenized_examples["start_positions"].append(cls_index)
                tokenized_examples["end_positions"].append(cls_index)
            else:
                # text에서 정답의 Start/end character index
                start_char = answers["answer_start"][0]
                end_char = start_char + len(answers["text"][0])

                # text에서 current span의 Start token index
                token_start_index = 0
                while sequence_ids[token_start_index] != (1 if pad_on_right else 0):
                    token_start_index += 1

                # text에서 current span의 End token index
                token_end_index = len(input_ids) - 1
                while sequence_ids[token_end_index] != (1 if pad_on_right else 0):
                    token_end_index -= 1

                # 정답이 span을 벗어났는지 확인합니다(정답이 없는 경우 CLS index로 label되어있음).
                if not (
                    offsets[token_start_index][0] <= start_char
                    and offsets[token_end_index][1] >= end_char
                ):
                    tokenized_examples["start_positions"].append(cls_index)
                    tokenized_examples["end_positions"].append(cls_index)
                else:
                    # token_start_index 및 token_end_index를 answer의 끝으로 이동합니다.
                    # Note: answer가 마지막 단어인 경우 last offset을 따라갈 수 있습니다(edge case).
                    while (
                        token_start_index < len(offsets)
                        and offsets[token_start_index][0] <= start_char
                    ):
                        token_start_index += 1
                    tokenized_examples["start_positions"].append(token_start_index - 1)
                    while offsets[token_end_index][1] >= end_char:
                        token_end_index -= 1
                    tokenized_examples["end_positions"].append(token_end_index + 1)
            if i==0:
                print(tokenized_examples['input_ids'][i])
                print(tokenized_examples['start_positions'][i], tokenized_examples['end_positions'][i])
                print(answers)
                print(tokenizer.decode(tokenized_examples['input_ids'][i][tokenized_examples['start_positions'][i]:tokenized_examples['end_positions'][i]+1]))
                for j, token in enumerate(tokenized_examples['input_ids'][i]): 
                    print(f'{j}: {tokenizer.decode(token)}', offsets[j]) 
        return tokenized_examples

In [24]:
train_dataset = train_dataset.map(
        prepare_train_features,
        batched=True,
        num_proc=None,
        remove_columns=column_names,
        load_from_cache_file=False
    )

[0, 3698, 2069, 3954, 2470, 3666, 2079, 7895, 8586, 2207, 2069, 554, 2259, 3728, 3860, 2073, 35, 2, 3666, 10346, 2252, 4013, 3666, 10450, 12, 29963, 30605, 2041, 19148, 2012, 9230, 13, 1497, 1402, 2252, 2021, 2179, 3666, 4570, 2079, 10450, 28674, 18, 3, 81, 3, 81, 2044, 2226, 17352, 2052, 10450, 2079, 27345, 3622, 18, 544, 12881, 22, 2211, 2079, 10450, 5069, 2052, 6940, 2496, 2051, 3911, 2211, 2079, 10450, 5069, 6233, 3896, 2496, 2051, 1513, 2062, 18, 6724, 2259, 26, 2440, 2052, 2307, 16, 22, 2440, 10598, 3956, 2019, 2223, 1570, 21, 19, 23, 3292, 10450, 5069, 2069, 3755, 6940, 7488, 7145, 2170, 14352, 18, 3, 81, 3, 81, 2044, 2226, 10450, 2073, 3666, 11119, 2145, 2259, 4405, 2318, 3666, 3698, 2069, 12104, 6233, 1889, 2259, 3666, 7145, 7895, 2170, 4424, 5187, 2138, 1889, 2259, 3860, 28674, 18, 11119, 2052, 5387, 2145, 3674, 2170, 3618, 5851, 16, 3698, 2069, 3954, 2470, 8199, 2079, 4668, 2069, 25154, 2085, 5851, 2069, 554, 2088, 1513, 2259, 3735, 2069, 3661, 2205, 2259, 3860, 2179, 4305, 

HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[0, 3698, 2069, 3954, 2470, 3666, 2079, 7895, 8586, 2207, 2069, 554, 2259, 3728, 3860, 2073, 35, 2, 3666, 10346, 2252, 4013, 3666, 10450, 12, 29963, 30605, 2041, 19148, 2012, 9230, 13, 1497, 1402, 2252, 2021, 2179, 3666, 4570, 2079, 10450, 28674, 18, 3, 81, 3, 81, 2044, 2226, 17352, 2052, 10450, 2079, 27345, 3622, 18, 544, 12881, 22, 2211, 2079, 10450, 5069, 2052, 6940, 2496, 2051, 3911, 2211, 2079, 10450, 5069, 6233, 3896, 2496, 2051, 1513, 2062, 18, 6724, 2259, 26, 2440, 2052, 2307, 16, 22, 2440, 10598, 3956, 2019, 2223, 1570, 21, 19, 23, 3292, 10450, 5069, 2069, 3755, 6940, 7488, 7145, 2170, 14352, 18, 3, 81, 3, 81, 2044, 2226, 10450, 2073, 3666, 11119, 2145, 2259, 4405, 2318, 3666, 3698, 2069, 12104, 6233, 1889, 2259, 3666, 7145, 7895, 2170, 4424, 5187, 2138, 1889, 2259, 3860, 28674, 18, 11119, 2052, 5387, 2145, 3674, 2170, 3618, 5851, 16, 3698, 2069, 3954, 2470, 8199, 2079, 4668, 2069, 25154, 2085, 5851, 2069, 554, 2088, 1513, 2259, 3735, 2069, 3661, 2205, 2259, 3860, 2179, 4305, 

In [1]:
from datasets import load_metric

metric = load_metric('squad')

In [2]:
print(metric)

Metric(name: "squad", features: {'predictions': {'id': Value(dtype='string', id=None), 'prediction_text': Value(dtype='string', id=None)}, 'references': {'id': Value(dtype='string', id=None), 'answers': Sequence(feature={'text': Value(dtype='string', id=None), 'answer_start': Value(dtype='int32', id=None)}, length=-1, id=None)}}, usage: """
Computes SQuAD scores (F1 and EM).
Args:
    predictions: List of question-answers dictionaries with the following key-values:
        - 'id': id of the question-answer pair as given in the references (see below)
        - 'prediction_text': the text of the answer
    references: List of question-answers dictionaries with the following key-values:
        - 'id': id of the question-answer pair (see above),
        - 'answers': a Dict in the SQuAD dataset format
            {
                'text': list of possible texts for the answer, as a list of strings
                'answer_start': list of start positions for the answer, as a list of ints
   