<h1>Table of Contents<span class="tocSkip"></span></h1>
<div class="toc"><ul class="toc-item"><li><span><a href="#New-class" data-toc-modified-id="New-class-1"><span class="toc-item-num">1&nbsp;&nbsp;</span>New class</a></span><ul class="toc-item"><li><span><a href="#config" data-toc-modified-id="config-1.1"><span class="toc-item-num">1.1&nbsp;&nbsp;</span>config</a></span></li><li><span><a href="#Dataset" data-toc-modified-id="Dataset-1.2"><span class="toc-item-num">1.2&nbsp;&nbsp;</span>Dataset</a></span></li><li><span><a href="#Definitions" data-toc-modified-id="Definitions-1.3"><span class="toc-item-num">1.3&nbsp;&nbsp;</span>Definitions</a></span></li></ul></li><li><span><a href="#Old-runs" data-toc-modified-id="Old-runs-2"><span class="toc-item-num">2&nbsp;&nbsp;</span>Old runs</a></span></li></ul></div>

In [1]:
import joblib
import torch
import torch.nn as nn
import transformers

import numpy as np
import pandas as pd

from sklearn import preprocessing
from sklearn import model_selection

from tqdm import tqdm
from transformers import AdamW
from transformers import get_linear_schedule_with_warmup
from transformers import AutoTokenizer, AutoModel

from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Sequence, Tuple, Union, NamedTuple

import dask.dataframe as dd
import dask.array as da
from dask_ml.model_selection import train_test_split
# from sklearn.model_selection import train_test_split

from copy import deepcopy
import spacy
from spacy.tokenizer import Tokenizer
from spacy.lang.en import English
from pathlib import Path
from spacy.gold import docs_to_json, biluo_tags_from_offsets, spans_from_biluo_tags, iob_to_biluo

In [2]:
from spacy.tokens.span import Span as SpacySpan
def combine_overlapping(segments: List[Tuple[int, int]]) -> Iterable[Tuple[int, int]]:
    segments = sorted(set(segments))
    combined = None
    for segment in segments:
        if combined is None:
            combined = segment
        elif combined[0] <= segment[0] <= combined[1] + 1:
            combined = (combined[0], max(segment[1], combined[1]))
        else:
            yield combined
            combined = segment
    if combined is not None:
        yield combined


def maximal_spans(spans: List[SpacySpan]) -> List[SpacySpan]:
    if not spans or len(spans) < 1 :
        return []
    document = spans[0].doc
    segments = [(span.start_char, span.end_char) for span in spans]
    maximal = [document.char_span(start_char, end_char) for start_char, end_char in combine_overlapping(segments)]
    return sorted(maximal)


In [3]:
def biluo_to_iob(tags):
    out = []
    for tag in tags:
        if tag is None:
            out.append(tag)
        else:
            tag = tag.replace("U-", "B-", 1).replace("L-", "I-", 1)
            out.append(tag)
    return out

# New class
## config

In [4]:
class AnnotatedText(NamedTuple):
    text: str
    start: int
    end: int
class AnnotatedTextLabeled(NamedTuple):
    text: str
    start: int
    end: int
    label: str

In [5]:
config = dict(
    max_len = 128,
    train_batch_size = 32,
    valid_batch_size = 8,
    iterations = 3,
    model_name = "bert-base-uncased",
    training_file = "/src/ner_dataset.csv",
    path_to_model = './model.bin'
)

## Dataset

In [None]:
df=pd.read_csv(config.get('training_file'), encoding="latin-1")
df.loc[:, "Sentence #"]=df["Sentence #"].fillna(method="ffill")
df=df.groupby("Sentence #").agg({"Tag": list, "Word": lambda x: ' '.join(x)}).reset_index()
df['Tag']=df['Tag'].apply(iob_to_biluo)#.apply(lambda y: list(map(lambda x: x.upper(), y)))
nlp=spacy.load('en_core_web_sm')
df['Tags']=df[['Word', 'Tag']].apply(lambda x: spans_from_biluo_tags(nlp(x['Word']), x['Tag']), axis=1)

In [31]:
df['AnnotationLabeled']=df['Tags'].apply(lambda y: list(map(lambda x: AnnotatedTextLabeled(x.text,
                                                                                 x.start_char,
                                                                                 x.end_char,
                                                                                 x.label_
                                                                                ), y)))
