Skip to content
This repository has been archived by the owner on Oct 19, 2023. It is now read-only.

feat: metrics middleware improvements and add logging middleware #85

Merged
merged 11 commits into from
May 23, 2023
145 changes: 116 additions & 29 deletions lcserve/backend/gateway.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import shutil
import sys
import time
import uuid
from enum import Enum
from functools import cached_property
from importlib import import_module
Expand Down Expand Up @@ -290,6 +291,7 @@ def __init__(
self._register_healthz()
self._register_modules()
self._setup_metrics()
self._setup_logging()

@property
def app(self) -> 'FastAPI':
Expand Down Expand Up @@ -360,25 +362,26 @@ def _setup_metrics(self):
meter_provider=self.meter_provider,
)

self.http_duration_counter = self.meter.create_counter(
name="http_request_duration_seconds",
description="HTTP request duration in seconds",
self.duration_counter = self.meter.create_counter(
name="lcserve_request_duration_seconds",
description="Lc-serve Request duration in seconds",
unit="s",
)

self.ws_duration_counter = self.meter.create_counter(
name="ws_request_duration_seconds",
description="WS request duration in seconds",
unit="s",
self.request_counter = self.meter.create_counter(
name="lcserve_request_count",
description="Lc-serve Request count",
)

self.app.add_middleware(
MeasureDurationHTTPMiddleware, counter=self.http_duration_counter
)
self.app.add_middleware(
MeasureDurationWebSocketMiddleware, counter=self.ws_duration_counter
MetricsMiddleware,
duration_counter=self.duration_counter,
request_counter=self.request_counter,
)

def _setup_logging(self):
self.app.add_middleware(LoggingMiddleware, logger=self.logger)

def _register_healthz(self):
@self.app.get("/healthz")
async def __healthz():
Expand Down Expand Up @@ -1034,27 +1037,34 @@ def __init__(self, interval: int):
self.interval = interval

async def send_duration_periodically(
self, shared_data: SharedData, route: str, counter: Optional['Counter'] = None
self,
shared_data: SharedData,
route: str,
protocol: str,
counter: Optional['Counter'] = None,
):
while True:
await asyncio.sleep(self.interval)
current_time = time.perf_counter()
duration = current_time - shared_data.last_reported_time
if counter:
counter.add(
current_time - shared_data.last_reported_time, {"route": route}
current_time - shared_data.last_reported_time,
{'route': route, 'protocol': protocol},
)

shared_data.last_reported_time = current_time


class BaseMeasureDurationMiddleware:
class MetricsMiddleware:
def __init__(
self, app: ASGIApp, scope_type: str, counter: Optional['Counter'] = None
self,
app: ASGIApp,
duration_counter: Optional['Counter'] = None,
request_counter: Optional['Counter'] = None,
):
self.app = app
self.scope_type = scope_type
self.counter = counter
self.duration_counter = duration_counter
self.request_counter = request_counter
# TODO: figure out solution for static assets
self.skip_routes = [
'/docs',
Expand All @@ -1067,32 +1077,109 @@ def __init__(
]

async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
if scope["type"] == self.scope_type and scope['path'] not in self.skip_routes:
if scope['path'] not in self.skip_routes:
timer = Timer(5)
shared_data = timer.SharedData(last_reported_time=time.perf_counter())
send_duration_task = asyncio.create_task(
timer.send_duration_periodically(
shared_data, scope['path'], self.counter
shared_data, scope['path'], scope['type'], self.duration_counter
)
)
try:
await self.app(scope, receive, send)
finally:
send_duration_task.cancel()
if self.counter:
self.counter.add(
if self.duration_counter:
self.duration_counter.add(
time.perf_counter() - shared_data.last_reported_time,
{"route": scope['path']},
{'route': scope['path'], 'protocol': scope['type']},
)
if self.request_counter:
self.request_counter.add(
1, {'route': scope['path'], 'protocol': scope['type']}
)
else:
await self.app(scope, receive, send)


class MeasureDurationHTTPMiddleware(BaseMeasureDurationMiddleware):
def __init__(self, app: ASGIApp, counter: Optional['Counter'] = None):
super().__init__(app, "http", counter)
class LoggingMiddleware:
def __init__(self, app: ASGIApp, logger: JinaLogger):
self.app = app
self.logger = logger
self.skip_routes = [
'/docs',
'/redoc',
'/openapi.json',
'/healthz',
'/dry_run',
'/metrics',
'/favicon.ico',
]

async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
if scope['path'] not in self.skip_routes:
# Get IP address, use X-Forwarded-For if set else use scope['client'][0]
ip_address = scope.get('client')[0] if scope.get('client') else None
if scope.get('headers'):
for header in scope['headers']:
if header[0].decode('latin-1') == 'x-forwarded-for':
ip_address = header[1].decode('latin-1')
break

# Init the request/connection ID
request_id = connection_id = None
if scope["type"] == "http":
request_id = uuid.uuid4()
elif scope["type"] == "websocket":
connection_id = uuid.uuid4()

original_send = send
is_accepted = False
status_code = None
start_time = time.perf_counter()

async def custom_send(message: dict) -> None:
nonlocal is_accepted
nonlocal status_code

if request_id and message.get('type') == 'http.response.start':
message.setdefault('headers', []).append(
(b'X-API-Request-ID', str(request_id).encode())
)
status_code = message.get('status')
print(status_code)
zac-li marked this conversation as resolved.
Show resolved Hide resolved
elif message["type"] == "websocket.accept":
is_accepted = True

await original_send(message)

# Ensure that the websocket.send message containing the connection ID is only sent once,
# and only after the websocket.accept message
if is_accepted and message.get('type') not in [
'websocket.send',
'websocket.close',
]:
await original_send(
{
"type": "websocket.send",
"text": f"connection_id:{connection_id}",
}
)
zac-li marked this conversation as resolved.
Show resolved Hide resolved
is_accepted = False

await self.app(scope, receive, custom_send)

end_time = time.perf_counter()
duration = round(end_time - start_time, 2)

class MeasureDurationWebSocketMiddleware(BaseMeasureDurationMiddleware):
def __init__(self, app: ASGIApp, counter: Optional['Counter'] = None):
super().__init__(app, "websocket", counter)
if scope["type"] == "http":
self.logger.info(
f"HTTP response id: {request_id} - Path: {scope['path']} - Client IP: {ip_address} - Status code: {status_code} - Duration: {duration}"
)
elif scope["type"] == "websocket":
self.logger.info(
f"WebSocket disconnection id: {connection_id} - Path: {scope['path']} - Client IP: {ip_address} - Duration: {duration}"
)

else:
await self.app(scope, receive, send)
4 changes: 4 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,7 @@ jina-hubble-sdk
nest-asyncio
textual
toml
# Below libs caused issue https://github.com/hwchase17/langchain/issues/5113,
# can unpin once it's resolved
typing-inspect==0.8.0
typing_extensions==4.5.0
22 changes: 20 additions & 2 deletions tests/integration/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ def get_values_from_prom(metrics, route):
return duration_seconds


def examine_prom_with_retry(start_time, metrics, expected_value, route):
def examine_request_duration_with_retry(start_time, expected_value, route):
timeout = 120
interval = 10

Expand All @@ -170,8 +170,26 @@ def examine_prom_with_retry(start_time, metrics, expected_value, route):
if elapsed_time > timeout:
pytest.fail("Timed out waiting for the Prometheus data to be populated")

duration_seconds = get_values_from_prom(metrics, route)
duration_seconds = get_values_from_prom(
"lcserve_request_duration_seconds", route
)
if round(float(duration_seconds)) == expected_value:
break

time.sleep(interval)


def examine_request_count_with_retry(start_time, expected_value, route):
timeout = 120
interval = 10

while True:
elapsed_time = time.time() - start_time
if elapsed_time > timeout:
pytest.fail("Timed out waiting for the Prometheus data to be populated")

request_count = get_values_from_prom("lcserve_request_count", route)
if float(request_count) == expected_value:
break

time.sleep(interval)
6 changes: 3 additions & 3 deletions tests/integration/jcloud/test_basic_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,11 @@ async def _test_ws_route(app_id):
await websocket.send(json.dumps({"interval": 1}))

received_messages = []
for _ in range(5):
for _ in range(6):
message = await websocket.recv()
received_messages.append(message)

assert received_messages == ["0", "1", "2", "3", "4"]
assert received_messages[1:] == ["0", "1", "2", "3", "4"]
zac-li marked this conversation as resolved.
Show resolved Hide resolved


async def _test_workspace(app_id):
Expand All @@ -63,6 +63,6 @@ async def _test_workspace(app_id):
message = await websocket.recv()
received_messages.append(message.strip())

assert received_messages == [f"Here's string {i}" for i in range(10)]
assert received_messages[1:] == [f"Here's string {i}" for i in range(10)]
except ConnectionClosedOK:
pass
9 changes: 4 additions & 5 deletions tests/integration/jcloud/test_fastapi_app.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import json

import aiohttp
import pytest
import requests
import aiohttp

from ..helper import deploy_jcloud_fastapi_app


@pytest.mark.asyncio
async def test_basic_app():
async with deploy_jcloud_fastapi_app() as app_id:
Expand All @@ -18,9 +19,7 @@ def _test_http_route(app_id):
"accept": "application/json",
"Content-Type": "application/json",
}
response = requests.get(
f"https://{app_id}.wolf.jina.ai/status", headers=headers
)
response = requests.get(f"https://{app_id}.wolf.jina.ai/status", headers=headers)

response_data = response.json()

Expand All @@ -35,4 +34,4 @@ async def _test_ws_route(app_id):
received_messages = []
async for message in websocket:
received_messages.append(message.data)
assert received_messages == ["0", "1", "2", "3", "4"]
assert received_messages[1:] == ["0", "1", "2", "3", "4"]
Loading
Loading