Welcome to the code behind Nesta's medium article, "Training spaCy's spancat component in Python"! 👋

Although spaCy pushes for config-based training, sometimes you simply want to experiment with a component in a jupyter notebook without the headache (and commitment) of setting everything up. And although there is plenty online on how to train a custom NER model in spaCy, there is virtually nothing on how to do the same for a custom spancat model.

Training a custom spancat model looks very similar to training a custom NER model but has several key differences including:

1. the training data needs to be in a different format; 
2. you must convert your training data to a list of spaCy Examples to train the model;
3. you access the spans and span-level confidence scores differently.

This notebook will walk you through training a custom spancat model.

#### 0. Load Libraries and parameters
let's first import the libraries and parameters we will need to run this notebook.

In [1]:
import os 
import spacy
import pandas as pd
import re
from spacy import displacy

import random
from spacy.util import minibatch, compounding
from spacy.training import Example
from tqdm import tqdm
from typing import List
from spacy.pipeline.spancat import DEFAULT_SPANCAT_MODEL

import string

we're going to instantiate a blank spaCy object and name the key to store our labelled spans "sc":

In [2]:
#instantiate blank spaCy object
nlp = spacy.blank('en')
#define your span key name
span_key = "sc"

#### 1. Load Data
Once we've loaded the relevant libraries and parameters, lets download and clean up our data. The README.md tells you how to download the data 💾.

We want to extract people, organisation and locations tags. So, we can just replace the other tags in the data we don't care about with 'O'. 

In [3]:
#load the data and fill nans with sentence #s
ner_data = (pd.read_csv("data.csv", encoding='ISO-8859-1')
            .fillna(method='ffill'))

#replace tags we don't care about with 'O'
ents_to_replace = ['tim', 'gpe', 'art', 'eve', 'nat']
bad_ents = []
for ent in list(set(ner_data.Tag)):
    if any(ent.endswith(e) for e in ents_to_replace):
        bad_ents.append(ent)
        
ner_data = ner_data.replace(bad_ents, 'O')

#### 2. Format data for spaCy
Now that we've loaded and replaced the non-relevant tags with 'O', let's format the data in a way that spaCy can handle. We will need to get the start and end of spans in the text. 

We will do this with the get_span_indx function.

In [4]:
def get_span_indx(
    labels: List[str],
    words: List[str],
    sentence: str
) -> List[tuple]:
    """Gets span starts and ends for Spacy spancat component.
        
        Returns list of tuples where the first element of the 
        tuple is the span start, the second element of the tuple
        is the span end and the third element of the tuple is
        the span category. 
    """
    #gets list of indices corresponding to labelled words 
    label_indx = []
    temp_list = []

    for i, l in enumerate(labels):
        if l != 'O':
            temp_list.append(i)
        else:
            label_indx.append(temp_list)
            temp_list = []    
        if i == len(labels) - 1:
            label_indx.append(temp_list)

    clean_label_indx = [x for x in label_indx if len(x) > 0]

    spans = []
    for indx in clean_label_indx:
        if len(indx) == 1:
            span = words[indx[0]]
            label = labels[indx[0]].split('-')[-1].upper()
        else:
            span = ' '.join([words[i] for i in indx])  
            label = [labels[i].split('-')[-1].upper() for i in indx][0]
        #remove punctuation and strip whitespace for both sents and spans
        span_clean = span.translate(str.maketrans('', '', string.punctuation))
        for m in re.finditer(span_clean, sentence):
            spans.append((m.start(), m.end(), label))
    
    return spans

Once we've written a function to extract span start and end indices, we will format the training data for spaCy such that:

```
  ('Thousands of demonstrators have marched through London to protest the war in Iraq and demand the withdrawal of British troops from that country',
  {'spans': {'sc': [(48, 54, 'GEO'), (77, 81, 'GEO')]}})
```

Where the first element of the tuple is the text and the second element of the tuple is a nested dictionary where the values of the span_key (in this instance, "sc") is a list of tuples containing the span character start, span character end and span category. 

Finally, we convert this data structure into a spaCy Example object.

In [5]:
#Create spaCy compliant training data 
train_data = []
for sent, sent_info in ner_data.groupby("Sentence #"):
    words = list(sent_info["Word"])
    #convert words to sentence and get rid of spaces between punctuation characters
    sentence = re.sub(r'\s([?.!"](?:\s|$))', r'\1', " ".join(words))
    #get labels
    labels = list(sent_info['Tag'])
    #identify token span start, span ends and span category
    span_ents = get_span_indx(labels, words, sentence)
    #create spaCy compliant spans[span_key] dict
    annotation = {'spans':{span_key: span_ents}}    
    #convert sentence and annotation into spaCy examples
    train_data.append(Example.from_dict(nlp.make_doc(sentence), annotation))

#### 3. Train spancat component with labelled entities

Now that we've formatted the training data in such a way to be able to train the spancat component, let's instantiate a blank spaCy object and add our labels to the component 🏷️.

