In [15]:
import os
from datasets import load_dataset
from transformers import AutoTokenizer

In [16]:
tokenizer = AutoTokenizer.from_pretrained("skt/A.X-Encoder-base")

In [17]:
data_dir = "./data/run1"
data_files = {
    "train": os.path.join(data_dir, "train.parquet"),
    "val": os.path.join(data_dir, "val.parquet"),
    "test": os.path.join(data_dir, "test.parquet"),
}

ds = load_dataset('parquet', data_files=data_files)

Generating train split: 0 examples [00:00, ? examples/s]

Generating val split: 0 examples [00:00, ? examples/s]

Generating test split: 0 examples [00:00, ? examples/s]

In [18]:
ds

DatasetDict({
    train: Dataset({
        features: ['file_id', 'doc_id', 'title', 'author', 'publisher', 'date', 'topic', 'original_topic', 'sentence_ids', 'text', 'entities'],
        num_rows: 11992
    })
    val: Dataset({
        features: ['file_id', 'doc_id', 'title', 'author', 'publisher', 'date', 'topic', 'original_topic', 'sentence_ids', 'text', 'entities'],
        num_rows: 1499
    })
    test: Dataset({
        features: ['file_id', 'doc_id', 'title', 'author', 'publisher', 'date', 'topic', 'original_topic', 'sentence_ids', 'text', 'entities'],
        num_rows: 1499
    })
})

In [19]:
labels = ["O"]
ENTITY_TYPES = ["PS", "LC", "OG", "DT", "TI", "QT"]

for ent in ENTITY_TYPES:
    labels.append(f"B-{ent}")
    labels.append(f"I-{ent}")

label2id = {label: i for i, label in enumerate(labels)}
id2label = {i: label for label, i in label2id.items()}

def encode_and_align_labels(examples):
    """
    examples["text"]: 문자열 리스트
    examples["entities"]: 엔티티 스팬 리스트의 리스트
        하나의 문장에 대해 [{"start":..., "end":..., "label":...}, ...]
    => tokenized input + BIO 라벨(id)을 반환
    """
    texts = examples["text"]
    all_entities = examples["entities"]

    tokenized = tokenizer(
        texts,
        padding=False,
        truncation=True,
        return_offsets_mapping=True
    )

    all_labels = []

    for i, offsets in enumerate(tokenized["offset_mapping"]):
        entities = all_entities[i]
        # 엔티티 스팬들을 빠르게 조회하기 위해
        # (start, end, label) 형태로 정렬
        spans = [(e["start"], e["end"], e["label"]) for e in entities]

        labels_ids = []
        for idx, (start, end) in enumerate(offsets):
            if start == end:
                # special token ([CLS], [SEP]) 등
                labels_ids.append(-100)
                continue

            token_label = "O"

            # 이 토큰이 포함되는 엔티티가 있는지 확인
            for ent_start, ent_end, ent_label in spans:
                # (토큰의 span)와 (엔티티 span)이 겹치면 엔티티로 간주
                if not (end <= ent_start or start >= ent_end):
                    # 겹친다는 뜻
                    # 해당 엔티티 내에서의 위치에 따라 B- / I- 결정
                    if start == ent_start:
                        token_label = f"B-{ent_label}"
                    else:
                        token_label = f"I-{ent_label}"
                    break  # 가장 먼저 매칭된 엔티티 사용

            labels_ids.append(label2id[token_label] if token_label in label2id else label2id["O"])

        all_labels.append(labels_ids)

    # offset_mapping은 모델에 넣을 필요 없으므로 제거
    tokenized.pop("offset_mapping")

    tokenized["labels"] = all_labels
    return tokenized


In [20]:
train_ds = ds['train']
encoded_train_ds = train_ds.map(
    encode_and_align_labels,
    batched=True,
    remove_columns=train_ds.column_names  # text, entities 제거하고 인풋/라벨만 남김
)

Map:   0%|          | 0/11992 [00:00<?, ? examples/s]

In [25]:
encoded_train_ds

Dataset({
    features: ['input_ids', 'token_type_ids', 'attention_mask', 'labels'],
    num_rows: 11992
})

In [29]:
input_ids = encoded_train_ds[0]['input_ids']
labels = encoded_train_ds[0]['labels']

In [32]:
print(input_ids[:20])
print(labels[:20])

[0, 41405, 20794, 48, 34213, 32133, 1828, 35527, 31889, 20290, 20629, 1829, 34090, 34169, 22371, 20395, 35094, 20794, 20430, 32088]
[-100, 5, 6, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 5, 6, 6, 6, 6, 0, 0]
