Skip to content
This repository has been archived by the owner on Aug 26, 2022. It is now read-only.

REST endpoint for named entity recognition training #3

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
.cache/
__pycache__/
38 changes: 38 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,44 @@ Example response:

---

### `POST` `/train/ent`

Example request:

```json
{
"text": "Google es una empresa.",
"model": "es",
"tags": [
{
"start": 0,
"len": 6,
"type": "ORG"
}
]
}
```

| Name | Type | Description |
| --- | --- | --- |
| `text` | string | text to be parsed |
| `tags` | array | entities to be used for training named entity recognition |
| `model` | string | identifier string for a model installed on the server |

Example response:

```json
[
{ "end": 6, "start": 0, "type": "ORG" }
]
```

| Name | Type | Description |
| --- | --- | --- |
| `end` | integer | character offset the entity ends **after** |
| `start` | integer | character offset the entity starts **on** |
| `type` | string | entity type |

### `GET` `/models`

List the names of models installed on the server.
Expand Down
22 changes: 22 additions & 0 deletions displacy/displacy_service/parse.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
from __future__ import unicode_literals

from spacy.gold import GoldParse
from spacy.pipeline import EntityRecognizer


class Parse(object):
def __init__(self, nlp, text, collapse_punctuation, collapse_phrases):
Expand Down Expand Up @@ -56,3 +59,22 @@ def __init__(self, nlp, text):
def to_json(self):
return [{'start': ent.start_char, 'end': ent.end_char, 'type': ent.label_}
for ent in self.doc.ents]


class TrainEntities(object):
def __init__(self, nlp, text, tags):
ner = nlp.entity
entities = [(tag['start'], tag['start'] + tag['len'], tag['type'])
for tag in tags]
for itn in range(10):
doc = nlp.make_doc(text)
gold = GoldParse(doc, entities=entities)
ner.update(doc, gold)
ner.model.end_training()
doc = nlp(text)
ner(doc)
self.doc = doc

def to_json(self):
return [{'start': ent.start_char, 'end': ent.end_char, 'type': ent.label_}
for ent in self.doc.ents]
57 changes: 47 additions & 10 deletions displacy/displacy_service/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,13 @@
import spacy
import json

from spacy.pipeline import EntityRecognizer
from spacy.symbols import ENT_TYPE, TAG, DEP

import spacy.util
from spacy.tagger import Tagger

from .parse import Parse, Entities
from .parse import Parse, Entities, TrainEntities


try:
Expand All @@ -33,7 +35,16 @@

def get_model(model_name):
if model_name not in _models:
_models[model_name] = spacy.load(model_name)
model = spacy.load(model_name)
if model.tagger is None:
model.tagger = Tagger(model.vocab, features=Tagger.feature_templates)
if model.entity is None:
model.entity = EntityRecognizer(model.vocab, entity_types=['PERSON', 'NORP', 'FACILITY', 'ORG', 'GPE',
'LOC', 'PRODUCT', 'EVENT', 'WORK_OF_ART',
'LANGUAGE', 'DATE', 'TIME', 'PERCENT',
'MONEY', 'QUANTITY', 'ORDINAL', 'CARDINAL'])
model.pipeline = [model.tagger, model.entity, model.parser]
_models[model_name] = model
return _models[model_name]


Expand Down Expand Up @@ -61,14 +72,20 @@ def get_pos_types(model):
return labels


def update_vocabulary(model, text):
doc = model.make_doc(text)
for word in doc:
_ = model.vocab[word.orth]


class ModelsResource(object):
"""List the available models."""
def on_get(self, req, resp):
try:
output = list(MODELS)
resp.body = json.dumps(output.to_json(), sort_keys=True, indent=2)
resp.content_type = b'text/string'
resp.append_header(b'Access-Control-Allow-Origin', b"*")
resp.content_type = 'text/string'
resp.append_header('Access-Control-Allow-Origin', "*")
resp.status = falcon.HTTP_200
except Exception:
resp.status = falcon.HTTP_500
Expand All @@ -86,8 +103,8 @@ def on_get(self, req, resp, model_name):
}