In [6]:
#spancat config 
config = {
    #this refers to the minimum probability to consider a prediction positive
    "threshold": 0.5,
    #the span key refers to the key in doc.spans 
    "spans_key": span_key,
    #this refers to the maximum number of labels to consider positive per span
    "max_positive": None,
     #a model instance that is given a list of documents with start end indices representing the labelled spans
    "model": DEFAULT_SPANCAT_MODEL,
    #A function that suggests spans. This suggester is fixed n-gram length of up to 3 tokens
    "suggester": {"@misc": "spacy.ngram_suggester.v1", "sizes": [1, 2, 3]},
}
#add spancat component to nlp object
nlp.add_pipe("spancat", config=config)
#get spancat component 
span=nlp.get_pipe('spancat')

#Add labels to spancat component 
for label in ('GEO', 'PER', 'ORG'):
    span.add_label(label)

Now that we've done that, let's get the pipe we want to train on, initialise the spacy object and create an optimizer.

In [7]:
#get pipe you want to train on 
pipe_exceptions = ["spancat"]
unaffected_pipes = [pipe for pipe in nlp.pipe_names if pipe not in pipe_exceptions]

# initialise spacy object 
nlp.initialize()
sgd = nlp.create_optimizer()

...and let the training begin! ⏲️

In [8]:
#start training the spancat component 
all_losses = []
with nlp.disable_pipes(*unaffected_pipes):
    for iteration in tqdm(range(10)):
        # shuffling examples before every iteration
        random.shuffle(train_data)
        losses = {}
        batches = minibatch(train_data, size=compounding(4.0, 32.0, 1.001))
        for batch in batches:
            nlp.update(list(batch), losses=losses, drop=0.1, sgd=sgd)
        print("epoch: {} Losses: {}".format(iteration, str(losses)))
        all_losses.append(losses['spancat'])

 10%|████▎                                      | 1/10 [03:40<33:02, 220.26s/it]

epoch: 0 Losses: {'spancat': 21944.50796819647}


 20%|████████▌                                  | 2/10 [07:27<29:56, 224.57s/it]

epoch: 1 Losses: {'spancat': 16376.95999833569}


 30%|████████████▉                              | 3/10 [11:09<26:01, 223.12s/it]

epoch: 2 Losses: {'spancat': 14718.516249835026}


 40%|█████████████████▏                         | 4/10 [14:50<22:13, 222.30s/it]

epoch: 3 Losses: {'spancat': 13677.246304473534}


 50%|█████████████████████▌                     | 5/10 [18:30<18:27, 221.54s/it]

epoch: 4 Losses: {'spancat': 12934.318437874244}


 60%|█████████████████████████▊                 | 6/10 [22:11<14:45, 221.31s/it]

epoch: 5 Losses: {'spancat': 12232.822345284116}


 70%|██████████████████████████████             | 7/10 [25:57<11:08, 222.96s/it]

epoch: 6 Losses: {'spancat': 11644.008111855655}


 80%|██████████████████████████████████▍        | 8/10 [29:45<07:28, 224.45s/it]

epoch: 7 Losses: {'spancat': 11169.782576366852}


 90%|██████████████████████████████████████▋    | 9/10 [33:32<03:45, 225.39s/it]

epoch: 8 Losses: {'spancat': 10808.241575563407}


100%|██████████████████████████████████████████| 10/10 [37:24<00:00, 224.47s/it]

epoch: 9 Losses: {'spancat': 10357.203760585631}





**Note:** we are only training for 10 iterations to save time, which likely won't yield great results. A higher number of iterations like 30 would be better.  

We can investigate the losses by printing all_losses:

In [9]:
for loss in all_losses:
    print(loss)

21944.50796819647
16376.95999833569
14718.516249835026
13677.246304473534
12934.318437874244
12232.822345284116
11644.008111855655
11169.782576366852
10808.241575563407
10357.203760585631


#### 4. Apply the custom trained spancat model 

Now that we've trained our model, lets apply it to a sentence in yesterday's news article on Lizz truss and see what spans it predicts as people, organisations and locations. 

In [10]:
# Testing the model on sentence from the BBC news article yesterday
text = "Yes, Liz Truss would deliver the governments 2050 net zero target in the United Kingdom, but the prime minister also restated her determination to issue more oil and gas licences in the North Sea."
doc = nlp(text)

We're able to print the predicted spans and span-level confidence scores

In [11]:
#print predicted spans and entity level scores
spans = doc.spans[span_key]
for span, confidence in zip(spans, spans.attrs["scores"]):
    print(span.label_, confidence)

PER 0.9257805
GEO 0.9826538
GEO 0.98000926


Finally, we can use spaCy's `displacy` to show the predicted spans in the text.

In [12]:
#show predicted spans 
displacy.render(doc, style="span")

And that's that! If you'd like to learn more about the spancat architecture 🏗️ and a bit more detail about the code, do refer back to the medium article. This is purely demonstrative so we're not training the model for a large # of iterations nor evaluating the model, tuning hyperparameters, playing around with the suggester or using a spans-specific dataset.  