diff --git a/README.md b/README.md index 3d685b8..e95e62b 100644 --- a/README.md +++ b/README.md @@ -2639,7 +2639,10 @@ dataset_objects = client.get_dataset_objects( dataset="YOUR_DATASET_NAME", version="latest", # default is "latest" tags=["cat"], - licenses=["MIT"] + licenses=["fastlabel"], + types=["train", "valid"], # choices are "train", "valid", "test" and "none" (Optional) + offset=0, # default is 0 (Optional) + limit=1000, # default is 1000, and must be less than 1000 (Optional) ) ``` @@ -2663,6 +2666,9 @@ client.download_dataset_objects( version="latest", # default is "latest" tags=["cat"], types=["train", "valid"], # choices are "train", "valid", "test" and "none" (Optional) + licenses=["fastlabel"], + offset=0, # default is 0 (Optional) + limit=1000, # default is 1000, and must be less than 1000 (Optional) ) ``` diff --git a/fastlabel/__init__.py b/fastlabel/__init__.py index ee4fee5..1970c28 100644 --- a/fastlabel/__init__.py +++ b/fastlabel/__init__.py @@ -28,6 +28,7 @@ from .api import Api from .exceptions import FastLabelInvalidException +from .query import DatasetObjectGetQuery logger = logging.getLogger(__name__) logging.basicConfig( @@ -3961,6 +3962,7 @@ def get_dataset_objects( tags: Optional[List[str]] = None, licenses: Optional[List[str]] = None, revision_id: str = None, + types: Optional[List[Union[str, DatasetObjectType]]] = None, offset: int = 0, limit: int = 1000, ) -> list: @@ -3973,6 +3975,31 @@ def get_dataset_objects( revision_id is dataset rebision (Optional). Only use specify one of revision_id or version. """ + endpoint = "dataset-objects-v2" + types = [DatasetObjectType.create(type_) for type_ in types or []] + params = self._prepare_params( + dataset=dataset, + version=version, + tags=tags, + licenses=licenses, + revision_id=revision_id, + types=types, + offset=offset, + limit=limit, + ) + return self.api.get_request(endpoint, params=params) + + def _prepare_params( + self, + dataset: str, + offset: int, + limit: int, + version: str, + revision_id: str, + tags: Optional[List[str]], + licenses: Optional[List[str]], + types: Optional[List[DatasetObjectType]], + ) -> DatasetObjectGetQuery: if version and revision_id: raise FastLabelInvalidException( "only use specify one of revisionId or version.", 400 @@ -3981,56 +4008,47 @@ def get_dataset_objects( raise FastLabelInvalidException( "Limit must be less than or equal to 1000.", 422 ) - endpoint = "dataset-objects-v2" - params = {"dataset": dataset, "offset": offset, "limit": limit} + params: DatasetObjectGetQuery = { + "dataset": dataset, + "offset": offset, + "limit": limit, + } if revision_id: params["revisionId"] = revision_id if version: params["version"] = version - - tags = tags or [] if tags: params["tags"] = tags if licenses: params["licenses"] = licenses - return self.api.get_request(endpoint, params=params) + if types: + params["types"] = [t.value for t in types] + return params def download_dataset_objects( self, dataset: str, path: str, version: str = "", + revision_id: str = "", tags: Optional[List[str]] = None, + licenses: Optional[List[str]] = None, types: Optional[List[Union[str, DatasetObjectType]]] = None, offset: int = 0, limit: int = 1000, ): endpoint = "dataset-objects-v2/signed-urls" - if limit > 1000: - raise FastLabelInvalidException( - "Limit must be less than or equal to 1000.", 422 - ) - params = {"dataset": dataset, "offset": offset, "limit": limit} - if version: - params["version"] = version - if tags: - params["tags"] = tags - if types: - try: - types = list( - map( - lambda t: t - if isinstance(t, DatasetObjectType) - else DatasetObjectType(t), - types, - ) - ) - except ValueError: - raise FastLabelInvalidException( - f"types must be {[k for k in DatasetObjectType.__members__.keys()]}.", - 422, - ) - params["types"] = [t.value for t in types] + types = [DatasetObjectType.create(type_) for type_ in types or []] + params = self._prepare_params( + dataset=dataset, + offset=offset, + limit=limit, + version=version, + revision_id=revision_id, + tags=tags, + types=types, + licenses=licenses, + ) response = self.api.get_request(endpoint, params=params) diff --git a/fastlabel/const.py b/fastlabel/const.py index f505bbc..ccaad9f 100644 --- a/fastlabel/const.py +++ b/fastlabel/const.py @@ -254,3 +254,15 @@ class DatasetObjectType(Enum): train = "train" valid = "valid" test = "test" + + @classmethod + def create(cls, value: "str | DatasetObjectType") -> "DatasetObjectType": + if isinstance(value, cls): + return value + try: + return cls(value) + except ValueError: + raise ValueError( + f"Invalid DatasetObjectType: {value}. " + f"types must be {[k for k in DatasetObjectType.__members__.keys()]}" + ) diff --git a/fastlabel/query.py b/fastlabel/query.py new file mode 100644 index 0000000..98992e9 --- /dev/null +++ b/fastlabel/query.py @@ -0,0 +1,12 @@ +from typing import List, Optional, TypedDict + + +class DatasetObjectGetQuery(TypedDict, total=False): + dataset: str + version: str + revisionId: str + tags: Optional[List[str]] + licenses: Optional[List[str]] + types: Optional[List[str]] + offset: int + limit: int