In [1]:
!pip install nerda -q
!pip install seqeval -q
!pip install flair -q

In [2]:
from NERDA.datasets import get_conll_data, download_conll_data 
from NERDA.models import NERDA

from google.colab import files
import pandas as pd
import ast
import unicodedata

import numpy as np
import seqeval.metrics
import spacy
import torch
from tqdm import tqdm, trange
from transformers import LukeTokenizer, LukeForEntitySpanClassification
from flair.data import Sentence
from flair.models import SequenceTagger
import timeit
from sklearn.model_selection import train_test_split 

# must upload processed_df.csv and retrain_processed.csv files
uploaded = files.upload()

download_conll_data()
training = get_conll_data('train')
validation = get_conll_data('valid')
testing = get_conll_data('test')

Saving processed_df.csv to processed_df.csv
Saving retrain_processed.csv to retrain_processed.csv
Reading https://data.deepai.org/conll2003.zip


In [3]:
# Download the testb set of the CoNLL-2003 dataset
!wget https://raw.githubusercontent.com/synalp/NER/master/corpus/CoNLL-2003/eng.testb

--2021-11-18 08:57:51--  https://raw.githubusercontent.com/synalp/NER/master/corpus/CoNLL-2003/eng.testb
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.111.133, 185.199.110.133, 185.199.109.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.111.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 748096 (731K) [text/plain]
Saving to: ‘eng.testb.1’


2021-11-18 08:57:51 (8.98 MB/s) - ‘eng.testb.1’ saved [748096/748096]



In [4]:
def generate_labels(input_text):
  input_text = str(input_text)
  if input_text.count(" ") > 0:
    if "went to the store" in input_text:
      if input_text.count(" ") > 4:
        return ["B-PER", "I-PER", "O", "O", "O", "O"]
      return ["B-PER", "O", "O", "O", "O"]
    return ["B-PER", "I-PER"]
  else: 
    return ["B-PER"]

def get_sentence_from_name(input_name):
  input_name = str(input_name)
  return input_name.split(" ")

In [5]:
tag_scheme = [
'B-PER',
'I-PER',
'B-ORG',
'I-ORG',
'B-LOC',
'I-LOC',
'B-MISC',
'I-MISC'
]

transformer = 'studio-ousia/luke-large-finetuned-conll-2003'

# hyperparameters for network
dropout = 0.1

training_hyperparameters = {
'epochs' : 2,
'warmup_steps' : 500,                                                   
'train_batch_size': 13,                                         
'learning_rate': 1e-5
}

In [6]:
retrain_subset = pd.read_csv("retrain_processed.csv", index_col=0)
retrain_subset["tags_list"] = retrain_subset["Name"].apply(lambda x: generate_labels(x))
retrain_subset["sentences"] = retrain_subset["Name"].apply(lambda x: get_sentence_from_name(x))

non_entity_list = ["He went to the store.", "She went to the store.", "They went to the store.", "My mom went to the store.", "My dad went to the store.", "My friend went to the store."]
entities = [["O"]*len(i.split(" ")) for i in non_entity_list]
non_entity_df = pd.DataFrame(non_entity_list)
non_entity_df["Race"] = "Not Applicable"
non_entity_df["Name"] = non_entity_df[0]
non_entity_df["sentences"] = non_entity_df[0].apply(lambda x: x.split(" "))
non_entity_df["tags_list"] = non_entity_df["sentences"].apply(lambda x: ["O"]*len(x))
non_entity_df = non_entity_df.drop(columns=[0])

retrain_subset = pd.concat([retrain_subset, non_entity_df])

In [8]:
rt_train, rt_valid = train_test_split(retrain_subset, test_size=0.15, stratify=retrain_subset['Race']) 

In [9]:
retrain_dict = {"sentences": list(rt_train["sentences"]), "tags": list(rt_train["tags_list"])}
valid_dict = {"sentences": list(rt_valid["sentences"]), "tags": list(rt_valid["tags_list"])}

In [10]:
total_sentences = list(retrain_dict["sentences"]) + list( ["sentences"])
total_tags = list(retrain_dict["tags"]) + list(training["tags"])

valid_sentences = list(valid_dict["sentences"]) + list(validation["sentences"])
valid_tags = list(valid_dict["tags"]) + list(validation["tags"])

