Skip to content

Commit

Permalink
✨ Add filters to HfApi.get_repo_discussions (#1845)
Browse files Browse the repository at this point in the history
* ✨ Add filters to HfApi.get_repo_discussions

* 💄 make style

* ⏪ Revert extraneous change

* 🩹 Mention filters in the documentation

* 👌 Literal type annotation

* ✅ Add some tests

* 💄 Code quality

* Apply suggestions from code review

---------

Co-authored-by: Lucain <lucainp@gmail.com>
  • Loading branch information
SBrandeis and Wauplin committed Nov 22, 2023
1 parent feb6e66 commit 1b1049a
Show file tree
Hide file tree
Showing 4 changed files with 92 additions and 4 deletions.
17 changes: 16 additions & 1 deletion docs/source/en/guides/community.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ The `HfApi` class allows you to retrieve Discussions and Pull Requests on a give

```python
>>> from huggingface_hub import get_repo_discussions
>>> for discussion in get_repo_discussions(repo_id="bigscience/bloom-1b3"):
>>> for discussion in get_repo_discussions(repo_id="bigscience/bloom"):
... print(f"{discussion.num} - {discussion.title}, pr: {discussion.is_pull_request}")

# 11 - Add Flax weights, pr: True
Expand All @@ -25,6 +25,21 @@ The `HfApi` class allows you to retrieve Discussions and Pull Requests on a give
[...]
```

`HfApi.get_repo_discussions` supports filtering by author, type (Pull Request or Discussion) and status (`open` or `closed`):

```python
>>> from huggingface_hub import get_repo_discussions
>>> for discussion in get_repo_discussions(
... repo_id="bigscience/bloom",
... author="ArthurZ",
... discussion_type="pull_request",
... discussion_status="open",
... ):
... print(f"{discussion.num} - {discussion.title} by {discussion.author}, pr: {discussion.is_pull_request}")

# 19 - Add Flax weights by ArthurZ, pr: True
```

`HfApi.get_repo_discussions` returns a [generator](https://docs.python.org/3.7/howto/functional.html#generators) that yields
[`Discussion`] objects. To get all the Discussions in a single list, run:

Expand Down
7 changes: 6 additions & 1 deletion src/huggingface_hub/constants.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os
import re
from typing import Optional
import typing
from typing import Literal, Optional, Tuple


# Possible values for env variables
Expand Down Expand Up @@ -79,6 +80,10 @@ def _as_int(value: Optional[str]) -> Optional[int]:
"models": REPO_TYPE_MODEL,
}

DiscussionTypeFilter = Literal["all", "discussion", "pull_request"]
DISCUSSION_TYPES: Tuple[DiscussionTypeFilter, ...] = typing.get_args(DiscussionTypeFilter)
DiscussionStatusFilter = Literal["all", "open", "closed"]
DISCUSSION_STATUS: Tuple[DiscussionTypeFilter, ...] = typing.get_args(DiscussionStatusFilter)

# default cache
default_home = os.path.join(os.path.expanduser("~"), ".cache")
Expand Down
35 changes: 33 additions & 2 deletions src/huggingface_hub/hf_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
Union,
overload,
)
from urllib.parse import quote
from urllib.parse import quote, urlencode

