Skip to content

Commit

Permalink
6 full dataset download pw (#233)
Browse files Browse the repository at this point in the history
* Add api workflow download client

* Add download option to CLI

* Add download tests and fix tox

* Add new download to cli instructions

* Minor change+

* Fix tests
  • Loading branch information
philippmwirth committed Mar 17, 2021
1 parent 50e543f commit e725894
Show file tree
Hide file tree
Showing 9 changed files with 199 additions and 14 deletions.
21 changes: 16 additions & 5 deletions docs/source/getting_started/command_line_tool.rst
Original file line number Diff line number Diff line change
Expand Up @@ -119,16 +119,27 @@ You can upload embeddings directly to the Lightly Platform using the CLI.
Download data using the CLI
-----------------------------------------------
You can download a dataset with a given tag from the Lightly Platform using the
following CLI command. The CLI provides you with two options. Either you
download just a list or copy the files from the original dataset into a new
folder. The second option is very handy for quick prototyping.
following CLI command. The CLI provides you with three options:

* Download the list of filenames for a given tag in the dataset.

* Download the images for a given tag in the dataset.

* Copy the images for a given tag from an input directory to a target directory.

The last option allows you to very quickly extract only the images in a given tag
without the need to download them explicitly.

.. code-block:: bash
# download a list of files
lightly-download tag_name=my_tag_name dataset_id=your_dataset_id token=your_token
# copy files in a tag to a new folder
# download the images and store them in an output directory
lightly-download tag_name=my_tag_name dataset_id=your_dataset_id token=your_token \
output_dir=path/to/output/dir
# copy images from an input directory to an output directory
lightly-download tag_name=my_tag_name dataset_id=your_dataset_id token=your_token \
input_dir=cat output_dir=cat_curated
input_dir=path/to/input/dir output_dir=path/to/output/dir
2 changes: 2 additions & 0 deletions lightly/api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,5 @@

# Copyright (c) 2020. Lightly AG and its affiliates.
# All Rights Reserved

from lightly.api import routes
3 changes: 2 additions & 1 deletion lightly/api/api_workflow_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

from lightly.api.api_workflow_upload_dataset import _UploadDatasetMixin
from lightly.api.api_workflow_upload_embeddings import _UploadEmbeddingsMixin
from lightly.api.api_workflow_download_dataset import _DownloadDatasetMixin
from lightly.api.api_workflow_sampling import _SamplingMixin
from lightly.openapi_generated.swagger_client import TagData, ScoresApi, QuotaApi
from lightly.openapi_generated.swagger_client.api.embeddings_api import EmbeddingsApi
Expand All @@ -27,7 +28,7 @@
from lightly.openapi_generated.swagger_client.configuration import Configuration


class ApiWorkflowClient(_UploadEmbeddingsMixin, _SamplingMixin, _UploadDatasetMixin, _DatasetsMixin):
class ApiWorkflowClient(_UploadEmbeddingsMixin, _SamplingMixin, _UploadDatasetMixin, _DownloadDatasetMixin, _DatasetsMixin):
"""Provides a uniform interface to communicate with the api
The APIWorkflowClient is used to communicaate with the Lightly API. The client
Expand Down
111 changes: 111 additions & 0 deletions lightly/api/api_workflow_download_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
import warnings
from concurrent.futures.thread import ThreadPoolExecutor
from typing import Union
import io
import os
import tqdm
from urllib.request import Request, urlopen
from PIL import Image

from lightly.openapi_generated.swagger_client import TagCreator
from lightly.openapi_generated.swagger_client.models.sample_create_request import SampleCreateRequest
from lightly.api.utils import check_filename, check_image, get_thumbnail_from_img, PIL_to_bytes
from lightly.api.bitmask import BitMask
from lightly.openapi_generated.swagger_client.models.initial_tag_create_request import InitialTagCreateRequest
from lightly.openapi_generated.swagger_client.models.image_type import ImageType
from lightly.data.dataset import LightlyDataset



def _make_dir_and_save_image(output_dir: str, filename: str, img: Image):
"""Saves the images and creates necessary subdirectories.
"""
path = os.path.join(output_dir, filename)

head = os.path.split(path)[0]
if not os.path.exists(head):
os.makedirs(head)

img.save(path)
img.close()


def _get_image_from_read_url(read_url: str):
"""Makes a get request to the signed read url and returns the image.
"""
request = Request(read_url, method='GET')
with urlopen(request) as response:
blob = response.read()
img = Image.open(io.BytesIO(blob))
return img


class _DownloadDatasetMixin:

def download_dataset(self,
output_dir: str,
tag_name: str = 'initial-tag',
verbose: bool = True):
"""Downloads images from the web-app and stores them in output_dir.
Args:
output_dir:
Where to store the downloaded images.
tag_name:
Name of the tag which should be downloaded.
verbose:
Whether or not to show the progress bar.
Raises:
ValueError if the specified tag does not exist on the dataset.
RuntimeError if the connection to the server failed.
"""

# check if images are available
dataset = self.datasets_api.get_dataset_by_id(self.dataset_id)
if dataset.img_type != ImageType.FULL:
# only thumbnails or metadata available
raise ValueError(
f"Dataset with id {self.dataset_id} has no downloadable images!"
)

# check if tag exists
available_tags = self._get_all_tags()
try:
print(available_tags)
tag = next(tag for tag in available_tags if tag.name == tag_name)
except StopIteration:
raise ValueError(
f"Dataset with id {self.dataset_id} has no tag {tag_name}!"
)

# get sample ids
sample_ids = self.mappings_api.get_sample_mappings_by_dataset_id(
self.dataset_id,
field='_id'
)

indices = BitMask.from_hex(tag.bit_mask_data).to_indices()
sample_ids = [sample_ids[i] for i in indices]
filenames = [self.filenames_on_server[i] for i in indices]

if verbose:
print(f'Downloading {len(sample_ids)} images:', flush=True)
pbar = tqdm.tqdm(unit='imgs', total=len(sample_ids))

# download images
for sample_id, filename in zip(sample_ids, filenames):
read_url = self.samples_api.get_sample_image_read_url_by_id(
self.dataset_id,
sample_id,
type="full",
)

img = _get_image_from_read_url(read_url)
_make_dir_and_save_image(output_dir, filename, img)

if verbose:
pbar.update(1)
12 changes: 9 additions & 3 deletions lightly/cli/download_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,12 @@ def _download_cli(cfg, is_cli_call=True):
msg += os.path.join(os.getcwd(), cfg['tag_name'] + '.txt')
print(msg, flush=True)

if cfg['input_dir'] and cfg['output_dir']:
if not cfg['input_dir'] and cfg['output_dir']:
# download full images from api
output_dir = fix_input_path(cfg['output_dir'])
api_workflow_client.download_dataset(output_dir, tag_name=tag_name)

elif cfg['input_dir'] and cfg['output_dir']:
input_dir = fix_input_path(cfg['input_dir'])
output_dir = fix_input_path(cfg['output_dir'])
print(f'Copying files from {input_dir} to {output_dir}.')
Expand Down Expand Up @@ -118,9 +122,11 @@ def download_cli(cfg):
>>> # download list of all files in tag 'my-tag' from the Lightly platform
>>> lightly-download token='123' dataset_id='XYZ' tag_name='my-tag'
>>>
>>> # download all images in tag 'my-tag' from the Lightly platform
>>> lightly-download token='123' dataset_id='XYZ' tag_name='my-tag' output_dir='my_data/'
>>>
>>> # copy all files in 'my-tag' to a new directory
>>> lightly-download token='123' dataset_id='XYZ' tag_name='my-tag' \\
>>> input_dir=data/ output_dir=new_data/
>>> lightly-download token='123' dataset_id='XYZ' tag_name='my-tag' input_dir='data/' output_dir='my_data/'
"""
Expand Down
13 changes: 10 additions & 3 deletions tests/api_workflow/mocked_api_workflow_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,12 +133,16 @@ def get_sample_image_write_url_by_id(self, dataset_id, sample_id, is_thumbnail,
url = f"{sample_id}_write_url"
return url

def get_sample_image_read_url_by_id(self, dataset_id, sample_id, type, **kwargs):
url = f"{sample_id}_write_url"
return url


class MockedDatasetsApi(DatasetsApi):
def __init__(self, api_client):
no_datasets = 3
self.default_datasets = [DatasetData(name=f"dataset_{i}", id=f"dataset_{i}_id", last_modified_at=i,
type="", size_in_bytes=-1, n_samples=-1, created_at=-1)
type="", img_type="full", size_in_bytes=-1, n_samples=-1, created_at=-1)
for i in range(no_datasets)]
self.reset()

Expand All @@ -157,6 +161,9 @@ def create_dataset(self, body: DatasetCreateRequest, **kwargs):
response_ = CreateEntityResponse(id=id)
return response_

def get_dataset_by_id(self, dataset_id):
return next(dataset for dataset in self.default_datasets if dataset_id == dataset.id)

def delete_dataset_by_id(self, dataset_id, **kwargs):
datasets_without_that_id = [dataset for dataset in self.datasets if dataset.id != dataset_id]
assert len(datasets_without_that_id) == len(self.datasets) - 1
Expand Down Expand Up @@ -222,5 +229,5 @@ def __init__(self, *args, **kwargs):


class MockedApiWorkflowSetup(unittest.TestCase):
def setUp(self) -> None:
self.api_workflow_client = MockedApiWorkflowClient(token="token_xyz", dataset_id="dataset_id_xyz")
def setUp(self, token="token_xyz", dataset_id="dataset_id_xyz") -> None:
self.api_workflow_client = MockedApiWorkflowClient(token=token, dataset_id=dataset_id)
3 changes: 2 additions & 1 deletion tests/api_workflow/test_api_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,14 @@
class TestApiWorkflow(MockedApiWorkflowSetup):

def setUp(self) -> None:
lightly.api.api_workflow_client.__version__ = "1.1.1"
lightly.api.api_workflow_client.__version__ = lightly.__version__
self.api_workflow_client = MockedApiWorkflowClient(token="token_xyz")

def test_error_if_version_is_incompatible(self):
lightly.api.api_workflow_client.__version__ = "0.0.0"
with self.assertRaises(ValueError):
MockedApiWorkflowClient(token="token_xyz")
lightly.api.api_workflow_client.__version__ = lightly.__version__

def test_dataset_id_nonexisting(self):
self.api_workflow_client.datasets_api.reset()
Expand Down
46 changes: 46 additions & 0 deletions tests/api_workflow/test_api_workflow_download_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
import os
import shutil

from unittest.mock import patch

import PIL
import numpy as np

import torchvision

import lightly
from lightly.data.dataset import LightlyDataset

from tests.api_workflow.mocked_api_workflow_client import MockedApiWorkflowSetup
from lightly.openapi_generated.swagger_client.models.dataset_data import DatasetData



class TestApiWorkflowDownloadDataset(MockedApiWorkflowSetup):
def setUp(self) -> None:
MockedApiWorkflowSetup.setUp(self, dataset_id='dataset_0_id')
self.api_workflow_client.tags_api.no_tags = 3

def test_download_non_existing_tag(self):
with self.assertRaises(ValueError):
self.api_workflow_client.download_dataset('path/to/dir', tag_name='this_is_not_a_real_tag_name')

def test_download_thumbnails(self):
def get_thumbnail_dataset_by_id(*args):
return DatasetData(name=f'dataset', id='dataset_id', last_modified_at=0,
type='thumbnails', size_in_bytes=-1, n_samples=-1, created_at=-1)
self.api_workflow_client.datasets_api.get_dataset_by_id = get_thumbnail_dataset_by_id
with self.assertRaises(ValueError):
self.api_workflow_client.download_dataset('path/to/dir')

def test_download_dataset(self):
def my_func(read_url):
return PIL.Image.fromarray(np.zeros((32, 32))).convert('RGB')
#mock_get_image_from_readurl.return_value = PIL.Image.fromarray(np.zeros((32, 32)))
lightly.api.api_workflow_download_dataset._get_image_from_read_url = my_func
self.api_workflow_client.download_dataset('path-to-dir-remove-me', tag_name='initial-tag')
shutil.rmtree('path-to-dir-remove-me')




2 changes: 1 addition & 1 deletion tox.ini
Original file line number Diff line number Diff line change
Expand Up @@ -107,4 +107,4 @@ commands =
pip install torch==1.7.0+cu101 torchvision==0.8.1+cu101 -f https://download.pytorch.org/whl/torch_stable.html
pip install .[all]
echo "Running video test"
make test
make test

0 comments on commit e725894

Please sign in to comment.