Skip to content
This repository was archived by the owner on Nov 19, 2024. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion cads_api_client/api_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import functools
import warnings
from typing import Any, Literal
from typing import Any, Callable, Literal

import attrs
import multiurl.base
Expand Down Expand Up @@ -51,6 +51,7 @@ class ApiClient:
retry_after: float = 120
maximum_tries: int = 500
session: requests.Session = attrs.field(factory=requests.Session)
_log_callback: Callable[..., None] | None = None

def __attrs_post_init__(self) -> None:
if self.url is None:
Expand Down Expand Up @@ -108,6 +109,7 @@ def _get_request_kwargs(
download_options=self._download_options,
sleep_max=self.sleep_max,
cleanup=self.cleanup,
log_callback=self._log_callback,
)

@functools.cached_property
Expand Down
4 changes: 3 additions & 1 deletion cads_api_client/catalogue.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

import datetime
from typing import Any
from typing import Any, Callable

import attrs
import requests
Expand Down Expand Up @@ -112,6 +112,7 @@ class Catalogue:
download_options: dict[str, Any]
sleep_max: float
cleanup: bool
log_callback: Callable[..., None] | None
force_exact_url: bool = False

def __attrs_post_init__(self) -> None:
Expand All @@ -128,6 +129,7 @@ def _request_kwargs(self) -> RequestKwargs:
download_options=self.download_options,
sleep_max=self.sleep_max,
cleanup=self.cleanup,
log_callback=self.log_callback,
)

def get_collections(self, **params: Any) -> Collections:
Expand Down
35 changes: 10 additions & 25 deletions cads_api_client/legacy_api_client.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,17 @@
from __future__ import annotations

import collections
import functools
import logging
import typing
import warnings
from types import TracebackType
from typing import Any, Callable, TypeVar, cast, overload
from typing import Any, Callable, TypeVar, overload

import cdsapi.api
import multiurl
import requests

from . import __version__ as cads_api_client_version
from . import processing
from .api_client import ApiClient
from .processing import Remote, Results

