Welcome to the code behind Nesta's medium article, "Training spaCy's spancat component in Python"! 👋

Although spaCy pushes for config-based training, sometimes you simply want to experiment with a component in a jupyter notebook without the headache (and commitment) of setting everything up. And although there is plenty online on how to train a custom NER model in spaCy, there is virtually nothing on how to do the same for a custom spancat model.

Training a custom spancat model looks very similar to training a custom NER model but has two key differences: 

1. the training data needs to be in a different format and; 
2. you must convert your training data to a list of spaCy Examples to train the model.

This notebook will walk you through training a custom spancat model.

#### 0. Load Libraries and parameters
let's first import the libraries and parameters we will need to run this notebook.

In [7]:
import os 
import spacy
import pandas as pd
import re
import string 
from spacy import displacy

import random
from spacy.util import minibatch, compounding
from spacy.training import Example
import pickle
from tqdm import tqdm
from typing import List
from spacy.pipeline.spancat import DEFAULT_SPANCAT_MODEL

we're going to name the key to store our labelled spans "sc".

In [2]:
span_key = "sc"

#### 1. Load Data
Once we've loaded the relevant libraries and parameters, lets download and clean up our data. The README.md tells you how to download the data 💾.

We want to extract people, organisation and locations tags. So, we can just replace the other tags in the data we don't care about with 'O'. 

In [3]:
#load the data and fill nans with sentence #s
ner_data = (pd.read_csv("ner_datasetreference.csv", encoding='ISO-8859-1')
            .fillna(method='ffill'))

#replace tags we don't care about with 'O'
ents_to_replace = ['tim', 'gpe', 'art', 'eve', 'nat']
bad_ents = []
for ent in list(set(ner_data.Tag)):
    if any(ent.endswith(e) for e in ents_to_replace):
        bad_ents.append(ent)
        
ner_data = ner_data.replace(bad_ents, 'O')

#### 2. Format data for spaCy
Now that we've loaded and replaced the non-relevant tags with 'O', let's format the data in a way that spaCy can handle. We will need to get the start and end of spans in the text. 

We will do this with the get_span_indx function.

In [4]:
def get_span_indx(
    labels: List[str]
) -> List[tuple]:
    """Gets span starts and ends for Spacy spancat component.
        
        Returns list of tuples where the first element of the 
        tuple is the span start, the second element of the tuple
        is the span end and the third element of the tuple is
        the span category. 
    """
    label_indx = []
    temp_list = []
    for i, l in enumerate(labels):
        if l != 'O':
            temp_list.append(i)
        else:
            label_indx.append(temp_list)
            temp_list = []    
        if i == len(labels) - 1:
            label_indx.append(temp_list)

    clean_label_indx = [x for x in label_indx if len(x) > 0]
    
    spans = []
    for indx in clean_label_indx:
        if len(indx) == 1:
            spans.append((indx[0], indx[0] + 1))
        else:
            spans.append((indx[0], indx[-1] + 1))

    return [(s[0], s[1], labels[s[0]].split('-')[-1].upper()) for s in spans]

Once we've written a function to extract span start and end indices, we will format the training data for spaCy such that:

```
  ('Thousands of demonstrators have marched through London to protest the war in Iraq and demand the withdrawal of British troops from that country',
  {'spans': {'sc': [(6, 7, 'GEO'), (12, 13, 'GEO')]}})
```

Where the first element of the tuple is the text and the second element of the tuple is a nested dictionary where the values of the span_key (in this instance, "sc") is a list of tuples containing the span token start, span token end and span category. 

In [5]:
#Create spaCy compliant training data 
train_data = []
for sent, sent_info in ner_data.groupby("Sentence #"):
    sentence = re.sub(r'\s([?.!"](?:\s|$))', r'\1', " ".join(sent_info["Word"]))
    sentence_no_punct = sentence.translate(str.maketrans('', '', string.punctuation))
    labels = list(sent_info['Tag'])
    span_ents = get_span_indx(labels)
    train_data.append((sentence_no_punct, {'spans':{span_key: [(span_ent[0], span_ent[1], span_ent[2]) for span_ent in span_ents]}}))

