In [None]:
!pip install nerda -q
!pip install seqeval -q
!pip install flair -q

In [None]:
from NERDA.datasets import get_conll_data, download_conll_data 
from NERDA.models import NERDA

from google.colab import files
import pandas as pd
import ast
import unicodedata

import numpy as np
import seqeval.metrics
import spacy
import torch
from tqdm import tqdm, trange
from transformers import LukeTokenizer, LukeForEntitySpanClassification
from flair.data import Sentence
from flair.models import SequenceTagger
import timeit
from sklearn.model_selection import train_test_split 

# must upload processed_df.csv and retrain_processed.csv files
uploaded = files.upload()

download_conll_data()
training = get_conll_data('train')
validation = get_conll_data('valid')
testing = get_conll_data('test')

In [None]:
# Download the testb set of the CoNLL-2003 dataset
!wget https://raw.githubusercontent.com/synalp/NER/master/corpus/CoNLL-2003/eng.testb

In [None]:
def generate_labels(input_text):
  input_text = str(input_text)
  if input_text.count(" ") > 0:
    if "went to the store" in input_text:
      if input_text.count(" ") > 4:
        return ["B-PER", "I-PER", "O", "O", "O", "O"]
      return ["B-PER", "O", "O", "O", "O"]
    return ["B-PER", "I-PER"]
    
    
  else: 
    return ["B-PER"]

def get_sentence_from_name(input_name):
  input_name = str(input_name)
  return input_name.split(" ")

In [None]:
tag_scheme = [
'B-PER',
'I-PER',
'B-ORG',
'I-ORG',
'B-LOC',
'I-LOC',
'B-MISC',
'I-MISC'
]

transformer = 'studio-ousia/luke-large-finetuned-conll-2003'

# hyperparameters for network
dropout = 0.1

training_hyperparameters = {
'epochs' : 2,
'warmup_steps' : 500,                                                   
'train_batch_size': 13,                                         
'learning_rate': 1e-5
}

In [None]:
retrain_subset = pd.read_csv("retrain_processed.csv", index_col=0)
retrain_subset["tags_list"] = retrain_subset["Name"].apply(lambda x: generate_labels(x))
retrain_subset["sentences"] = retrain_subset["Name"].apply(lambda x: get_sentence_from_name(x))

rt_train, rt_valid = train_test_split(retrain_subset, test_size=0.15, stratify=retrain_subset['Race']) 

retrain_subset = pd.read_csv("retrain_processed.csv", index_col=0)
retrain_subset["tags_list"] = retrain_subset["Name"].apply(lambda x: generate_labels(x))
retrain_subset["sentences"] = retrain_subset["Name"].apply(lambda x: get_sentence_from_name(x))

retrain_dict = {"sentences": list(rt_train["sentences"]), "tags": list(rt_train["tags_list"])}
valid_dict = {"sentences": list(rt_valid["sentences"]), "tags": list(rt_valid["tags_list"])}

total_sentences = list(retrain_dict["sentences"]) + list( ["sentences"])
total_tags = list(retrain_dict["tags"]) + list(training["tags"])

valid_sentences = list(valid_dict["sentences"]) + list(validation["sentences"])
valid_tags = list(valid_dict["tags"]) + list(validation["tags"])

total_retrain_dict = {"sentences": total_sentences, "tags": total_tags}
total_valid_dict = {"sentences": valid_sentences, "tags": valid_tags}

In [None]:
model = NERDA(
dataset_training = total_retrain_dict,
dataset_validation = total_valid_dict,
tag_scheme = tag_scheme, 
tag_outside = 'O',
transformer = transformer,
dropout = dropout,
hyperparameters = training_hyperparameters
)

In [None]:
model.train()

In [None]:
processed_test_df = pd.read_csv("processed_df.csv", index_col=0)
processed_test_df["tags_list"] = processed_test_df["Name"].apply(lambda x: generate_labels(x))
processed_test_df["sentences"] = processed_test_df["Name"].apply(lambda x: get_sentence_from_name(x))
processed_test_dict = {"sentences": list(processed_test_xdf["sentences"]), "tags": list(processed_test_df["tags_list"])}

