Welcome to the code behind Nesta's medium article, "Training spaCy's spancat component in Python"! 👋

Although spaCy pushes for config-based training, sometimes you simply want to experiment with a component in a jupyter notebook without the headache (and commitment) of setting everything up. And although there is plenty online on how to train a custom NER model in spaCy, there is virtually nothing on how to do the same for a custom spancat model.

Training a custom spancat model looks very similar to training a custom NER model but has several key differences including:

1. the training data needs to be in a different format; 
2. you must convert your training data to a list of spaCy Examples to train the model;
3. you access the spans and span-level confidence scores differently.

This notebook will walk you through training a custom spancat model.

#### 0. Load Libraries and parameters
let's first import the libraries and parameters we will need to run this notebook.

In [5]:
import os 
import spacy
import pandas as pd
import re
from spacy import displacy

import random
from spacy.util import minibatch, compounding
from spacy.training import Example
from tqdm import tqdm
from typing import List
from spacy.pipeline.spancat import DEFAULT_SPANCAT_MODEL

we're going to instantiate a blank spaCy object and name the key to store our labelled spans "sc":

In [6]:
#instantiate blank spaCy object
nlp = spacy.blank('en')
#define your span key name
span_key = "sc"

#### 1. Load Data
Once we've loaded the relevant libraries and parameters, lets download and clean up our data. The README.md tells you how to download the data 💾.

We want to extract people, organisation and locations tags. So, we can just replace the other tags in the data we don't care about with 'O'. 

In [11]:
#load the data and fill nans with sentence #s
ner_data = (pd.read_csv("data.csv", encoding='ISO-8859-1')
            .fillna(method='ffill'))

#replace tags we don't care about with 'O'
ents_to_replace = ['tim', 'gpe', 'art', 'eve', 'nat']
bad_ents = []
for ent in list(set(ner_data.Tag)):
    if any(ent.endswith(e) for e in ents_to_replace):
        bad_ents.append(ent)
        
ner_data = ner_data.replace(bad_ents, 'O')

#### 2. Format data for spaCy
Now that we've loaded and replaced the non-relevant tags with 'O', let's format the data in a way that spaCy can handle. We will need to get the start and end of spans in the text. 

We will do this with the get_span_indx function.

In [12]:
def get_span_indx(
    labels: List[str],
    words: List[str],
    sentence: str
) -> List[tuple]:
    """Gets span starts and ends for Spacy spancat component.
        
        Returns list of tuples where the first element of the 
        tuple is the span start, the second element of the tuple
        is the span end and the third element of the tuple is
        the span category. 
    """
    label_indx = []
    temp_list = []

    for i, l in enumerate(labels):
        if l != 'O':
            temp_list.append(i)
        else:
            label_indx.append(temp_list)
            temp_list = []    
        if i == len(labels) - 1:
            label_indx.append(temp_list)

    clean_label_indx = [x for x in label_indx if len(x) > 0]

    spans = []
    for indx in clean_label_indx:
        if len(indx) == 1:
            span = words[indx[0]]
            label = labels[indx[0]].split('-')[-1].upper()
        else:
            span = ' '.join([words[i] for i in indx])  
            label = [labels[i].split('-')[-1].upper() for i in indx][0]
        for m in re.finditer(span, sentence):
            spans.append((m.start(), m.end(), label))
    
    return spans 

Once we've written a function to extract span start and end indices, we will format the training data for spaCy such that:

```
  ('Thousands of demonstrators have marched through London to protest the war in Iraq and demand the withdrawal of British troops from that country',
  {'spans': {'sc': [(48, 54, 'GEO'), (77, 81, 'GEO')]}})
```

Where the first element of the tuple is the text and the second element of the tuple is a nested dictionary where the values of the span_key (in this instance, "sc") is a list of tuples containing the span character start, span character end and span category. 

Finally, we convert this data structure into a spaCy Example object.

In [13]:
#Create spaCy compliant training data 
train_data = []
for sent, sent_info in ner_data.groupby("Sentence #"):
    words = list(sent_info["Word"])
    #convert words to sentence
    sentence = re.sub(r'\s([?.!"](?:\s|$))', r'\1', " ".join(words))
    #get labels
    labels = list(sent_info['Tag'])
    #identify token span start, span ends and span category
    try:
        span_ents = get_span_indx(labels, words, sentence)
    except:
    #ignore special characters that mess up regex matching
        pass
    #create spaCy compliant spans[span_key] dict
    annotation = {'spans':{span_key: [(span_ent[0], span_ent[1], span_ent[2]) for span_ent in span_ents]}}
    #convert sentence and annotation into spaCy example
    train_data.append(Example.from_dict(nlp.make_doc(sentence), annotation))

