# Named Entity Recognition using BERT

## Import everything important

In [1]:
import joblib
import torch
import torch.nn as nn
from torch.utils import data

import numpy as np
import pandas as pd

from sklearn import preprocessing
from sklearn import model_selection

from tqdm import tqdm
from transformers import BertTokenizer, BertModel, AdamW, get_linear_schedule_with_warmup

  from .autonotebook import tqdm as notebook_tqdm


## Some config

In [2]:
MAX_LEN = 128
TRAIN_BATCH_SIZE = 32
VALID_BATCH_SIZE = 8
EPOCHS = 1
BASE_MODEL_PATH = "bert-base-uncased"
MODEL_PATH = "model.bin"
TRAINING_FILE = "ner_dataset.csv"
TOKENIZER = BertTokenizer.from_pretrained(
    BASE_MODEL_PATH,
    do_lower_case=True
)

## Dataset

In [3]:
class EntityDataset(data.Dataset):
    def __init__(self, texts, pos, tags):
        self.texts = texts
        self.pos = pos
        self.tags = tags
    
    def __len__(self):
        return len(self.texts)
    
    def __getitem__(self, item):
        text = self.texts[item]
        pos = self.pos[item]
        tags = self.tags[item]

        ids = []
        target_pos = []
        target_tag = []

        for i, s in enumerate(text):
            inputs = TOKENIZER.encode(
                s,
                add_special_tokens=False
            )
            input_len = len(inputs)
            ids.extend(inputs)
            target_pos.extend([pos[i]] * input_len)
            target_tag.extend([tags[i]] * input_len)

        ids = ids[:MAX_LEN - 2]
        target_pos = target_pos[:MAX_LEN - 2]
        target_tag = target_tag[:MAX_LEN - 2]

        ids = [101] + ids + [102]
        target_pos = [0] + target_pos + [0]
        target_tag = [0] + target_tag + [0]

        mask = [1] * len(ids)
        token_type_ids = [0] * len(ids)

        padding_len = MAX_LEN - len(ids)

        ids += ([0] * padding_len)
        mask += ([0] * padding_len)
        token_type_ids += ([0] * padding_len)
        target_pos += ([0] * padding_len)
        target_tag += ([0] * padding_len)

        return {
            "ids": torch.tensor(ids, dtype=torch.long),
            "mask": torch.tensor(mask, dtype=torch.long),
            "token_type_ids": torch.tensor(token_type_ids, dtype=torch.long),
            "target_pos": torch.tensor(target_pos, dtype=torch.long),
            "target_tag": torch.tensor(target_tag, dtype=torch.long),
        }

## Training and evaluation functions

In [4]:
def train_fn(data_loader, model, optimizer, device, scheduler):
    model.train()
    final_loss = 0
    for data in tqdm(data_loader, total=len(data_loader)):
        for k, v in data.items():
            data[k] = v.to(device)
        optimizer.zero_grad()
        outputs = model(**data)
        loss = outputs[2] if len(outputs) > 2 else outputs[1]  # Ensure we get the correct loss
        loss.backward()
        optimizer.step()
        scheduler.step()
        final_loss += loss.item()
    return final_loss / len(data_loader)


def eval_fn(data_loader, model, device):
    model.eval()
    final_loss = 0
    with torch.no_grad():  # Disable gradient calculation for evaluation
        for data in tqdm(data_loader, total=len(data_loader)):
            for k, v in data.items():
                data[k] = v.to(device)
            outputs = model(**data)
            loss = outputs[2] if len(outputs) > 2 else outputs[1]  # Ensure we get the correct loss
            final_loss += loss.item()
    return final_loss / len(data_loader)

## Loss function and model

In [5]:
def loss_fn(output, target, mask, num_labels):
    lfn = nn.CrossEntropyLoss()
    active_loss = mask.view(-1) == 1
    active_logits = output.view(-1, num_labels)
    active_labels = torch.where(
        active_loss,
        target.view(-1),
        torch.tensor(lfn.ignore_index).type_as(target)
    )
    loss = lfn(active_logits, active_labels)
    return loss