In [None]:
def generate_entities(input_string):
  return model.predict([input_string])[0]

In [None]:
processed_white_df = processed_test_df.loc[processed_test_df["Race"]=="White"].reset_index(drop=True)
processed_black_df = processed_test_df.loc[processed_test_df["Race"]=="Black"].reset_index(drop=True)
processed_api_df = processed_test_df.loc[processed_test_df["Race"]=="API"].reset_index(drop=True)
processed_hispanic_df = processed_test_df.loc[processed_test_df["Race"]=="Hispanic"].reset_index(drop=True)

processed_test_dict_w = {"sentences": list(processed_white_df["sentences"]), "tags": list(processed_white_df["tags_list"])}
processed_test_dict_b = {"sentences": list(processed_black_df["sentences"]), "tags": list(processed_black_df["tags_list"])}
processed_test_dict_a = {"sentences": list(processed_api_df["sentences"]), "tags": list(processed_api_df["tags_list"])}
processed_test_dict_h = {"sentences": list(processed_hispanic_df["sentences"]), "tags": list(processed_hispanic_df["tags_list"])}

## Primarily White Names from Curated Test Data:

In [None]:
start = timeit.default_timer()
curated_test_labels_w = processed_test_dict_w["tags"]
curated_pred_labels_w = [generate_entities(processed_test_dict_w["sentences"][i]) for i in range(len(processed_test_dict_w["sentences"]))]
print(seqeval.metrics.classification_report(curated_test_labels_w, curated_pred_labels_w, digits=4)) 
stop = timeit.default_timer()
print('LUKE Runtime: {} seconds'.format(stop - start))

## Primarily Black / African American Names from Curated Test Data:

In [None]:
start = timeit.default_timer()
curated_test_labels_b = processed_test_dict_b["tags"]
curated_pred_labels_b = [generate_entities(processed_test_dict_b["sentences"][i]) for i in range(len(processed_test_dict_b["sentences"]))]
print(seqeval.metrics.classification_report(curated_test_labels_b, curated_pred_labels_b, digits=4)) 
stop = timeit.default_timer()
print('LUKE Runtime: {} seconds'.format(stop - start))

## Primarily Asian or Native Hawaiian or Other Pacific Islander Names from Curated Test Data:

In [None]:
start = timeit.default_timer()
curated_test_labels_a = processed_test_dict_a["tags"]
curated_pred_labels_a = [generate_entities(processed_test_dict_a["sentences"][i]) for i in range(len(processed_test_dict_a["sentences"]))]
print(seqeval.metrics.classification_report(curated_test_labels_a, curated_pred_labels_a, digits=4)) 
stop = timeit.default_timer()
print('LUKE Runtime: {} seconds'.format(stop - start))

## Primarily Hispanic / Latino Names from Curated Test Data:

In [None]:
start = timeit.default_timer()
curated_test_labels_h = processed_test_dict_h["tags"]
curated_pred_labels_h = [generate_entities(processed_test_dict_h["sentences"][i]) for i in range(len(processed_test_dict_h["sentences"]))]
print(seqeval.metrics.classification_report(curated_test_labels_h, curated_pred_labels_h, digits=4)) 
stop = timeit.default_timer()
print('LUKE Runtime: {} seconds'.format(stop - start))

In [None]:
# Load the tokenizer
tokenizer = LukeTokenizer.from_pretrained("studio-ousia/luke-large-finetuned-conll-2003")

def load_documents(dataset_file):
    documents = []
    words = []
    labels = []
    sentence_boundaries = []
    with open(dataset_file) as f:
        for line in f:
            line = line.rstrip()
            if line.startswith("-DOCSTART"):
                if words:
                    documents.append(dict(
                        words=words,
                        labels=labels,
                        sentence_boundaries=sentence_boundaries
                    ))
                    words = []
                    labels = []
                    sentence_boundaries = []
                continue

            if not line:
                if not sentence_boundaries or len(words) != sentence_boundaries[-1]:
                    sentence_boundaries.append(len(words))
            else:
                items = line.split(" ")
                words.append(items[0])
                labels.append(items[-1])

    if words:
        documents.append(dict(
            words=words,
            labels=labels,
            sentence_boundaries=sentence_boundaries
        ))
        
    return documents


