Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: add default preprocess fn for ssl (#331)
- Loading branch information
Showing
4 changed files
with
91 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -6,3 +6,4 @@ torchvision==0.10.0 | |
scipy==1.7.1 | ||
transformers==4.12.3 | ||
wandb==0.12.7 | ||
albumentations==1.1.0 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
from docarray import Document | ||
|
||
|
||
def vision_preprocessor( | ||
doc: Document, | ||
height: int = 224, | ||
width: int = 224, | ||
channel_axis: int = -1, | ||
): | ||
"""Randomly augmentation a Document with `blob` field. | ||
The method applies flipping, color jitter, cropping, gaussian blur and random rectangle erase | ||
to the given image. | ||
:param doc: The document to preprocess. | ||
:param height: image height. | ||
:param width: image width. | ||
:param channel_axis: The color channel of the image, by default -1, i.e, the expected input is H, W, C. | ||
.. note:: | ||
This method will set `channel_axis` to -1 as the last dimension of the image blob. If you're using tensorflow backend, | ||
needs to call `doc.set_image_blob_channel_axis(-1, 0)` to revert the channel axis. | ||
""" | ||
import albumentations as A | ||
|
||
if doc.content is None: | ||
if doc.uri: | ||
doc.load_uri_to_image_blob( | ||
width=width, height=height, channel_axis=channel_axis | ||
) | ||
else: | ||
raise AttributeError('Can not load `blob` field from the given document.') | ||
if channel_axis not in [-1, 2]: | ||
doc.set_image_blob_channel_axis(channel_axis, -1) | ||
# p is the probability to apply the transform. | ||
transform = A.Compose( | ||
[ | ||
A.HorizontalFlip(p=0.5), | ||
A.ColorJitter(p=1, brightness=0, contrast=0, saturation=0, hue=0), | ||
A.RandomResizedCrop(width=width, height=height, p=1), | ||
A.GaussianBlur(p=1), | ||
A.GridDropout(p=0.5), | ||
] | ||
) | ||
doc.blob = transform(image=doc.blob)['image'] | ||
return doc.blob |
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
import os | ||
|
||
import numpy as np | ||
import pytest | ||
from docarray import Document | ||
|
||
from finetuner.tuner.augmentation import vision_preprocessor | ||
|
||
cur_dir = os.path.dirname(os.path.abspath(__file__)) | ||
|
||
|
||
@pytest.mark.parametrize( | ||
'doc, height, width, num_channels, channel_axis', | ||
[ | ||
(Document(blob=np.random.rand(224, 224, 3)), 224, 224, 3, -1), | ||
(Document(blob=np.random.rand(256, 256, 3)), 256, 256, 3, -1), | ||
(Document(blob=np.random.rand(256, 256, 1)), 256, 256, 1, -1), # grayscale | ||
( | ||
Document(blob=np.random.rand(3, 224, 224)), | ||
224, | ||
224, | ||
3, | ||
0, | ||
), # channel axis at 0th position | ||
( | ||
Document(uri=os.path.join(cur_dir, 'resources/lena.png')), | ||
512, | ||
512, | ||
3, | ||
1, | ||
), # load from uri | ||
], | ||
) | ||
def test_vision_preprocessor(doc, height, width, num_channels, channel_axis): | ||
original_blob = doc.blob | ||
augmented_content = vision_preprocessor(doc, height, width, channel_axis) | ||
assert augmented_content is not None | ||
assert augmented_content.shape == (height, width, num_channels) | ||
assert not np.array_equal(original_blob, augmented_content) | ||
|
||
|
||
def test_vision_preprocessor_fail_given_no_blob_and_uri(): | ||
doc = Document() | ||
with pytest.raises(AttributeError): | ||
vision_preprocessor(doc) |