Skip to content

Commit

Permalink
feat: add default preprocess fn for ssl (#331)
Browse files Browse the repository at this point in the history
  • Loading branch information
bwanglzu committed Jan 11, 2022
1 parent b86c0e6 commit bc25c37
Show file tree
Hide file tree
Showing 4 changed files with 91 additions and 0 deletions.
1 change: 1 addition & 0 deletions .github/requirements-cicd.txt
Expand Up @@ -6,3 +6,4 @@ torchvision==0.10.0
scipy==1.7.1
transformers==4.12.3
wandb==0.12.7
albumentations==1.1.0
45 changes: 45 additions & 0 deletions finetuner/tuner/augmentation.py
@@ -0,0 +1,45 @@
from docarray import Document


def vision_preprocessor(
doc: Document,
height: int = 224,
width: int = 224,
channel_axis: int = -1,
):
"""Randomly augmentation a Document with `blob` field.
The method applies flipping, color jitter, cropping, gaussian blur and random rectangle erase
to the given image.
:param doc: The document to preprocess.
:param height: image height.
:param width: image width.
:param channel_axis: The color channel of the image, by default -1, i.e, the expected input is H, W, C.
.. note::
This method will set `channel_axis` to -1 as the last dimension of the image blob. If you're using tensorflow backend,
needs to call `doc.set_image_blob_channel_axis(-1, 0)` to revert the channel axis.
"""
import albumentations as A

if doc.content is None:
if doc.uri:
doc.load_uri_to_image_blob(
width=width, height=height, channel_axis=channel_axis
)
else:
raise AttributeError('Can not load `blob` field from the given document.')
if channel_axis not in [-1, 2]:
doc.set_image_blob_channel_axis(channel_axis, -1)
# p is the probability to apply the transform.
transform = A.Compose(
[
A.HorizontalFlip(p=0.5),
A.ColorJitter(p=1, brightness=0, contrast=0, saturation=0, hue=0),
A.RandomResizedCrop(width=width, height=height, p=1),
A.GaussianBlur(p=1),
A.GridDropout(p=0.5),
]
)
doc.blob = transform(image=doc.blob)['image']
return doc.blob
Binary file added tests/unit/tuner/resources/lena.png
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
45 changes: 45 additions & 0 deletions tests/unit/tuner/test_augmentation.py
@@ -0,0 +1,45 @@
import os

import numpy as np
import pytest
from docarray import Document

from finetuner.tuner.augmentation import vision_preprocessor

cur_dir = os.path.dirname(os.path.abspath(__file__))


@pytest.mark.parametrize(
'doc, height, width, num_channels, channel_axis',
[
(Document(blob=np.random.rand(224, 224, 3)), 224, 224, 3, -1),
(Document(blob=np.random.rand(256, 256, 3)), 256, 256, 3, -1),
(Document(blob=np.random.rand(256, 256, 1)), 256, 256, 1, -1), # grayscale
(
Document(blob=np.random.rand(3, 224, 224)),
224,
224,
3,
0,
), # channel axis at 0th position
(
Document(uri=os.path.join(cur_dir, 'resources/lena.png')),
512,
512,
3,
1,
), # load from uri
],
)
def test_vision_preprocessor(doc, height, width, num_channels, channel_axis):
original_blob = doc.blob
augmented_content = vision_preprocessor(doc, height, width, channel_axis)
assert augmented_content is not None
assert augmented_content.shape == (height, width, num_channels)
assert not np.array_equal(original_blob, augmented_content)


def test_vision_preprocessor_fail_given_no_blob_and_uri():
doc = Document()
with pytest.raises(AttributeError):
vision_preprocessor(doc)

0 comments on commit bc25c37

Please sign in to comment.