Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Refactoring] data import #1832

Merged
merged 4 commits into from
May 22, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
43 changes: 20 additions & 23 deletions backend/data_import/celery_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from django_drf_filepond.models import TemporaryUpload

from .datasets import load_dataset
from .pipeline.catalog import AudioFile, ImageFile
from .pipeline.catalog import Format, create_file_format
from .pipeline.exceptions import (
FileImportException,
FileTypeException,
Expand All @@ -19,21 +19,15 @@
from projects.models import Project


def check_file_type(filename, file_format: str, filepath: str):
def check_file_type(filename, file_format: Format, filepath: str):
if not settings.ENABLE_FILE_TYPE_CHECK:
return
kind = filetype.guess(filepath)
if file_format == ImageFile.name:
accept_types = ImageFile.accept_types.replace(" ", "").split(",")
elif file_format == AudioFile.name:
accept_types = AudioFile.accept_types.replace(" ", "").split(",")
else:
return
if kind.mime not in accept_types:
raise FileTypeException(filename, kind.mime, accept_types)
if not file_format.validate_mime(kind.mime):
raise FileTypeException(filename, kind.mime, file_format.accept_types)


def check_uploaded_files(upload_ids: List[str], file_format: str):
def check_uploaded_files(upload_ids: List[str], file_format: Format):
errors: List[FileImportException] = []
cleaned_ids = []
temporary_uploads = TemporaryUpload.objects.filter(upload_id__in=upload_ids)
Expand All @@ -56,19 +50,22 @@ def check_uploaded_files(upload_ids: List[str], file_format: str):
def import_dataset(user_id, project_id, file_format: str, upload_ids: List[str], task: str, **kwargs):
project = get_object_or_404(Project, pk=project_id)
user = get_object_or_404(get_user_model(), pk=user_id)
try:
fmt = create_file_format(file_format)
upload_ids, errors = check_uploaded_files(upload_ids, fmt)
temporary_uploads = TemporaryUpload.objects.filter(upload_id__in=upload_ids)
filenames = [
FileName(full_path=tu.get_file_path(), generated_name=tu.file.name, upload_name=tu.upload_name)
for tu in temporary_uploads
]

upload_ids, errors = check_uploaded_files(upload_ids, file_format)
temporary_uploads = TemporaryUpload.objects.filter(upload_id__in=upload_ids)
filenames = [
FileName(full_path=tu.get_file_path(), generated_name=tu.file.name, upload_name=tu.upload_name)
for tu in temporary_uploads
]

dataset = load_dataset(task, file_format, filenames, project, **kwargs)
dataset.save(user, batch_size=settings.IMPORT_BATCH_SIZE)
upload_to_store(temporary_uploads)
errors.extend(dataset.errors)
return {"error": [e.dict() for e in errors]}
dataset = load_dataset(task, fmt, filenames, project, **kwargs)
dataset.save(user, batch_size=settings.IMPORT_BATCH_SIZE)
upload_to_store(temporary_uploads)
errors.extend(dataset.errors)
return {"error": [e.dict() for e in errors]}
except FileImportException as e:
return {"error": [e.dict()]}


def upload_to_store(temporary_uploads):
Expand Down
8 changes: 4 additions & 4 deletions backend/data_import/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from django.contrib.auth.models import User

from .models import DummyLabelType
from .pipeline.catalog import RELATION_EXTRACTION, TextFile, TextLine
from .pipeline.catalog import RELATION_EXTRACTION, Format
from .pipeline.data import BaseData, BinaryData, TextData
from .pipeline.exceptions import FileParseException
from .pipeline.factories import create_parser
Expand Down Expand Up @@ -210,7 +210,7 @@ def errors(self) -> List[FileParseException]:
return self.reader.errors + self.example_maker.errors + self.category_maker.errors + self.span_maker.errors


def select_dataset(project: Project, task: str, file_format: str) -> Type[Dataset]:
def select_dataset(project: Project, task: str, file_format: Format) -> Type[Dataset]:
mapping = {
DOCUMENT_CLASSIFICATION: TextClassificationDataset,
SEQUENCE_LABELING: SequenceLabelingDataset,
Expand All @@ -222,12 +222,12 @@ def select_dataset(project: Project, task: str, file_format: str) -> Type[Datase
}
if task not in mapping:
task = project.project_type
if project.is_text_project and file_format in [TextLine.name, TextFile.name]:
if project.is_text_project and file_format.is_plain_text():
return PlainDataset
return mapping[task]


def load_dataset(task: str, file_format: str, data_files: List[FileName], project: Project, **kwargs) -> Dataset:
def load_dataset(task: str, file_format: Format, data_files: List[FileName], project: Project, **kwargs) -> Dataset:
parser = create_parser(file_format, **kwargs)
reader = Reader(data_files, parser)
dataset_class = select_dataset(project, task, file_format)
Expand Down
29 changes: 29 additions & 0 deletions backend/data_import/pipeline/catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from pydantic import BaseModel
from typing_extensions import Literal

from .exceptions import FileFormatException
from projects.models import (
DOCUMENT_CLASSIFICATION,
IMAGE_CLASSIFICATION,
Expand Down Expand Up @@ -140,6 +141,13 @@ class Format:
def dict(cls):
return {"name": cls.name, "accept_types": cls.accept_types}

def validate_mime(self, mime: str):
return True

@staticmethod
def is_plain_text():
return False


class CSV(Format):
name = "CSV"
Expand Down Expand Up @@ -170,11 +178,19 @@ class TextFile(Format):
name = "TextFile"
accept_types = "text/*"

@staticmethod
def is_plain_text():
return True


class TextLine(Format):
name = "TextLine"
accept_types = "text/*"

@staticmethod
def is_plain_text():
return True


class CoNLL(Format):
name = "CoNLL"
Expand All @@ -185,11 +201,17 @@ class ImageFile(Format):
name = "ImageFile"
accept_types = "image/png, image/jpeg, image/bmp, image/gif"

def validate_mime(self, mime: str):
return mime in self.accept_types


class AudioFile(Format):
name = "AudioFile"
accept_types = "audio/ogg, audio/aac, audio/mpeg, audio/wav"

def validate_mime(self, mime: str):
return mime in self.accept_types


class ArgColumn(BaseModel):
encoding: encodings = "utf_8"
Expand Down Expand Up @@ -239,6 +261,13 @@ def dict(self) -> Dict:
}


def create_file_format(file_format: str) -> Format:
for format_class in Format.__subclasses__():
if format_class.name == file_format:
return format_class()
raise FileFormatException(file_format)


class Options:
options: Dict[str, List] = defaultdict(list)

Expand Down
9 changes: 9 additions & 0 deletions backend/data_import/pipeline/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,3 +42,12 @@ def __str__(self):

def dict(self):
return {"filename": self.filename, "line": -1, "message": str(self)}


class FileFormatException(FileImportException):
def __init__(self, file_format: str):
self.file_format = file_format

def dict(self):
message = f"Unknown file format: {self.file_format}"
return {"message": message}
7 changes: 3 additions & 4 deletions backend/data_import/pipeline/factories.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
CoNLL,
Excel,
FastText,
Format,
ImageFile,
TextFile,
TextLine,
Expand All @@ -23,7 +24,7 @@
)


def create_parser(file_format: str, **kwargs):
def create_parser(file_format: Format, **kwargs):
mapping = {
TextFile.name: TextFileParser,
TextLine.name: LineParser,
Expand All @@ -36,6 +37,4 @@ def create_parser(file_format: str, **kwargs):
ImageFile.name: PlainParser,
AudioFile.name: PlainParser,
}
if file_format not in mapping:
raise ValueError(f"Invalid format: {file_format}")
return mapping[file_format](**kwargs)
return mapping[file_format.name](**kwargs)
45 changes: 21 additions & 24 deletions backend/data_import/pipeline/label.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import uuid
from typing import Any, Optional

from pydantic import UUID4, BaseModel, validator
from pydantic import UUID4, BaseModel, ConstrainedStr, NonNegativeInt, root_validator

from .label_types import LabelTypes
from examples.models import Example
Expand All @@ -15,6 +15,10 @@
from projects.models import Project


class NonEmptyStr(ConstrainedStr):
min_length = 1


class Label(BaseModel, abc.ABC):
id: int = -1
uuid: UUID4
Expand Down Expand Up @@ -45,18 +49,11 @@ def __hash__(self):


class CategoryLabel(Label):
label: str
label: NonEmptyStr

def __lt__(self, other):
return self.label < other.label

@validator("label")
def label_is_not_empty(cls, value: str):
if value:
return value
else:
raise ValueError("is not empty.")

@classmethod
def parse(cls, example_uuid: UUID4, obj: Any):
return cls(example_uuid=example_uuid, label=obj)
Expand All @@ -65,17 +62,24 @@ def create_type(self, project: Project) -> Optional[LabelType]:
return CategoryType(text=self.label, project=project)

def create(self, user, example: Example, types: LabelTypes, **kwargs):
return CategoryModel(uuid=self.uuid, user=user, example=example, label=types.get_by_text(self.label))
return CategoryModel(uuid=self.uuid, user=user, example=example, label=types[self.label])


class SpanLabel(Label):
label: str
start_offset: int
end_offset: int
label: NonEmptyStr
start_offset: NonNegativeInt
end_offset: NonNegativeInt

def __lt__(self, other):
return self.start_offset < other.start_offset

@root_validator
def check_start_offset_is_less_than_end_offset(cls, values):
start_offset, end_offset = values.get("start_offset"), values.get("end_offset")
if start_offset >= end_offset:
raise ValueError("start_offset must be less than end_offset.")
return values

@classmethod
def parse(cls, example_uuid: UUID4, obj: Any):
if isinstance(obj, list) or isinstance(obj, tuple):
Expand All @@ -96,23 +100,16 @@ def create(self, user, example: Example, types: LabelTypes, **kwargs):
example=example,
start_offset=self.start_offset,
end_offset=self.end_offset,
label=types.get_by_text(self.label),
label=types[self.label],
)