class EntityModel(nn.Module):
    def __init__(self, num_tag, num_pos):
        super(EntityModel, self).__init__()
        self.num_tag = num_tag
        self.num_pos = num_pos
        self.bert = BertModel.from_pretrained(
            BASE_MODEL_PATH
        )
        self.bert_drop_1 = nn.Dropout(0.3)
        self.bert_drop_2 = nn.Dropout(0.3)
        self.out_tag = nn.Linear(768, self.num_tag)
        self.out_pos = nn.Linear(768, self.num_pos)
    
    def forward(
        self, 
        ids, 
        mask, 
        token_type_ids, 
        target_pos, 
        target_tag
    ):
        o1, _ = self.bert(
            ids, 
            attention_mask=mask, 
            token_type_ids=token_type_ids
        )

        bo_tag = self.bert_drop_1(o1)
        bo_pos = self.bert_drop_2(o1)

        tag = self.out_tag(bo_tag)
        pos = self.out_pos(bo_pos)

        loss_tag = loss_fn(tag, target_tag, mask, self.num_tag)
        loss_pos = loss_fn(pos, target_pos, mask, self.num_pos)

        loss = (loss_tag + loss_pos) / 2

        return tag, pos, loss

## Data processing

In [6]:
def process_data(data_path):
    df = pd.read_csv(data_path, encoding="latin-1")
    df.loc[:, "Sentence #"] = df["Sentence #"].fillna(method="ffill")

    enc_pos = preprocessing.LabelEncoder()
    enc_tag = preprocessing.LabelEncoder()

    df.loc[:, "POS"] = enc_pos.fit_transform(df["POS"])
    df.loc[:, "Tag"] = enc_tag.fit_transform(df["Tag"])

    sentences = df.groupby("Sentence #")["Word"].apply(list).values
    pos = df.groupby("Sentence #")["POS"].apply(list).values
    tag = df.groupby("Sentence #")["Tag"].apply(list).values
    return sentences, pos, tag, enc_pos, enc_tag

## Training

In [7]:
sentences, pos, tag, enc_pos, enc_tag = process_data(TRAINING_FILE)

meta_data = {
    "enc_pos": enc_pos,
    "enc_tag": enc_tag
}

joblib.dump(meta_data, "meta.bin")

num_pos = len(list(enc_pos.classes_))
num_tag = len(list(enc_tag.classes_))

(
    train_sentences,
    test_sentences,
    train_pos,
    test_pos,
    train_tag,
    test_tag
) = model_selection.train_test_split(
    sentences, 
    pos, 
    tag, 
    random_state=42, 
    test_size=0.1
)

train_dataset = EntityDataset(
    texts=train_sentences, pos=train_pos, tags=train_tag
)

train_data_loader = torch.utils.data.DataLoader(
    train_dataset, batch_size=TRAIN_BATCH_SIZE, num_workers=1
)

valid_dataset = EntityDataset(
    texts=test_sentences, pos=test_pos, tags=test_tag
)

valid_data_loader = torch.utils.data.DataLoader(
    valid_dataset, batch_size=VALID_BATCH_SIZE, num_workers=1
)

device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
model = EntityModel(num_tag=num_tag, num_pos=num_pos)
model.to(device)

param_optimizer = list(model.named_parameters())
no_decay = ["bias", "LayerNorm.bias", "LayerNorm.weight"]
optimizer_parameters = [
    {
        "params": [
            p for n, p in param_optimizer if not any(
                nd in n for nd in no_decay
            )
        ],
        "weight_decay": 0.001,
    },
    {
        "params": [
            p for n, p in param_optimizer if any(
                nd in n for nd in no_decay
            )
        ],
        "weight_decay": 0.0,
    },
]

num_train_steps = int(
    len(train_sentences) / TRAIN_BATCH_SIZE * EPOCHS
)
optimizer = torch.optim.AdamW(optimizer_parameters, lr=3e-5) 
scheduler = get_linear_schedule_with_warmup(
    optimizer, 
    num_warmup_steps=0, 
    num_training_steps=num_train_steps
)

