Skip to content

Commit

Permalink
fix: Improve error message when RPC response is not valid JSON
Browse files Browse the repository at this point in the history
  • Loading branch information
edgarrmondragon committed Mar 3, 2023
1 parent 1649185 commit 6cc1d11
Show file tree
Hide file tree
Showing 5 changed files with 87 additions and 30 deletions.
5 changes: 5 additions & 0 deletions .changes/unreleased/Fixed-20230302-192233.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
kind: Fixed
body: Improve error message when RPC response is not valid JSON
time: 2023-03-02T19:22:33.750455-06:00
custom:
Issue: "732"
12 changes: 12 additions & 0 deletions src/citric/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,3 +33,15 @@ class RPCInterfaceNotEnabledError(LimeSurveyError):
def __init__(self) -> None:
"""Create a new exception."""
super().__init__("RPC interface not enabled")


class InvalidJSONResponseError(LimeSurveyError):
"""RPC interface maybe not enabled on LimeSurvey."""

def __init__(self) -> None:
"""Create a new exception."""
msg = (
"Received a non-JSON response, verify that the JSON RPC interface is "
"enabled in global settings"
)
super().__init__(msg)
38 changes: 29 additions & 9 deletions src/citric/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,15 @@

from __future__ import annotations

import json
import logging
import random
from typing import TYPE_CHECKING, TypeVar

import requests

from citric.exceptions import (
InvalidJSONResponseError,
LimeSurveyApiError,
LimeSurveyStatusError,
ResponseMismatchError,
Expand All @@ -20,11 +22,32 @@
from types import TracebackType
from typing import Any

__all__ = ["Session"]

GET_SESSION_KEY = "get_session_key"
_T = TypeVar("_T", bound="Session")
logger = logging.getLogger(__name__)


def handle_rpc_errors(result: dict[str, Any], error: str | None) -> None:
"""Handle RPC errors.
Args:
result: The result of the RPC call.
error: The error message of the RPC call.
Raises:
LimeSurveyStatusError: The response key from the response payload has
a non-null status.
LimeSurveyApiError: The response payload has a non-null error key.
"""
if isinstance(result, dict) and result.get("status") not in {"OK", None}:
raise LimeSurveyStatusError(result["status"])

if error is not None:
raise LimeSurveyApiError(error)


class Session:
"""LimeSurvey RemoteControl 2 session.
Expand Down Expand Up @@ -127,12 +150,10 @@ def _invoke(
params (Any): Positional arguments of the RPC method.
Raises:
LimeSurveyStatusError: The response key from the response payload has
a non-null status.
LimeSurveyApiError: The response payload has a non-null error key.
ResponseMismatchError: Request ID does not match the response ID.
RPCInterfaceNotEnabledError: If the JSON RPC interface is not enabled
(empty response).
InvalidJSONResponseError: If the response is not valid JSON.
Returns:
Any: An RPC result.
Expand All @@ -151,18 +172,17 @@ def _invoke(
if res.text == "":
raise RPCInterfaceNotEnabledError

data = res.json()
try:
data = res.json()
except json.JSONDecodeError as e:
raise InvalidJSONResponseError from e

result = data["result"]
error = data["error"]
response_id = data["id"]
logger.info("Invoked RPC method %s with ID %d", method, request_id)

if isinstance(result, dict) and result.get("status") not in {"OK", None}:
raise LimeSurveyStatusError(result["status"])

if error is not None:
raise LimeSurveyApiError(error)
handle_rpc_errors(result, error)

if response_id != request_id:
msg = f"Response ID {response_id} does not match request ID {request_id}"
Expand Down
55 changes: 34 additions & 21 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,30 +30,15 @@ class LimeSurveyMockAdapter(BaseAdapter):

ldap_session_key = "ldap-key"

def send( # noqa: PLR0913
def _handle_json_response(
self,
request: requests.PreparedRequest,
stream: bool = False, # noqa: FBT001, FBT002
timeout: float | tuple[float, float] | tuple[float, None] | None = None,
verify: bool | str = True, # noqa: FBT002
cert: Any | None = None,
proxies: Mapping[str, str] | None = None,
):
"""Send a mocked request."""
request_data = json.loads(request.body or "{}")

method: str,
params: list[Any],
request_id: int,
) -> requests.Response:
response = requests.Response()
response.__setattr__("_content", b"")
response.status_code = 200

method = request_data["method"]
params = request_data["params"]
request_id = request_data.get("id", 1)

output = {"result": None, "error": None, "id": request_id}

if method == "__disabled":
return response
output: dict[str, Any] = {"result": None, "error": None, "id": request_id}

if method in self.api_error_methods:
output["error"] = "API Error!"
Expand All @@ -78,6 +63,34 @@ def send( # noqa: PLR0913

return response

def send( # noqa: PLR0913
self,
request: requests.PreparedRequest,
stream: bool = False, # noqa: FBT001, FBT002
timeout: float | tuple[float, float] | tuple[float, None] | None = None,
verify: bool | str = True, # noqa: FBT002
cert: Any | None = None,
proxies: Mapping[str, str] | None = None,
):
"""Send a mocked request."""
request_data = json.loads(request.body or "{}")
method = request_data["method"]
params = request_data["params"]
request_id = request_data.get("id", 1)

if method == "__disabled":
response = requests.Response()
response.status_code = 200
return response

if method == "__not_json":
response = requests.Response()
response.status_code = 200
response.__setattr__("_content", b"this is not json")
return response

return self._handle_json_response(method, params, request_id)

def close(self):
"""Clean up adapter specific items."""

Expand Down
7 changes: 7 additions & 0 deletions tests/test_rpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import requests

from citric.exceptions import (
InvalidJSONResponseError,
LimeSurveyApiError,
LimeSurveyStatusError,
ResponseMismatchError,
Expand Down Expand Up @@ -117,6 +118,12 @@ def test_empty_response(session: Session):
session.__disabled()


def test_non_json_response(session: Session):
"""Test non-JSON response."""
with pytest.raises(InvalidJSONResponseError, match="Received a non-JSON response"):
session.__not_json()


def test_api_error(session: Session):
"""Test non-null error raises LimeSurveyApiError."""
with pytest.raises(LimeSurveyApiError, match="API Error!"):
Expand Down

0 comments on commit 6cc1d11

Please sign in to comment.