# Training the custom Named Entity Recognizer

In [2]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [3]:
import os
import csv
import json
import re
import pandas as pd
import random
import warnings
from pathlib import Path
import spacy
from spacy.util import minibatch, compounding
from spacy.matcher import PhraseMatcher

warnings.filterwarnings("ignore")

In [6]:
text = os.listdir("/qubitrics/text")  #Task 1 data
entities = os.listdir("/qubitrics/entities")  #Task 2 data

In [7]:
print(len(text))
print(len(entities))

716
716


# Fetching data

#### Reconstructing strings for training the model
- The reciept text is extracted and reconstructed into strings from `Task 1 - Scanned Reciept Localization` training data
- The true values for entities in the exracted text in taken from `Task 2 - Scanned Reciept OCR` training data

In [8]:
set_text  = set(text)
set_ent = set(entities)

training_set = list(set_text.intersection(set_ent)) #to avoid mismatch between text and entity dictionaries

In [9]:
data = pd.DataFrame(columns=["filename", "text"])

data["filename"] = training_set

data_text = []
for file in data["filename"]:
    rec_text = []
    pattern = r"\d+,\d+,\d+,\d+,\d+,\d+,\d+,\d+,(.+)"
    with open(f"/qubitrics/text/{file}") as f:
        f.seek(0)
        lines = f.readlines()
        for line in lines:
            rec_text += re.findall(pattern, line)
    data_text.append(" ".join([x.strip() for x in rec_text]))
data["text"] = data_text

ent_list = []
for file in data["filename"]:
    with open(f"/qubitrics/entities/{file}") as f:
        entity_dict = json.load(f)
        ent_list.append(entity_dict)
data["entity_dictionary"] = ent_list

data.shape

(716, 3)

In [11]:
data