best_loss = np.inf
for epoch in range(EPOCHS):
    train_loss = train_fn(
        train_data_loader, 
        model, 
        optimizer, 
        device, 
        scheduler
    )
    test_loss = eval_fn(
        valid_data_loader,
        model,
        device
    )
    print(f"Train Loss = {train_loss} Valid Loss = {test_loss}")
    if test_loss < best_loss:
        torch.save(model.state_dict(), MODEL_PATH)
        best_loss = test_loss

  df.loc[:, "POS"] = enc_pos.fit_transform(df["POS"])
  df.loc[:, "Tag"] = enc_tag.fit_transform(df["Tag"])
  0%|          | 0/1349 [00:00<?, ?it/s]Traceback (most recent call last):
  File "<string>", line 1, in <module>
  File "/Users/mdaniyalk/miniforge3/envs/workshop/lib/python3.9/multiprocessing/spawn.py", line 116, in spawn_main
    exitcode = _main(fd, parent_sentinel)
  File "/Users/mdaniyalk/miniforge3/envs/workshop/lib/python3.9/multiprocessing/spawn.py", line 126, in _main
    self = reduction.pickle.load(from_parent)
AttributeError: Can't get attribute 'EntityDataset' on <module '__main__' (built-in)>
  0%|          | 0/1349 [00:03<?, ?it/s]


KeyboardInterrupt: 

## Inference

In [8]:
import re

meta_data = joblib.load("meta.bin")
enc_pos = meta_data["enc_pos"]
enc_tag = meta_data["enc_tag"]

num_pos = len(list(enc_pos.classes_))
num_tag = len(list(enc_tag.classes_))

sentence = "It's easier to dial three numbers than it is to look up a number and dial about four different departments. You're trying to decide whether you need to dial the Wilson Police Department, the Sheriff's Department, the Fire Department of the Rescue and you can dial this three numbers and get all four departments at one time. The 911 system will use the same number of dispatchers or telecommunicators as they like to be called as the current separate systems. "
tokenized_sentence = TOKENIZER.encode(sentence)
sentence = sentence.split()
print(sentence)
print(tokenized_sentence)

test_dataset = EntityDataset(
    texts=[sentence], 
    pos=[[0] * len(sentence)], 
    tags=[[0] * len(sentence)]
)

device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
model = EntityModel(num_tag=num_tag, num_pos=num_pos)
model.load_state_dict(torch.load(MODEL_PATH))
model.to(device)

with torch.no_grad():
    data = test_dataset[0]
    for k, v in data.items():
        data[k] = v.to(device).unsqueeze(0)
    tag, pos, _ = model(**data)

    print(
        enc_tag.inverse_transform(
            tag.argmax(2).cpu().numpy().reshape(-1)
        )[:len(tokenized_sentence)]
    )
    print(
        enc_pos.inverse_transform(
            pos.argmax(2).cpu().numpy().reshape(-1)
        )[:len(tokenized_sentence)]
    )
       # print(tokenized_sentence[i])

['Cigarettes', 'in', 'this', 'machine', 'cost', '65', 'cents.', 'Now', "that's", '55', 'cents', 'for', 'growing', 'the', 'tobacco,', 'manufacturing,', 'and', 'distributing', 'the', 'cigarettes.', '2', 'cents', 'state', 'tax,', 'and', '8', 'cents', 'federal', 'tax.', 'The', 'people', 'in', 'the', 'tobacco', 'fields', 'receive', 'about', 'a', 'nickel', 'a', 'pack.', 'Adding', 'another', '8', 'cents', 'tax', 'will', 'not', 'give', 'them', 'a', 'penny', 'more.', 'They', 'fear', 'they', 'will', 'end', 'up', 'losing', 'a', 'lot.', "They'll", 'lower', 'quotas', 'if', 'the', 'usage', 'cuts', 'down.', 'And', '8', 'cents', "can't", 'help', 'that', 'any.', 'Now', 'if', 'they', 'lower', 'a', 'quota,', "what's", 'that', 'do', 'to', 'you?', 'It', 'reduces', 'the', 'amount', 'of', 'tobacco', 'I', 'can', 'grow.', 'As', 'you', 'can', 'tell,', 'Tom', 'and', 'Alan', 'Broadwell', 'are', 'already', 'counting', 'their', 'losses.', 'Last', 'year', 'their', 'quota', 'was', '32', 'acres.', 'This', 'year', 'it'