#### 3. Train spancat component with labelled entities

Now that we've formatted the training data in such a way to be able to train the spancat component, let's instantiate a blank spaCy object and add our labels to the component 🏷️.

In [14]:
#spancat config 
config = {
    #minimum confidence threshold 
    "threshold": 0.5,
    #span key 
    "spans_key": span_key,
    "max_positive": None,
    "model": DEFAULT_SPANCAT_MODEL,
    #our suggester is fixed n-gram length of up to 3 tokens
    "suggester": {"@misc": "spacy.ngram_suggester.v1", "sizes": [1, 2, 3]},
}
#add spancat component to nlp object
nlp.add_pipe("spancat", config=config)
#get spancat component 
span=nlp.get_pipe('spancat')

#Add labels to spancat component 
for label in ('GEO', 'PER', 'ORG'):
    span.add_label(label)

Now that we've done that, let's get the pipe we want to train on, initialise the spacy object and create an optimizer.

In [15]:
#get pipe you want to train on 
pipe_exceptions = ["spancat"]
unaffected_pipes = [pipe for pipe in nlp.pipe_names if pipe not in pipe_exceptions]

# initialise spacy object 
nlp.initialize()
sgd = nlp.create_optimizer()

...and let the training begin! ⏲️

In [12]:
#start training the spancat component 
all_losses = []
with nlp.disable_pipes(*unaffected_pipes):
    for iteration in tqdm(range(30)):
        # shuffling examples before every iteration
        random.shuffle(train_data)
        losses = {}
        batches = minibatch(train_data, size=compounding(4.0, 32.0, 1.001))
        for batch in batches:
            nlp.update(list(batch), losses=losses, drop=0.1, sgd=sgd)
        print("epoch: {} Losses: {}".format(iteration, str(losses)))
        all_losses.append(losses['spancat'])

  3%|█▎                                       | 1/30 [03:52<1:52:13, 232.19s/it]

epoch: 0 Losses: {'spancat': 23819.832373756915}


  7%|██▋                                      | 2/30 [07:45<1:48:40, 232.88s/it]

epoch: 1 Losses: {'spancat': 17845.14417686942}


 10%|████                                     | 3/30 [11:37<1:44:40, 232.61s/it]

epoch: 2 Losses: {'spancat': 16125.311203195335}


 13%|█████▍                                   | 4/30 [15:31<1:40:56, 232.95s/it]

epoch: 3 Losses: {'spancat': 14996.864675723875}


 17%|██████▊                                  | 5/30 [19:25<1:37:14, 233.36s/it]

epoch: 4 Losses: {'spancat': 14163.882134490937}


 20%|████████▏                                | 6/30 [23:22<1:33:50, 234.61s/it]

epoch: 5 Losses: {'spancat': 13418.770895299465}


 23%|█████████▌                               | 7/30 [27:17<1:30:00, 234.79s/it]

epoch: 6 Losses: {'spancat': 12951.496103297977}


 27%|██████████▉                              | 8/30 [31:11<1:25:56, 234.40s/it]

epoch: 7 Losses: {'spancat': 12407.294817013433}


 30%|████████████▎                            | 9/30 [35:04<1:21:53, 233.98s/it]

epoch: 8 Losses: {'spancat': 11920.983098373748}


 33%|█████████████▎                          | 10/30 [38:57<1:17:57, 233.90s/it]

epoch: 9 Losses: {'spancat': 11526.122606375488}


 37%|██████████████▋                         | 11/30 [42:52<1:14:06, 234.01s/it]

epoch: 10 Losses: {'spancat': 11231.918217996834}


 40%|████████████████                        | 12/30 [46:46<1:10:12, 234.05s/it]

epoch: 11 Losses: {'spancat': 10934.958683155586}


 43%|█████████████████▎                      | 13/30 [50:41<1:06:25, 234.46s/it]

epoch: 12 Losses: {'spancat': 10697.222331435423}


 47%|██████████████████▋                     | 14/30 [54:36<1:02:34, 234.65s/it]

epoch: 13 Losses: {'spancat': 10371.370365341427}


 50%|█████████████████████                     | 15/30 [58:30<58:33, 234.21s/it]

