Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Enhancement] Refactor Annotation API #1592

Merged
merged 4 commits into from
Dec 9, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
189 changes: 169 additions & 20 deletions backend/api/tests/api/test_annotation.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,29 @@
from rest_framework import status
from rest_framework.reverse import reverse

from ...models import DOCUMENT_CLASSIFICATION, SEQUENCE_LABELING, Category
from ...models import (DOCUMENT_CLASSIFICATION, SEQ2SEQ, SEQUENCE_LABELING,
Category, Span, TextLabel)
from .utils import (CRUDMixin, make_annotation, make_doc, make_label,
make_user, prepare_project)


class TestAnnotationList(CRUDMixin):
model = Category
task = DOCUMENT_CLASSIFICATION
view_name = 'annotation_list'

@classmethod
def setUpTestData(cls):
cls.project = prepare_project(task=DOCUMENT_CLASSIFICATION)
cls.project = prepare_project(task=cls.task)
cls.non_member = make_user()
doc = make_doc(cls.project.item)
for member in cls.project.users:
make_annotation(task=DOCUMENT_CLASSIFICATION, doc=doc, user=member)
cls.url = reverse(viewname='annotation_list', args=[cls.project.item.id, doc.id])
cls.make_annotation(doc, member)
cls.url = reverse(viewname=cls.view_name, args=[cls.project.item.id, doc.id])

@classmethod
def make_annotation(cls, doc, member):
make_annotation(cls.task, doc=doc, user=member)

def test_allows_project_member_to_fetch_annotation(self):
for member in self.project.users:
Expand All @@ -30,19 +38,48 @@ def test_denies_unauthenticated_user_to_fetch_annotation(self):

def test_allows_project_member_to_bulk_delete_annotation(self):
self.assert_delete(self.project.users[0], status.HTTP_204_NO_CONTENT)
count = Category.objects.count()
count = self.model.objects.count()
self.assertEqual(count, 2) # delete only own annotation


class TestCategoryList(TestAnnotationList):
model = Category
task = DOCUMENT_CLASSIFICATION
view_name = 'category_list'


class TestSpanList(TestAnnotationList):
model = Span
task = SEQUENCE_LABELING
view_name = 'span_list'

@classmethod
def make_annotation(cls, doc, member):
make_annotation(cls.task, doc=doc, user=member, start_offset=0, end_offset=1)


class TestTextList(TestAnnotationList):
model = TextLabel
task = SEQ2SEQ
view_name = 'text_list'


class TestSharedAnnotationList(CRUDMixin):
model = Category
task = DOCUMENT_CLASSIFICATION
view_name = 'annotation_list'

@classmethod
def setUpTestData(cls):
cls.project = prepare_project(task=DOCUMENT_CLASSIFICATION, collaborative_annotation=True)
cls.project = prepare_project(task=cls.task, collaborative_annotation=True)
doc = make_doc(cls.project.item)
for member in cls.project.users:
make_annotation(task=DOCUMENT_CLASSIFICATION, doc=doc, user=member)
cls.url = reverse(viewname='annotation_list', args=[cls.project.item.id, doc.id])
cls.make_annotation(doc, member)
cls.url = reverse(viewname=cls.view_name, args=[cls.project.item.id, doc.id])

@classmethod
def make_annotation(cls, doc, member):
make_annotation(cls.task, doc=doc, user=member)

def test_allows_project_member_to_fetch_all_annotation(self):
for member in self.project.users:
Expand All @@ -51,19 +88,54 @@ def test_allows_project_member_to_fetch_all_annotation(self):

def test_allows_project_member_to_bulk_delete_annotation(self):
self.assert_delete(self.project.users[0], status.HTTP_204_NO_CONTENT)
count = Category.objects.count()
count = self.model.objects.count()
self.assertEqual(count, 0) # delete all annotation in the doc


class TestSharedCategoryList(TestSharedAnnotationList):
model = Category
task = DOCUMENT_CLASSIFICATION
view_name = 'category_list'


class TestSharedSpanList(TestSharedAnnotationList):
model = Span
task = SEQUENCE_LABELING
view_name = 'span_list'
start_offset = 0

@classmethod
def make_annotation(cls, doc, member):
make_annotation(
cls.task,
doc=doc,
user=member,
start_offset=cls.start_offset,
end_offset=cls.start_offset + 1
)
cls.start_offset += 1


class TestSharedTextList(TestSharedAnnotationList):
model = TextLabel
task = SEQ2SEQ
view_name = 'text_list'


