Skip to content

Commit

Permalink
Merge pull request #1832 from doccano/refactoring/dataImport
Browse files Browse the repository at this point in the history
[Refactoring] data import
  • Loading branch information
Hironsan committed May 22, 2022
2 parents 65c5596 + f3b40eb commit 249926b
Show file tree
Hide file tree
Showing 10 changed files with 112 additions and 65 deletions.
43 changes: 20 additions & 23 deletions backend/data_import/celery_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from django_drf_filepond.models import TemporaryUpload

from .datasets import load_dataset
from .pipeline.catalog import AudioFile, ImageFile
from .pipeline.catalog import Format, create_file_format
from .pipeline.exceptions import (
FileImportException,
FileTypeException,
Expand All @@ -19,21 +19,15 @@
from projects.models import Project


def check_file_type(filename, file_format: str, filepath: str):
def check_file_type(filename, file_format: Format, filepath: str):
if not settings.ENABLE_FILE_TYPE_CHECK:
return
kind = filetype.guess(filepath)
if file_format == ImageFile.name:
accept_types = ImageFile.accept_types.replace(" ", "").split(",")
elif file_format == AudioFile.name:
accept_types = AudioFile.accept_types.replace(" ", "").split(",")
else:
return
if kind.mime not in accept_types:
raise FileTypeException(filename, kind.mime, accept_types)
if not file_format.validate_mime(kind.mime):
raise FileTypeException(filename, kind.mime, file_format.accept_types)


def check_uploaded_files(upload_ids: List[str], file_format: str):
def check_uploaded_files(upload_ids: List[str], file_format: Format):
errors: List[FileImportException] = []
cleaned_ids = []
temporary_uploads = TemporaryUpload.objects.filter(upload_id__in=upload_ids)
Expand All @@ -56,19 +50,22 @@ def check_uploaded_files(upload_ids: List[str], file_format: str):
def import_dataset(user_id, project_id, file_format: str, upload_ids: List[str], task: str, **kwargs):
project = get_object_or_404(Project, pk=project_id)
user = get_object_or_404(get_user_model(), pk=user_id)
try:
fmt = create_file_format(file_format)
upload_ids, errors = check_uploaded_files(upload_ids, fmt)
temporary_uploads = TemporaryUpload.objects.filter(upload_id__in=upload_ids)
filenames = [
FileName(full_path=tu.get_file_path(), generated_name=tu.file.name, upload_name=tu.upload_name)
for tu in temporary_uploads
]

upload_ids, errors = check_uploaded_files(upload_ids, file_format)
temporary_uploads = TemporaryUpload.objects.filter(upload_id__in=upload_ids)
filenames = [
FileName(full_path=tu.get_file_path(), generated_name=tu.file.name, upload_name=tu.upload_name)
for tu in temporary_uploads
]

dataset = load_dataset(task, file_format, filenames, project, **kwargs)
dataset.save(user, batch_size=settings.IMPORT_BATCH_SIZE)
upload_to_store(temporary_uploads)
errors.extend(dataset.errors)
return {"error": [e.dict() for e in errors]}
dataset = load_dataset(task, fmt, filenames, project, **kwargs)
dataset.save(user, batch_size=settings.IMPORT_BATCH_SIZE)
upload_to_store(temporary_uploads)
errors.extend(dataset.errors)
return {"error": [e.dict() for e in errors]}
except FileImportException as e:
return {"error": [e.dict()]}


def upload_to_store(temporary_uploads):
Expand Down
8 changes: 4 additions & 4 deletions backend/data_import/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from django.contrib.auth.models import User

from .models import DummyLabelType
from .pipeline.catalog import RELATION_EXTRACTION, TextFile, TextLine
from .pipeline.catalog import RELATION_EXTRACTION, Format
from .pipeline.data import BaseData, BinaryData, TextData
from .pipeline.exceptions import FileParseException
from .pipeline.factories import create_parser
Expand Down Expand Up @@ -210,7 +210,7 @@ def errors(self) -> List[FileParseException]:
return self.reader.errors + self.example_maker.errors + self.category_maker.errors + self.span_maker.errors


def select_dataset(project: Project, task: str, file_format: str) -> Type[Dataset]:
def select_dataset(project: Project, task: str, file_format: Format) -> Type[Dataset]:
mapping = {
DOCUMENT_CLASSIFICATION: TextClassificationDataset,
SEQUENCE_LABELING: SequenceLabelingDataset,
Expand All @@ -222,12 +222,12 @@ def select_dataset(project: Project, task: str, file_format: str) -> Type[Datase
}
if task not in mapping:
task = project.project_type
if project.is_text_project and file_format in [TextLine.name, TextFile.name]:
if project.is_text_project and file_format.is_plain_text():
return PlainDataset
return mapping[task]


def load_dataset(task: str, file_format: str, data_files: List[FileName], project: Project, **kwargs) -> Dataset:
def load_dataset(task: str, file_format: Format, data_files: List[FileName], project: Project, **kwargs) -> Dataset:
parser = create_parser(file_format, **kwargs)
reader = Reader(data_files, parser)
dataset_class = select_dataset(project, task, file_format)
Expand Down
29 changes: 29 additions & 0 deletions backend/data_import/pipeline/catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from pydantic import BaseModel
from typing_extensions import Literal

from .exceptions import FileFormatException
from projects.models import (
DOCUMENT_CLASSIFICATION,
IMAGE_CLASSIFICATION,
Expand Down Expand Up @@ -140,6 +141,13 @@ class Format:
def dict(cls):
return {"name": cls.name, "accept_types": cls.accept_types}

def validate_mime(self, mime: str):
return True

@staticmethod
def is_plain_text():
return False


class CSV(Format):
name = "CSV"
Expand Down Expand Up @@ -170,11 +178,19 @@ class TextFile(Format):
name = "TextFile"
accept_types = "text/*"

@staticmethod
def is_plain_text():
return True


class TextLine(Format):
name = "TextLine"
accept_types = "text/*"

@staticmethod
def is_plain_text():
return True


class CoNLL(Format):
name = "CoNLL"
Expand All @@ -185,11 +201,17 @@ class ImageFile(Format):
name = "ImageFile"
accept_types = "image/png, image/jpeg, image/bmp, image/gif"

def validate_mime(self, mime: str):
return mime in self.accept_types


class AudioFile(Format):
name = "AudioFile"
accept_types = "audio/ogg, audio/aac, audio/mpeg, audio/wav"

def validate_mime(self, mime: str):
return mime in self.accept_types


class ArgColumn(BaseModel):
encoding: encodings = "utf_8"
Expand Down Expand Up @@ -239,6 +261,13 @@ def dict(self) -> Dict:
}


def create_file_format(file_format: str) -> Format:
for format_class in Format.__subclasses__():
if format_class.name == file_format:
return format_class()
raise FileFormatException(file_format)


class Options:
options: Dict[str, List] = defaultdict(list)

Expand Down
9 changes: 9 additions & 0 deletions backend/data_import/pipeline/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,3 +42,12 @@ def __str__(self):

def dict(self):
return {"filename": self.filename, "line": -1, "message": str(self)}


class FileFormatException(FileImportException):
def __init__(self, file_format: str):
self.file_format = file_format

def dict(self):
message = f"Unknown file format: {self.file_format}"
return {"message": message}
7 changes: 3 additions & 4 deletions backend/data_import/pipeline/factories.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
CoNLL,
Excel,
FastText,
Format,
ImageFile,
TextFile,
TextLine,
Expand All @@ -23,7 +24,7 @@
)


def create_parser(file_format: str, **kwargs):
def create_parser(file_format: Format, **kwargs):
mapping = {
TextFile.name: TextFileParser,
TextLine.name: LineParser,
Expand All @@ -36,6 +37,4 @@ def create_parser(file_format: str, **kwargs):
ImageFile.name: PlainParser,
AudioFile.name: PlainParser,
}
if file_format not in mapping:
raise ValueError(f"Invalid format: {file_format}")
return mapping[file_format](**kwargs)
return mapping[file_format.name](**kwargs)
45 changes: 21 additions & 24 deletions backend/data_import/pipeline/label.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import uuid
from typing import Any, Optional

from pydantic import UUID4, BaseModel, validator
from pydantic import UUID4, BaseModel, ConstrainedStr, NonNegativeInt, root_validator

from .label_types import LabelTypes
from examples.models import Example
Expand All @@ -15,6 +15,10 @@
from projects.models import Project


class NonEmptyStr(ConstrainedStr):
min_length = 1


class Label(BaseModel, abc.ABC):
id: int = -1
uuid: UUID4
Expand Down Expand Up @@ -45,18 +49,11 @@ def __hash__(self):


class CategoryLabel(Label):
label: str
label: NonEmptyStr

def __lt__(self, other):
return self.label < other.label

@validator("label")
def label_is_not_empty(cls, value: str):
if value:
return value
else:
raise ValueError("is not empty.")

@classmethod
def parse(cls, example_uuid: UUID4, obj: Any):
return cls(example_uuid=example_uuid, label=obj)
Expand All @@ -65,17 +62,24 @@ def create_type(self, project: Project) -> Optional[LabelType]:
return CategoryType(text=self.label, project=project)

def create(self, user, example: Example, types: LabelTypes, **kwargs):
return CategoryModel(uuid=self.uuid, user=user, example=example, label=types.get_by_text(self.label))
return CategoryModel(uuid=self.uuid, user=user, example=example, label=types[self.label])


class SpanLabel(Label):
label: str
start_offset: int
end_offset: int
label: NonEmptyStr
start_offset: NonNegativeInt
end_offset: NonNegativeInt

def __lt__(self, other):
return self.start_offset < other.start_offset

@root_validator
def check_start_offset_is_less_than_end_offset(cls, values):
start_offset, end_offset = values.get("start_offset"), values.get("end_offset")
if start_offset >= end_offset:
raise ValueError("start_offset must be less than end_offset.")
return values

@classmethod
def parse(cls, example_uuid: UUID4, obj: Any):
if isinstance(obj, list) or isinstance(obj, tuple):
Expand All @@ -96,23 +100,16 @@ def create(self, user, example: Example, types: LabelTypes, **kwargs):
example=example,
start_offset=self.start_offset,
end_offset=self.end_offset,
label=types.get_by_text(self.label),
label=types[self.label],
)