Unnamed: 0,filename,text,entity_dictionary
0,X51007339157(1).txt,"SANYU STATIONERY SHOP NO. 31G&33G, JALAN SETIA...","{'company': 'SANYU STATIONERY SHOP', 'date': '..."
1,X51005763940(1).txt,HARVEY NORMAN HARVEY NORMAN M'SIA PARADIGM MAL...,"{'company': 'ELITETRAX MARKETING SDN BHD', 'da..."
2,X51005757199.txt,POPULAR BOOK CO. (M) SDN BHD (CO. NO. 113825-W...,"{'company': 'POPULAR BOOK CO. (M) SDN BHD', 'd..."
3,X51008142063.txt,KEDAI PAPAN YEW CHUAN (0005583085-K) LOT 276 J...,"{'company': 'KEDAI PAPAN YEW CHUAN', 'date': '..."
4,X51005447861.txt,POPULAR BOOK CO. (M) SDN BHD (CO. NO. 113825-W...,"{'company': 'POPULAR BOOK CO. (M) SDN BHD', 'd..."
...,...,...,...
711,X51005568827.txt,BANH MI CAFE DIMILIKI: BANH MI CAFE SDN BHD 11...,"{'company': 'BANH MI CAFE SDN BHD', 'date': '2..."
712,X51005677331.txt,"SYARIKAT PERNIAGAAN GIN KEE (81109-A) NO 290, ...","{'company': 'SYARIKAT PERNIAGAAN GIN KEE', 'da..."
713,X51005442338.txt,"PASAR MINI JIN SENG 379,JALAN PERMAS SATU, BAN...","{'company': 'PASAR MINI JIN SENG', 'date': '18..."
714,X51006414483.txt,UNIHAKKA INTERNATIONAL SDN BHD 10 APR 2018 18:...,"{'company': 'UNIHAKKA INTERNATIONAL SDN BHD', ..."


## Transforming data into spaCy compliant format

For training a spaCy model for customized named entity recognition, the data must be transformed to the following format<br>
`[(input text, entites:[(start_index, end_index, entity_name), (start_index, end_index, entity_name), ...]), ....]` 
<br>
To correcty match the entity values given in the entity dictionary to the phrases in the corresponding text file, I have used:
- Phrase Matcher
- Regular Expressions
Regular expressions are used to identify and match the phrases/tokens that were not matched by the Phrase Matcher to ensure the training data set to be as effective as possible 

In [12]:
training_data = []
id_ent = []

nlp_match = spacy.load('en_core_web_sm')
matcher = PhraseMatcher(nlp_match.vocab)
for index, row in data.iterrows():
    ent_dic = row["entity_dictionary"]
    ent = []
    phrases = list(ent_dic.values())
    patterns = [nlp_match.make_doc(phrase) for phrase in phrases]
    matcher.add("EntityList", None, *patterns)

    doc = nlp_match(row["text"])
    matches = matcher(doc)
    for match_id, start, end in matches:
        try:
            span = doc[start:end]
            if start>0:
                sb = doc[0:start]
                start_index=len(sb.text)+1
            else:
                start_index=0
            end_index= start_index+len(span.text)
        except:
            pass

        for key, value in ent_dic.items():
            if value==span.text:
                ent_tup=(start_index, end_index, key)
                ent.append(ent_tup)
    ent_set = {"company", "date", "total", "address"}
    detected_entities = set([key for start, end, key in ent])
    missed_entities = list(ent_set - detected_entities)
    if "total" in missed_entities:
        value = ent_dic["total"]
        if len(value)>0:
            catch_total = re.search(value, str(row["text"]).replace(",", ""))
            ent_tup = (catch_total.span()[0], catch_total.span()[1], "total")
            ent.append(ent_tup)
    if "date" in missed_entities:
        value = ent_dic["date"]
        if len(value)>0:
            catch_date = re.search(value, str(row["text"]))
            if catch_date == None:
                catch_date = re.search(r"\d\d[-/]*\d\d[-/]*\d\d", str(row["text"]))
            try:
                ent_tup = (catch_total.span()[0], catch_total.span()[1], "date")
                ent.append(ent_tup)
            except:
                pass
    if "company" in missed_entities:
        value = ent_dic["company"]
        catch_company = re.search(value, str(row["text"]))
        if catch_company!=None:
            ent_tup = (catch_company.span()[0], catch_company.span()[1], "company")
            ent.append(ent_tup)
        else:
            catch_company = re.search(value, str(row["text"]).replace(".", ""))
            if catch_company!=None:
                ent_tup = (catch_company.span()[0], catch_company.span()[1], "company")
                ent.append(ent_tup)
    if "address" in missed_entities:
        try:
            value = ent_dic["address"]
            catch_address = re.search(value, str(row["text"]))
            if catch_address!=None:
                ent_tup = (catch_address.span()[0], catch_address.span()[1], "address")
                ent.append(ent_tup)
        except:
            pass
    id_ent.append(len(ent))
    entity_dictionary = {"entities": ent}
    train_tup = (row["text"], entity_dictionary)
    training_data.append(train_tup)


In [13]:
len(training_data)

716

## Training the custom NER Model

### Hyperparmeter Values:
- **Number of Iterations**: 80
- **Dropout**: 0.6
- **Minimum Batch Size**: 4
- **Maximum Batch Size**: 32
- **Compounding factor**: 0.01

In [14]:
TRAIN_DATA = training_data
output_dir="/content/drive/MyDrive/qubitrics_internship_assignment/model"
n_iter = 80

In [15]:
nlp = spacy.blank("en")
ner = nlp.create_pipe("ner")
nlp.add_pipe(ner, last=True)


for _, annotations in TRAIN_DATA:
    for ent in annotations.get("entities"):
        ner.add_label(ent[2])

nlp.begin_training()
for itn in range(n_iter):
    random.shuffle(TRAIN_DATA)
    losses = {}
    batches = minibatch(TRAIN_DATA, size=compounding(4.0, 32.0, 1.01))
    for batch in batches:
        texts, annotations = zip(*batch)
        try:
            nlp.update(
                texts,
                annotations, 
                drop=0.6,
                losses=losses,
            )
        except:
            pass
    print(f"{itn} Losses", losses)

output_dir = Path(output_dir)
if not output_dir.exists():
    output_dir.mkdir()
nlp.to_disk(output_dir)
print("\nSaved model to", output_dir)

0 Losses {'ner': 12974.743664056032}
1 Losses {'ner': 17605.0854382627}
2 Losses {'ner': 22397.8668254816}
3 Losses {'ner': 22230.640010044022}
4 Losses {'ner': 23507.19024447282}
5 Losses {'ner': 23773.06150963734}
6 Losses {'ner': 22811.54948925385}
7 Losses {'ner': 19172.24888464052}
8 Losses {'ner': 17384.020578325573}
9 Losses {'ner': 18561.810039345055}
10 Losses {'ner': 17603.17206580461}
11 Losses {'ner': 17206.517447793653}
12 Losses {'ner': 14790.336147851283}
13 Losses {'ner': 16355.159918238129}
14 Losses {'ner': 13728.183245581147}
15 Losses {'ner': 13496.02545978631}
16 Losses {'ner': 13030.180207959613}
17 Losses {'ner': 11069.342757673901}
18 Losses {'ner': 9289.0420544195}
19 Losses {'ner': 8272.538787234234}
20 Losses {'ner': 8160.740559127125}
21 Losses {'ner': 8738.867668556524}
22 Losses {'ner': 5865.906072477139}
23 Losses {'ner': 5362.072545458739}
24 Losses {'ner': 5091.726134207575}
25 Losses {'ner': 4252.640293221}
26 Losses {'ner': 4769.905745454959}
27 Losse