Skip to content

Commit

Permalink
Fix HfApi.create_repo when repo_type is 'space' (#394)
Browse files Browse the repository at this point in the history
* 🐛 fix create spaces

* 🐛 fix create spaces bug

* ✨ add spaces sdk types as constant

* 🧪 add a commit that unskips spaces test

* ✅ skip spaces tests while in beta

* 🚧 update cli

* 🐛 fix missed merge confict

* 💄 style

* 🐛 add missing import

* 🎨 only hardcode gated param for spaces repos

* ✅ Update spaces tests to use prod

* ✅ use staging endpoint for spaces tests

* 🔥 remove hack
  • Loading branch information
nateraw committed Oct 29, 2021
1 parent 0466022 commit b514069
Show file tree
Hide file tree
Showing 4 changed files with 86 additions and 24 deletions.
13 changes: 12 additions & 1 deletion src/huggingface_hub/commands/user.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,11 @@
from typing import List, Union

from huggingface_hub.commands import BaseHuggingfaceCLICommand
from huggingface_hub.constants import REPO_TYPES, REPO_TYPES_URL_PREFIXES
from huggingface_hub.constants import (
REPO_TYPES,
REPO_TYPES_URL_PREFIXES,
SPACES_SDK_TYPES,
)
from huggingface_hub.hf_api import HfApi, HfFolder
from requests.exceptions import HTTPError

Expand Down Expand Up @@ -68,6 +72,12 @@ def register_subcommand(parser: ArgumentParser):
repo_create_parser.add_argument(
"--organization", type=str, help="Optional: organization namespace."
)
repo_create_parser.add_argument(
"--space_sdk",
type=str,
help='Optional: Hugging Face Spaces SDK type. Required when --type is set to "space".',
choices=SPACES_SDK_TYPES,
)
repo_create_parser.add_argument(
"-y",
"--yes",
Expand Down Expand Up @@ -263,6 +273,7 @@ def run(self):
token=token,
organization=self.args.organization,
repo_type=self.args.type,
space_sdk=self.args.space_sdk,
)
except HTTPError as e:
print(e)
Expand Down
1 change: 1 addition & 0 deletions src/huggingface_hub/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
REPO_TYPE_DATASET = "dataset"
REPO_TYPE_SPACE = "space"
REPO_TYPES = [None, REPO_TYPE_DATASET, REPO_TYPE_SPACE]
SPACES_SDK_TYPES = ["gradio", "streamlit", "static"]

REPO_TYPES_URL_PREFIXES = {
REPO_TYPE_DATASET: "datasets/",
Expand Down
28 changes: 27 additions & 1 deletion src/huggingface_hub/hf_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,13 @@
import requests
from requests.exceptions import HTTPError

from .constants import ENDPOINT, REPO_TYPES, REPO_TYPES_MAPPING, REPO_TYPES_URL_PREFIXES
from .constants import (
ENDPOINT,
REPO_TYPES,
REPO_TYPES_MAPPING,
REPO_TYPES_URL_PREFIXES,
SPACES_SDK_TYPES,
)


if sys.version_info >= (3, 8):
Expand Down Expand Up @@ -664,6 +670,7 @@ def create_repo(
repo_type: Optional[str] = None,
exist_ok=False,
lfsmultipartthresh: Optional[int] = None,
space_sdk: Optional[str] = None,
) -> str:
"""
HuggingFace git-based system, used for models, datasets, and spaces.
Expand All @@ -679,6 +686,8 @@ def create_repo(
lfsmultipartthresh: Optional: internal param for testing purposes.
space_sdk: Choice of SDK to use if repo_type is "space". Can be "streamlit", "gradio", or "static".
Returns:
URL to the newly created repo.
"""
Expand Down Expand Up @@ -707,6 +716,22 @@ def create_repo(
json = {"name": name, "organization": organization, "private": private}
if repo_type is not None:
json["type"] = repo_type
if repo_type == "space":
if space_sdk is None:
raise ValueError(
"No space_sdk provided. `create_repo` expects space_sdk to be one of "
f"{SPACES_SDK_TYPES} when repo_type is 'space'`"
)
if space_sdk not in SPACES_SDK_TYPES:
raise ValueError(
f"Invalid space_sdk. Please choose one of {SPACES_SDK_TYPES}."
)
json["sdk"] = space_sdk
if space_sdk is not None and repo_type != "space":
warnings.warn(
"Ignoring provided space_sdk because repo_type is not 'space'."
)

if lfsmultipartthresh is not None:
json["lfsmultipartthresh"] = lfsmultipartthresh
r = requests.post(
Expand Down Expand Up @@ -821,6 +846,7 @@ def update_repo_visibility(
path_prefix += REPO_TYPES_URL_PREFIXES[repo_type]

path = "{}{}/{}/settings".format(path_prefix, namespace, name)

json = {"private": private}

r = requests.put(
Expand Down
68 changes: 46 additions & 22 deletions tests/test_hf_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,14 @@
import unittest
from io import BytesIO

import pytest

import requests
from huggingface_hub.constants import REPO_TYPE_DATASET, REPO_TYPE_SPACE
from huggingface_hub.constants import (
REPO_TYPE_DATASET,
REPO_TYPE_SPACE,
SPACES_SDK_TYPES,
)
from huggingface_hub.file_download import cached_download, hf_hub_download
from huggingface_hub.hf_api import (
DatasetInfo,
Expand Down Expand Up @@ -153,28 +159,46 @@ def test_create_update_and_delete_dataset_repo(self):
name=DATASET_REPO_NAME, token=self._token, repo_type=REPO_TYPE_DATASET
)

@unittest.skip("skipped while spaces in beta")
def test_create_update_and_delete_space_repo(self):
self._api.create_repo(
name=SPACE_REPO_NAME, token=self._token, repo_type=REPO_TYPE_SPACE
)
res = self._api.update_repo_visibility(
name=SPACE_REPO_NAME,
token=self._token,
private=True,
repo_type=REPO_TYPE_SPACE,
)
self.assertTrue(res["private"])
res = self._api.update_repo_visibility(
name=SPACE_REPO_NAME,
token=self._token,
private=False,
repo_type=REPO_TYPE_SPACE,
)
self.assertFalse(res["private"])
self._api.delete_repo(
name=SPACE_REPO_NAME, token=self._token, repo_type=REPO_TYPE_SPACE
)
with pytest.raises(ValueError, match=r"No space_sdk provided.*"):
self._api.create_repo(
token=self._token,
name=SPACE_REPO_NAME,
repo_type=REPO_TYPE_SPACE,
space_sdk=None,
)
with pytest.raises(ValueError, match=r"Invalid space_sdk.*"):
self._api.create_repo(
token=self._token,
name=SPACE_REPO_NAME,
repo_type=REPO_TYPE_SPACE,
space_sdk="asdfasdf",
)

for sdk in SPACES_SDK_TYPES:
self._api.create_repo(
name=SPACE_REPO_NAME,
token=self._token,
repo_type=REPO_TYPE_SPACE,
space_sdk=sdk,
)
res = self._api.update_repo_visibility(
name=SPACE_REPO_NAME,
token=self._token,
private=True,
repo_type=REPO_TYPE_SPACE,
)
self.assertTrue(res["private"])
res = self._api.update_repo_visibility(
name=SPACE_REPO_NAME,
token=self._token,
private=False,
repo_type=REPO_TYPE_SPACE,
)
self.assertFalse(res["private"])
self._api.delete_repo(
name=SPACE_REPO_NAME, token=self._token, repo_type=REPO_TYPE_SPACE
)


class HfApiUploadFileTest(HfApiCommonTestWithLogin):
Expand Down

0 comments on commit b514069

Please sign in to comment.