Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
Add tests
Adds a number of tests to get decent test coverage (87%).

Updates tox.ini to install both requirements.txt and
requirements_test.txt
  • Loading branch information
johnfraney committed Jun 2, 2018
1 parent 6f760e5 commit 8c06b06
Show file tree
Hide file tree
Showing 14 changed files with 224 additions and 20 deletions.
2 changes: 0 additions & 2 deletions ner_trainer/conf.py
@@ -1,7 +1,6 @@
# Adapted from https://github.com/carltongibson/django-filter/blob/master/django_filters/conf.py
from django.conf import settings as dj_settings
from django.core.signals import setting_changed
from django.utils.translation import ugettext_lazy as _

SETTINGS_PREFIX = 'NER_TRAINER_'

Expand Down Expand Up @@ -56,4 +55,3 @@ def change_setting(self, setting, value, enter, **kwargs):

settings = Settings()
setting_changed.connect(settings.change_setting)

Empty file.
Empty file.
6 changes: 4 additions & 2 deletions ner_trainer/management/commands/train_ner_model.py
@@ -1,4 +1,4 @@
from django.core.management.base import BaseCommand, CommandError
from django.core.management.base import BaseCommand

from ner_trainer.conf import settings
from ner_trainer.models import Entity, Phrase
Expand All @@ -15,7 +15,9 @@ def handle(self, *args, **options):
entity_labels = Entity.objects.values_list('label', flat=True)
phrases = Phrase.tagged_objects.all()
train_data = [p.as_spacy_train_data() for p in phrases]
self.stdout.write(self.style.NOTICE(f'Training NER model with {train_iterations} iterations'))
self.stdout.write(self.style.NOTICE(
'Training NER model with {} iterations'.format(train_iterations)
))
train_ner(entity_labels=entity_labels,
train_data=train_data,
model_name=model_name,
Expand Down
1 change: 0 additions & 1 deletion ner_trainer/models.py
@@ -1,7 +1,6 @@
# -*- coding: utf-8 -*-

from django.db import models
from django.urls import reverse
from model_utils.models import TimeStampedModel

from .validators import validate_all_caps
Expand Down
3 changes: 1 addition & 2 deletions ner_trainer/views.py
@@ -1,6 +1,5 @@
# -*- coding: utf-8 -*-
import spacy
from django.http import HttpResponse
from pathlib import Path
from rest_framework import viewsets
from rest_framework.decorators import action
Expand Down Expand Up @@ -75,7 +74,7 @@ def get_view_name(self):
return 'NER Model Test'

def get(self, request, format=None):
return Response(['Post a text field to test the NER model.'])
return Response("Post a text field to test the NER model.")

def post(self, request):
text = request.data.get('text', None)
Expand Down
Empty file modified runtests.py 100644 → 100755
Empty file.
1 change: 1 addition & 0 deletions setup.py
Expand Up @@ -49,6 +49,7 @@ def get_version(*file_paths):
version=version,
description="""Tools for training spaCy Named Entity Recognition models in Django""",
long_description=readme + '\n\n' + history,
long_description_content_type="text/markdown",
author='John Franey',
author_email='johnfraney@gmail.com',
url='https://github.com/johnfraney/django-ner-trainer',
Expand Down
13 changes: 13 additions & 0 deletions tests/test_conf.py
@@ -0,0 +1,13 @@
from django.test import TestCase

from ner_trainer.conf import settings, DEFAULTS


class SettingsTests(TestCase):
def test_default_settings(self):
for setting_name, default_value in DEFAULTS.items():
self.assertEqual(getattr(settings, setting_name), default_value)

def test_nonexistant_setting(self):
with self.assertRaises(AttributeError):
settings.BANANA
18 changes: 18 additions & 0 deletions tests/test_forms.py
@@ -0,0 +1,18 @@
from django.test import TestCase
from ner_trainer.forms import ModelTestForm


class ModelTestFormTests(TestCase):
def test_invalid_model_test_form(self):
"""
Ensure a ModelTestForm instance without text is invalid.
"""
form = ModelTestForm()
self.assertFalse(form.is_valid())

def test_valid_model_test_form(self):
"""
Ensure a ModelTestForm instance with text is valid.
"""
form = ModelTestForm({'text': 'This is a sentence'})
self.assertTrue(form.is_valid())
60 changes: 47 additions & 13 deletions tests/test_models.py
Expand Up @@ -10,16 +10,50 @@

from django.test import TestCase

from ner_trainer import models


class TestNer_trainer(TestCase):

def setUp(self):
pass

def test_something(self):
pass

def tearDown(self):
pass
from ner_trainer.models import (
Entity,
Phrase,
)


class TestEntity(TestCase):
def test_str(self):
entity = Entity(label='PROVINCE', name='Province')
self.assertEqual(str(entity), 'Province')


class TestPhrase(TestCase):
@classmethod
def setUpClass(cls):
super(TestPhrase, cls).setUpClass()
phrase = Phrase.objects.create(text='I like London and Berlin.')
entity = Entity.objects.create(label='LOC', name='Location')
phrase.entities.create(
entity=entity,
start_index=7,
end_index=13,
)
phrase.entities.create(
entity=entity,
start_index=18,
end_index=24,
)
cls.phrase = phrase

def test_str(self):
self.assertEqual(str(self.phrase), 'I like London and Berlin.')

def test_active_phrase_manager(self):
Phrase.objects.create(text="Skipped phrase", skipped=True)
self.assertEqual(Phrase.objects.count(), 2)
self.assertEqual(Phrase.active_objects.count(), 1)

def test_tagged_phrase_manager(self):
self.assertEqual(Phrase.tagged_objects.count(), 1)
self.assertEqual(Phrase.tagged_objects.first(), self.phrase)

def test_as_spacy_train_data(self):
spacy_train_data = ('I like London and Berlin.', {
'entities': [(7, 13, 'LOC'), (18, 24, 'LOC')]
})
self.assertEqual(self.phrase.as_spacy_train_data(), spacy_train_data)
24 changes: 24 additions & 0 deletions tests/test_validators.py
@@ -0,0 +1,24 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-

"""
test_django-ner-trainer
------------
Tests for `django-ner-trainer` models module.
"""

from django.core.exceptions import ValidationError
from django.test import TestCase

from ner_trainer.validators import validate_all_caps


class TestValidators(TestCase):
def test_validate_all_caps(self):
good_label = 'PROVINCE'
self.assertEqual(validate_all_caps(good_label), None)

bad_label = 'province'
with self.assertRaises(ValidationError):
validate_all_caps(bad_label)
115 changes: 115 additions & 0 deletions tests/test_views.py
@@ -0,0 +1,115 @@
import shutil
from django.core.management import call_command
from rest_framework import status
from rest_framework.test import APITestCase

from ner_trainer.models import (
Entity,
Phrase,
)


class EntityTests(APITestCase):
def test_entity_list(self):
"""
Ensure we can list Entity objects.
"""
for i in range(5):
name = 'Entity {}'.format(i)
label = 'ENTITY_{}'.format(i)
Entity.objects.create(name=name, label=label)
response = self.client.get('/entities/', format='json')
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(len(response.data), 5)

def test_entity_detail(self):
"""
Ensure we can detail an Entity object.
"""
Entity.objects.create(name='Banana', label='BANANA')
response = self.client.get('/entities/BANANA/')
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(response.data['name'], 'Banana')
self.assertEqual(response.data['label'], 'BANANA')


class PhraseTests(APITestCase):
@classmethod
def setUpClass(cls):
super(PhraseTests, cls).setUpClass()

def test_phrase_list(self):
"""
Ensure we can list Phrase objects.
"""
Phrase.objects.create(text='Phrase 1')
response = self.client.get('/phrases/', format='json')
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(len(response.data), 1)
self.assertEqual(response.data[0]['text'], 'Phrase 1')

def test_active_phrase_list(self):
"""
Ensure we can list active Phrase objects.
"""
phrases = []
for i in range(10):
text = 'Phrase {}'.format(i)
phrases.append(Phrase(text=text))
Phrase.objects.bulk_create(phrases)
response = self.client.get('/phrases/active/', format='json')
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(len(response.data), 10)
self.assertEqual(response.data[0]['text'], 'Phrase 0')
self.assertEqual(response.data[-1]['text'], 'Phrase 9')

def test_tagged_phrase_list(self):
"""
Ensure we can list tagged Phrase objects.
"""
phrase = Phrase.objects.create(text='I like London and Berlin.')
entity = Entity.objects.create(label='LOC', name='Location')
phrase.entities.create(
entity=entity,
start_index=7,
end_index=13,
)
phrase.entities.create(
entity=entity,
start_index=18,
end_index=24,
)
response = self.client.get('/phrases/tagged/', format='json')
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(len(response.data), 1)


class NERModelTestViewTests(APITestCase):
def test_get(self):
response = self.client.get('/test-ner/')
self.assertEqual(response.status_code, status.HTTP_200_OK)

def test_post_with_text_without_model(self):
data = {'text': 'This is a sentence.'}
response = self.client.post('/test-ner/', data, format='json')
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertTrue('Could not find NER model' in response.data)

def test_post_without_text(self):
data = {}
response = self.client.post('/test-ner/', data, format='json')
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertTrue('Missing required field' in response.data)

def test_post_with_model(self):
call_command('train_ner_model')
text = 'This is a sentence.'
data = {'text': text}
response = self.client.post('/test-ner/', data, format='json')
self.assertEqual(response.status_code, status.HTTP_200_OK)
expected_response_data = {
'text': 'This is a sentence.',
'entities': []
}
self.assertEqual(response.data, expected_response_data)
shutil.rmtree('spacy_model', ignore_errors=True)
1 change: 1 addition & 0 deletions tox.ini
Expand Up @@ -10,6 +10,7 @@ commands = coverage run --source ner_trainer runtests.py
deps =
django-111: Django>=1.11,<1.12
django-20: Django>=2.0,<2.1
-r{toxinidir}/requirements.txt
-r{toxinidir}/requirements_test.txt
basepython =
py36: python3.6
Expand Down

0 comments on commit 8c06b06

Please sign in to comment.