In [1]:
import ndjson

TRAINING_FILE = "/home/yves/projects/Quill-NLP-Tools-and-Datasets/notw.ndjson"
MAX_SEQ_LENGTH = 100
TRAIN_SIZE = 10000
TEST_SIZE = 500
BATCH_SIZE = 8


with open(TRAINING_FILE) as i:
    data = ndjson.load(i)
    
data = data[:TRAIN_SIZE]

data = [{"text": item.get("synth_sentence", item.get("orig_sentence")), 
         "entities": item.get("entities", [])} for item in data]


In [2]:
label2idx = {"O": 0}

for sentence in data:
    if "entities" in sentence:
        for (_, _, label) in sentence["entities"]:
            if label not in label2idx:
                label2idx[label] = len(label2idx)
            
print(label2idx)
        

{'O': 0, 'POSSESSIVE': 1, 'VERB': 2, 'ADV': 3, 'WOMAN': 4, 'ITS': 5, 'THEN': 6, 'CHILD': 7}


In [3]:
from typing import List

class BertInputItem(object):
    """ A BertInputItem contains all the information that is needed to finetune
    a Bert model.

    Attributes:
        input_ids: the ids of the input tokens
        input_mask: a list of booleans that indicates what tokens should be masked
        segment_ids: a list of segment ids for the tokens
        label_id: a label id or a list of label ids for the input
    """

    def __init__(self, text: str,
                 input_ids: List[int],
                 input_mask: List[int],
                 segment_ids: List[int],
                 label_ids: List[int]):
        self.text = text
        self.input_ids = input_ids
        self.input_mask = input_mask
        self.segment_ids = segment_ids
        self.label_ids = label_ids
        

def preprocess_sequence_labelling(examples, label2idx, max_seq_length, tokenizer):
    input_items = []
    for (ex_index, ex) in enumerate(examples):

        # Create a list of token ids
        toks = tokenizer.encode_plus(ex["text"], max_length=max_seq_length, pad_to_max_length=True)
        input_ids = toks["input_ids"]
        segment_ids = toks["token_type_ids"]
        input_mask = toks["attention_mask"]
        
        tokens = tokenizer.convert_ids_to_tokens(input_ids)
        
        if "entities" not in ex:
            labels = [label2idx["O"]] * len(input_ids)
        else:
            labels = [label2idx["O"]]
            cur_index = 0
            for num, tok in enumerate(tokens[1:]):

                if num > 0 and not tok.startswith("##"):
                    cur_index += 1

                found_entity = False
                for entity in ex["entities"]:
                    if cur_index >= entity[0] and cur_index <= entity[1]:
                        labels.append(label2idx[entity[2]])
                        found_entity = True
                if not found_entity:
                    labels.append(label2idx["O"])


                if tok.startswith("##"):
                    cur_index += len(tok)-2
                else:
                    cur_index += len(tok)
        
        assert len(labels) == len(input_ids)
        
        input_items.append(
            BertInputItem(text=ex["text"],
                          input_ids=input_ids,
                          input_mask=input_mask,
                          segment_ids=segment_ids,
                          label_ids=labels))
    return input_items

In [4]:
from transformers import BertForTokenClassification
from transformers import BertTokenizer

tokenizer = BertTokenizer.from_pretrained("bert-base-cased")
input_items = preprocess_sequence_labelling(data, label2idx, MAX_SEQ_LENGTH, tokenizer)



I0316 20:47:21.745091 140454718867264 file_utils.py:41] PyTorch version 1.2.0+cu92 available.
I0316 20:47:22.655247 140454718867264 file_utils.py:57] TensorFlow version 2.1.0 available.
I0316 20:47:23.325318 140454718867264 tokenization_utils.py:501] loading file https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-vocab.txt from cache at /home/yves/.cache/torch/transformers/5e8a2b4893d13790ed4150ca1906be5f7a03d6c4ddf62296c383f6db42814db2.e13dbb970cb325137104fb2e5f36fe865f27746c6b526f6352861b1980eb80b1


In [None]:
import torch
import numpy as np
from torch.utils.data import TensorDataset, DataLoader

def get_data_loader(input_items: List[BertInputItem], batch_size: int, shuffle: bool=True) -> DataLoader:
    """
    Constructs a DataLoader for a list of BERT input items.

    Args:
        input_items: a list of BERT input items
        batch_size: the batch size
        shuffle: indicates whether the data should be shuffled or not.

    Returns: a DataLoader for the input items

    """
    all_input_ids = torch.tensor([f.input_ids for f in input_items], dtype=torch.long)
    all_input_mask = torch.tensor([f.input_mask for f in input_items], dtype=torch.long)
    all_segment_ids = torch.tensor([f.segment_ids for f in input_items], dtype=torch.long)
    all_label_ids = torch.tensor([f.label_ids for f in input_items], dtype=torch.long)
    
    data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids)

    dataloader = DataLoader(data, shuffle=shuffle, batch_size=batch_size)

    return dataloader

