Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Endpoints added for enchanced fields search #1751

Merged
merged 6 commits into from
Apr 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
- Added support for Protocol Buffers ([#1728](https://github.com/neptune-ai/neptune-client/pull/1728))
- Series values DTO conversion reworked with protocol buffer support ([#1738](https://github.com/neptune-ai/neptune-client/pull/1738))
- Series values fetching reworked with protocol buffer support ([#1744](https://github.com/neptune-ai/neptune-client/pull/1744))
- Added support for enhanced field definitions querying ([#1751](https://github.com/neptune-ai/neptune-client/pull/1751))

### Fixes
- Fixed `tqdm.notebook` import only in Notebook environment ([#1716](https://github.com/neptune-ai/neptune-client/pull/1716))
Expand Down
102 changes: 102 additions & 0 deletions src/neptune/api/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,9 @@
"StringSeriesValues",
"StringPointValue",
"ImageSeriesValues",
"QueryFieldDefinitionsResult",
"NextPage",
"QueryFieldsResult",
)

import abc
Expand Down Expand Up @@ -601,6 +604,105 @@ def from_proto(data: ProtoLeaderboardEntriesSearchResultDTO) -> LeaderboardEntri
)


@dataclass
class NextPage:
limit: Optional[int]
next_page_token: Optional[str]

@staticmethod
def from_dict(data: Dict[str, Any]) -> NextPage:
return NextPage(limit=data.get("limit"), next_page_token=data.get("nextPageToken"))

@staticmethod
def from_model(model: Any) -> NextPage:
return NextPage(limit=model.limit, next_page_token=model.nextPageToken)

@staticmethod
def from_proto(data: Any) -> NextPage:
raise NotImplementedError()

def to_dto(self) -> Dict[str, Any]:
return {
"limit": self.limit,
"nextPageToken": self.next_page_token,
}


@dataclass
class QueryFieldsExperimentResult:
object_id: str
object_key: str
fields: List[Field]

@staticmethod
def from_dict(data: Dict[str, Any]) -> QueryFieldsExperimentResult:
return QueryFieldsExperimentResult(
object_id=data["experimentId"],
object_key=data["experimentShortId"],
fields=[Field.from_dict(field) for field in data["attributes"]],
)

@staticmethod
def from_model(model: Any) -> QueryFieldsExperimentResult:
return QueryFieldsExperimentResult(
object_id=model.experimentId,
object_key=model.experimentShortId,
fields=[Field.from_model(field) for field in model.attributes],
)

@staticmethod
def from_proto(data: Any) -> QueryFieldsExperimentResult:
raise NotImplementedError()


@dataclass
class QueryFieldsResult:
entries: List[QueryFieldsExperimentResult]
next_page: NextPage

@staticmethod
def from_dict(data: Dict[str, Any]) -> QueryFieldsResult:
return QueryFieldsResult(
entries=[QueryFieldsExperimentResult.from_dict(entry) for entry in data["entries"]],
next_page=NextPage.from_dict(data["nextPage"]),
)

@staticmethod
def from_model(model: Any) -> QueryFieldsResult:
return QueryFieldsResult(
entries=[QueryFieldsExperimentResult.from_model(entry) for entry in model.entries],
next_page=NextPage.from_model(model.nextPage),
)

@staticmethod
def from_proto(data: Any) -> QueryFieldsResult:
raise NotImplementedError()


@dataclass
class QueryFieldDefinitionsResult:
entries: List[FieldDefinition]
next_page: NextPage

@staticmethod
def from_dict(data: Dict[str, Any]) -> QueryFieldDefinitionsResult:
return QueryFieldDefinitionsResult(
entries=[FieldDefinition.from_dict(entry) for entry in data["entries"]],
next_page=NextPage.from_dict(data["nextPage"]),
)

@staticmethod
def from_model(model: Any) -> QueryFieldDefinitionsResult:
return QueryFieldDefinitionsResult(
entries=[FieldDefinition.from_model(entry) for entry in model.entries],
next_page=NextPage.from_model(model.nextPage),
)

@staticmethod
def from_proto(data: Any) -> QueryFieldDefinitionsResult:
raise NotImplementedError()


@dataclass
class FieldDefinition:
path: str
Expand Down
60 changes: 60 additions & 0 deletions src/neptune/api/pagination.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
#
# Copyright (c) 2024, Neptune Labs Sp. z o.o.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
__all__ = ("paginate_over",)

import abc
from dataclasses import dataclass
from typing import (
Any,
Callable,
Iterable,
Iterator,
Optional,
TypeVar,
)

from typing_extensions import Protocol

from neptune.api.models import NextPage


@dataclass
class WithPagination(abc.ABC):
next_page: Optional[NextPage]


T = TypeVar("T", bound=WithPagination)
Entry = TypeVar("Entry")


class Paginatable(Protocol):
def __call__(self, *, next_page: Optional[NextPage] = None, **kwargs: Any) -> Any: ...
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If there is an assumption that a Paginatable will always return an Iterator of some kind, we could make the type hint more precise. Is that correct, or we cannot guarantee that?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That was my initial try but I was struggling with it. You cannot create a Protocol with Generic at the same time and you have to provide a type var as an input if you would like to have something as an output. You can give it a try if you want.



def paginate_over(
getter: Paginatable,
extract_entries: Callable[[T], Iterable[Entry]],
**kwargs: Any,
) -> Iterator[Entry]:
"""
Generic approach to pagination via `NextPage`
"""
data = getter(**kwargs, next_page=None)
yield from extract_entries(data)

while data.next_page is not None and data.next_page.next_page_token is not None:
data = getter(**kwargs, next_page=data.next_page)
yield from extract_entries(data)
58 changes: 58 additions & 0 deletions src/neptune/internal/backends/hosted_neptune_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,9 @@
ImageSeriesValues,
IntField,
LeaderboardEntry,
NextPage,
QueryFieldDefinitionsResult,
QueryFieldsResult,
StringField,
StringSeriesField,
StringSeriesValues,
Expand Down Expand Up @@ -1025,6 +1028,31 @@ def get_float_series_values(
except HTTPNotFound:
raise FetchAttributeNotFoundException(path_to_str(path))

@with_api_exceptions_handler
def query_fields_within_project(
self,
project_id: QualifiedName,
field_names_filter: Optional[List[str]] = None,
experiment_ids_filter: Optional[List[str]] = None,
next_page: Optional[NextPage] = None,
) -> QueryFieldsResult:
pagination = {"nextPage": next_page.to_dto()} if next_page else {}
params = {
"projectIdentifier": project_id,
"query": {
**pagination,
"attributeNamesFilter": field_names_filter,
"experimentIdsFilter": experiment_ids_filter,
},
**DEFAULT_REQUEST_KWARGS,
}

try:
result = self.leaderboard_client.api.queryAttributesWithinProject(**params).response().result
return QueryFieldsResult.from_model(result)
except HTTPNotFound:
raise ProjectNotFound(project_id=project_id)

@with_api_exceptions_handler
def fetch_atom_attribute_values(
self, container_id: str, container_type: ContainerType, path: List[str]
Expand Down Expand Up @@ -1143,6 +1171,36 @@ def get_model_version_url(
base_url = self.get_display_address()
return f"{base_url}/{workspace}/{project_name}/m/{model_id}/v/{sys_id}"

def query_fields_definitions_within_project(
self,
project_id: QualifiedName,
field_name_regex: Optional[str] = None,
experiment_ids_filter: Optional[List[str]] = None,
next_page: Optional[NextPage] = None,
) -> QueryFieldDefinitionsResult:
pagination = {"nextPage": next_page.to_dto()} if next_page else {}
params = {
"projectIdentifier": project_id,
"query": {
**pagination,
"experimentIdsFilter": experiment_ids_filter,
"attributeNameRegex": field_name_regex,
},
}

try:
data = (
self.leaderboard_client.api.queryAttributeDefinitionsWithinProject(
**params,
**DEFAULT_REQUEST_KWARGS,
)
.response()
.result
)
return QueryFieldDefinitionsResult.from_model(data)
except HTTPNotFound:
raise ProjectNotFound(project_id=project_id)

def get_fields_definitions(
self,
container_id: str,
Expand Down
21 changes: 21 additions & 0 deletions src/neptune/internal/backends/neptune_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,9 @@
ImageSeriesValues,
IntField,
LeaderboardEntry,
NextPage,
QueryFieldDefinitionsResult,
QueryFieldsResult,
StringField,
StringSeriesField,
StringSeriesValues,
Expand Down Expand Up @@ -334,3 +337,21 @@ def search_leaderboard_entries(
@abc.abstractmethod
def list_fileset_files(self, attribute: List[str], container_id: str, path: str) -> List[FileEntry]:
pass

@abc.abstractmethod
def query_fields_definitions_within_project(
self,
project_id: QualifiedName,
field_name_regex: Optional[str] = None,
experiment_ids_filter: Optional[List[str]] = None,
next_page: Optional[NextPage] = None,
) -> QueryFieldDefinitionsResult: ...

@abc.abstractmethod
def query_fields_within_project(
self,
project_id: QualifiedName,
field_names_filter: Optional[List[str]] = None,
experiment_ids_filter: Optional[List[str]] = None,
next_page: Optional[NextPage] = None,
) -> QueryFieldsResult: ...
27 changes: 27 additions & 0 deletions src/neptune/internal/backends/neptune_backend_mock.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,9 @@
ImageSeriesValues,
IntField,
LeaderboardEntry,
NextPage,
QueryFieldDefinitionsResult,
QueryFieldsResult,
StringField,
StringPointValue,
StringSeriesField,
Expand Down Expand Up @@ -803,3 +806,27 @@ def get_fields_with_paths_filter(
self, container_id: str, container_type: ContainerType, paths: List[str], use_proto: Optional[bool] = None
) -> List[Field]:
return []

def query_fields_definitions_within_project(
self,
project_id: QualifiedName,
field_name_regex: Optional[str] = None,
experiment_ids_filter: Optional[List[str]] = None,
next_page: Optional[NextPage] = None,
) -> QueryFieldDefinitionsResult:
return QueryFieldDefinitionsResult(
entries=[],
next_page=NextPage(next_page_token=None, limit=0),
)

def query_fields_within_project(
self,
project_id: QualifiedName,
field_names_filter: Optional[List[str]] = None,
experiment_ids_filter: Optional[List[str]] = None,
next_page: Optional[NextPage] = None,
) -> QueryFieldsResult:
return QueryFieldsResult(
entries=[],
next_page=NextPage(next_page_token=None, limit=0),
)
26 changes: 25 additions & 1 deletion src/neptune/internal/backends/offline_neptune_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,9 @@
ImageSeriesValues,
IntField,
LeaderboardEntry,
NextPage,
QueryFieldDefinitionsResult,
QueryFieldsResult,
StringField,
StringSeriesField,
StringSeriesValues,
Expand All @@ -46,7 +49,10 @@
from neptune.internal.backends.neptune_backend_mock import NeptuneBackendMock
from neptune.internal.backends.nql import NQLQuery
from neptune.internal.container_type import ContainerType
from neptune.internal.id_formats import UniqueId
from neptune.internal.id_formats import (
QualifiedName,
UniqueId,
)
from neptune.typing import ProgressBarType


Expand Down Expand Up @@ -170,3 +176,21 @@ def search_leaderboard_entries(
use_proto: Optional[bool] = None,
) -> Generator[LeaderboardEntry, None, None]:
raise NeptuneOfflineModeFetchException

def query_fields_definitions_within_project(
self,
project_id: QualifiedName,
field_name_regex: Optional[str] = None,
experiment_ids_filter: Optional[List[str]] = None,
next_page: Optional[NextPage] = None,
) -> QueryFieldDefinitionsResult:
raise NeptuneOfflineModeFetchException

def query_fields_within_project(
self,
project_id: QualifiedName,
field_names_filter: Optional[List[str]] = None,
experiment_ids_filter: Optional[List[str]] = None,
next_page: Optional[NextPage] = None,
) -> QueryFieldsResult:
raise NeptuneOfflineModeFetchException
Loading
Loading