Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Added WordGenerator dataset #760

Merged
merged 8 commits into from
Dec 27, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 20 additions & 3 deletions docs/source/datasets.rst
Original file line number Diff line number Diff line change
Expand Up @@ -13,19 +13,36 @@ Available Datasets
------------------
Here are all datasets that are available through docTR:


Public datasets
^^^^^^^^^^^^^^^

.. autoclass:: FUNSD
.. autoclass:: SROIE
.. autoclass:: CORD
.. autoclass:: OCRDataset
.. autoclass:: CharacterGenerator
.. autoclass:: DocArtefacts
.. autoclass:: IIIT5K
.. autoclass:: SVT
.. autoclass:: SVHN
.. autoclass:: SynthText
.. autoclass:: IC03
.. autoclass:: IC13

docTR synthetic datasets
^^^^^^^^^^^^^^^^^^^^^^^^

.. autoclass:: DocArtefacts
.. autoclass:: CharacterGenerator
.. autoclass:: WordGenerator

docTR private datasets
^^^^^^^^^^^^^^^^^^^^^^

Since many documents include sensitive / personal information, we are not able to share all the data that has been used for this project. However, we provide some guidance on how to format your own dataset into the same format so that you can use all docTR tools all the same.

.. autoclass:: DetectionDataset
.. autoclass:: RecognitionDataset
.. autoclass:: OCRDataset


Data Loading
------------
Expand Down
2 changes: 1 addition & 1 deletion doctr/datasets/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from doctr.file_utils import is_tf_available

from .classification import *
from .generator import *
from .cord import *
from .detection import *
from .doc_artefacts import *
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,47 +16,40 @@

def synthesize_text_img(
text: str,
img_size: Optional[Tuple[int, int]] = None,
font_size: Optional[int] = None,
font_size: int = 32,
font_family: Optional[str] = None,
background_color: Optional[Tuple[int, int, int]] = None,
text_color: Optional[Tuple[int, int, int]] = None,
text_pos: Optional[Tuple[int, int]] = None,
) -> Image:
"""Generate a synthetic character image with black background and white text
"""Generate a synthetic text image

Args:
text: the character to render as an image
img_size: the size of the rendered image
text: the text to render as an image
font_size: the size of the font
font_family: the font family (has to be installed on your system)
background_color: background color of the final image
text_color: text color on the final image
text_pos: offset of the text

Returns:
PIL image of the character
PIL image of the text
"""

background_color = (0, 0, 0) if background_color is None else background_color
text_color = (255, 255, 255) if text_color is None else text_color
default_h = 32
if font_size is None:
font_size = int(0.9 * default_h) if img_size is None else int(0.9 * img_size[0])

font = get_font(font_family, font_size)
text_size = font.getsize(text)
if img_size is None:
img_size = (default_h, text_size[0] if len(text) > 1 else default_h)
text_w, text_h = font.getsize(text)
h, w = int(round(1.3 * text_h)), int(round(1.1 * text_w))
# If single letter, make the image square, otherwise expand to meet the text size
img_size = (h, w) if len(text) > 1 else (max(h, w), max(h, w))

img = Image.new('RGB', img_size[::-1], color=background_color)
d = ImageDraw.Draw(img)

# Draw the character
if text_pos is None:
text_pos = (0, 0) if text_size[0] >= img_size[1] else (int(round(img_size[0] * 3 / 16)), 0)
# Offset so that the text is centered
text_pos = (int(round((img_size[1] - text_w) / 2)), int(round((img_size[0] - text_h) / 2)))
# Draw the text
d.text(text_pos, text, font=font, fill=text_color)

return img


Expand Down Expand Up @@ -105,3 +98,57 @@ def _read_sample(self, index: int) -> Tuple[Any, int]:
img = tensor_from_pil(pil_img)

return img, target


class _WordGenerator(AbstractDataset):

def __init__(
self,
vocab: str,
min_chars: int,
max_chars: int,
num_samples: int,
cache_samples: bool = False,
font_family: Optional[Union[str, List[str]]] = None,
img_transforms: Optional[Callable[[Any], Any]] = None,
sample_transforms: Optional[Callable[[Any, Any], Tuple[Any, Any]]] = None,
) -> None:
self.vocab = vocab
self.wordlen_range = (min_chars, max_chars)
self._num_samples = num_samples
self.font_family = font_family if isinstance(font_family, list) else [font_family] # type: ignore[list-item]
# Validate fonts
if isinstance(font_family, list):
for font in self.font_family:
try:
_ = get_font(font, 10)
except OSError:
raise ValueError(f"unable to locate font: {font}")
self.img_transforms = img_transforms
self.sample_transforms = sample_transforms

self._data: List[Image.Image] = []
if cache_samples:
_words = [self._generate_string(*self.wordlen_range) for _ in range(num_samples)]
self._data = [
(synthesize_text_img(text, font_family=random.choice(self.font_family)), text)
for text in _words
]

def _generate_string(self, min_chars: int, max_chars: int) -> str:
num_chars = random.randint(min_chars, max_chars)
return "".join(random.choice(self.vocab) for _ in range(num_chars))

def __len__(self) -> int:
return self._num_samples

def _read_sample(self, index: int) -> Tuple[Any, str]:
# Samples are already cached
if len(self._data) > 0:
pil_img, target = self._data[index]
else:
target = self._generate_string(*self.wordlen_range)
pil_img = synthesize_text_img(target, font_family=random.choice(self.font_family))
img = tensor_from_pil(pil_img)

return img, target
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@

from torch.utils.data._utils.collate import default_collate

from .base import _CharacterGenerator
from .base import _CharacterGenerator, _WordGenerator

__all__ = ['CharacterGenerator']
__all__ = ['CharacterGenerator', 'WordGenerator']


class CharacterGenerator(_CharacterGenerator):
Expand All @@ -30,3 +30,25 @@ class CharacterGenerator(_CharacterGenerator):
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
setattr(self, 'collate_fn', default_collate)


class WordGenerator(_WordGenerator):
"""Implements a character image generation dataset

