## Train Named Entity Recognition model with SpaCy
This project shows how to extract information from text documents using transfer learning with pretrained model from SpaCy library.


In [1]:
! pip install spacy
! pip install mlflow
! pip install scikit-learn




In [33]:
# import libraries
from spacy.util import filter_spans
import json
import spacy
from spacy.tokens import DocBin
from tqdm import tqdm
from sklearn.model_selection import train_test_split
import mlflow
import mlflow.spacy
import subprocess

enable_mlflow = False

if enable_mlflow:
    # set the experiment id
    mlflow.end_run()
    mlflow.set_tracking_uri('http://localhost:5000')
    mlflow.start_run()

In [34]:
with open('data/bitcoin_tweets_annotated.json', 'r') as f:
    data = json.load(f)
    
print(data[0])

{'id': 12887, 'text': 'Blue Ridge Bank and Your Corp shares halted by NYSE after #bitcoin ATM announcement https://t.co/xaaZmaJKiV @MyBlueRidgeBank… https://t.co/sgBxMkP1SI', 'label': [[0, 15, 'ORG'], [44, 52, 'CRYPTO']], 'Comments': []}


### Prepare training data

In [35]:
training_data = {
    'classes' : ['CRYPTO_NAME', "CRYPTO_PRICE", "URL"],
    'annotations' : []
}

for example in data:
  data_row = {}
  data_row['text'] = example['text']
  data_row['entities'] = []

  for annotation in example['label']:
    start = annotation[0]
    end = annotation[1]
    label = annotation[2]
    data_row['entities'].append((start, end, label))
  training_data['annotations'].append(data_row)
  
print(training_data['annotations'][1])

{'text': '😎 Today, that\'s this #Thursday, we will do a "🎬 Take 2" with our friend @LeoWandersleb, #Btc #wallet #security expe… https://t.co/go6aDgRml5', 'entities': [(90, 94, 'CRYPTO')]}


In [36]:
nlp = spacy.blank("en")

def createDocBin(data: list)->DocBin:
    doc_bin = DocBin()
    for training_row  in tqdm(data):
        text = training_row['text']
        labels = training_row['entities']
        doc = nlp.make_doc(text)
        ents = []
        for start, end, label in labels:
            span = doc.char_span(start, end, label=label, alignment_mode="contract")
            if span is not None:
                ents.append(span)
        filtered_ents = filter_spans(ents)
        doc.ents = filtered_ents
        doc_bin.add(doc)
    return doc_bin

train, test = train_test_split(training_data['annotations'], test_size=0.2)

doc_bin_train = createDocBin(train)
doc_bin_test = createDocBin(test)
doc_bin_train.to_disk("train_data.spacy")
doc_bin_test.to_disk("test_data.spacy")

100%|██████████| 56/56 [00:00<00:00, 1817.87it/s]
100%|██████████| 14/14 [00:00<00:00, 2123.39it/s]


### Run commands for training model

In [37]:
!python -m spacy init fill-config base_config.cfg config.cfg

if enable_mlflow:
    # Log the configuration file
    mlflow.log_artifact("config.cfg")

# Train the model
train_command = "python -m spacy train config.cfg --output ./ --paths.train ./train_data.spacy --paths.dev ./test_data.spacy"
process = subprocess.Popen(train_command, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)

# Capture training output for metrics
stdout, stderr = process.communicate()

if process.returncode != 0:
    print("Training failed:", stderr.decode())
else:
    print("Training completed successfully.")


  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
[38;5;2m✔ Auto-filled config with all values[0m
[38;5;2m✔ Saved config[0m
config.cfg
You can now add your data and train your pipeline:
python -m spacy train config.cfg --paths.train ./train.spacy --paths.dev ./dev.spacy
Training completed successfully.


In [11]:
# Log metrics from training (example: training loss)
afterLoss = False
lineNo = 0
# capture metrics from output like below
# E    #       LOSS TOK2VEC  LOSS NER  ENTS_F  ENTS_P  ENTS_R  SCORE
# 0       0          0.00     48.50    0.00    0.00    0.00    0.00
# 202    1200         77.08     22.51   70.37   73.08   67.86    0.70
# 269    1400         23.71      6.65   78.57   78.57   78.57    0.79
# 344    1600         23.57      3.46   84.62   91.67   78.57    0.85
for line in stdout.decode().split('\n'):
    if "LOSS" not in line and not afterLoss:
        continue
    afterLoss = True
    lineNo += 1
    if lineNo <= 2:
        continue
    if "Saved pipeline" in line:
        break
    print(line)
    values = line.split()
    step = lineNo - 2
    if enable_mlflow:
        mlflow.log_metric("LOSS_TOK2VEC", float(values[2]), step=step)
        mlflow.log_metric("LOSS_NER", float(values[3]), step=step)
        mlflow.log_metric("SCORE", float(values[7]), step=step)

  0       0          0.00     49.79    0.00    0.00    0.00    0.00
 19     200        118.30   1782.31   67.74   70.00   65.62    0.68
 43     400         28.18     19.29   51.85   63.64   43.75    0.52
 73     600         29.32     13.99   61.02   66.67   56.25    0.61
109     800         10.31      2.81   57.63   62.96   53.12    0.58
155    1000          0.02      0.01   63.33   67.86   59.38    0.63
208    1200        111.20     22.25   63.33   67.86   59.38    0.63
274    1400        332.04     85.99   75.00   87.50   65.62    0.75
354    1600        101.24      8.71   65.52   73.08   59.38    0.66
454    1800          0.00      0.00   70.00   75.00   65.62    0.70
554    2000          0.00      0.00   70.00   75.00   65.62    0.70
696    2200          8.82      1.98   70.00   75.00   65.62    0.70
896    2400         69.07      4.24   60.71   70.83   53.12    0.61
1096    2600         93.22      8.87   64.41   70.37   59.38    0.64
1296    2800        428.43     68.45   63.33   

In [12]:
if enable_mlflow:
    # Log the final trained model
    mlflow.spacy.log_model(spacy_model=nlp, artifact_path="spacy_model")
    mlflow.end_run()

### Test model

In [46]:
nlp_ner = spacy.load("model-best")

doc = nlp_ner("#BTC still trading at Price down: 37052.1 € this morning. Hyper Corp announced #Bitcoin https://t.co/1XNq01CaMn")

colors = {"PRICE": "#F67DE3", "CRYPTO": "#7DF6D9", "ORG":"#7156F6"}
options = {"colors": colors} 

spacy.displacy.render(doc, style="ent", options= options, jupyter=True)