-
Notifications
You must be signed in to change notification settings - Fork 251
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Add api workflow download client * Add download option to CLI * Add download tests and fix tox * Add new download to cli instructions * Minor change+ * Fix tests
- Loading branch information
1 parent
50e543f
commit e725894
Showing
9 changed files
with
199 additions
and
14 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,3 +2,5 @@ | |
|
||
# Copyright (c) 2020. Lightly AG and its affiliates. | ||
# All Rights Reserved | ||
|
||
from lightly.api import routes |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,111 @@ | ||
import warnings | ||
from concurrent.futures.thread import ThreadPoolExecutor | ||
from typing import Union | ||
import io | ||
import os | ||
import tqdm | ||
from urllib.request import Request, urlopen | ||
from PIL import Image | ||
|
||
from lightly.openapi_generated.swagger_client import TagCreator | ||
from lightly.openapi_generated.swagger_client.models.sample_create_request import SampleCreateRequest | ||
from lightly.api.utils import check_filename, check_image, get_thumbnail_from_img, PIL_to_bytes | ||
from lightly.api.bitmask import BitMask | ||
from lightly.openapi_generated.swagger_client.models.initial_tag_create_request import InitialTagCreateRequest | ||
from lightly.openapi_generated.swagger_client.models.image_type import ImageType | ||
from lightly.data.dataset import LightlyDataset | ||
|
||
|
||
|
||
def _make_dir_and_save_image(output_dir: str, filename: str, img: Image): | ||
"""Saves the images and creates necessary subdirectories. | ||
""" | ||
path = os.path.join(output_dir, filename) | ||
|
||
head = os.path.split(path)[0] | ||
if not os.path.exists(head): | ||
os.makedirs(head) | ||
|
||
img.save(path) | ||
img.close() | ||
|
||
|
||
def _get_image_from_read_url(read_url: str): | ||
"""Makes a get request to the signed read url and returns the image. | ||
""" | ||
request = Request(read_url, method='GET') | ||
with urlopen(request) as response: | ||
blob = response.read() | ||
img = Image.open(io.BytesIO(blob)) | ||
return img | ||
|
||
|
||
class _DownloadDatasetMixin: | ||
|
||
def download_dataset(self, | ||
output_dir: str, | ||
tag_name: str = 'initial-tag', | ||
verbose: bool = True): | ||
"""Downloads images from the web-app and stores them in output_dir. | ||
Args: | ||
output_dir: | ||
Where to store the downloaded images. | ||
tag_name: | ||
Name of the tag which should be downloaded. | ||
verbose: | ||
Whether or not to show the progress bar. | ||
Raises: | ||
ValueError if the specified tag does not exist on the dataset. | ||
RuntimeError if the connection to the server failed. | ||
""" | ||
|
||
# check if images are available | ||
dataset = self.datasets_api.get_dataset_by_id(self.dataset_id) | ||
if dataset.img_type != ImageType.FULL: | ||
# only thumbnails or metadata available | ||
raise ValueError( | ||
f"Dataset with id {self.dataset_id} has no downloadable images!" | ||
) | ||
|
||
# check if tag exists | ||
available_tags = self._get_all_tags() | ||
try: | ||
print(available_tags) | ||
tag = next(tag for tag in available_tags if tag.name == tag_name) | ||
except StopIteration: | ||
raise ValueError( | ||
f"Dataset with id {self.dataset_id} has no tag {tag_name}!" | ||
) | ||
|
||
# get sample ids | ||
sample_ids = self.mappings_api.get_sample_mappings_by_dataset_id( | ||
self.dataset_id, | ||
field='_id' | ||
) | ||
|
||
indices = BitMask.from_hex(tag.bit_mask_data).to_indices() | ||
sample_ids = [sample_ids[i] for i in indices] | ||
filenames = [self.filenames_on_server[i] for i in indices] | ||
|
||
if verbose: | ||
print(f'Downloading {len(sample_ids)} images:', flush=True) | ||
pbar = tqdm.tqdm(unit='imgs', total=len(sample_ids)) | ||
|
||
# download images | ||
for sample_id, filename in zip(sample_ids, filenames): | ||
read_url = self.samples_api.get_sample_image_read_url_by_id( | ||
self.dataset_id, | ||
sample_id, | ||
type="full", | ||
) | ||
|
||
img = _get_image_from_read_url(read_url) | ||
_make_dir_and_save_image(output_dir, filename, img) | ||
|
||
if verbose: | ||
pbar.update(1) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
import os | ||
import shutil | ||
|
||
from unittest.mock import patch | ||
|
||
import PIL | ||
import numpy as np | ||
|
||
import torchvision | ||
|
||
import lightly | ||
from lightly.data.dataset import LightlyDataset | ||
|
||
from tests.api_workflow.mocked_api_workflow_client import MockedApiWorkflowSetup | ||
from lightly.openapi_generated.swagger_client.models.dataset_data import DatasetData | ||
|
||
|
||
|
||
class TestApiWorkflowDownloadDataset(MockedApiWorkflowSetup): | ||
def setUp(self) -> None: | ||
MockedApiWorkflowSetup.setUp(self, dataset_id='dataset_0_id') | ||
self.api_workflow_client.tags_api.no_tags = 3 | ||
|
||
def test_download_non_existing_tag(self): | ||
with self.assertRaises(ValueError): | ||
self.api_workflow_client.download_dataset('path/to/dir', tag_name='this_is_not_a_real_tag_name') | ||
|
||
def test_download_thumbnails(self): | ||
def get_thumbnail_dataset_by_id(*args): | ||
return DatasetData(name=f'dataset', id='dataset_id', last_modified_at=0, | ||
type='thumbnails', size_in_bytes=-1, n_samples=-1, created_at=-1) | ||
self.api_workflow_client.datasets_api.get_dataset_by_id = get_thumbnail_dataset_by_id | ||
with self.assertRaises(ValueError): | ||
self.api_workflow_client.download_dataset('path/to/dir') | ||
|
||
def test_download_dataset(self): | ||
def my_func(read_url): | ||
return PIL.Image.fromarray(np.zeros((32, 32))).convert('RGB') | ||
#mock_get_image_from_readurl.return_value = PIL.Image.fromarray(np.zeros((32, 32))) | ||
lightly.api.api_workflow_download_dataset._get_image_from_read_url = my_func | ||
self.api_workflow_client.download_dataset('path-to-dir-remove-me', tag_name='initial-tag') | ||
shutil.rmtree('path-to-dir-remove-me') | ||
|
||
|
||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters