Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix HfApi.create_repo when repo_type is 'space' #394

Merged
merged 13 commits into from
Oct 29, 2021
13 changes: 12 additions & 1 deletion src/huggingface_hub/commands/user.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,11 @@
from typing import List, Union

from huggingface_hub.commands import BaseHuggingfaceCLICommand
from huggingface_hub.constants import REPO_TYPES, REPO_TYPES_URL_PREFIXES
from huggingface_hub.constants import (
REPO_TYPES,
REPO_TYPES_URL_PREFIXES,
SPACES_SDK_TYPES,
)
from huggingface_hub.hf_api import HfApi, HfFolder
from requests.exceptions import HTTPError

Expand Down Expand Up @@ -68,6 +72,12 @@ def register_subcommand(parser: ArgumentParser):
repo_create_parser.add_argument(
"--organization", type=str, help="Optional: organization namespace."
)
repo_create_parser.add_argument(
"--space_sdk",
type=str,
help='Optional: Hugging Face Spaces SDK type. Required when --type is set to "space".',
choices=SPACES_SDK_TYPES,
)
repo_create_parser.add_argument(
"-y",
"--yes",
Expand Down Expand Up @@ -263,6 +273,7 @@ def run(self):
token=token,
organization=self.args.organization,
repo_type=self.args.type,
space_sdk=self.args.space_sdk,
)
except HTTPError as e:
print(e)
Expand Down
1 change: 1 addition & 0 deletions src/huggingface_hub/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
REPO_TYPE_DATASET = "dataset"
REPO_TYPE_SPACE = "space"
REPO_TYPES = [None, REPO_TYPE_DATASET, REPO_TYPE_SPACE]
SPACES_SDK_TYPES = ["gradio", "streamlit", "static"]

REPO_TYPES_URL_PREFIXES = {
REPO_TYPE_DATASET: "datasets/",
Expand Down
34 changes: 32 additions & 2 deletions src/huggingface_hub/hf_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,13 @@
import requests
from requests.exceptions import HTTPError

from .constants import ENDPOINT, REPO_TYPES, REPO_TYPES_MAPPING, REPO_TYPES_URL_PREFIXES
from .constants import (
ENDPOINT,
REPO_TYPES,
REPO_TYPES_MAPPING,
REPO_TYPES_URL_PREFIXES,
SPACES_SDK_TYPES,
)


if sys.version_info >= (3, 8):
Expand Down Expand Up @@ -664,6 +670,7 @@ def create_repo(
repo_type: Optional[str] = None,
exist_ok=False,
lfsmultipartthresh: Optional[int] = None,
space_sdk: Optional[str] = None,
) -> str:
"""
HuggingFace git-based system, used for models, datasets, and spaces.
Expand All @@ -679,6 +686,8 @@ def create_repo(

lfsmultipartthresh: Optional: internal param for testing purposes.

space_sdk: Choice of SDK to use if repo_type is "space". Can be "streamlit", "gradio", or "static".

Returns:
URL to the newly created repo.
"""
Expand Down Expand Up @@ -707,6 +716,22 @@ def create_repo(
json = {"name": name, "organization": organization, "private": private}
if repo_type is not None:
json["type"] = repo_type
if repo_type == "space":
if space_sdk is None:
raise ValueError(
"No space_sdk provided. `create_repo` expects space_sdk to be one of "
f"{SPACES_SDK_TYPES} when repo_type is 'space'`"
)
if space_sdk not in SPACES_SDK_TYPES:
raise ValueError(
f"Invalid space_sdk. Please choose one of {SPACES_SDK_TYPES}."
)
json["sdk"] = space_sdk
if space_sdk is not None and repo_type != "space":
warnings.warn(
"Ignoring provided space_sdk because repo_type is not 'space'."
)

if lfsmultipartthresh is not None:
json["lfsmultipartthresh"] = lfsmultipartthresh
r = requests.post(
Expand Down Expand Up @@ -821,7 +846,12 @@ def update_repo_visibility(
path_prefix += REPO_TYPES_URL_PREFIXES[repo_type]

path = "{}{}/{}/settings".format(path_prefix, namespace, name)
json = {"private": private}

# HACK - spaces repo updates break without recently added 'gated' param. Hardcoding here for now.
if repo_type == "space":
json = {"private": private, "gated": False}
else:
json = {"private": private}

r = requests.put(
path,
Expand Down
71 changes: 49 additions & 22 deletions tests/test_hf_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,14 @@
import unittest
from io import BytesIO

import pytest

import requests
from huggingface_hub.constants import REPO_TYPE_DATASET, REPO_TYPE_SPACE
from huggingface_hub.constants import (
REPO_TYPE_DATASET,
REPO_TYPE_SPACE,
SPACES_SDK_TYPES,
)
from huggingface_hub.file_download import cached_download, hf_hub_download
from huggingface_hub.hf_api import (
DatasetInfo,
Expand Down Expand Up @@ -153,28 +159,49 @@ def test_create_update_and_delete_dataset_repo(self):
name=DATASET_REPO_NAME, token=self._token, repo_type=REPO_TYPE_DATASET
)

@unittest.skip("skipped while spaces in beta")
@with_production_testing
def test_create_update_and_delete_space_repo(self):
self._api.create_repo(
name=SPACE_REPO_NAME, token=self._token, repo_type=REPO_TYPE_SPACE
)
res = self._api.update_repo_visibility(
name=SPACE_REPO_NAME,
token=self._token,
private=True,
repo_type=REPO_TYPE_SPACE,
)
self.assertTrue(res["private"])
res = self._api.update_repo_visibility(
name=SPACE_REPO_NAME,
token=self._token,
private=False,
repo_type=REPO_TYPE_SPACE,
)
self.assertFalse(res["private"])
self._api.delete_repo(
name=SPACE_REPO_NAME, token=self._token, repo_type=REPO_TYPE_SPACE
)
_api = HfApi()
_token = os.environ.get("API_TOKEN", None)
with pytest.raises(ValueError, match=r"No space_sdk provided.*"):
_api.create_repo(
token=_token,
name=SPACE_REPO_NAME,
repo_type=REPO_TYPE_SPACE,
space_sdk=None,
)
with pytest.raises(ValueError, match=r"Invalid space_sdk.*"):
_api.create_repo(
token=_token,
name=SPACE_REPO_NAME,
repo_type=REPO_TYPE_SPACE,
space_sdk="asdfasdf",
)

for sdk in SPACES_SDK_TYPES:
_api.create_repo(
name=SPACE_REPO_NAME,
token=_token,
repo_type=REPO_TYPE_SPACE,
space_sdk=sdk,
)
res = _api.update_repo_visibility(
name=SPACE_REPO_NAME,
token=_token,
private=True,
repo_type=REPO_TYPE_SPACE,
)
self.assertTrue(res["private"])
res = _api.update_repo_visibility(
name=SPACE_REPO_NAME,
token=_token,
private=False,
repo_type=REPO_TYPE_SPACE,
)
self.assertFalse(res["private"])
_api.delete_repo(
name=SPACE_REPO_NAME, token=_token, repo_type=REPO_TYPE_SPACE
)


class HfApiUploadFileTest(HfApiCommonTestWithLogin):
Expand Down