In [1]:
import warnings
warnings.filterwarnings('ignore')

from typing_extensions import TypedDict
from typing import List,Any
import re

In [2]:
IntList = List[int] # A list of token_ids
IntListList = List[IntList] # A List of List of token_ids, e.g. a Batch

In [3]:
pattern = r'(\b[A-Z][a-z]+\b)(\s\b[A-Z][a-z]+\b)*'
re.compile(pattern)

def get_annotations(text, pattern):
    annotations = []
    for match in re.finditer(pattern, text):
        label_dic = dict()
        label_dic['start'] = match.start()
        label_dic['end'] = match.end()
        label_dic['label'] = 'CLEntity' # Entity starting with a capital letter
        annotations.append(label_dic)
    return annotations

In [4]:
json_data = []
book = open("./example-texts/red-rain-in-iliad.txt")
for line in book:
    line = line.strip()
    
    line_data = dict()
    line_data['content'] = line
    line_data['annotations'] = get_annotations(line, pattern)
    json_data.append(line_data)
print(json_data)

[{'content': 'STOP', 'annotations': []}, {'content': '', 'annotations': []}, {'content': '', 'annotations': []}, {'content': '', 'annotations': []}, {'content': 'Early Journal Content on JSTOR, Free to Anyone in the World', 'annotations': [{'start': 0, 'end': 21, 'label': 'CLEntity'}, {'start': 32, 'end': 36, 'label': 'CLEntity'}, {'start': 40, 'end': 46, 'label': 'CLEntity'}, {'start': 54, 'end': 59, 'label': 'CLEntity'}]}, {'content': '', 'annotations': []}, {'content': 'This article is one of nearly 500,000 scholarly works digitized and made freely available to everyone in', 'annotations': [{'start': 0, 'end': 4, 'label': 'CLEntity'}]}, {'content': 'the world by JSTOR.', 'annotations': []}, {'content': '', 'annotations': []}, {'content': 'Known as the Early Journal Content, this set of works include research articles, news, letters, and other', 'annotations': [{'start': 0, 'end': 5, 'label': 'CLEntity'}, {'start': 13, 'end': 34, 'label': 'CLEntity'}]}, {'content': 'writings publishe

In [5]:
# with open('example-texts/iliad.txt') as fo:
#     text = fo.read()

In [6]:
# annotations = []
# for match in re.finditer(pattern, text):
#     label_dic = dict()
#     label_dic['start'] = match.start()
#     label_dic['end'] = match.end()
#     label_dic['text'] = text[match.start():match.end()]
#     label_dic['label'] = 'CL-Entity' # Entity starting with a capital letter
#     annotations.append(label_dic)
# print(len(annotations))


In [7]:
from transformers import BertTokenizerFast,  BatchEncoding
from tokenizers import Encoding

def align_tokens_and_annotations_bilou(tokenized: Encoding, annotations):
    tokens = tokenized.tokens
    aligned_labels = ["O"] * len(
        tokens
    )  # Make a list to store our labels the same length as our tokens
    for anno in annotations:
        annotation_token_ix_set = set()# A set that stores the token indices of the annotation
        for char_ix in range(anno["start"], anno["end"]):

            token_ix = tokenized.char_to_token(char_ix)
            if token_ix is not None:
                annotation_token_ix_set.add(token_ix)

        if len(annotation_token_ix_set) == 1:
            # If there is only one token
            token_ix = annotation_token_ix_set.pop()
            prefix = (
                "U"  # This annotation spans one token so is prefixed with U for unique
            )
            aligned_labels[token_ix] = f"{prefix}-{anno['label']}"

        else:

            last_token_in_anno_ix = len(annotation_token_ix_set) - 1
            for num, token_ix in enumerate(sorted(annotation_token_ix_set)):
                if num == 0:
                    prefix = "B"
                elif num == last_token_in_anno_ix:
                    prefix = "L"  # Its the last token
                else:
                    prefix = "I"  # We're inside of a multi token annotation
                aligned_labels[token_ix] = f"{prefix}-{anno['label']}"
    return aligned_labels

In [8]:
# try an exmaple
example = {'content': 'We encourage people to read and share the Early Journal Content openly and to tell others that this', 'annotations': [{'start': 0, 'end': 2, 'label': 'CLEntity'}, {'start': 42, 'end': 63, 'label': 'CLEntity'}]}
tokenizer = BertTokenizerFast.from_pretrained('bert-base-cased') # Load a pre-trained tokenizer
tokenized_batch : BatchEncoding = tokenizer(example["content"])
tokenized_text : Encoding = tokenized_batch[0]
labels = align_tokens_and_annotations_bilou(tokenized_text, example["annotations"])

for token, label in zip(tokenized_text.tokens, labels):
    print(token, "-", label)

[CLS] - O
We - U-CLEntity
encourage - O
people - O
to - O
read - O
and - O
share - O
the - O
Early - B-CLEntity
Journal - I-CLEntity
Content - L-CLEntity
openly - O
and - O
to - O
tell - O
others - O
that - O
this - O
[SEP] - O


In [9]:
import itertools

class LabelSet:
    def __init__(self, labels: List[str]):
        self.labels_to_id = {}
        self.ids_to_label = {}
        self.labels_to_id["O"] = 0
        self.ids_to_label[0] = "O"
        num = 0  # in case there are no labels
        # Writing BILU will give us incremntal ids for the labels
        for _num, (label, s) in enumerate(itertools.product(labels, "BILU")):
            num = _num + 1  # skip 0
            l = f"{s}-{label}"
            self.labels_to_id[l] = num
            self.ids_to_label[num] = l
        # Add the OUTSIDE label - no label for the token

    def get_aligned_label_ids_from_annotations(self, tokenized_text, annotations):
        raw_labels = align_tokens_and_annotations_bilou(tokenized_text, annotations)
        return list(map(self.labels_to_id.get, raw_labels))


example_label_set = LabelSet(labels=["CLEntity"])
aligned_label_ids = example_label_set.get_aligned_label_ids_from_annotations(
    tokenized_text, example["annotations"]
)
tokens = tokenized_text.tokens
for token, label in zip(tokens, aligned_label_ids):
    print(token, "-", label)

[CLS] - 0
We - 4
encourage - 0
people - 0
to - 0
read - 0
and - 0
share - 0
the - 0
Early - 1
Journal - 2
Content - 3
openly - 0
and - 0
to - 0
tell - 0
others - 0
that - 0
this - 0
[SEP] - 0


In [10]:
from dataclasses import dataclass
from torch.utils.data import Dataset
from transformers import PreTrainedTokenizerFast

In [11]:
@dataclass
class TrainingExample:
    input_ids: IntList
    attention_masks: IntList
    labels: IntList


class TraingDataset(Dataset):
    def __init__(
        self,
        data: Any,
        label_set: LabelSet,
        tokenizer: PreTrainedTokenizerFast,
        tokens_per_batch=32,
        window_stride=None,
    ):
        self.label_set = label_set
        if window_stride is None:
            self.window_stride = tokens_per_batch
        self.tokenizer = tokenizer
        self.texts = []
        self.annotations = []

        for example in data:
            self.texts.append(example["content"])
            self.annotations.append(example["annotations"])
        ###TOKENIZE All THE DATA
        tokenized_batch = self.tokenizer(self.texts, add_special_tokens=False)
        ###ALIGN LABELS ONE EXAMPLE AT A TIME
        aligned_labels = []
        for ix in range(len(tokenized_batch.encodings)):
            encoding = tokenized_batch.encodings[ix]
            raw_annotations = self.annotations[ix]
            aligned = label_set.get_aligned_label_ids_from_annotations(
                encoding, raw_annotations
            )
            aligned_labels.append(aligned)
        ###END OF LABEL ALIGNMENT

        ###MAKE A LIST OF TRAINING EXAMPLES. (This is where we add padding)
        self.training_examples: List[TrainingExample] = []
        empty_label_id = "O"
        for encoding, label in zip(tokenized_batch.encodings, aligned_labels):
            length = len(label)  # How long is this sequence
            for start in range(0, length, self.window_stride):

                end = min(start + tokens_per_batch, length)

                # How much padding do we need ?
                padding_to_add = max(0, tokens_per_batch - end + start)
                self.training_examples.append(
                    TrainingExample(
                        # Record the tokens
                        input_ids=encoding.ids[start:end]  # The ids of the tokens
                        + [self.tokenizer.pad_token_id]
                        * padding_to_add,  # padding if needed
                        labels=(
                            label[start:end]
                            + [-100] * padding_to_add  # padding if needed
                        ),  # -100 is a special token for padding of labels,
                        attention_masks=(
                            encoding.attention_mask[start:end]
                            + [0]
                            * padding_to_add  # 0'd attenetion masks where we added padding
                        ),
                    )
                )

    def __len__(self):
        return len(self.training_examples)

    def __getitem__(self, idx) -> TrainingExample:

        return self.training_examples[idx]

In [12]:
label_set = LabelSet(labels=["CLEntity"])
ds = TraingDataset(
    data=json_data, tokenizer=tokenizer, label_set=label_set, tokens_per_batch=16
)
ex = ds[10]
print(ex)

TrainingExample(input_ids=[1284, 8343, 1234, 1106, 2373, 1105, 2934, 1103, 4503, 3603, 27551, 9990, 1105, 1106, 1587, 1639], attention_masks=[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], labels=[4, 0, 0, 0, 0, 0, 0, 0, 1, 2, 3, 0, 0, 0, 0, 0])


In [13]:
import torch


class TraingingBatch:
    def __getitem__(self, item):
        return getattr(self, item)

    def __init__(self, examples: List[TrainingExample]):
        self.input_ids: torch.Tensor
        self.attention_masks: torch.Tensor
        self.labels: torch.Tensor
        input_ids: IntListList = []
        masks: IntListList = []
        labels: IntListList = []
        for ex in examples:
            input_ids.append(ex.input_ids)
            masks.append(ex.attention_masks)
            labels.append(ex.labels)
        self.input_ids = torch.LongTensor(input_ids)
        self.attention_masks = torch.LongTensor(masks)
        self.labels = torch.LongTensor(labels)

In [14]:
from torch.utils.data.dataloader import DataLoader
from transformers import BertForTokenClassification, AdamW, BertTokenizer
import torch

In [15]:
model = BertForTokenClassification.from_pretrained(
    "bert-base-cased", num_labels=len(ds.label_set.ids_to_label.values())
)

Some weights of the model checkpoint at bert-base-cased were not used when initializing BertForTokenClassification: ['cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.predictions.bias', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertForTokenClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForTokenClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForTokenClassification were not initialized from the model checkpoint at bert-base-cas

In [16]:
from transformers import Trainer, TrainingArguments
training_args = TrainingArguments(
    output_dir = "./trainning_output/output1.txt",
    num_train_epochs=10,
    per_device_train_batch_size=4,  # batch size per device during training
    weight_decay=0.01,               # strength of weight decay
    load_best_model_at_end=True,
    logging_steps=200,
    evaluation_strategy="steps",
)

In [17]:
optimizer = AdamW(model.parameters(), lr=5e-6)

In [18]:
dataloader = DataLoader(
    ds,
    collate_fn=TraingingBatch,
    batch_size=4,
    shuffle=True,
)

In [19]:
print(list(dataloader)[0].input_ids)
print(list(dataloader)[0].labels)
print(len(list(dataloader)[0].input_ids[0]))
print(len(list(dataloader)[0].labels[0]))

tensor([[ 3482,   117,  1105, 20421,  1142,  3438,  1111,  2174,  8225,   119,
           147,  9272,  9565,  1110,  1226,  1104],
        [16664,  2414,  1105,  1150,  1138,  1680,  1103,  2016,  1104,  1103,
             0,     0,     0,     0,     0,     0],
        [ 5424,   119,  1130, 12371,   117,  1103, 24705, 20937,  1389,  4924,
          1867,  1104,  1103,     0,     0,     0],
        [13915,  1112,  1103,  4503,  3603, 27551,   117,  1142,  1383,  1104,
          1759,  1511,  1844,  4237,   117,  2371]])
tensor([[   4,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0, -100, -100],
        [   4,    0,    1,    2,    3,    0,    4,    0,    1,    3,    0,    0,
            0,    0,    0,    0],
        [   1,    2,    2,    2,    3,    0,    0,    0,    0,    0,    4, -100,
         -100, -100, -100, -100],
        [   0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
         -100, -100, -100, -100]])
16
16


In [20]:
for num, batch in enumerate(dataloader):
    
    output = model(
        input_ids=batch.input_ids,
        attention_mask=batch.attention_masks,
        labels=batch.labels,
    )
    print(output.loss)
    output.loss.backward()
    optimizer.step()
    if num > 20:
        break

tensor(1.6521, grad_fn=<NllLossBackward>)
tensor(1.5041, grad_fn=<NllLossBackward>)
tensor(1.5877, grad_fn=<NllLossBackward>)
tensor(1.4045, grad_fn=<NllLossBackward>)
tensor(1.2162, grad_fn=<NllLossBackward>)
tensor(1.3801, grad_fn=<NllLossBackward>)
tensor(1.2443, grad_fn=<NllLossBackward>)
tensor(1.3830, grad_fn=<NllLossBackward>)
tensor(1.0788, grad_fn=<NllLossBackward>)
tensor(1.3048, grad_fn=<NllLossBackward>)
tensor(1.0057, grad_fn=<NllLossBackward>)
tensor(1.0927, grad_fn=<NllLossBackward>)
tensor(0.8509, grad_fn=<NllLossBackward>)
tensor(0.9573, grad_fn=<NllLossBackward>)
tensor(1.0834, grad_fn=<NllLossBackward>)
tensor(0.7541, grad_fn=<NllLossBackward>)
tensor(0.9077, grad_fn=<NllLossBackward>)
tensor(0.5362, grad_fn=<NllLossBackward>)
tensor(0.6874, grad_fn=<NllLossBackward>)
tensor(0.8043, grad_fn=<NllLossBackward>)
tensor(1.0622, grad_fn=<NllLossBackward>)
tensor(0.7217, grad_fn=<NllLossBackward>)
