In [2]:
import os
import re
import random
import string
from typing import List
import spacy
from spacy import displacy
from spacy.util import minibatch, compounding
from spacy.training import Example
from spacy.pipeline.spancat import DEFAULT_SPANCAT_MODEL
import pandas as pd
from tqdm import tqdm
#instantiate blank spaCy object
nlp = spacy.blank('en')
#define your span key name
span_key = "sc"

In [3]:
#load the data and fill nans with sentence #s
ner_data = (pd.read_csv("data.csv", encoding='ISO-8859-1')
            .fillna(method='ffill'))

#replace tags we don't care about with 'O'
ents_to_replace = ['tim', 'gpe', 'art', 'eve', 'nat']
bad_ents = []
for ent in list(set(ner_data.Tag)):
    if any(ent.endswith(e) for e in ents_to_replace):
        bad_ents.append(ent)
        
ner_data = ner_data.replace(bad_ents, 'O')

In [4]:
def get_span_indx(
    labels: List[str],
    words: List[str],
    sentence: str
) -> List[tuple]:
    """Gets span starts and ends for Spacy spancat component.
        
        Returns list of tuples where the first element of the 
        tuple is the span start, the second element of the tuple
        is the span end and the third element of the tuple is
        the span category. 
    """
    #gets list of indices corresponding to labelled words 
    label_indx = []
    temp_list = []

    for i, l in enumerate(labels):
        if l != 'O':
            temp_list.append(i)
        else:
            label_indx.append(temp_list)
            temp_list = []    
        if i == len(labels) - 1:
            label_indx.append(temp_list)

    clean_label_indx = [x for x in label_indx if len(x) > 0]

    spans = []
    for indx in clean_label_indx:
        if len(indx) == 1:
            span = words[indx[0]]
            label = labels[indx[0]].split('-')[-1].upper()
        else:
            span = ' '.join([words[i] for i in indx])  
            label = [labels[i].split('-')[-1].upper() for i in indx][0]
        #remove punctuation and strip whitespace for spans
        span_clean = span.translate(str.maketrans('', '', string.punctuation))
        for m in re.finditer(span_clean, sentence):
            spans.append((m.start(), m.end(), label))
    
    return spans

In [5]:
#Create spaCy compliant training data 
train_data = []
for sent, sent_info in ner_data.groupby("Sentence #"):
    words = list(sent_info["Word"])
    #convert words to sentence and get rid of spaces between punctuation characters
    sentence = re.sub(r'\s([?.!"](?:\s|$))', r'\1', " ".join(words))
    #get labels
    labels = list(sent_info['Tag'])
    #identify token span start, span ends and span category
    span_ents = get_span_indx(labels, words, sentence)
    #create spaCy compliant spans[span_key] dict
    annotation = {'spans':{span_key: span_ents}}    
    #convert sentence and annotation into spaCy examples
    train_data.append(Example.from_dict(nlp.make_doc(sentence), annotation))

In [6]:
#spancat config - the definitions of each parameter are taken from spaCy's documentation 
config = {
    #this refers to the minimum probability to consider a prediction positive
    "threshold": 0.5,
    #the span key refers to the key in doc.spans 
    "spans_key": span_key,
    #this refers to the maximum number of labels to consider positive per span
    "max_positive": None,
     #a model instance that is given a list of documents with start end indices representing the labelled spans
    "model": DEFAULT_SPANCAT_MODEL,
    #A function that suggests spans. This suggester is fixed n-gram length of up to 3 tokens
    "suggester": {"@misc": "spacy.ngram_suggester.v1", "sizes": [1, 2, 3]},
}
#add spancat component to nlp object
nlp.add_pipe("spancat", config=config)
#get spancat component 
span=nlp.get_pipe('spancat')

#Add labels to spancat component 
for label in ('GEO', 'PER', 'ORG'):
    span.add_label(label)

In [7]:
#get pipe you want to train on 
pipe_exceptions = ["spancat"]
unaffected_pipes = [pipe for pipe in nlp.pipe_names if pipe not in pipe_exceptions]

# initialise spacy object 
nlp.initialize()
sgd = nlp.create_optimizer()