#### 3. Train spancat component with labelled entities

Now that we've formatted the training data in such a way to be able to train the spancat component, let's instantiate a blank spaCy object and add our labels to the component 🏷️.

In [22]:
#instantiate blank spaCy object
nlp = spacy.blank('en')
config = {
    #minimum confidence threshold 
    "threshold": 0.2,
    "spans_key": span_key,
    "max_positive": None,
    "model": DEFAULT_SPANCAT_MODEL,
    #we're suggesting fixed n-gram length of up to 3 tokens
    "suggester": {"@misc": "spacy.ngram_suggester.v1", "sizes": [1, 2, 3]},
}
nlp.add_pipe("spancat", config=config)
span=nlp.get_pipe('spancat')

#Add labels to spancat component 
for label in ('GEO', 'PER', 'ORG'):
    span.add_label(label)

Now that we've done that, let's get the pipe we want to train on, initialise the spacy object and create an optimizer.

In [23]:
#get pipe you want to train on 
pipe_exceptions = ["spancat"]
unaffected_pipes = [pipe for pipe in nlp.pipe_names if pipe not in pipe_exceptions]

# initialise spacy object 
nlp.initialize()
sgd = nlp.create_optimizer()

...and let the training begin! ⏲️

In [24]:
#start training spancat component 
all_losses = []
with nlp.disable_pipes(*unaffected_pipes):
    for iteration in tqdm(range(30)):
        # shuffling examples before every iteration
        random.shuffle(train_data)
        losses = {}
        # batch up the examples using spaCy's minibatch
        batches = minibatch(train_data, size=compounding(4.0, 32.0, 1.001))
        batch_examples = []
        for batch in batches:
            texts, annotations = zip(*batch)
            examples = Example.from_dict(nlp.make_doc(texts[0]), annotations[0])
            batch_examples.append(examples)
        #nlp.update for spacy component takes list of examples 
        nlp.update(batch_examples, losses=losses, drop=0.1, sgd=sgd)
        all_losses.append(losses['spancat'])

100%|███████████████████████████████████████████| 30/30 [06:32<00:00, 13.08s/it]


We can investigate the losses by printing all_losses:

In [25]:
for loss in all_losses:
    print(loss)

157150.109375
84768.9296875
41765.94921875
12769.548828125
2690.6083984375
448.4676818847656
158.1262969970703
108.01145935058594
106.51181030273438
83.88938903808594
47.929771423339844
31.802108764648438
50.90104293823242
43.02567672729492
33.061344146728516
30.267887115478516
30.3626708984375
37.78573989868164
27.46986961364746
35.77145004272461
34.98078155517578
29.197343826293945
29.09355354309082
23.123306274414062
41.04826736450195
26.509939193725586
41.246124267578125
32.0553092956543
37.610897064208984
34.841835021972656


#### 4. Apply the custom trained spancat model 

Now that we've trained our model, lets apply it to a news article sentence and see what spans it predicts as people, organisations and locations. 

In [30]:
# Testing the model on sentence from the BBC news article yesterday
text = "Yes Liz Truss would deliver the governments 2050 net zero target in the UK and france but the prime minister also restated her determination to issue more oil and gas licences in the North Sea"
doc = nlp(text)

We're able to print the predicted spans and span-level confidence scores

In [31]:
#print predicted spans and entity level scores
spans = doc.spans[span_key]
for span, confidence in zip(spans, spans.attrs["scores"]):
    print(span.label_, confidence)

Finally, we can use spaCy's `displacy` to show the predicted spans in the text.

In [32]:
#show predicted spans 
displacy.render(doc, style="span")

And that's that! If you'd like to learn more about the spancat architecture 🏗️ and a bit more detail about the code, do refer back to the medium article. This is purely demonstrative so we're not evaluating the model, tuning hyperparameters, playing around with the suggester or using a spans-specific dataset.  