class TextLabel(Label):
text: str
text: NonEmptyStr

def __lt__(self, other):
return self.text < other.text

@validator("text")
def text_is_not_empty(cls, value: str):
if value:
return value
else:
raise ValueError("is not empty.")

@classmethod
def parse(cls, example_uuid: UUID4, obj: Any):
return cls(example_uuid=example_uuid, text=obj)
Expand All @@ -127,7 +124,7 @@ def create(self, user, example: Example, types: LabelTypes, **kwargs):
class RelationLabel(Label):
from_id: int
to_id: int
type: str
type: NonEmptyStr

def __lt__(self, other):
return self.from_id < other.from_id
Expand All @@ -144,7 +141,7 @@ def create(self, user, example: Example, types: LabelTypes, **kwargs):
uuid=self.uuid,
user=user,
example=example,
type=types.get_by_text(self.type),
type=types[self.type],
from_id=kwargs["id_to_span"][self.from_id],
to_id=kwargs["id_to_span"][self.to_id],
)
6 changes: 3 additions & 3 deletions backend/data_import/pipeline/label_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,12 @@ def __init__(self, label_type_class: Type[LabelType]):
def __contains__(self, text: str) -> bool:
return text in self.types

def __getitem__(self, text: str) -> LabelType:
return self.types[text]

