Skip to content

Commit

Permalink
Merge pull request #1592 from doccano/enhancement/refactorAnnotationAPI
Browse files Browse the repository at this point in the history
[Enhancement] Refactor Annotation API
  • Loading branch information
Hironsan committed Dec 9, 2021
2 parents 14830a9 + fc4bd36 commit 0cf4e31
Show file tree
Hide file tree
Showing 10 changed files with 304 additions and 23 deletions.
189 changes: 169 additions & 20 deletions backend/api/tests/api/test_annotation.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,29 @@
from rest_framework import status
from rest_framework.reverse import reverse

from ...models import DOCUMENT_CLASSIFICATION, SEQUENCE_LABELING, Category
from ...models import (DOCUMENT_CLASSIFICATION, SEQ2SEQ, SEQUENCE_LABELING,
Category, Span, TextLabel)
from .utils import (CRUDMixin, make_annotation, make_doc, make_label,
make_user, prepare_project)


class TestAnnotationList(CRUDMixin):
model = Category
task = DOCUMENT_CLASSIFICATION
view_name = 'annotation_list'

@classmethod
def setUpTestData(cls):
cls.project = prepare_project(task=DOCUMENT_CLASSIFICATION)
cls.project = prepare_project(task=cls.task)
cls.non_member = make_user()
doc = make_doc(cls.project.item)
for member in cls.project.users:
make_annotation(task=DOCUMENT_CLASSIFICATION, doc=doc, user=member)
cls.url = reverse(viewname='annotation_list', args=[cls.project.item.id, doc.id])
cls.make_annotation(doc, member)
cls.url = reverse(viewname=cls.view_name, args=[cls.project.item.id, doc.id])

@classmethod
def make_annotation(cls, doc, member):
make_annotation(cls.task, doc=doc, user=member)

def test_allows_project_member_to_fetch_annotation(self):
for member in self.project.users:
Expand All @@ -30,19 +38,48 @@ def test_denies_unauthenticated_user_to_fetch_annotation(self):

def test_allows_project_member_to_bulk_delete_annotation(self):
self.assert_delete(self.project.users[0], status.HTTP_204_NO_CONTENT)
count = Category.objects.count()
count = self.model.objects.count()
self.assertEqual(count, 2) # delete only own annotation


class TestCategoryList(TestAnnotationList):
model = Category
task = DOCUMENT_CLASSIFICATION
view_name = 'category_list'


class TestSpanList(TestAnnotationList):
model = Span
task = SEQUENCE_LABELING
view_name = 'span_list'

@classmethod
def make_annotation(cls, doc, member):
make_annotation(cls.task, doc=doc, user=member, start_offset=0, end_offset=1)


class TestTextList(TestAnnotationList):
model = TextLabel
task = SEQ2SEQ
view_name = 'text_list'


class TestSharedAnnotationList(CRUDMixin):
model = Category
task = DOCUMENT_CLASSIFICATION
view_name = 'annotation_list'

@classmethod
def setUpTestData(cls):
cls.project = prepare_project(task=DOCUMENT_CLASSIFICATION, collaborative_annotation=True)
cls.project = prepare_project(task=cls.task, collaborative_annotation=True)
doc = make_doc(cls.project.item)
for member in cls.project.users:
make_annotation(task=DOCUMENT_CLASSIFICATION, doc=doc, user=member)
cls.url = reverse(viewname='annotation_list', args=[cls.project.item.id, doc.id])
cls.make_annotation(doc, member)
cls.url = reverse(viewname=cls.view_name, args=[cls.project.item.id, doc.id])

@classmethod
def make_annotation(cls, doc, member):
make_annotation(cls.task, doc=doc, user=member)

def test_allows_project_member_to_fetch_all_annotation(self):
for member in self.project.users:
Expand All @@ -51,19 +88,54 @@ def test_allows_project_member_to_fetch_all_annotation(self):

def test_allows_project_member_to_bulk_delete_annotation(self):
self.assert_delete(self.project.users[0], status.HTTP_204_NO_CONTENT)
count = Category.objects.count()
count = self.model.objects.count()
self.assertEqual(count, 0) # delete all annotation in the doc


class TestSharedCategoryList(TestSharedAnnotationList):
model = Category
task = DOCUMENT_CLASSIFICATION
view_name = 'category_list'


class TestSharedSpanList(TestSharedAnnotationList):
model = Span
task = SEQUENCE_LABELING
view_name = 'span_list'
start_offset = 0

@classmethod
def make_annotation(cls, doc, member):
make_annotation(
cls.task,
doc=doc,
user=member,
start_offset=cls.start_offset,
end_offset=cls.start_offset + 1
)
cls.start_offset += 1