class TestAnnotationCreation(CRUDMixin):
task = DOCUMENT_CLASSIFICATION
view_name = 'annotation_list'

def setUp(self):
self.project = prepare_project(task=DOCUMENT_CLASSIFICATION)
self.project = prepare_project(task=self.task)
self.non_member = make_user()
doc = make_doc(self.project.item)
self.data = self.create_data()
self.url = reverse(viewname=self.view_name, args=[self.project.item.id, doc.id])

def create_data(self):
label = make_label(self.project.item)
self.data = {'label': label.id}
self.url = reverse(viewname='annotation_list', args=[self.project.item.id, doc.id])
return {'label': label.id}

def test_allows_project_member_to_annotate(self):
for member in self.project.users:
Expand All @@ -76,22 +148,48 @@ def test_denies_unauthenticated_user_to_annotate(self):
self.assert_create(expected=status.HTTP_403_FORBIDDEN)


class TestCategoryCreation(TestAnnotationCreation):
view_name = 'category_list'


class TestSpanCreation(TestAnnotationCreation):
task = SEQUENCE_LABELING
view_name = 'span_list'

def create_data(self):
label = make_label(self.project.item)
return {'label': label.id, 'start_offset': 0, 'end_offset': 1}


class TestTextLabelCreation(TestAnnotationCreation):
task = SEQ2SEQ
view_name = 'text_list'

def create_data(self):
return {'text': 'example'}


class TestAnnotationDetail(CRUDMixin):
task = SEQUENCE_LABELING
view_name = 'annotation_detail'

def setUp(self):
self.project = prepare_project(task=SEQUENCE_LABELING)
self.project = prepare_project(task=self.task)
self.non_member = make_user()
doc = make_doc(self.project.item)
label = make_label(self.project.item)
annotation = make_annotation(
task=SEQUENCE_LABELING,
annotation = self.create_annotation_data(doc=doc)
self.data = {'label': label.id}
self.url = reverse(viewname=self.view_name, args=[self.project.item.id, doc.id, annotation.id])

def create_annotation_data(self, doc):
return make_annotation(
task=self.task,
doc=doc,
user=self.project.users[0],
start_offset=0,
end_offset=1
)
self.data = {'label': label.id}
self.url = reverse(viewname='annotation_detail', args=[self.project.item.id, doc.id, annotation.id])

def test_allows_owner_to_get_annotation(self):
self.assert_fetch(self.project.users[0], status.HTTP_200_OK)
Expand Down Expand Up @@ -127,15 +225,45 @@ def test_denies_non_project_member_to_delete_annotation(self):
self.assert_delete(self.non_member, status.HTTP_403_FORBIDDEN)


class TestCategoryDetail(TestAnnotationDetail):
task = DOCUMENT_CLASSIFICATION
view_name = 'category_detail'

def create_annotation_data(self, doc):
return make_annotation(task=self.task, doc=doc, user=self.project.users[0])


class TestSpanDetail(TestAnnotationDetail):
task = SEQUENCE_LABELING
view_name = 'span_detail'


class TestTextDetail(TestAnnotationDetail):
task = SEQ2SEQ
view_name = 'text_detail'

def setUp(self):
super().setUp()
self.data = {'text': 'changed'}

def create_annotation_data(self, doc):
return make_annotation(task=self.task, doc=doc, user=self.project.users[0])


class TestSharedAnnotationDetail(CRUDMixin):
task = DOCUMENT_CLASSIFICATION
view_name = 'annotation_detail'

def setUp(self):
self.project = prepare_project(task=DOCUMENT_CLASSIFICATION, collaborative_annotation=True)
self.project = prepare_project(task=self.task, collaborative_annotation=True)
doc = make_doc(self.project.item)
annotation = make_annotation(task=DOCUMENT_CLASSIFICATION, doc=doc, user=self.project.users[0])
annotation = self.make_annotation(doc, self.project.users[0])
label = make_label(self.project.item)
self.data = {'label': label.id}
self.url = reverse(viewname='annotation_detail', args=[self.project.item.id, doc.id, annotation.id])
self.url = reverse(viewname=self.view_name, args=[self.project.item.id, doc.id, annotation.id])

def make_annotation(self, doc, member):
return make_annotation(self.task, doc=doc, user=member)

def test_allows_any_member_to_get_annotation(self):
for member in self.project.users:
Expand All @@ -147,3 +275,24 @@ def test_allows_any_member_to_update_annotation(self):

def test_allows_any_member_to_delete_annotation(self):
self.assert_delete(self.project.users[1], status.HTTP_204_NO_CONTENT)


class TestSharedCategoryDetail(TestSharedAnnotationDetail):
view_name = 'category_detail'


class TestSharedSpanDetail(TestSharedAnnotationDetail):
task = SEQUENCE_LABELING
view_name = 'span_detail'

def make_annotation(self, doc, member):
return make_annotation(self.task, doc=doc, user=member, start_offset=0, end_offset=1)


class TestSharedTextDetail(TestSharedAnnotationDetail):
task = SEQ2SEQ
view_name = 'text_detail'

def setUp(self):
super().setUp()
self.data = {'text': 'changed'}
31 changes: 31 additions & 0 deletions backend/api/urls.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
example, example_state, export_dataset, health,
import_dataset, import_export, label, project,
relation_types, role, statistics, tag, task, user)
from .views.tasks import category, span, text