resp.body = json.dumps(output.to_json(), sort_keys=True, indent=2)
resp.content_type = b'text/string'
resp.append_header(b'Access-Control-Allow-Origin', b"*")
resp.content_type = 'text/string'
resp.append_header('Access-Control-Allow-Origin', "*")
resp.status = falcon.HTTP_200
except Exception:
resp.status = falcon.HTTP_500
Expand All @@ -107,8 +124,8 @@ def on_post(self, req, resp):
model = get_model(model_name)
parse = Parse(model, text, collapse_punctuation, collapse_phrases)
resp.body = json.dumps(parse.to_json(), sort_keys=True, indent=2)
resp.content_type = b'text/string'
resp.append_header(b'Access-Control-Allow-Origin', b"*")
resp.content_type = 'text/string'
resp.append_header('Access-Control-Allow-Origin', "*")
resp.status = falcon.HTTP_200
except Exception:
resp.status = falcon.HTTP_500
Expand All @@ -125,15 +142,35 @@ def on_post(self, req, resp):
model = get_model(model_name)
entities = Entities(model, text)
resp.body = json.dumps(entities.to_json(), sort_keys=True, indent=2)
resp.content_type = b'text/string'
resp.append_header(b'Access-Control-Allow-Origin', b"*")
resp.content_type = 'text/string'
resp.append_header('Access-Control-Allow-Origin', "*")
resp.status = falcon.HTTP_200
except Exception:
resp.status = falcon.HTTP_500


class TrainEntResource(object):
"""Parse text and use it to train the entity recognizer."""
def on_post(self, req, resp):
req_body = req.stream.read()
json_data = json.loads(req_body.decode('utf8'))
text = json_data.get('text')
tags = json_data.get('tags')
model_name = json_data.get('model', 'en')
try:
model = get_model(model_name)
update_vocabulary(model, text)
entities = TrainEntities(model, text, tags)
resp.body = json.dumps(entities.to_json(), sort_keys=True, indent=2)
resp.content_type = 'text/string'
resp.append_header('Access-Control-Allow-Origin', "*")
resp.status = falcon.HTTP_200
except Exception:
resp.status = falcon.HTTP_500

APP = falcon.API()
APP.add_route('/dep', DepResource())
APP.add_route('/ent', EntResource())
APP.add_route('/train/ent', TrainEntResource())
APP.add_route('/{model_name}/schema', SchemaResource())
APP.add_route('/models', ModelsResource())
22 changes: 22 additions & 0 deletions displacy/displacy_service/tests/test_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,3 +26,25 @@ def test_ents():
body='''{"text": "Google is a company.", "model": "en"}''')
ents = json.loads(result.text)
assert ents == [{"start": 0, "end": len("Google"), "type": "ORG"}]


def test_train_ents():
test_api = TestAPI()
result = test_api.simulate_post(path='/train/ent',
body='''{"text": "Google es una empresa.", "model": "es",
"tags": [{"start": 0, "len": 6, "type": "ORG"}]}''')
ents = json.loads(result.text)
assert ents == [{"start": 0, "end": len("Google"), "type": "ORG"}]


def test_train_and_query_ents():
test_api = TestAPI()
result = test_api.simulate_post(path='/train/ent',
body='''{"text": "Google es una empresa.", "model": "es",
"tags": [{"start": 0, "len": 6, "type": "ORG"}]}''')
ents = json.loads(result.text)
assert ents == [{"start": 0, "end": len("Google"), "type": "ORG"}]
result = test_api.simulate_post(path='/ent',
body='''{"text": "Google es una empresa.", "model": "es"}''')
ents = json.loads(result.text)
assert ents == [{"start": 0, "end": len("Google"), "type": "ORG"}]