class TextLabel(Label):
text: str
text: NonEmptyStr

def __lt__(self, other):
return self.text < other.text

@validator("text")
def text_is_not_empty(cls, value: str):
if value:
return value
else:
raise ValueError("is not empty.")

@classmethod
def parse(cls, example_uuid: UUID4, obj: Any):
return cls(example_uuid=example_uuid, text=obj)
Expand All @@ -127,7 +124,7 @@ def create(self, user, example: Example, types: LabelTypes, **kwargs):
class RelationLabel(Label):
from_id: int
to_id: int
type: str
type: NonEmptyStr

def __lt__(self, other):
return self.from_id < other.from_id
Expand All @@ -144,7 +141,7 @@ def create(self, user, example: Example, types: LabelTypes, **kwargs):
uuid=self.uuid,
user=user,
example=example,
type=types.get_by_text(self.type),
type=types[self.type],
from_id=kwargs["id_to_span"][self.from_id],
to_id=kwargs["id_to_span"][self.to_id],
)
6 changes: 3 additions & 3 deletions backend/data_import/pipeline/label_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,12 @@ def __init__(self, label_type_class: Type[LabelType]):
def __contains__(self, text: str) -> bool:
return text in self.types

def __getitem__(self, text: str) -> LabelType:
return self.types[text]

