# NER with BERT in Spark NLP

<h5>Import libraries and download datasets</h5>

In [1]:
from pyspark.sql import SparkSession
from pyspark.ml import Pipeline

import sparknlp
from sparknlp.annotator import *
from sparknlp.common import *
from sparknlp.base import *



In [2]:
data_folder = r"../../Dataset/"

spark = sparknlp.start(gpu=True)
print("Spark NLP version: ", sparknlp.version())
print("Apache Spark version: ", spark.version)

Spark NLP version:  2.6.2
Apache Spark version:  2.4.7


<h5>Building NER pipeline</h5>

In [3]:
# Convert the CoNLL file to Spark data frame with additional fields
path = data_folder+"conll2003/eng.train";
with open(path) as f:
    c=f.read()

print (c[:500])

-DOCSTART- -X- -X- O

EU NNP B-NP B-ORG
rejects VBZ B-VP O
German JJ B-NP B-MISC
call NN I-NP O
to TO B-VP O
boycott VB I-VP O
British JJ B-NP B-MISC
lamb NN I-NP O
. . O O

Peter NNP B-NP B-PER
Blackburn NNP I-NP I-PER

BRUSSELS NNP B-NP B-LOC
1996-08-22 CD I-NP O

The DT B-NP O
European NNP I-NP B-ORG
Commission NNP I-NP I-ORG
said VBD B-VP O
on IN B-PP O
Thursday NNP B-NP O
it PRP B-NP O
disagreed VBD B-VP O
with IN B-PP O
German JJ B-NP B-MISC
advice NN I-NP O
to TO B-PP O
consumers NNS B-NP


In [4]:
from sparknlp.training import CoNLL
training_data = CoNLL().readDataset(spark, path)
training_data.show()