def save(self, label_types: List[LabelType]):
self.label_type_class.objects.bulk_create(label_types, ignore_conflicts=True)

def update(self, project: Project):
types = self.label_type_class.objects.filter(project=project)
self.types = {label_type.text: label_type for label_type in types}

def get_by_text(self, text: str) -> LabelType:
return self.types[text]
16 changes: 12 additions & 4 deletions backend/data_import/tests/test_label.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def test_create_type(self):
def test_create(self):
category = CategoryLabel(label="A", example_uuid=uuid.uuid4())
types = MagicMock()
types.get_by_text.return_value = mommy.make(CategoryType, project=self.project.item)
types.__getitem__.return_value = mommy.make(CategoryType, project=self.project.item)
category_model = category.create(self.user, self.example, types)
self.assertIsInstance(category_model, CategoryModel)

Expand All @@ -65,7 +65,7 @@ class TestSpanLabel(TestLabel):

def test_comparison(self):
span1 = SpanLabel(label="A", start_offset=0, end_offset=1, example_uuid=uuid.uuid4())
span2 = SpanLabel(label="A", start_offset=1, end_offset=1, example_uuid=uuid.uuid4())
span2 = SpanLabel(label="A", start_offset=1, end_offset=2, example_uuid=uuid.uuid4())
self.assertLess(span1, span2)

def test_parse_tuple(self):
Expand All @@ -82,6 +82,14 @@ def test_parse_dict(self):
self.assertEqual(span.start_offset, 0)
self.assertEqual(span.end_offset, 1)

def test_invalid_negative_offset(self):
with self.assertRaises(ValueError):
SpanLabel(label="A", start_offset=-1, end_offset=1, example_uuid=uuid.uuid4())

def test_invalid_offset(self):
with self.assertRaises(ValueError):
SpanLabel(label="A", start_offset=1, end_offset=0, example_uuid=uuid.uuid4())

def test_parse_invalid_dict(self):
example_uuid = uuid.uuid4()
with self.assertRaises(ValueError):
Expand All @@ -96,7 +104,7 @@ def test_create_type(self):
def test_create(self):
span = SpanLabel(label="A", start_offset=0, end_offset=1, example_uuid=uuid.uuid4())
types = MagicMock()
types.get_by_text.return_value = mommy.make(SpanType, project=self.project.item)
types.__getitem__.return_value = mommy.make(SpanType, project=self.project.item)
span_model = span.create(self.user, self.example, types)
self.assertIsInstance(span_model, SpanModel)

Expand Down Expand Up @@ -160,7 +168,7 @@ def test_create_type(self):
def test_create(self):
relation = RelationLabel(type="A", from_id=0, to_id=1, example_uuid=uuid.uuid4())
types = MagicMock()
types.get_by_text.return_value = mommy.make(RelationType, project=self.project.item)
types.__getitem__.return_value = mommy.make(RelationType, project=self.project.item)
id_to_span = {
0: mommy.make(SpanModel, start_offset=0, end_offset=1),
1: mommy.make(SpanModel, start_offset=2, end_offset=3),
Expand Down

0 comments on commit 249926b

Please sign in to comment.