def save(self, label_types: List[LabelType]):
self.label_type_class.objects.bulk_create(label_types, ignore_conflicts=True)

def update(self, project: Project):
types = self.label_type_class.objects.filter(project=project)
self.types = {label_type.text: label_type for label_type in types}

def get_by_text(self, text: str) -> LabelType:
return self.types[text]
16 changes: 12 additions & 4 deletions backend/data_import/tests/test_label.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def test_create_type(self):
def test_create(self):
category = CategoryLabel(label="A", example_uuid=uuid.uuid4())
types = MagicMock()
types.get_by_text.return_value = mommy.make(CategoryType, project=self.project.item)
types.__getitem__.return_value = mommy.make(CategoryType, project=self.project.item)
category_model = category.create(self.user, self.example, types)
self.assertIsInstance(category_model, CategoryModel)

Expand All @@ -65,7 +65,7 @@ class TestSpanLabel(TestLabel):

def test_comparison(self):
span1 = SpanLabel(label="A", start_offset=0, end_offset=1, example_uuid=uuid.uuid4())
span2 = SpanLabel(label="A", start_offset=1, end_offset=1, example_uuid=uuid.uuid4())
span2 = SpanLabel(label="A", start_offset=1, end_offset=2, example_uuid=uuid.uuid4())
self.assertLess(span1, span2)

def test_parse_tuple(self):
Expand All @@ -82,6 +82,14 @@ def test_parse_dict(self):
self.assertEqual(span.start_offset, 0)
self.assertEqual(span.end_offset, 1)

def test_invalid_negative_offset(self):
with self.assertRaises(ValueError):
SpanLabel(label="A", start_offset=-1, end_offset=1, example_uuid=uuid.uuid4())

def test_invalid_offset(self):
with self.assertRaises(ValueError):
SpanLabel(label="A", start_offset=1, end_offset=0, example_uuid=uuid.uuid4())

def test_parse_invalid_dict(self):
example_uuid = uuid.uuid4()
with self.assertRaises(ValueError):
Expand All @@ -96,7 +104,7 @@ def test_create_type(self):
def test_create(self):
span = SpanLabel(label="A", start_offset=0, end_offset=1, example_uuid=uuid.uuid4())
types = MagicMock()
types.get_by_text.return_value = mommy.make(SpanType, project=self.project.item)
types.__getitem__.return_value = mommy.make(SpanType, project=self.project.item)
span_model = span.create(self.user, self.example, types)
self.assertIsInstance(span_model, SpanModel)

Expand Down Expand Up @@ -160,7 +168,7 @@ def test_create_type(self):
def test_create(self):
relation = RelationLabel(type="A", from_id=0, to_id=1, example_uuid=uuid.uuid4())
types = MagicMock()
types.get_by_text.return_value = mommy.make(RelationType, project=self.project.item)
types.__getitem__.return_value = mommy.make(RelationType, project=self.project.item)
id_to_span = {
0: mommy.make(SpanModel, start_offset=0, end_offset=1),
1: mommy.make(SpanModel, start_offset=2, end_offset=3),
Expand Down