Skip to content

Commit

Permalink
Add config options for label preprocessing (Refs #5)
Browse files Browse the repository at this point in the history
- Expose normalize_unicode parameter of LmdbDataset
- Add remove_whitespace flag for disabling whitespace removal in labels
  • Loading branch information
baudm committed Jul 28, 2022
1 parent 98959c9 commit e8ea463
Show file tree
Hide file tree
Showing 4 changed files with 17 additions and 7 deletions.
6 changes: 4 additions & 2 deletions configs/main.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,15 @@ data:
_target_: strhub.data.module.SceneTextDataModule
root_dir: data
train_dir: ???
batch_size: ${model.batch_size}
img_size: ${model.img_size}
charset_train: ${model.charset_train}
charset_test: ${model.charset_test}
max_label_length: ${model.max_label_length}
batch_size: ${model.batch_size}
num_workers: 2
remove_whitespace: true
normalize_unicode: true
augment: true
num_workers: 2

trainer:
_target_: pytorch_lightning.Trainer
Expand Down
9 changes: 6 additions & 3 deletions strhub/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,12 +55,14 @@ class LmdbDataset(Dataset):
"""

def __init__(self, root: str, charset: str, max_label_len: int, min_image_dim: int = 0,
normalize_unicode: bool = True, unlabelled: bool = False, transform: Optional[Callable] = None,
remove_whitespace: bool = True, normalize_unicode: bool = True,
unlabelled: bool = False, transform: Optional[Callable] = None,
num_workers: int = 1):
self.env = lmdb.open(root, max_readers=num_workers, max_spare_txns=num_workers,
readonly=True, create=False, readahead=False, meminit=False, lock=False)
self.max_label_len = max_label_len
self.min_image_dim = min_image_dim
self.remove_whitespace = remove_whitespace
self.normalize_unicode = normalize_unicode
self.unlabelled = unlabelled
self.transform = transform
Expand All @@ -81,8 +83,9 @@ def _preprocess_labels(self, charset):
index += 1 # lmdb starts with 1
label_key = f'label-{index:09d}'.encode()
label = txn.get(label_key).decode()
# There shouldn't be any whitespace in the labels but try to remove them for good measure
label = ''.join(label.split())
# Normally, whitespace is removed from the labels.
if self.remove_whitespace:
label = ''.join(label.split())
# Normalize unicode composites (if any) and convert to compatible ASCII characters
if self.normalize_unicode:
label = unicodedata.normalize('NFKD', label).encode('ascii', 'ignore').decode()
Expand Down
7 changes: 6 additions & 1 deletion strhub/data/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ class SceneTextDataModule(pl.LightningDataModule):

def __init__(self, root_dir: str, train_dir: str, img_size: Sequence[int], max_label_length: int,
charset_train: str, charset_test: str, batch_size: int, num_workers: int, augment: bool,
remove_whitespace: bool = True, normalize_unicode: bool = True,
min_image_dim: int = 0, rotation: int = 0, collate_fn: Optional[Callable] = None):
super().__init__()
self.root_dir = root_dir
Expand All @@ -42,6 +43,8 @@ def __init__(self, root_dir: str, train_dir: str, img_size: Sequence[int], max_l
self.batch_size = batch_size
self.num_workers = num_workers
self.augment = augment
self.remove_whitespace = remove_whitespace
self.normalize_unicode = normalize_unicode
self.min_image_dim = min_image_dim
self.rotation = rotation
self.collate_fn = collate_fn
Expand Down Expand Up @@ -69,7 +72,7 @@ def train_dataset(self):
transform = self.get_transform(self.img_size, self.augment)
root = PurePath(self.root_dir, 'train', self.train_dir)
self._train_dataset = build_tree_dataset(root, self.charset_train, self.max_label_length,
self.min_image_dim,
self.min_image_dim, self.remove_whitespace, self.normalize_unicode,
transform=transform, num_workers=self.num_workers)
return self._train_dataset

Expand All @@ -79,6 +82,7 @@ def val_dataset(self):
transform = self.get_transform(self.img_size)
root = PurePath(self.root_dir, 'val')
self._val_dataset = build_tree_dataset(root, self.charset_test, self.max_label_length,
self.min_image_dim, self.remove_whitespace, self.normalize_unicode,
transform=transform, num_workers=self.num_workers)
return self._val_dataset

Expand All @@ -96,6 +100,7 @@ def test_dataloaders(self, subset):
transform = self.get_transform(self.img_size, rotation=self.rotation)
root = PurePath(self.root_dir, 'test')
datasets = {s: LmdbDataset(str(root.joinpath(s)), self.charset_test, self.max_label_length,
self.min_image_dim, self.remove_whitespace, self.normalize_unicode,
transform=transform) for s in subset}
return {k: DataLoader(v, batch_size=self.batch_size, num_workers=self.num_workers,
pin_memory=True, collate_fn=self.collate_fn)
Expand Down
2 changes: 1 addition & 1 deletion test.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def main():
model.freeze() # disable autograd
hp = model.hparams
datamodule = SceneTextDataModule('data', '_unused_', hp.img_size, hp.max_label_length, hp.charset_train,
hp.charset_test, args.batch_size, args.num_workers, False, args.rotation)
hp.charset_test, args.batch_size, args.num_workers, False, rotation=args.rotation)

test_set = SceneTextDataModule.TEST_BENCHMARK_SUB + SceneTextDataModule.TEST_BENCHMARK
if args.new:
Expand Down

0 comments on commit e8ea463

Please sign in to comment.