total_retrain_dict = {"sentences": total_sentences, "tags": total_tags}
total_valid_dict = {"sentences": valid_sentences, "tags": valid_tags}

In [11]:
model = NERDA(
dataset_training = total_retrain_dict,
dataset_validation = total_valid_dict,
tag_scheme = tag_scheme, 
tag_outside = 'O',
transformer = transformer,
dropout = dropout,
hyperparameters = training_hyperparameters
)

Device automatically set to: cuda


Downloading:   0%|          | 0.00/877 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/2.09G [00:00<?, ?B/s]

Some weights of the model checkpoint at studio-ousia/luke-large-finetuned-conll-2003 were not used when initializing LukeModel: ['classifier.weight', 'classifier.bias', 'luke.embeddings.position_ids']
- This IS expected if you are initializing LukeModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing LukeModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


Downloading:   0%|          | 0.00/1.66k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/878k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/446k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/14.6M [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/33.0 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/0.98k [00:00<?, ?B/s]

In [12]:
model.train()


 Epoch 1 / 2


100%|██████████| 2360/2360 [21:50<00:00,  1.80it/s]
100%|██████████| 774/774 [01:31<00:00,  8.47it/s]


Train Loss = 0.1199184441631063 Valid Loss = 0.556860538659565

 Epoch 2 / 2


100%|██████████| 2360/2360 [21:50<00:00,  1.80it/s]
100%|██████████| 774/774 [01:31<00:00,  8.47it/s]

Train Loss = 0.02337111015998343 Valid Loss = 0.040128559404297506





'Model trained successfully'

In [13]:
processed_test_df = pd.read_csv("processed_df.csv", index_col=0)
processed_test_df["tags_list"] = processed_test_df["Name"].apply(lambda x: generate_labels(x))
processed_test_df["sentences"] = processed_test_df["Name"].apply(lambda x: get_sentence_from_name(x))
processed_test_dict = {"sentences": list(processed_test_xdf["sentences"]), "tags": list(processed_test_df["tags_list"])}

In [14]:
def generate_entities(input_string):
  return model.predict([input_string])[0]

In [15]:
processed_white_df = processed_test_df.loc[processed_test_df["Race"]=="White"].reset_index(drop=True)
processed_black_df = processed_test_df.loc[processed_test_df["Race"]=="Black"].reset_index(drop=True)
processed_api_df = processed_test_df.loc[processed_test_df["Race"]=="API"].reset_index(drop=True)
processed_hispanic_df = processed_test_df.loc[processed_test_df["Race"]=="Hispanic"].reset_index(drop=True)

processed_test_dict_w = {"sentences": list(processed_white_df["sentences"]), "tags": list(processed_white_df["tags_list"])}
processed_test_dict_b = {"sentences": list(processed_black_df["sentences"]), "tags": list(processed_black_df["tags_list"])}
processed_test_dict_a = {"sentences": list(processed_api_df["sentences"]), "tags": list(processed_api_df["tags_list"])}
processed_test_dict_h = {"sentences": list(processed_hispanic_df["sentences"]), "tags": list(processed_hispanic_df["tags_list"])}

## Primarily White Names from Curated Test Data:

In [16]:
start = timeit.default_timer()
curated_test_labels_w = processed_test_dict_w["tags"]
curated_pred_labels_w = [generate_entities(processed_test_dict_w["sentences"][i]) for i in range(len(processed_test_dict_w["sentences"]))]
print(seqeval.metrics.classification_report(curated_test_labels_w, curated_pred_labels_w, digits=4)) 
stop = timeit.default_timer()
print('LUKE Runtime: {} seconds'.format(stop - start))

  _warn_prf(average, modifier, msg_start, len(result))


              precision    recall  f1-score   support

         ORG     0.0000    0.0000    0.0000         0
         PER     1.0000    0.9976    0.9988      8480

   micro avg     0.9983    0.9976    0.9980      8480
   macro avg     0.5000    0.4988    0.4994      8480
weighted avg     1.0000    0.9976    0.9988      8480

LUKE Runtime: 1778.2663463770004 seconds


## Primarily Black / African American Names from Curated Test Data:

In [17]:
start = timeit.default_timer()
curated_test_labels_b = processed_test_dict_b["tags"]
curated_pred_labels_b = [generate_entities(processed_test_dict_b["sentences"][i]) for i in range(len(processed_test_dict_b["sentences"]))]
print(seqeval.metrics.classification_report(curated_test_labels_b, curated_pred_labels_b, digits=4)) 
stop = timeit.default_timer()
print('LUKE Runtime: {} seconds'.format(stop - start))

  _warn_prf(average, modifier, msg_start, len(result))


              precision    recall  f1-score   support

         LOC     0.0000    0.0000    0.0000         0
        MISC     0.0000    0.0000    0.0000         0
         ORG     0.0000    0.0000    0.0000         0
         PER     0.9867    0.9508    0.9684      8476

   micro avg     0.9554    0.9508    0.9531      8476
   macro avg     0.2467    0.2377    0.2421      8476
weighted avg     0.9867    0.9508    0.9684      8476

LUKE Runtime: 2101.36411804 seconds


## Primarily Asian or Native Hawaiian or Other Pacific Islander Names from Curated Test Data:

In [18]:
start = timeit.default_timer()
curated_test_labels_a = processed_test_dict_a["tags"]
curated_pred_labels_a = [generate_entities(processed_test_dict_a["sentences"][i]) for i in range(len(processed_test_dict_a["sentences"]))]
print(seqeval.metrics.classification_report(curated_test_labels_a, curated_pred_labels_a, digits=4)) 
stop = timeit.default_timer()
print('LUKE Runtime: {} seconds'.format(stop - start))

  _warn_prf(average, modifier, msg_start, len(result))


              precision    recall  f1-score   support

         LOC     0.0000    0.0000    0.0000         0
        MISC     0.0000    0.0000    0.0000         0
         ORG     0.0000    0.0000    0.0000         0
         PER     0.9914    0.9859    0.9887      8468

   micro avg     0.9903    0.9859    0.9881      8468
   macro avg     0.2479    0.2465    0.2472      8468
weighted avg     0.9914    0.9859    0.9887      8468

LUKE Runtime: 2512.461219871 seconds


## Primarily Hispanic / Latino Names from Curated Test Data:

In [19]:
start = timeit.default_timer()
curated_test_labels_h = processed_test_dict_h["tags"]
curated_pred_labels_h = [generate_entities(processed_test_dict_h["sentences"][i]) for i in range(len(processed_test_dict_h["sentences"]))]
print(seqeval.metrics.classification_report(curated_test_labels_h, curated_pred_labels_h, digits=4)) 
stop = timeit.default_timer()
print('LUKE Runtime: {} seconds'.format(stop - start))

  _warn_prf(average, modifier, msg_start, len(result))


              precision    recall  f1-score   support

         LOC     0.0000    0.0000    0.0000         0
        MISC     0.0000    0.0000    0.0000         0
         ORG     0.0000    0.0000    0.0000         0
         PER     0.9986    0.9870    0.9928      8480

   micro avg     0.9868    0.9870    0.9869      8480
   macro avg     0.2496    0.2468    0.2482      8480
weighted avg     0.9986    0.9870    0.9928      8480

LUKE Runtime: 2811.11916396 seconds


In [20]:
# Load the tokenizer
tokenizer = LukeTokenizer.from_pretrained("studio-ousia/luke-large-finetuned-conll-2003")

def load_documents(dataset_file):
    documents = []
    words = []
    labels = []
    sentence_boundaries = []
    with open(dataset_file) as f:
        for line in f:
            line = line.rstrip()
            if line.startswith("-DOCSTART"):
                if words:
                    documents.append(dict(
                        words=words,
                        labels=labels,
                        sentence_boundaries=sentence_boundaries
                    ))
                    words = []
                    labels = []
                    sentence_boundaries = []
                continue

            if not line:
                if not sentence_boundaries or len(words) != sentence_boundaries[-1]:
                    sentence_boundaries.append(len(words))
            else:
                items = line.split(" ")
                words.append(items[0])
                labels.append(items[-1])

    if words:
        documents.append(dict(
            words=words,
            labels=labels,
            sentence_boundaries=sentence_boundaries
        ))
        
    return documents


def load_examples(documents):
    examples = []
    max_token_length = 510
    max_mention_length = 30

    for document in tqdm(documents):
        words = document["words"]
        subword_lengths = [len(tokenizer.tokenize(w)) for w in words]
        total_subword_length = sum(subword_lengths)
        sentence_boundaries = document["sentence_boundaries"]

        for i in range(len(sentence_boundaries) - 1):
            sentence_start, sentence_end = sentence_boundaries[i:i+2]
            if total_subword_length <= max_token_length:
                # if the total sequence length of the document is shorter than the
                # maximum token length, we simply use all words to build the sequence
                context_start = 0
                context_end = len(words)
            else:
                # if the total sequence length is longer than the maximum length, we add
                # the surrounding words of the target sentence　to the sequence until it
                # reaches the maximum length
                context_start = sentence_start
                context_end = sentence_end
                cur_length = sum(subword_lengths[context_start:context_end])
                while True:
                    if context_start > 0:
                        if cur_length + subword_lengths[context_start - 1] <= max_token_length:
                            cur_length += subword_lengths[context_start - 1]
                            context_start -= 1
                        else:
                            break
                    if context_end < len(words):
                        if cur_length + subword_lengths[context_end] <= max_token_length:
                            cur_length += subword_lengths[context_end]
                            context_end += 1
                        else:
                            break

            text = ""
            for word in words[context_start:sentence_start]:
                if word[0] == "'" or (len(word) == 1 and is_punctuation(word)):
                    text = text.rstrip()
                text += word
                text += " "

            sentence_words = words[sentence_start:sentence_end]
            sentence_subword_lengths = subword_lengths[sentence_start:sentence_end]

            word_start_char_positions = []
            word_end_char_positions = []
            for word in sentence_words:
                if word[0] == "'" or (len(word) == 1 and is_punctuation(word)):
                    text = text.rstrip()
                word_start_char_positions.append(len(text))
                text += word
                word_end_char_positions.append(len(text))
                text += " "

            for word in words[sentence_end:context_end]:
                if word[0] == "'" or (len(word) == 1 and is_punctuation(word)):
                    text = text.rstrip()
                text += word
                text += " "
            text = text.rstrip()

            entity_spans = []
            original_word_spans = []
            for word_start in range(len(sentence_words)):
                for word_end in range(word_start, len(sentence_words)):
                    if sum(sentence_subword_lengths[word_start:word_end]) <= max_mention_length:
                        entity_spans.append(
                            (word_start_char_positions[word_start], word_end_char_positions[word_end])
                        )
                        original_word_spans.append(
                            (word_start, word_end + 1)
                        )

            examples.append(dict(
                text=text,
                words=sentence_words,
                entity_spans=entity_spans,
                original_word_spans=original_word_spans,
            ))

    return examples


def is_punctuation(char):
    cp = ord(char)
    if (cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or (cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126):
        return True
    cat = unicodedata.category(char)
    if cat.startswith("P"):
        return True
    return False

In [21]:
test_documents = load_documents("eng.testb")
test_examples = load_examples(test_documents)

100%|██████████| 231/231 [00:03<00:00, 65.94it/s]


In [22]:
def get_named_entities_custom_luke(input_row):
  words = input_row["words"]
  sentence_boundaries = input_row["sentence_boundaries"]
  start = 0
  total_labels = []
  for i in sentence_boundaries: 
    if i != 0: 
      current_string = words[start:i]
      if len(current_string) >= 120:
        midpoint = len(current_string) // 2
        first_half = current_string[:midpoint]
        second_half = current_string[midpoint:]
        prediction = model.predict([first_half])[0] + model.predict([second_half])[0]
      else: 
        prediction = model.predict([current_string])[0]
      total_labels.extend(prediction)
      start = i
  return total_labels

In [23]:
test_labels = [test_documents[i]["labels"] for i in range(len(test_documents))]
pred_labels = [get_named_entities_custom_luke(test_documents[i]) for i in range(len(test_documents))]

In [24]:
print(seqeval.metrics.classification_report(test_labels, pred_labels, digits=4)) 

              precision    recall  f1-score   support

         LOC     0.9458    0.9424    0.9441      1666
        MISC     0.7566    0.8559    0.8032       701
         ORG     0.8950    0.9265    0.9105      1647
         PER     0.9558    0.9582    0.9570      1602

   micro avg     0.9075    0.9314    0.9193      5616
   macro avg     0.8883    0.9208    0.9037      5616
weighted avg     0.9101    0.9314    0.9203      5616

