Welcome to the code behind Nesta's medium article, "Training spaCy's spancat component in Python"! 👋

#### 0. Load Libraries and parameters
let's first import the libraries and parameters we will need to run this notebook.

In [9]:
import os 
import spacy
import pandas as pd
import re
import string 
from spacy import displacy

import random
from spacy.util import minibatch, compounding
from spacy.training import Example
import pickle
from tqdm import tqdm
from typing import List

we're going to name the key to store our labelled spans "sc".

In [5]:
span_key = "sc"

#### 1. Load Data
Once we've loaded the relevant libraries and parameters, lets download and clean up our data. The README.md tells you how to download the data 💾.

We want to extract people, organisation and locations tags. So, we can just replace the other tags in the data we don't care about with 'O'. 

In [6]:
#load the data and fill nans with sentence #s
ner_data = (pd.read_csv("ner_datasetreference.csv", encoding='ISO-8859-1')
            .fillna(method='ffill'))

#replace tags we don't care about with 'O'
ents_to_replace = ['tim', 'gpe', 'art', 'eve', 'nat']
bad_ents = []
for ent in list(set(ner_data.Tag)):
    if any(ent.endswith(e) for e in ents_to_replace):
        bad_ents.append(ent)
        
ner_data = ner_data.replace(bad_ents, 'O')

#### 2. Format data for spaCy
Now that we've loaded and replaced the non-relevant tags with 'O', let's format the data in a way that spaCy can handle. We will need to get the start and end of spans in the text. 

We will do this with the get_span_indx function.

In [11]:
def get_span_indx(
    labels: List[str]
) -> List[tuple]:
    """Gets span starts and ends for Spacy spancat component.
        
        Returns list of tuples where the first element of the 
        tuple is the span start, the second element of the tuple
        is the span end and the third element of the tuple is
        the span category. 
    """
    label_indx = []
    temp_list = []
    for i, l in enumerate(labels):
        if l != 'O':
            temp_list.append(i)
        else:
            label_indx.append(temp_list)
            temp_list = []    
        if i == len(labels) - 1:
            label_indx.append(temp_list)

    clean_label_indx = [x for x in label_indx if len(x) > 0]
    
    spans = []
    for indx in clean_label_indx:
        if len(indx) == 1:
            spans.append((indx[0], indx[0] + 1))
        else:
            spans.append((indx[0], indx[-1] + 1))

    return [(s[0], s[1], labels[s[0]].split('-')[-1].upper()) for s in spans]

Once we've written a function to extract span start and end indices, we will format the training data for spaCy such that:

```
  ('Thousands of demonstrators have marched through London to protest the war in Iraq and demand the withdrawal of British troops from that country',
  {'spans': {'sc': [(6, 7, 'GEO'), (12, 13, 'GEO')]}})
```

Where the first element of the tuple is the text and the second element of the tuple is a nested dictionary where the values of the span_key (in this instance, "sc") is a list of tuples containing the span token start, span token end and span category. 

In [13]:
#Create spaCy compliant training data 
train_data = []
for sent, sent_info in ner_data.groupby("Sentence #"):
    sentence = re.sub(r'\s([?.!"](?:\s|$))', r'\1', " ".join(sent_info["Word"]))
    sentence_no_punct = sentence.translate(str.maketrans('', '', string.punctuation))
    labels = list(sent_info['Tag'])
    span_ents = get_span_indx(labels)
    train_data.append((sentence_no_punct, {'spans':{span_key: [(span_ent[0], span_ent[1], span_ent[2]) for span_ent in span_ents]}}))

#### 3. Train spancat component with labelled entities

Now that we've formatted the training data in such a way to be able to train the spancat component, let's instantiate a blank spaCy object and add our labels to the component 🏷️.

In [14]:
#instantiate blank spaCy object
nlp = spacy.blank('en')
#add spancat component to spaCy object
nlp.add_pipe("spancat", config={"spans_key": span_key})
#get spancat component
span=nlp.get_pipe('spancat')

#Add labels to spancat component 
for label in ('GEO', 'PER', 'ORG'):
    span.add_label(label)

Now that we've done that, let's get the pipe we want to train on, initialise the spacy object and create an optimizer.

In [15]:
#get pipe you want to train on 
pipe_exceptions = ["spancat"]
unaffected_pipes = [pipe for pipe in nlp.pipe_names if pipe not in pipe_exceptions]

# initialise spacy object 
nlp.initialize()
sgd = nlp.create_optimizer()

...and let the training begin! ⏲️

In [16]:
# initialise spacy object
nlp.initialize()
sgd = nlp.create_optimizer()

#start training spancat component 
all_losses = []
with nlp.disable_pipes(*unaffected_pipes):
    for iteration in tqdm(range(2)):
        # shufling examples  before every iteration
        random.shuffle(train_data)
        losses = {}
        # batch up the examples using spaCy's minibatch
        batches = minibatch(train_data, size=compounding(4.0, 32.0, 1.001))
        batch_examples = []
        for batch in batches:
            texts, annotations = zip(*batch)
            examples = Example.from_dict(nlp.make_doc(texts[0]), annotations[0])
            batch_examples.append(examples)
        #nlp.update for spacy component takes list of examples 
        nlp.update(batch_examples, losses=losses, drop=0.1, sgd=sgd)
        all_losses.append(losses['spancat'])

100%|█████████████████████████████████████████████| 2/2 [00:26<00:00, 13.27s/it]


We can invesitage the losses by printing all_losses:

In [24]:
print(all_losses)

[162956.734375, 89576.71875]


#### 4. Apply the custom trained spancat model 

Now that we've trained our model, lets apply it to a news article sentence and see what spans it predicts as people, organisations and locations. 

In [17]:
# Testing the model on sentence from the BBC news article yesterday
text = "Yes, Liz Truss would deliver the government’s 2050 net zero target in the U.K and france, but the prime minister also restated her determination to issue more oil and gas licences in the North Sea."
doc = nlp(text)

We're able to print the predicted spans and span-level confidence scores

In [18]:
#print predicted spans and entity level scores
spans = doc.spans[span_key]
for span, confidence in zip(spans, spans.attrs["scores"]):
    print(span.label_, confidence)

GEO 0.77190596
GEO 0.73341775
ORG 0.5466638
GEO 0.6097495
PER 0.5765578
GEO 0.7282142
ORG 0.7250457
PER 0.7744567
GEO 0.6433577
GEO 0.5784256
ORG 0.70230997
PER 0.54134667
ORG 0.7090536
GEO 0.90961087
ORG 0.86782706
GEO 0.58660764
GEO 0.8568609
GEO 0.51240474
GEO 0.7408936
ORG 0.596305
GEO 0.63694704
PER 0.8106931
ORG 0.5563701
ORG 0.731171
GEO 0.51841235
PER 0.5638563
GEO 0.6822313
GEO 0.5060679


Finally, we can use spaCy's `displacy` to show the predicted spans in the text.

In [None]:
#show predicted spans 
displacy.serve(doc, style="span")




Using the 'span' visualizer
Serving on http://0.0.0.0:5000 ...



And that's that! If you'd like to learn more about the spancat architecture 🏗️ and a bit more detail about the code, do refer back to the medium article. This is purely demonstrative so we're note evaluating the model or even playing around with the suggester. 