Skip to content

Commit

Permalink
chore: log all requests, also during retry (DEV-3213) (#759)
Browse files Browse the repository at this point in the history
  • Loading branch information
jnussbaum committed Jan 24, 2024
1 parent 2a04f5a commit fd05080
Showing 1 changed file with 65 additions and 111 deletions.
176 changes: 65 additions & 111 deletions src/dsp_tools/utils/connection_live.py
Expand Up @@ -3,8 +3,9 @@
import time
from dataclasses import dataclass, field
from datetime import datetime
from functools import partial
from importlib.metadata import version
from typing import Any, Callable, Optional, cast
from typing import Any, Literal, Optional, cast

import regex
from requests import JSONDecodeError, ReadTimeout, RequestException, Response, Session
Expand All @@ -19,6 +20,35 @@
logger = get_logger(__name__)


@dataclass
class RequestParameters:
method: Literal["POST", "GET", "PUT", "DELETE"]
url: str
timeout: int
data: dict[str, Any] | None = None
data_serialized: bytes | None = field(init=False, default=None)
headers: dict[str, str] | None = None
files: dict[str, tuple[str, Any]] | None = None

def __post_init__(self) -> None:
self.data_serialized = self._serialize_payload(self.data)

def _serialize_payload(self, payload: dict[str, Any] | None) -> bytes | None:
# If data is not encoded as bytes, issues can occur with non-ASCII characters,
# where the content-length of the request will turn out to be different from the actual length.
return json.dumps(payload, cls=SetEncoder, ensure_ascii=False).encode("utf-8") if payload else None

def as_kwargs(self) -> dict[str, Any]:
return {
"method": self.method,
"url": self.url,
"timeout": self.timeout,
"data": self.data_serialized,
"headers": self.headers,
"files": self.files,
}


@dataclass
class ConnectionLive:
"""
Expand Down Expand Up @@ -56,6 +86,7 @@ def login(self, email: str, password: str) -> None:
response = self.post(
route="/v2/authentication",
data={"email": email, "password": password},
timeout=10,
)
except PermanentConnectionError as e:
raise UserError(err_msg) from e
Expand Down Expand Up @@ -95,7 +126,7 @@ def post(
timeout: int | None = None,
) -> dict[str, Any]:
"""
Make a HTTP POST request to the server to which this connection has been established.
Make an HTTP POST request to the server to which this connection has been established.
Args:
route: route that will be called on the server
Expand All @@ -110,32 +141,14 @@ def post(
Raises:
PermanentConnectionError: if the server returns a permanent error
"""
if not route.startswith("/"):
route = f"/{route}"
url = self.server + route
if data:
headers = headers or {}
if "Content-Type" not in headers:
headers["Content-Type"] = "application/json; charset=UTF-8"
timeout = timeout or self.timeout_put_post

self._log_request(
method="POST",
url=url,
data=data,
uploaded_file=files["file"][0] if files else None,
headers=headers,
timeout=timeout,
)
response = self._try_network_action(
lambda: self.session.post(
url=url,
headers=headers,
timeout=timeout,
data=self._serialize_payload(data),
files=files,
)
params = RequestParameters(
"POST", self._make_url(route), timeout or self.timeout_put_post, data, headers, files
)
response = self._try_network_action(params)
return cast(dict[str, Any], response.json())

def get(
Expand All @@ -144,7 +157,7 @@ def get(
headers: dict[str, str] | None = None,
) -> dict[str, Any]:
"""
Make a HTTP GET request to the server to which this connection has been established.
Make an HTTP GET request to the server to which this connection has been established.
Args:
route: route that will be called on the server
Expand All @@ -156,25 +169,8 @@ def get(
Raises:
PermanentConnectionError: if the server returns a permanent error
"""
if not route.startswith("/"):
route = f"/{route}"
url = self.server + route
timeout = self.timeout_get_delete

self._log_request(
method="GET",
url=url,
data=None,
headers=headers,
timeout=timeout,
)
response = self._try_network_action(
lambda: self.session.get(
url=url,
headers=headers,
timeout=timeout,
)
)
params = RequestParameters("GET", self._make_url(route), self.timeout_get_delete, headers=headers)
response = self._try_network_action(params)
return cast(dict[str, Any], response.json())

def put(
Expand All @@ -184,7 +180,7 @@ def put(
headers: dict[str, str] | None = None,
) -> dict[str, Any]:
"""
Make a HTTP GET request to the server to which this connection has been established.
Make an HTTP GET request to the server to which this connection has been established.
Args:
route: route that will be called on the server
Expand All @@ -197,30 +193,12 @@ def put(
Raises:
PermanentConnectionError: if the server returns a permanent error
"""
if not route.startswith("/"):
route = f"/{route}"
url = self.server + route
if data:
headers = headers or {}
if "Content-Type" not in headers:
headers["Content-Type"] = "application/json; charset=UTF-8"
timeout = self.timeout_put_post

self._log_request(
method="PUT",
url=url,
data=data,
headers=headers,
timeout=timeout,
)
response = self._try_network_action(
lambda: self.session.put(
url=url,
headers=headers,
data=self._serialize_payload(data),
timeout=timeout,
)
)
params = RequestParameters("PUT", self._make_url(route), self.timeout_put_post, data, headers)
response = self._try_network_action(params)
return cast(dict[str, Any], response.json())

def delete(
Expand All @@ -229,7 +207,7 @@ def delete(
headers: dict[str, str] | None = None,
) -> dict[str, Any]:
"""
Make a HTTP GET request to the server to which this connection has been established.
Make an HTTP GET request to the server to which this connection has been established.
Args:
route: route that will be called on the server
Expand All @@ -241,35 +219,25 @@ def delete(
Raises:
PermanentConnectionError: if the server returns a permanent error
"""
params = RequestParameters("DELETE", self._make_url(route), self.timeout_get_delete, headers=headers)
response = self._try_network_action(params)
return cast(dict[str, Any], response.json())

def _make_url(self, route: str) -> str:
if not route.startswith("/"):
route = f"/{route}"
url = self.server + route
timeout = self.timeout_get_delete

self._log_request(
method="DELETE",
url=url,
data=None,
headers=headers,
timeout=timeout,
)
response = self.session.delete(
url=url,
headers=headers,
timeout=timeout,
)
return cast(dict[str, Any], response.json())
return self.server + route

def _try_network_action(self, action: Callable[[], Response]) -> Response:
def _try_network_action(self, params: RequestParameters) -> Response:
"""
Try 7 times to execute a HTTP request.
Try 7 times to execute an HTTP request.
If a timeout error, a ConnectionError, or a requests.RequestException occur,
or if the response indicates that there is a non-permanent server-side problem,
this function waits and retries the HTTP request.
The waiting times are 1, 2, 4, 8, 16, 32, 64 seconds.
Args:
action: a lambda with the code to be executed, or a function
params: keyword arguments for the HTTP request
Raises:
PermanentConnectionError: if the server returns a permanent error
Expand All @@ -278,8 +246,10 @@ def _try_network_action(self, action: Callable[[], Response]) -> Response:
Returns:
the return value of action
"""
action = partial(self.session.request, **params.as_kwargs())
for i in range(7):
try:
self._log_request(params)
response = action()
except (TimeoutError, ReadTimeout, ReadTimeoutError):
self._log_and_sleep(reason="Timeout Error", retry_counter=i, exc_info=True)
Expand Down Expand Up @@ -318,10 +288,9 @@ def _log_response(self, response: Response) -> None:
content = self._anonymize(response.json())
except JSONDecodeError:
content = {"content": response.text}
response_headers = self._anonymize(dict(response.headers))
dumpobj = {
"status code": response.status_code,
"response headers": response_headers,
"status_code": response.status_code,
"headers": self._anonymize(dict(response.headers)),
"content": content,
}
logger.debug(f"RESPONSE: {json.dumps(dumpobj)}")
Expand Down Expand Up @@ -352,30 +321,15 @@ def _in_testing_environment(self) -> bool:
in_testing_env = os.getenv("DSP_TOOLS_TESTING") # set in .github/workflows/tests-on-push.yml
return in_testing_env == "true"

def _log_request(
self,
method: str,
url: str,
data: dict[str, Any] | None,
timeout: int,
headers: dict[str, str] | None = None,
uploaded_file: str | None = None,
) -> None:
headers = headers or {}
headers.update({k: str(v) for k, v in self.session.headers.items()})
headers = self._anonymize(headers)
data = self._anonymize(data)
def _log_request(self, params: RequestParameters) -> None:
dumpobj = {
"HTTP request": method,
"url": url,
"headers": headers,
"timetout": timeout,
"payload": data,
"uploaded file": uploaded_file,
"method": params.method,
"url": params.url,
"headers": self._anonymize(dict(self.session.headers) | (params.headers or {})),
"timeout": params.timeout,
}
if params.data:
dumpobj["data"] = self._anonymize(params.data)
if params.files:
dumpobj["files"] = params.files["file"][0]
logger.debug(f"REQUEST: {json.dumps(dumpobj, cls=SetEncoder)}")

def _serialize_payload(self, payload: dict[str, Any] | None) -> bytes | None:
# If data is not encoded as bytes, issues can occur with non-ASCII characters,
# where the content-length of the request will turn out to be different from the actual length.
return json.dumps(payload, cls=SetEncoder, ensure_ascii=False).encode("utf-8") if payload else None

0 comments on commit fd05080

Please sign in to comment.