In [279]:
from transformers import AutoTokenizer
from transformers import DataCollatorForTokenClassification
import torch
from torch.utils.data import DataLoader
from datasets import Dataset, DatasetDict
import pandas as pd

## Helper functions and declarations

In [None]:
label_names = ['O', 'B-ART','I-ART','B-CON','I-CON','B-LOC','I-LOC',
        'B-MAT','I-MAT','B-PER','I-PER','B-SPE','I-SPE']
id2label = {i: l for i, l in enumerate(label_names)}
label2id = {l: i for i, l in id2label.items()}

In [320]:
def align_labels_with_tokens(labels, word_ids):
    new_labels = []
    current_word = None
    for word_id in word_ids:
        if word_id != current_word:
            # Start of a new word!
            current_word = word_id
            label = -100 if word_id is None else labels[word_id]
            new_labels.append(label)
        elif word_id is None:
            # Special token
            new_labels.append(-100)
        else:
            # Same word as previous token
            label = labels[word_id]
            # If the label is B-XXX we change it to I-XXX
            if label % 2 == 1:
                label += 1
            new_labels.append(label)

    return new_labels

def tokenize_and_align_labels(examples):
    tokenized_inputs = tokenizer(
        examples['tokens'], truncation=True, is_split_into_words=True
    )
    all_labels = examples['ner_tags']
    new_labels = []
    for i, labels in enumerate(all_labels):
        word_ids = tokenized_inputs.word_ids(i)
        new_labels.append(align_labels_with_tokens(labels, word_ids))
    
    tokenized_inputs['labels'] = new_labels
    return tokenized_inputs

def parse(sentence):
    tokens, ner_tags = [], []
    for phrase in sentence:
            features = phrase.split(' ')
            tokens.append(features[0].strip())
            ner_tags.append(label2id[features[2].strip()])
    return {'tokens': tokens, 'ner_tags': ner_tags}

## Data Loading

In [307]:
raw_datasets = {}
for split in ('train', 'val', 'test'):
    raw_datasets[split] = {}
    with open(f'{split}.txt', 'r') as infile:
        raw_data = infile.read()
        lines = raw_data.split('\n')
        sentences = raw_data.split('\n\n')
        sentences = [sentence.split('\n') for sentence in sentences]
        for sentence in sentences:
            s_parsed = parse(sentence)
            for k in s_parsed:
                if k not in raw_datasets[split].keys():
                    raw_datasets[split][k] = []
                raw_datasets[split][k].append(s_parsed[k])


datasets = {}
for split in raw_datasets.keys():
    datasets[split] = Dataset.from_dict(raw_datasets[split])
data = DatasetDict(datasets)

## Preprocessing

### Tokenization

In [321]:
model_checkpoint = "bert-base-cased"
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)

data2 = data.map(
    tokenize_and_align_labels,
    batched=True,
    remove_columns=data['train'].column_names)

Map:   0%|          | 0/1992 [00:00<?, ? examples/s]

Map:   0%|          | 0/850 [00:00<?, ? examples/s]

Map:   0%|          | 0/864 [00:00<?, ? examples/s]

### Collation

In [244]:
data_collator = DataCollatorForTokenClassification(tokenizer)