Example::
>>> from doctr.datasets import WordGenerator
>>> ds = WordGenerator(vocab='abdef')
>>> img, target = ds[0]

Args:
vocab: vocabulary to take the character from
min_chars: minimum number of characters in a word
max_chars: maximum number of characters in a word
num_samples: number of samples that will be generated iterating over the dataset
cache_samples: whether generated images should be cached firsthand
font_family: font to use to generate the text images
img_transforms: composable transformations that will be applied to each image
sample_transforms: composable transformations that will be applied to both the image and the target
"""

pass
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@

import tensorflow as tf

from .base import _CharacterGenerator
from .base import _CharacterGenerator, _WordGenerator

__all__ = ['CharacterGenerator']
__all__ = ['CharacterGenerator', 'WordGenerator']


class CharacterGenerator(_CharacterGenerator):
Expand Down Expand Up @@ -37,3 +37,25 @@ def collate_fn(samples):
images = tf.stack(images, axis=0)

return images, tf.convert_to_tensor(targets)


class WordGenerator(_WordGenerator):
"""Implements a character image generation dataset

Example::
>>> from doctr.datasets import WordGenerator
>>> ds = WordGenerator(vocab='abdef')
>>> img, target = ds[0]

Args:
vocab: vocabulary to take the character from
min_chars: minimum number of characters in a word
max_chars: maximum number of characters in a word
num_samples: number of samples that will be generated iterating over the dataset
cache_samples: whether generated images should be cached firsthand
font_family: font to use to generate the text images
img_transforms: composable transformations that will be applied to each image
sample_transforms: composable transformations that will be applied to both the image and the target
"""

pass
29 changes: 29 additions & 0 deletions tests/pytorch/test_datasets_pt.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,35 @@ def test_charactergenerator():
assert targets.dtype == torch.int64


def test_wordgenerator():

input_size = (32, 128)
wordlen_range = (1, 10)
vocab = 'abcdef'

ds = datasets.WordGenerator(
vocab=vocab,
min_chars=wordlen_range[0],
max_chars=wordlen_range[1],
num_samples=10,
cache_samples=True,
img_transforms=Resize(input_size),
)

assert len(ds) == 10
image, target = ds[0]
assert isinstance(image, torch.Tensor)
assert image.shape[-2:] == input_size
assert image.dtype == torch.float32
assert isinstance(target, str) and len(target) >= wordlen_range[0] and len(target) <= wordlen_range[1]
assert all(char in vocab for char in target)

loader = DataLoader(ds, batch_size=2, collate_fn=ds.collate_fn)
images, targets = next(iter(loader))
assert isinstance(images, torch.Tensor) and images.shape == (2, 3, *input_size)
assert isinstance(targets, list) and len(targets) == 2 and all(isinstance(t, str) for t in targets)


@pytest.mark.parametrize(
"num_samples, rotate",
[
Expand Down
29 changes: 29 additions & 0 deletions tests/tensorflow/test_datasets_tf.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,35 @@ def test_charactergenerator():
assert targets.dtype == tf.int32


def test_wordgenerator():

input_size = (32, 128)
wordlen_range = (1, 10)
vocab = 'abcdef'

ds = datasets.WordGenerator(
vocab=vocab,
min_chars=wordlen_range[0],
max_chars=wordlen_range[1],
num_samples=10,
cache_samples=True,
img_transforms=Resize(input_size),
)

assert len(ds) == 10
image, target = ds[0]
assert isinstance(image, tf.Tensor)
assert image.shape[:2] == input_size
assert image.dtype == tf.float32
assert isinstance(target, str) and len(target) >= wordlen_range[0] and len(target) <= wordlen_range[1]
assert all(char in vocab for char in target)

loader = DataLoader(ds, batch_size=2, collate_fn=ds.collate_fn)
images, targets = next(iter(loader))
assert isinstance(images, tf.Tensor) and images.shape == (2, *input_size, 3)
assert isinstance(targets, list) and len(targets) == 2 and all(isinstance(t, str) for t in targets)


@pytest.mark.parametrize(
"num_samples, rotate",
[
Expand Down