diff --git a/jupyter_server/services/events/handlers.py b/jupyter_server/services/events/handlers.py index f861f4227d..940310327c 100644 --- a/jupyter_server/services/events/handlers.py +++ b/jupyter_server/services/events/handlers.py @@ -3,13 +3,18 @@ .. versionadded:: 2.0 """ import logging +from datetime import datetime +from typing import Any, Dict, Optional from jupyter_telemetry.eventlog import _skip_message from pythonjsonlogger import jsonlogger from tornado import web, websocket +from jupyter_server.auth import authorized from jupyter_server.base.handlers import JupyterHandler +from ...base.handlers import APIHandler + AUTH_RESOURCE = "events" @@ -53,11 +58,6 @@ async def get(self, *args, **kwargs): res = super().get(*args, **kwargs) await res - @property - def event_bus(self): - """Jupyter Server's event bus that emits structured event data.""" - return self.settings["event_bus"] - def open(self): """Routes events that are emitted by Jupyter Server's EventBus to a WebSocket client in the browser. @@ -75,6 +75,59 @@ def on_close(self): self.event_bus.handlers.remove(self.logging_handler) +def validate_model(data: Dict[str, Any]) -> None: + """Validates for required fields in the JSON request body""" + required_keys = {"schema_name", "version", "event"} + for key in required_keys: + if key not in data: + raise web.HTTPError(400, f"Missing `{key}` in the JSON request body.") + + +def get_timestamp(data: Dict[str, Any]) -> Optional[datetime]: + """Parses timestamp from the JSON request body""" + try: + if "timestamp" in data: + timestamp = datetime.strptime(data["timestamp"], "%Y-%m-%dT%H:%M:%S%zZ") + else: + timestamp = None + except Exception: + raise web.HTTPError( + 400, + """Failed to parse timestamp from JSON request body, + an ISO format datetime string with UTC offset is expected, + for example, 2022-05-26T13:50:00+05:00Z""", + ) + + return timestamp + + +class EventHandler(APIHandler): + """REST api handler for events""" + + auth_resource = AUTH_RESOURCE + + @web.authenticated + @authorized + async def post(self): + payload = self.get_json_body() + if payload is None: + raise web.HTTPError(400, "No JSON data provided") + + try: + validate_model(payload) + self.event_bus.record_event( + schema_name=payload.get("schema_name"), + version=payload.get("version"), + event=payload.get("event"), + timestamp_override=get_timestamp(payload), + ) + self.set_status(204) + self.finish() + except Exception as e: + raise web.HTTPError(500, str(e)) from e + + default_handlers = [ + (r"/api/events", EventHandler), (r"/api/events/subscribe", SubscribeWebsocket), ] diff --git a/tests/services/events/test_api.py b/tests/services/events/test_api.py index 4d14f7b78b..a82144ba45 100644 --- a/tests/services/events/test_api.py +++ b/tests/services/events/test_api.py @@ -1,22 +1,41 @@ +import io import json +import logging import pathlib import pytest +import tornado +from jupyter_telemetry.eventlog import _skip_message +from pythonjsonlogger import jsonlogger + +from tests.utils import expected_http_error @pytest.fixture -def event_bus(jp_serverapp): +def eventbus_sink(jp_serverapp): event_bus = jp_serverapp.event_bus # Register the event schema defined in this directory. schema_file = pathlib.Path(__file__).parent / "mock_event.yaml" event_bus.register_schema_file(schema_file) - # event_bus.allowed_schemas = ["event.mock.jupyter.org/message"] + + sink = io.StringIO() + formatter = jsonlogger.JsonFormatter(json_serializer=_skip_message) + handler = logging.StreamHandler(sink) + handler.setFormatter(formatter) + event_bus.handlers = [handler] + event_bus.log.addHandler(handler) + + return event_bus, sink + + +@pytest.fixture +def event_bus(eventbus_sink): + event_bus, sink = eventbus_sink return event_bus async def test_subscribe_websocket(jp_ws_fetch, event_bus): - # Open a websocket connection. ws = await jp_ws_fetch("/api/events/subscribe") event_bus.record_event( @@ -26,7 +45,119 @@ async def test_subscribe_websocket(jp_ws_fetch, event_bus): ) message = await ws.read_message() event_data = json.loads(message) - # Close websocket ws.close() assert event_data.get("event_message") == "Hello, world!" + + +payload_1 = """\ +{ + "schema_name": "event.mock.jupyter.org/message", + "version": 1, + "event": { + "event_message": "Hello, world!" + }, + "timestamp": "2022-05-26T12:50:00+06:00Z" +} +""" + +payload_2 = """\ +{ + "schema_name": "event.mock.jupyter.org/message", + "version": 1, + "event": { + "event_message": "Hello, world!" + } +} +""" + + +@pytest.mark.parametrize("payload", [payload_1, payload_2]) +async def test_post_event(jp_fetch, eventbus_sink, payload): + event_bus, sink = eventbus_sink + + r = await jp_fetch("api", "events", method="POST", body=payload) + assert r.code == 204 + + output = sink.getvalue() + assert output + input = json.loads(payload) + data = json.loads(output) + assert input["event"]["event_message"] == data["event_message"] + assert data["__timestamp__"] + if "timestamp" in input: + assert input["timestamp"] == data["__timestamp__"] + + +payload_3 = """\ +{ + "schema_name": "event.mock.jupyter.org/message", + "event": { + "event_message": "Hello, world!" + } +} +""" + +payload_4 = """\ +{ + "version": 1, + "event": { + "event_message": "Hello, world!" + } +} +""" + +payload_5 = """\ +{ + "schema_name": "event.mock.jupyter.org/message", + "version": 1 +} +""" + +payload_6 = """\ +{ + "schema_name": "event.mock.jupyter.org/message", + "version": 1, + "event": { + "event_message": "Hello, world!" + }, + "timestamp": "2022-05-26 12:50:00" +} +""" + + +@pytest.mark.parametrize("payload", [payload_3, payload_4, payload_5, payload_6]) +async def test_post_event_400(jp_fetch, event_bus, payload): + with pytest.raises(tornado.httpclient.HTTPClientError) as e: + await jp_fetch("api", "events", method="POST", body=payload) + + expected_http_error(e, 400) + + +payload_7 = """\ +{ + "schema_name": "event.mock.jupyter.org/message", + "version": 1, + "event": { + "message": "Hello, world!" + } +} +""" + +payload_8 = """\ +{ + "schema_name": "event.mock.jupyter.org/message", + "version": 2, + "event": { + "message": "Hello, world!" + } +} +""" + + +@pytest.mark.parametrize("payload", [payload_7, payload_8]) +async def test_post_event_500(jp_fetch, event_bus, payload): + with pytest.raises(tornado.httpclient.HTTPClientError) as e: + await jp_fetch("api", "events", method="POST", body=payload) + + expected_http_error(e, 500)