Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introduce querying capabilities to fetch_runs_table #1660

Merged
merged 16 commits into from
Mar 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
### Features
- Added `get_workspace_status()` method to management API ([#1662](https://github.com/neptune-ai/neptune-client/pull/1662))
- Added auto-scaling pixel values for image logging ([#1664](https://github.com/neptune-ai/neptune-client/pull/1664))
- Introduce querying capabilities to `fetch_runs_table()` ([#1660](https://github.com/neptune-ai/neptune-client/pull/1660))

### Fixes
- Restored support for SSL verification exception ([#1661](https://github.com/neptune-ai/neptune-client/pull/1661))
Expand Down
18 changes: 13 additions & 5 deletions src/neptune/api/searching_entries.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,13 @@

from bravado.client import construct_request # type: ignore
from bravado.config import RequestConfig # type: ignore
from bravado.exception import HTTPBadRequest # type: ignore
from typing_extensions import (
Literal,
TypeAlias,
)

from neptune.exceptions import NeptuneInvalidQueryException
from neptune.internal.backends.api_model import (
AttributeType,
AttributeWithProperties,
Expand Down Expand Up @@ -142,11 +144,17 @@ def get_single_page(

http_client = client.swagger_spec.http_client

return (
http_client.request(request_params, operation=None, request_config=request_config)
.response()
.incoming_response.json()
)
try:
return (
http_client.request(request_params, operation=None, request_config=request_config)
.response()
.incoming_response.json()
)
except HTTPBadRequest as e:
title = e.response.json().get("title")
if title == "Syntax error":
raise NeptuneInvalidQueryException(nql_query=str(normalized_query))
raise e


def to_leaderboard_entry(entry: Dict[str, Any]) -> LeaderboardEntry:
Expand Down
10 changes: 10 additions & 0 deletions src/neptune/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@
"StreamAlreadyUsedException",
"NeptuneUserApiInputException",
"NeptuneMaxDiskUtilizationExceeded",
"NeptuneInvalidQueryException",
]

from typing import (
Expand Down Expand Up @@ -1223,3 +1224,12 @@ def __init__(self, disk_utilization: float, utilization_limit: float):
super().__init__(
message.format(disk_utilization=disk_utilization, utilization_limit=utilization_limit, **STYLES)
)


class NeptuneInvalidQueryException(NeptuneException):
def __init__(self, nql_query: str):
message = f"""
The provided NQL query is invalid: {nql_query}.
For syntax help, see https://docs.neptune.ai/usage/nql/
"""
super().__init__(message)
38 changes: 35 additions & 3 deletions src/neptune/internal/backends/nql.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#

from __future__ import annotations

__all__ = [
"NQLQuery",
"NQLEmptyQuery",
Expand All @@ -21,6 +24,7 @@
"NQLAttributeOperator",
"NQLAttributeType",
"NQLQueryAttribute",
"RawNQLQuery",
]

import typing
Expand All @@ -31,9 +35,11 @@

@dataclass
class NQLQuery:
pass
def eval(self) -> NQLQuery:
return self


@dataclass
class NQLEmptyQuery(NQLQuery):
def __str__(self) -> str:
return ""
Expand All @@ -49,10 +55,20 @@ class NQLQueryAggregate(NQLQuery):
items: Iterable[NQLQuery]
aggregator: NQLAggregator

def eval(self) -> NQLQuery:
self.items = list(filter(lambda nql: not isinstance(nql, NQLEmptyQuery), (item.eval() for item in self.items)))

if len(self.items) == 0:
return NQLEmptyQuery()
elif len(self.items) == 1:
return self.items[0]
return self

def __str__(self) -> str:
if self.items:
evaluated = self.eval()
if isinstance(evaluated, NQLQueryAggregate):
return "(" + f" {self.aggregator.value} ".join(map(str, self.items)) + ")"
return ""
return str(evaluated)


class NQLAttributeOperator(str, Enum):
Expand Down Expand Up @@ -85,3 +101,19 @@ def __str__(self) -> str:
value = f'"{self.value}"'

return f"(`{self.name}`:{self.type.value} {self.operator.value} {value})"


@dataclass
class RawNQLQuery(NQLQuery):
query: str

def eval(self) -> NQLQuery:
Raalsky marked this conversation as resolved.
Show resolved Hide resolved
if self.query == "":
return NQLEmptyQuery()
return self

def __str__(self) -> str:
evaluated = self.eval()
if isinstance(evaluated, RawNQLQuery):
return self.query
return str(evaluated)
33 changes: 28 additions & 5 deletions src/neptune/metadata_containers/project.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,11 @@
)
from neptune.metadata_containers import MetadataContainer
from neptune.metadata_containers.abstract import NeptuneObjectCallback
from neptune.metadata_containers.utils import prepare_nql_query
from neptune.metadata_containers.utils import (
build_raw_query,
deprecated_func_arg_warning_check,
prepare_nql_query,
)
from neptune.table import Table
from neptune.types.mode import Mode
from neptune.typing import (
Expand Down Expand Up @@ -196,6 +200,7 @@ def get_url(self) -> str:
def fetch_runs_table(
self,
*,
query: Optional[str] = None,
id: Optional[Union[str, Iterable[str]]] = None,
state: Optional[Union[Literal["inactive", "active"], Iterable[Literal["inactive", "active"]]]] = None,
owner: Optional[Union[str, Iterable[str]]] = None,
Expand All @@ -213,6 +218,9 @@ def fetch_runs_table(
Only runs matching all of the criteria will be returned.

Args:
query: NQL query string. Syntax: https://docs.neptune.ai/usage/nql/
Example: `"(accuracy: float > 0.88) AND (loss: float < 0.2)"`.
Exclusive with the `id`, `state`, `owner`, and `tag` parameters.
id: Neptune ID of a run, or list of several IDs.
Example: `"SAN-1"` or `["SAN-1", "SAN-2"]`.
Matching any element of the list is sufficient to pass the criterion.
Expand Down Expand Up @@ -288,25 +296,40 @@ def fetch_runs_table(
See also the API reference in the docs:
https://docs.neptune.ai/api/project#fetch_runs_table
"""

deprecated_func_arg_warning_check("fetch_runs_table", "id", id)
deprecated_func_arg_warning_check("fetch_runs_table", "state", state)
deprecated_func_arg_warning_check("fetch_runs_table", "owner", owner)
deprecated_func_arg_warning_check("fetch_runs_table", "tag", tag)

if any((id, state, owner, tag)) and query is not None:
raise ValueError(
"You can't use the 'query' parameter together with the 'id', 'state', 'owner', or 'tag' parameters."
)

ids = as_list("id", id)
states = as_list("state", state)
owners = as_list("owner", owner)
tags = as_list("tag", tag)

verify_type("query", query, (str, type(None)))
verify_type("trashed", trashed, (bool, type(None)))
verify_type("limit", limit, (int, type(None)))
verify_type("sort_by", sort_by, str)
verify_type("ascending", ascending, bool)
verify_type("progress_bar", progress_bar, (type(None), bool, type(ProgressBarCallback)))
verify_collection_type("state", states, str)

for state in states:
verify_value("state", state.lower(), ("inactive", "active"))

if isinstance(limit, int) and limit <= 0:
raise ValueError(f"Parameter 'limit' must be a positive integer or None. Got {limit}.")

nql_query = prepare_nql_query(ids, states, owners, tags, trashed)
for state in states:
verify_value("state", state.lower(), ("inactive", "active"))

if query is not None:
nql_query = build_raw_query(query, trashed=trashed)
else:
nql_query = prepare_nql_query(ids, states, owners, tags, trashed)

return MetadataContainer._fetch_entries(
self,
Expand Down
28 changes: 28 additions & 0 deletions src/neptune/metadata_containers/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

from datetime import datetime
from typing import (
Any,
Generator,
Iterable,
List,
Expand All @@ -30,6 +31,7 @@
)

from neptune.common.warnings import (
NeptuneDeprecationWarning,
NeptuneWarning,
warn_once,
)
Expand All @@ -42,8 +44,10 @@
NQLAggregator,
NQLAttributeOperator,
NQLAttributeType,
NQLQuery,
NQLQueryAggregate,
NQLQueryAttribute,
RawNQLQuery,
)
from neptune.internal.utils.run_state import RunState

Expand Down Expand Up @@ -169,3 +173,27 @@ def _parse_entry(entry: LeaderboardEntry) -> LeaderboardEntry:
exception=NeptuneWarning,
)
return entry


def deprecated_func_arg_warning_check(fname: str, vname: str, var: Any) -> None:
if var is not None:
msg = f"""The argument '{vname}' of the function '{fname}' is deprecated and will be removed in the future."""
warn_once(msg, exception=NeptuneDeprecationWarning)


def build_raw_query(query: str, trashed: Optional[bool]) -> NQLQuery:
raw_nql = RawNQLQuery(query)

if trashed is None:
return raw_nql

nql = NQLQueryAggregate(
items=[
raw_nql,
NQLQueryAttribute(
name="sys/trashed", type=NQLAttributeType.BOOLEAN, operator=NQLAttributeOperator.EQUALS, value=trashed
),
],
aggregator=NQLAggregator.AND,
)
return nql