+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+
|                text|            document|            sentence|               token|                 pos|               label|
+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+
|EU rejects German...|[[document, 0, 47...|[[document, 0, 47...|[[token, 0, 1, EU...|[[pos, 0, 1, NNP,...|[[named_entity, 0...|
|     Peter Blackburn|[[document, 0, 14...|[[document, 0, 14...|[[token, 0, 4, Pe...|[[pos, 0, 4, NNP,...|[[named_entity, 0...|
| BRUSSELS 1996-08-22|[[document, 0, 18...|[[document, 0, 18...|[[token, 0, 7, BR...|[[pos, 0, 7, NNP,...|[[named_entity, 0...|
|The European Comm...|[[document, 0, 18...|[[document, 0, 18...|[[token, 0, 2, Th...|[[pos, 0, 2, DT, ...|[[named_entity, 0...|
|Germany 's repres...|[[document, 0, 21...|[[document, 0, 21...|[[token, 0, 6, Ge...|[[pos, 0, 6, NNP,..

In [5]:
training_data.count()

14041

<h5>Loading Bert</h5>

In Spark NLP, we have four pre-trained variants of BERT: <b>bert_base_uncased</b> , <b>bert_base_cased</b> , <b>bert_large_uncased</b> , <b>bert_large_cased</b> . Which one to use depends on your use case, train set, and the complexity of the task you are trying to model.

In the code snippet above, we basically load the bert_base_cased version from Spark NLP public resources and point the sentence and token columns in setInputCols(). In short, BertEmbeddings() annotator will take sentence and token columns and populate Bert embeddings in bert column. In general, each word is translated to a 768-dimensional vector. The parameter setPoolingLayer() can be set to 0 as the first layer and fastest, -1 as the last layer and -2 as the second-to-last-hidden layer.

As explained by the authors of official BERT paper, different BERT layers capture different information. The last layer is too closed to the target functions (i.e. masked language model and next sentence prediction) during pre-training, therefore it may be biased to those targets. If you want to use the last hidden layer anyway, please feel free to set pooling_layer=-1. Intuitively, pooling_layer=-1 is close to the training output, so it may be biased to the training targets. If you don't fine-tune the model, then this could lead to a bad representation. That said, it is a matter of trade-off between model accuracy and computational resources you have.

In [7]:
bert_annotator = BertEmbeddings.pretrained('bert_base_cased', 'en') \
 .setInputCols(["sentence",'token'])\
 .setOutputCol("bert")\
 .setCaseSensitive(False)

bert_base_cased download started this may take some time.
Approximate size to download 389.1 MB
[OK!]


In [13]:
from sparknlp.training import CoNLL

test_data = CoNLL().readDataset(spark, data_folder+"conll2003/eng.testa")
test_data = bert_annotator.transform(test_data)
test_data.write.parquet("test_withEmbeds.parquet")
test_data.show(3)

+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+
|                text|            document|            sentence|               token|                 pos|               label|                bert|
+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+
|CRICKET - LEICEST...|[[document, 0, 64...|[[document, 0, 64...|[[token, 0, 6, CR...|[[pos, 0, 6, NNP,...|[[named_entity, 0...|[[word_embeddings...|
|   LONDON 1996-08-30|[[document, 0, 16...|[[document, 0, 16...|[[token, 0, 5, LO...|[[pos, 0, 5, NNP,...|[[named_entity, 0...|[[word_embeddings...|
|West Indian all-r...|[[document, 0, 18...|[[document, 0, 18...|[[token, 0, 3, We...|[[pos, 0, 3, NNP,...|[[named_entity, 0...|[[word_embeddings...|
+--------------------+--------------------+--------------------+--------------------+--------------------+

In [14]:
test_data.select("bert.result","bert.embeddings",'label.result').show()

+--------------------+--------------------+--------------------+
|              result|          embeddings|              result|
+--------------------+--------------------+--------------------+
|[cricket, -, leic...|[[0.36637592, -0....|[O, O, B-ORG, O, ...|
|[london, 1996-08-30]|[[0.78848666, -0....|          [B-LOC, O]|
|[west, indian, al...|[[0.3852343, -0.5...|[B-MISC, I-MISC, ...|
|[their, stay, on,...|[[0.4464562, -0.4...|[O, O, O, O, O, O...|
|[after, bowling, ...|[[0.50563043, -0....|[O, O, B-ORG, O, ...|
|[trailing, by, 21...|[[0.4501431, -0.8...|[O, O, O, O, B-OR...|
|[essex, ,, howeve...|[[0.72113115, -0....|[B-ORG, O, O, O, ...|
|[hussain, ,, cons...|[[0.4423417, -0.8...|[B-PER, O, O, O, ...|
|[by, the, close, ...|[[0.21335557, -0....|[O, O, O, B-ORG, ...|
|[at, the, oval, ,...|[[0.53978103, -0....|[O, O, B-LOC, O, ...|
|[he, was, well, b...|[[0.46333486, -0....|[O, O, O, O, O, B...|
|[derbyshire, kept...|[[-0.35205558, 0....|[B-ORG, O, O, O, ...|
|[australian, tom,...|[[0

In [15]:
import numpy as np

emb_vector = np.array(test_data.select("bert.embeddings").take(1))
emb_vector

array([[[[ 0.36637592, -0.86787492,  0.18460657, ...,  0.19050074,
           0.31882697,  0.21404736],
         [ 0.0042053 , -0.7304967 ,  0.10252799, ...,  0.72762662,
           0.23618948,  0.54797202],
         [ 0.53181094, -0.21557562,  0.11225671, ...,  0.6807006 ,
          -0.73567241,  0.51040608],
         ...,
         [ 0.57870978,  0.06638508, -0.49203444, ...,  0.2483449 ,
           0.50983089,  0.2516709 ],
         [ 0.15336806, -0.28964311, -0.2610296 , ...,  0.42096072,
           0.40467095,  0.25303409],
         [ 0.05699301,  0.32524562,  0.06525278, ...,  0.18811801,
           0.14695087,  0.07696211]]]])

In [16]:
nerTagger = NerDLApproach()\
  .setInputCols(["sentence", "token", "bert"])\
  .setLabelColumn("label")\
  .setOutputCol("ner")\
  .setMaxEpochs(1)\
  .setLr(0.001)\
  .setPo(0.005)\
  .setBatchSize(8)\
  .setRandomSeed(0)\
  .setVerbose(1)\
  .setValidationSplit(0.2)\
  .setEvaluationLogExtended(True) \
  .setEnableOutputLogs(True)\
  .setIncludeConfidence(True)\
  .setTestDataset("test_withEmbeds.parquet")


pipeline = Pipeline(
    stages = [
    bert_annotator,
    nerTagger
  ])

You can also set learning rate ( setLr ), learning rate decay coefficient ( setPo ), setBatchSize and setDropout rate. Please see the official repo for the entire list.

In [17]:
%%time

ner_model = pipeline.fit(training_data.limit(1000))

CPU times: user 46.2 ms, sys: 22.5 ms, total: 68.7 ms
Wall time: 3min 14s


In [18]:
ner_model

PipelineModel_ae4f53a349f4

In [19]:
predictions = ner_model.transform(test_data)
predictions.show(3)

+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+
|                text|            document|            sentence|               token|                 pos|               label|                bert|                 ner|
+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+
|CRICKET - LEICEST...|[[document, 0, 64...|[[document, 0, 64...|[[token, 0, 6, CR...|[[pos, 0, 6, NNP,...|[[named_entity, 0...|[[word_embeddings...|[[named_entity, 0...|
|   LONDON 1996-08-30|[[document, 0, 16...|[[document, 0, 16...|[[token, 0, 5, LO...|[[pos, 0, 5, NNP,...|[[named_entity, 0...|[[word_embeddings...|[[named_entity, 0...|
|West Indian all-r...|[[document, 0, 18...|[[document, 0, 18...|[[token, 0, 3, We...|[[pos, 0, 3, NNP,...|[[named_entity, 0...|[[word_embeddings...|[[

In [20]:
predictions.select('token.result','label.result','ner.result').show(truncate=40)

+----------------------------------------+----------------------------------------+----------------------------------------+
|                                  result|                                  result|                                  result|
+----------------------------------------+----------------------------------------+----------------------------------------+
|[CRICKET, -, LEICESTERSHIRE, TAKE, OV...|   [O, O, B-ORG, O, O, O, O, O, O, O, O]|   [O, O, B-ORG, O, O, O, O, O, O, O, O]|
|                    [LONDON, 1996-08-30]|                              [B-LOC, O]|                              [B-LOC, O]|
|[West, Indian, all-rounder, Phil, Sim...|[B-MISC, I-MISC, O, B-PER, I-PER, O, ...|[B-PER, B-PER, B-PER, I-PER, I-PER, O...|
|[Their, stay, on, top, ,, though, ,, ...|[O, O, O, O, O, O, O, O, O, O, O, O, ...|[O, O, O, O, O, O, O, O, O, O, O, O, ...|
|[After, bowling, Somerset, out, for, ...|[O, O, B-ORG, O, O, O, O, O, O, O, O,...|[O, O, B-PER, O, O, O, O, O, O, O, O,...|


In [21]:
predictions.printSchema()

root
 |-- text: string (nullable = true)
 |-- document: array (nullable = false)
 |    |-- element: struct (containsNull = true)
 |    |    |-- annotatorType: string (nullable = true)
 |    |    |-- begin: integer (nullable = false)
 |    |    |-- end: integer (nullable = false)
 |    |    |-- result: string (nullable = true)
 |    |    |-- metadata: map (nullable = true)
 |    |    |    |-- key: string
 |    |    |    |-- value: string (valueContainsNull = true)
 |    |    |-- embeddings: array (nullable = true)
 |    |    |    |-- element: float (containsNull = false)
 |-- sentence: array (nullable = false)
 |    |-- element: struct (containsNull = true)
 |    |    |-- annotatorType: string (nullable = true)
 |    |    |-- begin: integer (nullable = false)
 |    |    |-- end: integer (nullable = false)
 |    |    |-- result: string (nullable = true)
 |    |    |-- metadata: map (nullable = true)
 |    |    |    |-- key: string
 |    |    |    |-- value: string (valueContainsNull = tr

In [22]:
import pyspark.sql.functions as F

predictions.select(F.explode(F.arrays_zip('token.result','label.result','ner.result')).alias("cols")) \
.select(F.expr("cols['0']").alias("token"),
        F.expr("cols['1']").alias("ground_truth"),
        F.expr("cols['2']").alias("prediction")).show(truncate=False)

+--------------+------------+----------+
|token         |ground_truth|prediction|
+--------------+------------+----------+
|CRICKET       |O           |O         |
|-             |O           |O         |
|LEICESTERSHIRE|B-ORG       |B-ORG     |
|TAKE          |O           |O         |
|OVER          |O           |O         |
|AT            |O           |O         |
|TOP           |O           |O         |
|AFTER         |O           |O         |
|INNINGS       |O           |O         |
|VICTORY       |O           |O         |
|.             |O           |O         |
|LONDON        |B-LOC       |B-LOC     |
|1996-08-30    |O           |O         |
|West          |B-MISC      |B-PER     |
|Indian        |I-MISC      |B-PER     |
|all-rounder   |O           |B-PER     |
|Phil          |B-PER       |I-PER     |
|Simmons       |I-PER       |I-PER     |
|took          |O           |O         |
|four          |O           |O         |
+--------------+------------+----------+
only showing top

<h5> Loading from Local</h5>

In [24]:
ner_model.stages[1].write().save('NER_bert_20200221')

In [26]:
# loading the one trained 10 epochs on GPU with entire train set
loaded_ner_model = NerDLModel.load("NER_bert_20200221")\
 .setInputCols(["sentence", "token", "bert"])\
 .setOutputCol("ner")

In [27]:
predictions_loaded = loaded_ner_model.transform(test_data)

predictions_loaded.select(F.explode(F.arrays_zip('token.result','label.result','ner.result')).alias("cols")) \
.select(F.expr("cols['0']").alias("token"),
        F.expr("cols['1']").alias("ground_truth"),
        F.expr("cols['2']").alias("prediction")).show(30, truncate=False)

+--------------+------------+----------+
|token         |ground_truth|prediction|
+--------------+------------+----------+
|CRICKET       |O           |O         |
|-             |O           |O         |
|LEICESTERSHIRE|B-ORG       |B-ORG     |
|TAKE          |O           |O         |
|OVER          |O           |O         |
|AT            |O           |O         |
|TOP           |O           |O         |
|AFTER         |O           |O         |
|INNINGS       |O           |O         |
|VICTORY       |O           |O         |
|.             |O           |O         |
|LONDON        |B-LOC       |B-LOC     |
|1996-08-30    |O           |O         |
|West          |B-MISC      |B-PER     |
|Indian        |I-MISC      |B-PER     |
|all-rounder   |O           |B-PER     |
|Phil          |B-PER       |I-PER     |
|Simmons       |I-PER       |I-PER     |
|took          |O           |O         |
|four          |O           |O         |
|for           |O           |O         |
|38            |

In [28]:
import pandas as pd

df = predictions_loaded.select('token.result','label.result','ner.result').toPandas()

df

Unnamed: 0,result,result.1,result.2
0,"[CRICKET, -, LEICESTERSHIRE, TAKE, OVER, AT, T...","[O, O, B-ORG, O, O, O, O, O, O, O, O]","[O, O, B-ORG, O, O, O, O, O, O, O, O]"
1,"[LONDON, 1996-08-30]","[B-LOC, O]","[B-LOC, O]"
2,"[West, Indian, all-rounder, Phil, Simmons, too...","[B-MISC, I-MISC, O, B-PER, I-PER, O, O, O, O, ...","[B-PER, B-PER, B-PER, I-PER, I-PER, O, O, O, O..."
3,"[Their, stay, on, top, ,, though, ,, may, be, ...","[O, O, O, O, O, O, O, O, O, O, O, O, O, B-ORG,...","[O, O, O, O, O, O, O, O, O, O, O, O, O, B-ORG,..."
4,"[After, bowling, Somerset, out, for, 83, on, t...","[O, O, B-ORG, O, O, O, O, O, O, O, O, B-LOC, I...","[O, O, B-PER, O, O, O, O, O, O, O, O, O, O, O,..."
5,"[Trailing, by, 213, ,, Somerset, got, a, solid...","[O, O, O, O, B-ORG, O, O, O, O, O, O, O, O, O,...","[O, O, O, O, B-ORG, O, O, O, O, O, O, O, O, O,..."
6,"[Essex, ,, however, ,, look, certain, to, rega...","[B-ORG, O, O, O, O, O, O, O, O, O, O, O, B-PER...","[B-ORG, O, O, O, O, O, O, O, O, O, O, O, B-PER..."
7,"[Hussain, ,, considered, surplus, to, England,...","[B-PER, O, O, O, O, B-LOC, O, O, O, O, O, O, O...","[B-PER, O, O, O, O, B-ORG, O, O, O, O, O, O, O..."
8,"[By, the, close, Yorkshire, had, turned, that,...","[O, O, O, B-ORG, O, O, O, O, O, O, O, O, O, B-...","[O, O, O, O, O, O, O, O, O, O, O, O, O, O, O, ..."
9,"[At, the, Oval, ,, Surrey, captain, Chris, Lew...","[O, O, B-LOC, O, B-ORG, O, B-PER, I-PER, O, O,...","[O, O, O, O, B-PER, O, B-PER, I-PER, O, O, O, ..."


<h5>Bert with poolingLayer -2</h5>

In [29]:
bert_annotator.setPoolingLayer(-2)

AttributeError: 'BertEmbeddings' object has no attribute 'setPoolingLayer'

In [None]:
pipeline = Pipeline(
    stages = [
    bert_annotator,
    nerTagger
  ])

In [None]:
ner_model_v2 = pipeline.fit(training_data.limit(1000))