df['Annotation']=df['Tags'].apply(lambda y: list(map(lambda x: AnnotatedText(x.text,
                                                                                 x.start_char,
                                                                                 x.end_char,
                                                                                ), y)))


In [32]:
# df.drop(['Tag', 'Tags'], axis=1, inplace=True, errors='ignore')
df.shape

(47959, 6)

## Definitions

In [48]:
def group_sub_entities(entities: List[dict], tokenizer) -> dict:
    """
    Group together the adjacent tokens with the same entity predicted.
    Args:
        entities (:obj:`dict`): The entities predicted by the pipeline.
    """
    # Get the first entity in the entity group
    entity = entities[0]["entity"]
#     scores = np.mean([entity["score"] for entity in entities])
    tokens = [entity["word"] for entity in entities]

    entity_group = {
        "entity_group": entity,
#         "score": np.mean(scores),
        "word": tokenizer.convert_tokens_to_string(tokens),
    }
    return entity_group

def group_entities(entities: List[dict], tokenizer) -> List[dict]:
    """
    Find and group together the adjacent tokens with the same entity predicted.
    Args:
        entities (:obj:`dict`): The entities predicted by the pipeline.
    """

    entity_groups = []
    entity_group_disagg = []

    if entities:
        last_idx = entities[-1]["index"]

    for entity in entities:
        is_last_idx = entity["index"] == last_idx
        if not entity_group_disagg:
            entity_group_disagg += [entity]
            if is_last_idx:
                entity_groups += [group_sub_entities(entity_group_disagg, tokenizer)]
            continue

        # If the current entity is similar and adjacent to the previous entity, append it to the disaggregated entity group
        # The split is meant to account for the "B" and "I" suffixes
        if (
            entity["entity"].split("-")[-1] == entity_group_disagg[-1]["entity"].split("-")[-1]
            and entity["index"] == entity_group_disagg[-1]["index"] + 1
        ):
            entity_group_disagg += [entity]
            # Group the entities at the last entity
            if is_last_idx:
                entity_groups += [group_sub_entities(entity_group_disagg, tokenizer)]
        # If the current entity is different from the previous entity, aggregate the disaggregated entity group
        else:
            entity_groups += [group_sub_entities(entity_group_disagg, tokenizer)]
            entity_group_disagg = [entity]
            # If it's the last entity, add it to the entity groups
            if is_last_idx:
                entity_groups += [group_sub_entities(entity_group_disagg, tokenizer)]

    return entity_groups

class EntityLabel(NamedTuple):
    start: int
    end: int
    label: Union[str, int]

class EntityDataset:
    """
    This class is used with torch.utils.data.DataLoader. This class accepts dask dataframe as input.
    pytorch handles when to ask for an item from this class and which item to fetch
    """

    def __init__(self,
                 ddf: dd,
                 text_column: str,
                 label_column: str,
                 max_len: int,
                 tokenizer,
                 predict_mode=False
                 ):
        """

        :param ddf: Dask dataframe
        :param text_column: name of the column that contains the text to classify
        :param label_column: name of the column that contains labels
                             entries inside label_column need to be of format `EntityLabel`
                             if there are more than one type of entities then `label` of
                             each tuple should be int but if only single entity then the
                             label really doesn't matter
        :param max_len: maximum length of a sentence to do padding or truncating
        """
        self.text_column = text_column
        self.label_column = label_column
        self.max_len = max_len
        self.tokenizer = tokenizer
        self.predict_mode = predict_mode
        self.texts=ddf[self.text_column].compute().values
        self.labels=ddf[self.label_column].compute().values  
        self.size = len(self.texts)
    
    def __len__(self):
        return self.size

    def __getitem__(self, item):
        # get row number `item` from dask dataframe
        text = self.texts[item]#, self.text_column].compute()
        tags = self.labels[item]#, self.label_column].compute()
        ids = []
        target_tag = []

        for i, s in enumerate(text):
            inputs = self.tokenizer.encode(
                s,
                add_special_tokens=False
            )
            input_len = len(inputs)
            ids.extend(inputs)
            target_tag.extend([tags[i]] * input_len)

        ids = ids[:self.max_len - 2]
        target_tag = target_tag[:self.max_len - 2]

        ids = [101] + ids + [102]
        target_tag = [0] + target_tag + [0]

        mask = [1] * len(ids)
        token_type_ids = [0] * len(ids)

        padding_len = self.max_len - len(ids)

        ids = ids + ([0] * padding_len)
        mask = mask + ([0] * padding_len)
        token_type_ids = token_type_ids + ([0] * padding_len)
        target_tag = target_tag + ([0] * padding_len)

        return {
            "ids": torch.tensor(ids, dtype=torch.long),
            "mask": torch.tensor(mask, dtype=torch.long),
            "token_type_ids": torch.tensor(token_type_ids, dtype=torch.long),
            "target_tag": torch.tensor(target_tag, dtype=torch.long),
        }