class TestSharedTextList(TestSharedAnnotationList):
model = TextLabel
task = SEQ2SEQ
view_name = 'text_list'


class TestAnnotationCreation(CRUDMixin):
task = DOCUMENT_CLASSIFICATION
view_name = 'annotation_list'

def setUp(self):
self.project = prepare_project(task=DOCUMENT_CLASSIFICATION)
self.project = prepare_project(task=self.task)
self.non_member = make_user()
doc = make_doc(self.project.item)
self.data = self.create_data()
self.url = reverse(viewname=self.view_name, args=[self.project.item.id, doc.id])

def create_data(self):
label = make_label(self.project.item)
self.data = {'label': label.id}
self.url = reverse(viewname='annotation_list', args=[self.project.item.id, doc.id])
return {'label': label.id}

def test_allows_project_member_to_annotate(self):
for member in self.project.users:
Expand All @@ -76,22 +148,48 @@ def test_denies_unauthenticated_user_to_annotate(self):
self.assert_create(expected=status.HTTP_403_FORBIDDEN)


class TestCategoryCreation(TestAnnotationCreation):
view_name = 'category_list'


class TestSpanCreation(TestAnnotationCreation):
task = SEQUENCE_LABELING
view_name = 'span_list'

def create_data(self):
label = make_label(self.project.item)
return {'label': label.id, 'start_offset': 0, 'end_offset': 1}


class TestTextLabelCreation(TestAnnotationCreation):
task = SEQ2SEQ
view_name = 'text_list'

def create_data(self):
return {'text': 'example'}


class TestAnnotationDetail(CRUDMixin):
task = SEQUENCE_LABELING
view_name = 'annotation_detail'

def setUp(self):
self.project = prepare_project(task=SEQUENCE_LABELING)
self.project = prepare_project(task=self.task)
self.non_member = make_user()
doc = make_doc(self.project.item)
label = make_label(self.project.item)
annotation = make_annotation(
task=SEQUENCE_LABELING,
annotation = self.create_annotation_data(doc=doc)
self.data = {'label': label.id}
self.url = reverse(viewname=self.view_name, args=[self.project.item.id, doc.id, annotation.id])

def create_annotation_data(self, doc):
return make_annotation(
task=self.task,
doc=doc,
user=self.project.users[0],
start_offset=0,
end_offset=1
)
self.data = {'label': label.id}
self.url = reverse(viewname='annotation_detail', args=[self.project.item.id, doc.id, annotation.id])

def test_allows_owner_to_get_annotation(self):
self.assert_fetch(self.project.users[0], status.HTTP_200_OK)
Expand Down Expand Up @@ -127,15 +225,45 @@ def test_denies_non_project_member_to_delete_annotation(self):
self.assert_delete(self.non_member, status.HTTP_403_FORBIDDEN)


class TestCategoryDetail(TestAnnotationDetail):
task = DOCUMENT_CLASSIFICATION
view_name = 'category_detail'

def create_annotation_data(self, doc):
return make_annotation(task=self.task, doc=doc, user=self.project.users[0])


class TestSpanDetail(TestAnnotationDetail):
task = SEQUENCE_LABELING
view_name = 'span_detail'


class TestTextDetail(TestAnnotationDetail):
task = SEQ2SEQ
view_name = 'text_detail'

def setUp(self):
super().setUp()
self.data = {'text': 'changed'}

def create_annotation_data(self, doc):
return make_annotation(task=self.task, doc=doc, user=self.project.users[0])


class TestSharedAnnotationDetail(CRUDMixin):
task = DOCUMENT_CLASSIFICATION
view_name = 'annotation_detail'

def setUp(self):
self.project = prepare_project(task=DOCUMENT_CLASSIFICATION, collaborative_annotation=True)
self.project = prepare_project(task=self.task, collaborative_annotation=True)
doc = make_doc(self.project.item)
annotation = make_annotation(task=DOCUMENT_CLASSIFICATION, doc=doc, user=self.project.users[0])
annotation = self.make_annotation(doc, self.project.users[0])
label = make_label(self.project.item)
self.data = {'label': label.id}
self.url = reverse(viewname='annotation_detail', args=[self.project.item.id, doc.id, annotation.id])
self.url = reverse(viewname=self.view_name, args=[self.project.item.id, doc.id, annotation.id])

def make_annotation(self, doc, member):
return make_annotation(self.task, doc=doc, user=member)

def test_allows_any_member_to_get_annotation(self):
for member in self.project.users:
Expand All @@ -147,3 +275,24 @@ def test_allows_any_member_to_update_annotation(self):

def test_allows_any_member_to_delete_annotation(self):
self.assert_delete(self.project.users[1], status.HTTP_204_NO_CONTENT)


class TestSharedCategoryDetail(TestSharedAnnotationDetail):
view_name = 'category_detail'


class TestSharedSpanDetail(TestSharedAnnotationDetail):
task = SEQUENCE_LABELING
view_name = 'span_detail'

def make_annotation(self, doc, member):
return make_annotation(self.task, doc=doc, user=member, start_offset=0, end_offset=1)


class TestSharedTextDetail(TestSharedAnnotationDetail):
task = SEQ2SEQ
view_name = 'text_detail'

def setUp(self):
super().setUp()
self.data = {'text': 'changed'}
31 changes: 31 additions & 0 deletions backend/api/urls.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
example, example_state, export_dataset, health,
import_dataset, import_export, label, project,
relation_types, role, statistics, tag, task, user)
from .views.tasks import category, span, text