import requests
from requests.exceptions import HTTPError
Expand Down Expand Up @@ -93,6 +93,8 @@
from .constants import (
DEFAULT_ETAG_TIMEOUT,
DEFAULT_REVISION,
DISCUSSION_STATUS,
DISCUSSION_TYPES,
ENDPOINT,
INFERENCE_ENDPOINTS_ENDPOINT,
REGEX_COMMIT_OID,
Expand All @@ -101,6 +103,8 @@
REPO_TYPES_MAPPING,
REPO_TYPES_URL_PREFIXES,
SPACES_SDK_TYPES,
DiscussionStatusFilter,
DiscussionTypeFilter,
)
from .file_download import (
get_hf_file_metadata,
Expand Down Expand Up @@ -5196,6 +5200,9 @@ def get_repo_discussions(
self,
repo_id: str,
*,
author: Optional[str] = None,
discussion_type: Optional[DiscussionTypeFilter] = None,
discussion_status: Optional[DiscussionStatusFilter] = None,
repo_type: Optional[str] = None,
token: Optional[str] = None,
) -> Iterator[Discussion]:
Expand All @@ -5206,6 +5213,18 @@ def get_repo_discussions(
repo_id (`str`):
A namespace (user or an organization) and a repo name separated
by a `/`.
author (`str`, *optional*):
Pass a value to filter by discussion author. `None` means no filter.
Default is `None`.
discussion_type (`str`, *optional*):
Set to `"pull_request"` to fetch only pull requests, `"discussion"`
to fetch only discussions. Set to `"all"` or `None` to fetch both.
Default is `None`.
discussion_status (`str`, *optional*):
Set to `"open"` (respectively `"closed"`) to fetch only open
(respectively closed) discussions. Set to `"all"` or `None`
to fetch both.
Default is `None`.
repo_type (`str`, *optional*):
Set to `"dataset"` or `"space"` if fetching from a dataset or
space, `None` or `"model"` if fetching from a model. Default is
Expand Down Expand Up @@ -5236,11 +5255,23 @@ def get_repo_discussions(
raise ValueError(f"Invalid repo type, must be one of {REPO_TYPES}")
if repo_type is None:
repo_type = REPO_TYPE_MODEL
if discussion_type is not None and discussion_type not in DISCUSSION_TYPES:
raise ValueError(f"Invalid discussion_type, must be one of {DISCUSSION_TYPES}")
if discussion_status is not None and discussion_status not in DISCUSSION_STATUS:
raise ValueError(f"Invalid discussion_status, must be one of {DISCUSSION_STATUS}")

headers = self._build_hf_headers(token=token)
query_dict: Dict[str, str] = {}
if discussion_type is not None:
query_dict["type"] = discussion_type
if discussion_status is not None:
query_dict["status"] = discussion_status
if author is not None:
query_dict["author"] = author

def _fetch_discussion_page(page_index: int):
path = f"{self.endpoint}/api/{repo_type}s/{repo_id}/discussions?p={page_index}"
query_string = urlencode({**query_dict, "page_index": page_index})
path = f"{self.endpoint}/api/{repo_type}s/{repo_id}/discussions?{query_string}"
resp = get_session().get(path, headers=headers)
hf_raise_for_status(resp)
paginated_discussions = resp.json()
Expand Down
37 changes: 37 additions & 0 deletions tests/test_hf_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -2424,6 +2424,43 @@ def test_get_repo_discussion(self):
list([d.num for d in discussions_generator]), [self.discussion.num, self.pull_request.num]
)

def test_get_repo_discussion_by_type(self):
discussions_generator = self._api.get_repo_discussions(repo_id=self.repo_id, discussion_type="pull_request")
self.assertIsInstance(discussions_generator, types.GeneratorType)
self.assertListEqual(list([d.num for d in discussions_generator]), [self.pull_request.num])

discussions_generator = self._api.get_repo_discussions(repo_id=self.repo_id, discussion_type="discussion")
self.assertIsInstance(discussions_generator, types.GeneratorType)
self.assertListEqual(list([d.num for d in discussions_generator]), [self.discussion.num])

discussions_generator = self._api.get_repo_discussions(repo_id=self.repo_id, discussion_type="all")
self.assertIsInstance(discussions_generator, types.GeneratorType)
self.assertListEqual(
list([d.num for d in discussions_generator]), [self.discussion.num, self.pull_request.num]
)

def test_get_repo_discussion_by_author(self):
discussions_generator = self._api.get_repo_discussions(repo_id=self.repo_id, author="unknown")
self.assertIsInstance(discussions_generator, types.GeneratorType)
self.assertListEqual(list([d.num for d in discussions_generator]), [])

def test_get_repo_discussion_by_status(self):
self._api.change_discussion_status(self.repo_id, self.discussion.num, "closed")

discussions_generator = self._api.get_repo_discussions(repo_id=self.repo_id, discussion_status="open")
self.assertIsInstance(discussions_generator, types.GeneratorType)
self.assertListEqual(list([d.num for d in discussions_generator]), [self.pull_request.num])

discussions_generator = self._api.get_repo_discussions(repo_id=self.repo_id, discussion_status="closed")
self.assertIsInstance(discussions_generator, types.GeneratorType)
self.assertListEqual(list([d.num for d in discussions_generator]), [self.discussion.num])

discussions_generator = self._api.get_repo_discussions(repo_id=self.repo_id, discussion_status="all")
self.assertIsInstance(discussions_generator, types.GeneratorType)
self.assertListEqual(
list([d.num for d in discussions_generator]), [self.discussion.num, self.pull_request.num]
)

def test_get_discussion_details(self):
retrieved = self._api.get_discussion_details(repo_id=self.repo_id, discussion_num=2)
self.assertEqual(retrieved, self.discussion)
Expand Down

0 comments on commit 1b1049a

Please sign in to comment.