class EntityModel(nn.Module):
    def __init__(self,
                 num_tag: int,
                 model_name: str
                 ):
        super(EntityModel, self).__init__()
        self.num_tag = num_tag
        self.bert = AutoModel.from_pretrained(model_name)
        self.bert_drop_1 = nn.Dropout(0.3)
        self.out_tag = nn.Linear(768, self.num_tag)
        
    @staticmethod
    def _loss_fn(output, target, mask, num_labels):
        lfn = nn.CrossEntropyLoss()
        active_loss = mask.view(-1) == 1
        active_logits = output.view(-1, num_labels)
        active_labels = torch.where(
            active_loss,
            target.view(-1),
            torch.tensor(lfn.ignore_index).type_as(target)
        )
        loss = lfn(active_logits, active_labels)
        return loss
    
    def forward(
            self,
            ids,
            mask,
            token_type_ids,
            target_tag
    ):
        o1, _ = self.bert(
            ids,
            attention_mask=mask,
            token_type_ids=token_type_ids
        )

        bo_tag = self.bert_drop_1(o1)
        tag = self.out_tag(bo_tag)

        loss_tag = self._loss_fn(tag, target_tag, mask, self.num_tag)

        return tag, loss_tag
    
def get_tokens(sentence):
    """return tokens of sentence"""
    nlp = English()
    tokenizer = Tokenizer(nlp.vocab)
    return [token.text for token in tokenizer(sentence)]