Expand Down Expand Up @@ -103,17 +101,18 @@ def __init__(
self.debug_callback = debug_callback
self.session = requests.Session() if session is None else session

self.client = self.logging_decorator(ApiClient)(
self.client = ApiClient(
url=self.url,
key=self.key,
verify=self.verify,
sleep_max=self.sleep_max,
session=self.session,
cleanup=self.delete,
maximum_tries=self.retry_max,
retry_after=self.sleep_max,
timeout=self.timeout,
progress=self.progress,
cleanup=self.delete,
sleep_max=self.sleep_max,
retry_after=self.sleep_max,
maximum_tries=self.retry_max,
session=self.session,
log_callback=self.log,
)
self.debug(
"CDSAPI %s",
Expand All @@ -137,16 +136,6 @@ def raise_not_implemented_error(self) -> None:
"This is a beta version. This functionality has not been implemented yet."
)

def logging_decorator(self, func: F) -> F:
@functools.wraps(func)
def wrapper(*args: Any, **kwargs: Any) -> Any:
with LoggingContext(
logger=processing.LOGGER, quiet=self.quiet, debug=self._debug
):
return func(*args, **kwargs)

return cast(F, wrapper)

@overload
def retrieve(self, name: str, request: dict[str, Any], target: str) -> str: ...

Expand All @@ -160,20 +149,16 @@ def retrieve(
) -> str | Remote | Results:
submitted: Remote | Results
if self.wait_until_complete:
submitted = self.logging_decorator(self.client.submit_and_wait_on_results)(
submitted = self.client.submit_and_wait_on_results(
collection_id=name,
**request,
)
else:
submitted = self.logging_decorator(self.client.submit)(
submitted = self.client.submit(
collection_id=name,
**request,
)

# Decorate legacy methods
submitted.download = self.logging_decorator(submitted.download) # type: ignore[method-assign]
submitted.log = self.log # type: ignore[method-assign]

return submitted if target is None else submitted.download(target)

def log(self, level: int, *args: Any, **kwargs: Any) -> None:
Expand Down
34 changes: 27 additions & 7 deletions cads_api_client/processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import time
import urllib.parse
import warnings
from typing import Any, Type, TypedDict, TypeVar
from typing import Any, Callable, Type, TypedDict, TypeVar

try:
from typing import Self
Expand Down Expand Up @@ -46,6 +46,7 @@ class RequestKwargs(TypedDict):
download_options: dict[str, Any]
sleep_max: float
cleanup: bool
log_callback: Callable[..., None] | None


class ProcessingFailedError(RuntimeError):
Expand Down Expand Up @@ -96,6 +97,13 @@ def get_level_and_message(message: str) -> tuple[int, str]:
return level, message


def log(*args: Any, callback: Callable[..., None] | None = None, **kwargs: Any) -> None:
if callback is None:
LOGGER.log(*args, **kwargs)
else:
callback(*args, **kwargs)


@attrs.define(slots=False)
class ApiResponse:
response: requests.Response
Expand All @@ -106,6 +114,7 @@ class ApiResponse:
download_options: dict[str, Any]
sleep_max: float
cleanup: bool
log_callback: Callable[..., None] | None

@property
def _request_kwargs(self) -> RequestKwargs:
Expand All @@ -117,6 +126,7 @@ def _request_kwargs(self) -> RequestKwargs:
download_options=self.download_options,
sleep_max=self.sleep_max,
cleanup=self.cleanup,
log_callback=self.log_callback,
)

@classmethod
Expand All @@ -131,6 +141,7 @@ def from_request(
download_options: dict[str, Any],
sleep_max: float,
cleanup: bool,
log_callback: Callable[..., None] | None,
log_messages: bool = True,
**kwargs: Any,
) -> T_ApiResponse:
Expand All @@ -139,11 +150,15 @@ def from_request(
robust_request = multiurl.robust(session.request, **retry_options)

inputs = kwargs.get("json", {}).get("inputs", {})
LOGGER.debug(f"{method.upper()} {url} {inputs or ''}".strip())
log(
logging.DEBUG,
f"{method.upper()} {url} {inputs or ''}".strip(),
callback=log_callback,
)
response = robust_request(
method, url, headers=headers, **request_options, **kwargs
)
LOGGER.debug(f"REPLY {response.text}")
log(logging.DEBUG, f"REPLY {response.text}", callback=log_callback)

cads_raise_for_status(response)

Expand All @@ -156,6 +171,7 @@ def from_request(
download_options=download_options,
sleep_max=sleep_max,
cleanup=cleanup,
log_callback=log_callback,
)
if log_messages:
self.log_messages()
Expand Down Expand Up @@ -223,8 +239,8 @@ def _from_rel_href(self, rel: str) -> Self | None:
out = None
return out

def log(self, level: int, *args: Any, **kwargs: Any) -> None:
LOGGER.log(level, *args, **kwargs)
def log(self, *args: Any, **kwargs: Any) -> None:
log(*args, callback=self.log_callback, **kwargs)

def info(self, *args: Any, **kwargs: Any) -> None:
self.log(logging.INFO, *args, **kwargs)
Expand Down Expand Up @@ -367,6 +383,7 @@ class Remote:
download_options: dict[str, Any]
sleep_max: float
cleanup: bool
log_callback: Callable[..., None] | None

def __attrs_post_init__(self) -> None:
self.log_start_time = None
Expand All @@ -382,6 +399,7 @@ def _request_kwargs(self) -> RequestKwargs:
download_options=self.download_options,
sleep_max=self.sleep_max,
cleanup=self.cleanup,
log_callback=self.log_callback,
)

def _log_metadata(self, metadata: dict[str, Any]) -> None:
Expand Down Expand Up @@ -579,8 +597,8 @@ def reply(self) -> dict[str, Any]:
reply.setdefault("request_id", self.request_uid)
return reply

def log(self, level: int, *args: Any, **kwargs: Any) -> None:
LOGGER.log(level, *args, **kwargs)
def log(self, *args: Any, **kwargs: Any) -> None:
log(*args, callback=self.log_callback, **kwargs)

def info(self, *args: Any, **kwargs: Any) -> None:
self.log(logging.INFO, *args, **kwargs)
Expand Down Expand Up @@ -722,6 +740,7 @@ class Processing:
download_options: dict[str, Any]
sleep_max: float
cleanup: bool
log_callback: Callable[..., None] | None
force_exact_url: bool = False

def __attrs_post_init__(self) -> None:
Expand All @@ -738,6 +757,7 @@ def _request_kwargs(self) -> RequestKwargs:
download_options=self.download_options,
sleep_max=self.sleep_max,
cleanup=self.cleanup,
log_callback=self.log_callback,
)

def get_processes(self, **params: Any) -> Processes:
Expand Down
4 changes: 3 additions & 1 deletion cads_api_client/profile.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from typing import Any
from typing import Any, Callable

import attrs
import requests
Expand All @@ -18,6 +18,7 @@ class Profile:
download_options: dict[str, Any]
sleep_max: float
cleanup: bool
log_callback: Callable[..., None] | None
force_exact_url: bool = False

def __attrs_post_init__(self) -> None:
Expand All @@ -34,6 +35,7 @@ def _request_kwargs(self) -> processing.RequestKwargs:
download_options=self.download_options,
sleep_max=self.sleep_max,
cleanup=self.cleanup,
log_callback=self.log_callback,
)

def _get_api_response(
Expand Down
34 changes: 25 additions & 9 deletions tests/integration_test_70_legacy_api_client.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import contextlib
import logging
import os
import pathlib
import time
Expand Down Expand Up @@ -200,16 +201,31 @@ def test_legacy_api_client_kwargs(api_root_url: str, api_anon_key: str) -> None:


def test_legacy_api_client_logging(
caplog: pytest.LogCaptureFixture, legacy_client: LegacyApiClient
caplog: pytest.LogCaptureFixture,
api_root_url: str,
api_anon_key: str,
) -> None:
legacy_client.info("Info message")
legacy_client.warning("Warning message")
legacy_client.error("Error message")
assert caplog.record_tuples == [
("cads_api_client.legacy_api_client", 20, "Info message"),
("cads_api_client.legacy_api_client", 30, "Warning message"),
("cads_api_client.legacy_api_client", 40, "Error message"),
]
logger = logging.getLogger("foo")
client = LegacyApiClient(
url=api_root_url,
key=api_anon_key,
info_callback=logger.info,
warning_callback=logger.warning,
error_callback=logger.error,
debug_callback=logger.debug,
)
caplog.clear()
with caplog.at_level(logging.DEBUG):
client.debug("Debug message")
client.info("Info message")
client.warning("Warning message")
client.error("Error message")
assert caplog.record_tuples == [
("foo", 10, "Debug message"),
("foo", 20, "Info message"),
("foo", 30, "Warning message"),
("foo", 40, "Error message"),
]


def test_legacy_api_client_download(
Expand Down
1 change: 1 addition & 0 deletions tests/test_10_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,6 +273,7 @@ def cat() -> catalogue.Catalogue:
download_options={},
sleep_max=120,
cleanup=False,
log_callback=None,
)


Expand Down