urlpatterns_project = [
path(
Expand Down Expand Up @@ -112,6 +113,36 @@
view=annotation.AnnotationDetail.as_view(),
name='annotation_detail'
),
path(
route='examples/<int:example_id>/categories',
view=category.CategoryListAPI.as_view(),
name='category_list'
),
path(
route='examples/<int:example_id>/categories/<int:annotation_id>',
view=category.CategoryDetailAPI.as_view(),
name='category_detail'
),
path(
route='examples/<int:example_id>/spans',
view=span.SpanListAPI.as_view(),
name='span_list'
),
path(
route='examples/<int:example_id>/spans/<int:annotation_id>',
view=span.SpanDetailAPI.as_view(),
name='span_detail'
),
path(
route='examples/<int:example_id>/texts',
view=text.TextLabelListAPI.as_view(),
name='text_list'
),
path(
route='examples/<int:example_id>/texts/<int:annotation_id>',
view=text.TextLabelDetailAPI.as_view(),
name='text_detail'
),
path(
route='tags',
view=tag.TagList.as_view(),
Expand Down
Empty file.
57 changes: 57 additions & 0 deletions backend/api/views/tasks/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
from django.core.exceptions import ValidationError
from django.shortcuts import get_object_or_404
from rest_framework import generics, status
from rest_framework.permissions import IsAuthenticated
from rest_framework.response import Response

from ...models import Project
from ...permissions import IsInProjectOrAdmin, IsOwnAnnotation


class BaseListAPI(generics.ListCreateAPIView):
annotation_class = None
pagination_class = None
permission_classes = [IsAuthenticated & IsInProjectOrAdmin]
swagger_schema = None

@property
def project(self):
return get_object_or_404(Project, pk=self.kwargs['project_id'])

def get_queryset(self):
queryset = self.annotation_class.objects.filter(example=self.kwargs['example_id'])
if not self.project.collaborative_annotation:
queryset = queryset.filter(user=self.request.user)
return queryset

def create(self, request, *args, **kwargs):
request.data['example'] = self.kwargs['example_id']
try:
response = super().create(request, args, kwargs)
except ValidationError as err:
response = Response({'detail': err.messages}, status=status.HTTP_400_BAD_REQUEST)
return response

def perform_create(self, serializer):
serializer.save(example_id=self.kwargs['example_id'], user=self.request.user)

def delete(self, request, *args, **kwargs):
queryset = self.get_queryset()
queryset.all().delete()
return Response(status=status.HTTP_204_NO_CONTENT)


class BaseDetailAPI(generics.RetrieveUpdateDestroyAPIView):
lookup_url_kwarg = 'annotation_id'
swagger_schema = None

@property
def project(self):
return get_object_or_404(Project, pk=self.kwargs['project_id'])

def get_permissions(self):
if self.project.collaborative_annotation:
self.permission_classes = [IsAuthenticated & IsInProjectOrAdmin]
else:
self.permission_classes = [IsAuthenticated & IsInProjectOrAdmin & IsOwnAnnotation]
return super().get_permissions()
18 changes: 18 additions & 0 deletions backend/api/views/tasks/category.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
from ...models import Category
from ...serializers import CategorySerializer
from .base import BaseDetailAPI, BaseListAPI


class CategoryListAPI(BaseListAPI):
annotation_class = Category
serializer_class = CategorySerializer

def create(self, request, *args, **kwargs):
if self.project.single_class_classification:
self.get_queryset().delete()
return super().create(request, args, kwargs)


class CategoryDetailAPI(BaseDetailAPI):
queryset = Category.objects.all()
serializer_class = CategorySerializer

0 comments on commit 0cf4e31

Please sign in to comment.