class ModelRunner():
    def __init__(self):
        self.num_ents = None
    
    @staticmethod
    def _preprocess_data(ddf, text_column, label_column):
        def add_label(row):
            nlp = English()
            text = row[text_column]
            labels = row[label_column]
            this_doc = nlp(text)
            result = list()
            if isinstance(labels, list):
                for label in labels:
                    if getattr(label, 'label', None):
                        ner = label.label
                    else:
                        ner = 'this'
                    spans = list(filter(lambda x: x is not None,
                                        [this_doc.char_span(int(label.start), int(label.end), label=ner)]))
                    entities = [(span.start_char, span.end_char, ner) for span in maximal_spans(spans)]
                    result.extend(entities)
            return result

        def span_helper(row):
            "convert spact entity to IOB entity for BERT"
            doc = nlp(row[text_column])
            tags = biluo_to_iob(biluo_tags_from_offsets(doc, row['label1']))
            return tags

        ddf = ddf.assign(label1 = ddf.apply(add_label, axis=1, meta=list))
        ddf = ddf.assign(label2 = ddf.apply(span_helper, axis=1, meta=list))
        ddf = ddf.assign(text1 = ddf[text_column].apply(get_tokens, meta=list))
        return ddf
    
    @staticmethod
    def _encode_entities(ddf, label_column, meta_data):
        def infer_labels(row):
            unique_ents = set(row[label_column])
            return unique_ents

        def encode_labels(row, enc_tag):
            labels = enc_tag.transform(row[label_column])
            return labels

        enc_tag = preprocessing.LabelEncoder()
        # 1. collect all unique entities from each label in every sentence
        res = ddf.apply(infer_labels, axis=1, meta=pd.Series())
        unique_ents = set.union(*res.compute())
        unique_ents = list(unique_ents)

        # 2. fit a LabelEncoder to entities for neural network to receive
        # integer labelled classes
        _ = enc_tag.fit(unique_ents)
        ddf = ddf.assign(label3 = ddf.apply(encode_labels, args=(enc_tag,), axis=1, meta=list))
        num_ents = len(enc_tag.classes_)
        meta_data.update({
            "enc_tag": enc_tag,
            "num_ent": num_ents
        })
        return ddf

    
    
    @staticmethod
    def _train_fn(data_loader, model, optimizer, device, scheduler):
        model.train()
        final_loss = 0
        for data in tqdm(data_loader, total=len(data_loader)):
            for k, v in data.items():
                data[k] = v.to(device)
            optimizer.zero_grad()
            _, loss = model(**data)
            loss.backward()
            optimizer.step()
            scheduler.step()
            final_loss += loss.item()
        return final_loss / len(data_loader)

    @staticmethod
    def _val_fn(data_loader, model, device):
        model.eval()
        final_loss = 0
        for data in tqdm(data_loader, total=len(data_loader)):
            for k, v in data.items():
                data[k] = v.to(device)
            _, loss = model(**data)
            final_loss += loss.item()
        return final_loss / len(data_loader)

    def _create_datasets(self, ddf, text_column, label_column, max_len, tokenizer, 
                         train_batch_size, valid_batch_size):
        train_ddf, valid_ddf = train_test_split(ddf, random_state=42, test_size=0.1)

        train_dataset = EntityDataset(
            train_ddf,
            text_column=text_column,
            label_column=label_column,
            max_len=max_len,
            tokenizer=tokenizer
        )
        valid_dataset = EntityDataset(
            valid_ddf,
            text_column=text_column,
            label_column=label_column,
            max_len=max_len,
            tokenizer=tokenizer
        )
        train_data_loader = torch.utils.data.DataLoader(
            train_dataset, batch_size=train_batch_size, num_workers=4
        )
        valid_data_loader = torch.utils.data.DataLoader(
            valid_dataset, batch_size=valid_batch_size, num_workers=1
        )
        return train_ddf, valid_ddf, train_data_loader, valid_data_loader

    def train(
            self,
            ddf: dd,
            text_column: str,
            label_column: str,
            path_to_model: str,
            iterations: int = 3,
            config: dict = None
    ):
        """
        :param ddf: Dask dataframe
        :param text_column: name of the column that contains the text to classify
        :param label_column: name of the column that contains labels
        :param config: config for BERT NER
        """
        # Initialization
        if config is None:
            config = {}
        max_len = config.get('max_len', 128)
        model_name = config.get('model_name', 'bert-base-uncased')
        train_batch_size = config.get('train_batch_size', 32)
        valid_batch_size = config.get('valid_batch_size', 8)
        path_to_model = Path(path_to_model)
        tokenizer = AutoTokenizer.from_pretrained(model_name)
        meta_data = deepcopy(config)
        meta_data.update({
            'text_column': text_column,
            'label_column': label_column
        })
        # Build data
        ddf = self._preprocess_data(ddf, text_column, label_column)
        ddf = ddf.persist()
        ddf = self._encode_entities(ddf, 'label2', meta_data=meta_data)
        ddf = ddf.persist()
        train_ddf, valid_ddf, train_data_loader, valid_data_loader = self._create_datasets(ddf, 'text1',
                                                                                           'label3', max_len,
                                                                                           tokenizer,
                                                                                           train_batch_size,
                                                                                           valid_batch_size)
        # Call Model
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        model = EntityModel(num_tag=meta_data.get('num_ent'), model_name=model_name)
        model.to(device)
        param_optimizer = list(model.named_parameters())
        no_decay = ["bias", "LayerNorm.bias", "LayerNorm.weight"]
        optimizer_parameters = [
            {
                "params": [
                    p for n, p in param_optimizer if not any(
                        nd in n for nd in no_decay
                    )
                ],
                "weight_decay": 0.001,
            },
            {
                "params": [
                    p for n, p in param_optimizer if any(
                        nd in n for nd in no_decay
                    )
                ],
                "weight_decay": 0.0,
            }
        ]
        num_train_steps = int(
            len(train_ddf) / train_batch_size * iterations
        )
        optimizer = AdamW(optimizer_parameters, lr=3e-5)
        scheduler = get_linear_schedule_with_warmup(
            optimizer,
            num_warmup_steps=0,
            num_training_steps=num_train_steps
        )
        best_loss = np.inf
        for epoch in range(iterations):
            train_loss = self._train_fn(
                train_data_loader,
                model,
                optimizer,
                device,
                scheduler
            )
            test_loss = self._val_fn(
                valid_data_loader,
                model,
                device
            )
            print(f"Train Loss = {train_loss} Valid Loss = {test_loss}")
            if test_loss < best_loss:
                torch.save(model.state_dict(), path_to_model)
                best_loss = test_loss
        joblib.dump(meta_data, f"{path_to_model.parent}/meta.bin")

    @classmethod
    def run(cls,
            text,
            path_to_model
            ):
        path_to_model = Path(path_to_model)
        meta_data = joblib.load(f"{path_to_model.parent}/meta.bin")
        enc_tag = meta_data["enc_tag"]
        model_name = meta_data['model_name']
        num_ent = meta_data['num_ent']
        max_len = meta_data['max_len']
        text_column = meta_data.get('text_column', 'Word')
        label_column = meta_data.get('label_column', 'ner')

        tokenizer = AutoTokenizer.from_pretrained(model_name)
        num_tag = len(list(enc_tag.classes_))

        tokenized_sentence = tokenizer.encode(text)
        decoded_tokenized_sentence = tokenizer.decode(tokenized_sentence)

        print(tokenized_sentence)
        print(decoded_tokenized_sentence)

        d = {text_column: [text], label_column: [None]}
        df = pd.DataFrame(data=d)
        ddf = dd.from_pandas(df, npartitions=20)
        ddf = cls._preprocess_data(ddf, text_column, label_column)
        ddf = cls._encode_entities(ddf, 'label2', meta_data=meta_data)
        test_dataset = EntityDataset(
            ddf,
            text_column='text1',
            label_column='label3',
            max_len=max_len,
            tokenizer=tokenizer,
            predict_mode=True
            )

        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