def load_examples(documents):
    examples = []
    max_token_length = 510
    max_mention_length = 30

    for document in tqdm(documents):
        words = document["words"]
        subword_lengths = [len(tokenizer.tokenize(w)) for w in words]
        total_subword_length = sum(subword_lengths)
        sentence_boundaries = document["sentence_boundaries"]

        for i in range(len(sentence_boundaries) - 1):
            sentence_start, sentence_end = sentence_boundaries[i:i+2]
            if total_subword_length <= max_token_length:
                # if the total sequence length of the document is shorter than the
                # maximum token length, we simply use all words to build the sequence
                context_start = 0
                context_end = len(words)
            else:
                # if the total sequence length is longer than the maximum length, we add
                # the surrounding words of the target sentence　to the sequence until it
                # reaches the maximum length
                context_start = sentence_start
                context_end = sentence_end
                cur_length = sum(subword_lengths[context_start:context_end])
                while True:
                    if context_start > 0:
                        if cur_length + subword_lengths[context_start - 1] <= max_token_length:
                            cur_length += subword_lengths[context_start - 1]
                            context_start -= 1
                        else:
                            break
                    if context_end < len(words):
                        if cur_length + subword_lengths[context_end] <= max_token_length:
                            cur_length += subword_lengths[context_end]
                            context_end += 1
                        else:
                            break

            text = ""
            for word in words[context_start:sentence_start]:
                if word[0] == "'" or (len(word) == 1 and is_punctuation(word)):
                    text = text.rstrip()
                text += word
                text += " "

            sentence_words = words[sentence_start:sentence_end]
            sentence_subword_lengths = subword_lengths[sentence_start:sentence_end]

            word_start_char_positions = []
            word_end_char_positions = []
            for word in sentence_words:
                if word[0] == "'" or (len(word) == 1 and is_punctuation(word)):
                    text = text.rstrip()
                word_start_char_positions.append(len(text))
                text += word
                word_end_char_positions.append(len(text))
                text += " "

            for word in words[sentence_end:context_end]:
                if word[0] == "'" or (len(word) == 1 and is_punctuation(word)):
                    text = text.rstrip()
                text += word
                text += " "
            text = text.rstrip()

            entity_spans = []
            original_word_spans = []
            for word_start in range(len(sentence_words)):
                for word_end in range(word_start, len(sentence_words)):
                    if sum(sentence_subword_lengths[word_start:word_end]) <= max_mention_length:
                        entity_spans.append(
                            (word_start_char_positions[word_start], word_end_char_positions[word_end])
                        )
                        original_word_spans.append(
                            (word_start, word_end + 1)
                        )

            examples.append(dict(
                text=text,
                words=sentence_words,
                entity_spans=entity_spans,
                original_word_spans=original_word_spans,
            ))

    return examples


def is_punctuation(char):
    cp = ord(char)
    if (cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or (cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126):
        return True
    cat = unicodedata.category(char)
    if cat.startswith("P"):
        return True
    return False

In [None]:
test_documents = load_documents("eng.testb")
test_examples = load_examples(test_documents)

In [None]:
def get_named_entities_custom_luke(input_row):
  words = input_row["words"]
  sentence_boundaries = input_row["sentence_boundaries"]
  start = 0
  total_labels = []
  for i in sentence_boundaries: 
    if i != 0: 
      current_string = words[start:i]
      if len(current_string) >= 120:
        midpoint = len(current_string) // 2
        first_half = current_string[:midpoint]
        second_half = current_string[midpoint:]
        prediction = model.predict([first_half])[0] + model.predict([second_half])[0]
      else: 
        prediction = model.predict([current_string])[0]
      total_labels.extend(prediction)
      start = i
  return total_labels

In [None]:
test_labels = [test_documents[i]["labels"] for i in range(len(test_documents))]
pred_labels = [get_named_entities_custom_luke(test_documents[i]) for i in range(len(test_documents))]

In [None]:
print(seqeval.metrics.classification_report(test_labels, pred_labels, digits=4)) 