In [None]:
import random

random.shuffle(input_items)

test_items = input_items[-TEST_SIZE:]
valid_items = input_items[-2*TEST_SIZE:-TEST_SIZE]
train_items = input_items[:-2*TEST_SIZE]

test_dl = get_data_loader(test_items, BATCH_SIZE, shuffle=False)
dev_dl = get_data_loader(valid_items, BATCH_SIZE, shuffle=False)
train_dl = get_data_loader(train_items, BATCH_SIZE, shuffle=True)


In [None]:
import sys
sys.path.append('../')

from quillnlp.models.bert.train import train
from transformers import BertModel

model = BertForTokenClassification.from_pretrained("bert-base-cased", num_labels=len(label2idx))
model.to("cuda")

train(model, train_dl, dev_dl, BATCH_SIZE, 32/BATCH_SIZE, device="cuda")

I0316 20:47:26.952414 140454718867264 configuration_utils.py:256] loading configuration file https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-config.json from cache at /home/yves/.cache/torch/transformers/b945b69218e98b3e2c95acf911789741307dec43c698d35fad11c1ae28bda352.3d5adf10d3445c36ce131f4c6416aa62e9b58e1af56b97664773f4858a46286e
I0316 20:47:26.953918 140454718867264 configuration_utils.py:292] Model config BertConfig {
  "architectures": [
    "BertForMaskedLM"
  ],
  "attention_probs_dropout_prob": 0.1,
  "bos_token_id": null,
  "do_sample": false,
  "eos_token_ids": null,
  "finetuning_task": null,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 768,
  "id2label": {
    "0": "LABEL_0",
    "1": "LABEL_1"
  },
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "is_decoder": false,
  "label2id": {
    "LABEL_0": 0,
    "LABEL_1": 1
  },
  "layer_norm_eps": 1e-12,
  "length_penalty": 1.0,
  "max_length": 20,
  "max_position_embeddin

HBox(children=(IntProgress(value=0, description='Training iteration', max=1125, style=ProgressStyle(descriptio…




HBox(children=(IntProgress(value=0, description='Evaluation iteration', max=63, style=ProgressStyle(descriptio…


Loss history: []
Dev loss: 1.241124156921629


Epoch:   5%|▌         | 1/20 [02:06<40:07, 126.70s/it]

HBox(children=(IntProgress(value=0, description='Training iteration', max=1125, style=ProgressStyle(descriptio…




HBox(children=(IntProgress(value=0, description='Evaluation iteration', max=63, style=ProgressStyle(descriptio…


Loss history: [1.241124156921629]
Dev loss: 1.0270915882928031


Epoch:  10%|█         | 2/20 [04:14<38:04, 126.92s/it]

HBox(children=(IntProgress(value=0, description='Training iteration', max=1125, style=ProgressStyle(descriptio…




HBox(children=(IntProgress(value=0, description='Evaluation iteration', max=63, style=ProgressStyle(descriptio…


Loss history: [1.241124156921629, 1.0270915882928031]
Dev loss: 0.8729870442360167


Epoch:  15%|█▌        | 3/20 [06:20<35:56, 126.87s/it]

HBox(children=(IntProgress(value=0, description='Training iteration', max=1125, style=ProgressStyle(descriptio…

In [None]:
from quillnlp.models.bert.train import evaluate

output_model_file = "/tmp/model.bin"
print("Loading model from", output_model_file)
device="cpu"

model_state_dict = torch.load(output_model_file, map_location=lambda storage, loc: storage)
model = BertForTokenClassification.from_pretrained("bert-base-cased", state_dict=model_state_dict, num_labels=len(label2idx))
model.to(device)

#_, train_correct, train_predicted = evaluate(model, train_dataloader)
#_, dev_correct, dev_predicted = evaluate(model, dev_dataloader)
_, _, test_correct, test_predicted = evaluate(model, test_dl, device)

In [None]:
idx2label = {v:k for k,v in label2idx.items()}

for item, correct, predicted in zip(test_items, test_correct, test_predicted):
    print(item.text)
    for error in set(predicted):
        print("Found:", idx2label[error])
    for error in set(correct):
        print("Correct:", idx2label[error])
    