#         device = torch.device('cpu')
        model = EntityModel(num_tag=num_ent, model_name=model_name)
        model.load_state_dict(torch.load(path_to_model))
        model.to(device)

        with torch.no_grad():
            data = test_dataset[0]
            for k, v in data.items():
                data[k] = v.to(device).unsqueeze(0)
            tag, _ = model(**data)
            
            enc_tag_dict = dict(zip(enc_tag.transform(enc_tag.classes_), enc_tag.classes_,))
            enc_tag_dict.update({
                num_ent: 'O'
            })
            #
            tags = list(map(lambda x: enc_tag_dict.get(x),
                           tag.argmax(2).cpu().numpy().reshape(-1)[:len(tokenized_sentence)]))
        
        entities = []
        ignore_labels = ['O']
        print(tags)
        for ind, (tok, ent) in enumerate(zip(tokenized_sentence, tags[:len(tokenized_sentence)])):
            if ent not in ignore_labels:
                entity = {
                            "word": tokenizer.convert_ids_to_tokens(tok),
                            "entity": ent,
                            "index": ind,
                        }

                entities += [entity]
        return group_entities(entities, tokenizer)

In [43]:
try:
    del ddf
except:
    pass
ddf = dd.from_pandas(df[:500], npartitions=20)

In [44]:
runner = ModelRunner()
runner.train(ddf, 'Word', 'AnnotationLabeled', config.get('path_to_model'), config.get('iterations'), config)

100%|██████████| 14/14 [04:36<00:00, 19.77s/it]
100%|██████████| 9/9 [00:21<00:00,  2.42s/it]


Train Loss = 1.6947388734136308 Valid Loss = 1.0927788946363661


100%|██████████| 14/14 [00:08<00:00,  1.61it/s]
100%|██████████| 9/9 [00:00<00:00, 10.13it/s]


Train Loss = 0.9419715659958976 Valid Loss = 0.843152317735884


100%|██████████| 14/14 [00:08<00:00,  1.63it/s]
100%|██████████| 9/9 [00:01<00:00,  8.42it/s]


Train Loss = 0.75580923472132 Valid Loss = 0.7456969453228844


In [49]:
ModelRunner.run(df['Word'].values[-2], config.get('path_to_model'))

[101, 2144, 2059, 1010, 4614, 2031, 2218, 2270, 7012, 1997, 1996, 5496, 1998, 2699, 2000, 14785, 4697, 8777, 4584, 2306, 1996, 2231, 1012, 102]
[CLS] since then, authorities have held public trials of the accused and tried to marginalize moderate officials within the government. [SEP]




['B-art', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'B-art']


[{'entity_group': 'B-art', 'word': '[CLS]'},
 {'entity_group': 'B-art', 'word': '[SEP]'}]

In [None]:
# text = """
# The 2018–19 influenza season was a moderate severity season with two waves of influenza A 
# activity of similar magnitude during the season: A(H1N1)pdm09 predominated from October 2018 
# to mid-February 2019, and A(H3N2) activity increased from mid-February through mid-May.
# """
text = """Mr. Trump’s tweets began just moments after a Fox News report 
by Mike Tobin, a reporter for the network, about protests in Minnesota and elsewhere."""
       
         
ModelRunner.run(text, config.get('path_to_model'))