epoch: 14 Losses: {'spancat': 10129.045628480642}


 53%|█████████████████████▎                  | 16/30 [1:02:23<54:36, 234.06s/it]

epoch: 15 Losses: {'spancat': 9980.203197981959}


 57%|██████████████████████▋                 | 17/30 [1:06:16<50:39, 233.80s/it]

epoch: 16 Losses: {'spancat': 9721.756982092862}


 60%|████████████████████████                | 18/30 [1:10:10<46:43, 233.63s/it]

epoch: 17 Losses: {'spancat': 9501.349641645858}


 63%|█████████████████████████▎              | 19/30 [1:14:03<42:47, 233.39s/it]

epoch: 18 Losses: {'spancat': 9318.827465295537}


 67%|██████████████████████████▋             | 20/30 [1:17:58<39:00, 234.03s/it]

epoch: 19 Losses: {'spancat': 9166.11591950753}


 70%|████████████████████████████            | 21/30 [1:21:53<35:08, 234.30s/it]

epoch: 20 Losses: {'spancat': 9001.56578258684}


 73%|█████████████████████████████▎          | 22/30 [1:25:48<31:15, 234.45s/it]

epoch: 21 Losses: {'spancat': 8874.431672896259}


 77%|██████████████████████████████▋         | 23/30 [1:29:44<27:25, 235.01s/it]

epoch: 22 Losses: {'spancat': 8765.872489455149}


 80%|████████████████████████████████        | 24/30 [1:33:40<23:31, 235.28s/it]

epoch: 23 Losses: {'spancat': 8603.444481174229}


 83%|█████████████████████████████████▎      | 25/30 [1:37:34<19:34, 234.95s/it]

epoch: 24 Losses: {'spancat': 8482.731262199755}


 87%|██████████████████████████████████▋     | 26/30 [1:41:31<15:42, 235.66s/it]

epoch: 25 Losses: {'spancat': 8357.392183293312}


 90%|████████████████████████████████████    | 27/30 [1:45:28<11:47, 235.97s/it]

epoch: 26 Losses: {'spancat': 8203.214589320942}


 93%|█████████████████████████████████████▎  | 28/30 [1:49:22<07:50, 235.34s/it]

epoch: 27 Losses: {'spancat': 8110.240328924367}


 97%|██████████████████████████████████████▋ | 29/30 [1:53:20<03:56, 236.25s/it]

epoch: 28 Losses: {'spancat': 8024.647758835301}


100%|████████████████████████████████████████| 30/30 [1:57:17<00:00, 234.57s/it]

epoch: 29 Losses: {'spancat': 7984.9958213473}





We can investigate the losses by printing all_losses:

In [13]:
for loss in all_losses:
    print(loss)

23819.832373756915
17845.14417686942
16125.311203195335
14996.864675723875
14163.882134490937
13418.770895299465
12951.496103297977
12407.294817013433
11920.983098373748
11526.122606375488
11231.918217996834
10934.958683155586
10697.222331435423
10371.370365341427
10129.045628480642
9980.203197981959
9721.756982092862
9501.349641645858
9318.827465295537
9166.11591950753
9001.56578258684
8874.431672896259
8765.872489455149
8603.444481174229
8482.731262199755
8357.392183293312
8203.214589320942
8110.240328924367
8024.647758835301
7984.9958213473


#### 4. Apply the custom trained spancat model 

Now that we've trained our model, lets apply it to a sentence in yesterday's news article on Lizz truss and see what spans it predicts as people, organisations and locations. 

In [17]:
# Testing the model on sentence from the BBC news article yesterday
text = "Yes, Liz Truss would deliver the governments 2050 net zero target in the United Kingdom, but the prime minister also restated her determination to issue more oil and gas licences in the North Sea."
doc = nlp(text)

We're able to print the predicted spans and span-level confidence scores

In [18]:
#print predicted spans and entity level scores
spans = doc.spans[span_key]
for span, confidence in zip(spans, spans.attrs["scores"]):
    print(span.label_, confidence)

PER 0.9284754
GEO 0.98638225
GEO 0.98024577


Finally, we can use spaCy's `displacy` to show the predicted spans in the text.

In [19]:
#show predicted spans 
displacy.render(doc, style="span")

And that's that! If you'd like to learn more about the spancat architecture 🏗️ and a bit more detail about the code, do refer back to the medium article. This is purely demonstrative so we're not evaluating the model, tuning hyperparameters, playing around with the suggester or using a spans-specific dataset.  