urlpatterns_project = [
path(
Expand Down Expand Up @@ -112,6 +113,36 @@
view=annotation.AnnotationDetail.as_view(),
name='annotation_detail'
),
path(
route='examples/<int:example_id>/categories',
view=category.CategoryListAPI.as_view(),
name='category_list'
),
path(
route='examples/<int:example_id>/categories/<int:annotation_id>',
view=category.CategoryDetailAPI.as_view(),
name='category_detail'
),
path(
route='examples/<int:example_id>/spans',
view=span.SpanListAPI.as_view(),
name='span_list'
),
path(
route='examples/<int:example_id>/spans/<int:annotation_id>',
view=span.SpanDetailAPI.as_view(),
name='span_detail'
),
path(
route='examples/<int:example_id>/texts',
view=text.TextLabelListAPI.as_view(),
name='text_list'
),
path(
route='examples/<int:example_id>/texts/<int:annotation_id>',
view=text.TextLabelDetailAPI.as_view(),
name='text_detail'
),
path(
route='tags',
view=tag.TagList.as_view(),
Expand Down
Empty file.
57 changes: 57 additions & 0 deletions backend/api/views/tasks/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
from django.core.exceptions import ValidationError
from django.shortcuts import get_object_or_404
from rest_framework import generics, status
from rest_framework.permissions import IsAuthenticated
from rest_framework.response import Response

from ...models import Project
from ...permissions import IsInProjectOrAdmin, IsOwnAnnotation


class BaseListAPI(generics.ListCreateAPIView):
annotation_class = None
pagination_class = None
permission_classes = [IsAuthenticated & IsInProjectOrAdmin]
swagger_schema = None

@property
def project(self):
return get_object_or_404(Project, pk=self.kwargs['project_id'])

def get_queryset(self):
queryset = self.annotation_class.objects.filter(example=self.kwargs['example_id'])
if not self.project.collaborative_annotation:
queryset = queryset.filter(user=self.request.user)
return queryset

def create(self, request, *args, **kwargs):
request.data['example'] = self.kwargs['example_id']
try:
response = super().create(request, args, kwargs)
except ValidationError as err:
response = Response({'detail': err.messages}, status=status.HTTP_400_BAD_REQUEST)
return response

def perform_create(self, serializer):
serializer.save(example_id=self.kwargs['example_id'], user=self.request.user)

def delete(self, request, *args, **kwargs):
queryset = self.get_queryset()
queryset.all().delete()
return Response(status=status.HTTP_204_NO_CONTENT)


class BaseDetailAPI(generics.RetrieveUpdateDestroyAPIView):
lookup_url_kwarg = 'annotation_id'
swagger_schema = None

@property
def project(self):
return get_object_or_404(Project, pk=self.kwargs['project_id'])

def get_permissions(self):
if self.project.collaborative_annotation:
self.permission_classes = [IsAuthenticated & IsInProjectOrAdmin]
else:
self.permission_classes = [IsAuthenticated & IsInProjectOrAdmin & IsOwnAnnotation]
return super().get_permissions()
18 changes: 18 additions & 0 deletions backend/api/views/tasks/category.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
from ...models import Category
from ...serializers import CategorySerializer
from .base import BaseDetailAPI, BaseListAPI


class CategoryListAPI(BaseListAPI):
annotation_class = Category
serializer_class = CategorySerializer

def create(self, request, *args, **kwargs):
if self.project.single_class_classification:
self.get_queryset().delete()
return super().create(request, args, kwargs)


class CategoryDetailAPI(BaseDetailAPI):
queryset = Category.objects.all()
serializer_class = CategorySerializer