Skip to content

Commit

Permalink
Use HTTP connection pools (#80)
Browse files Browse the repository at this point in the history
  • Loading branch information
goodoldneon committed Apr 24, 2024
1 parent 2480968 commit 17efc13
Show file tree
Hide file tree
Showing 6 changed files with 167 additions and 67 deletions.
63 changes: 36 additions & 27 deletions inngest/_internal/client_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,9 @@ def __init__(
event_origin = const.DEFAULT_EVENT_ORIGIN
self._event_api_origin = event_origin

self._http_client = net.ThreadAwareAsyncHTTPClient().initialize()
self._http_client_sync = httpx.Client()

def _build_send_request(
self,
events: list[event_lib.Event],
Expand Down Expand Up @@ -188,7 +191,7 @@ def _build_send_request(
d["ts"] = int(time.time() * 1000)
body.append(d)

return httpx.Client().build_request(
return self._http_client_sync.build_request(
"POST",
url,
headers=headers,
Expand Down Expand Up @@ -302,7 +305,7 @@ async def _get(self, url: str) -> httpx.Response:
Perform an asynchronous HTTP GET request. Handles authn
"""

req = httpx.Client().build_request(
req = self._http_client_sync.build_request(
"GET",
url,
headers=net.create_headers(
Expand All @@ -312,20 +315,21 @@ async def _get(self, url: str) -> httpx.Response:
),
)

async with httpx.AsyncClient() as client:
return await net.fetch_with_auth_fallback(
client,
req,
signing_key=self._signing_key,
signing_key_fallback=self._signing_key_fallback,
)
return await net.fetch_with_auth_fallback(
self.logger,
self._http_client,
self._http_client_sync,
req,
signing_key=self._signing_key,
signing_key_fallback=self._signing_key_fallback,
)

def _get_sync(self, url: str) -> httpx.Response:
"""
Perform a synchronous HTTP GET request. Handles authn
"""

req = httpx.Client().build_request(
req = self._http_client_sync.build_request(
"GET",
url,
headers=net.create_headers(
Expand All @@ -335,13 +339,12 @@ def _get_sync(self, url: str) -> httpx.Response:
),
)

with httpx.Client() as client:
return net.fetch_with_auth_fallback_sync(
client,
req,
signing_key=self._signing_key,
signing_key_fallback=self._signing_key_fallback,
)
return net.fetch_with_auth_fallback_sync(
self._http_client_sync,
req,
signing_key=self._signing_key,
signing_key_fallback=self._signing_key_fallback,
)

async def _get_batch(self, run_id: str) -> list[event_lib.Event]:
"""
Expand Down Expand Up @@ -424,11 +427,18 @@ async def send(
if not isinstance(events, list):
events = [events]

async with httpx.AsyncClient() as client:
req = self._build_send_request(events)
if isinstance(req, Exception):
raise req
return _extract_ids((await client.send(req)).json())
req = self._build_send_request(events)
if isinstance(req, Exception):
raise req

res = await net.fetch_with_thready_safety(
self.logger,
self._http_client,
self._http_client_sync,
req,
)

return _extract_ids(res.json())

def send_sync(
self,
Expand All @@ -445,11 +455,10 @@ def send_sync(
if not isinstance(events, list):
events = [events]

with httpx.Client() as client:
req = self._build_send_request(events)
if isinstance(req, Exception):
raise req
return _extract_ids((client.send(req)).json())
req = self._build_send_request(events)
if isinstance(req, Exception):
raise req
return _extract_ids((self._http_client_sync.send(req)).json())

def set_logger(self, logger: types.Logger) -> None:
self.logger = logger
Expand Down
46 changes: 23 additions & 23 deletions inngest/_internal/comm.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,7 @@ def _build_registration_request(
if sync_id is not None:
params[const.QueryParamKey.SYNC_ID.value] = sync_id

return httpx.Client().build_request(
return self._client._http_client_sync.build_request(
"POST",
registration_url,
headers=headers,
Expand Down Expand Up @@ -538,18 +538,19 @@ async def register(
if isinstance(req, Exception):
return CommResponse.from_error(self._client.logger, req)

async with httpx.AsyncClient() as client:
res = await net.fetch_with_auth_fallback(
client,
req,
signing_key=self._signing_key,
signing_key_fallback=self._signing_key_fallback,
)
res = await net.fetch_with_auth_fallback(
self._client.logger,
self._client._http_client,
self._client._http_client_sync,
req,
signing_key=self._signing_key,
signing_key_fallback=self._signing_key_fallback,
)

return self._parse_registration_response(
res,
server_kind,
)
return self._parse_registration_response(
res,
server_kind,
)

def register_sync(
self,
Expand All @@ -572,18 +573,17 @@ def register_sync(
if isinstance(req, Exception):
return CommResponse.from_error(self._client.logger, req)

with httpx.Client() as client:
res = net.fetch_with_auth_fallback_sync(
client,
req,
signing_key=self._signing_key,
signing_key_fallback=self._signing_key_fallback,
)
res = net.fetch_with_auth_fallback_sync(
self._client._http_client_sync,
req,
signing_key=self._signing_key,
signing_key_fallback=self._signing_key_fallback,
)

return self._parse_registration_response(
res,
server_kind,
)
return self._parse_registration_response(
res,
server_kind,
)

async def _respond(
self,
Expand Down
66 changes: 62 additions & 4 deletions inngest/_internal/net.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,37 @@
from __future__ import annotations

import hashlib
import hmac
import http
import os
import threading
import typing
import urllib.parse

import httpx

from . import const, errors, transforms, types

Method = typing.Literal["GET", "POST"]

class ThreadAwareAsyncHTTPClient(httpx.AsyncClient):
"""
Thin wrapper around httpx.AsyncClient. It keeps track of the thread it was
created in, which is critical since asyncio is not thread safe: calling an
async method in a different thread will raise an exception
"""

_creation_thread_id: typing.Optional[int] = None

def is_same_thread(self) -> bool:
if self._creation_thread_id is None:
raise Exception("did initialize ThreadAwareAsyncHTTPClient")

current_thread_id = threading.get_ident()
return self._creation_thread_id == current_thread_id

def initialize(self) -> ThreadAwareAsyncHTTPClient:
self._creation_thread_id = threading.get_ident()
return self


def create_headers(
Expand Down Expand Up @@ -84,7 +106,9 @@ def create_serve_url(


async def fetch_with_auth_fallback(
client: httpx.AsyncClient,
logger: types.Logger,
client: ThreadAwareAsyncHTTPClient,
client_sync: httpx.Client,
request: httpx.Request,
*,
signing_key: typing.Optional[str],
Expand All @@ -100,7 +124,12 @@ async def fetch_with_auth_fallback(
const.HeaderKey.AUTHORIZATION.value
] = f"Bearer {transforms.hash_signing_key(signing_key)}"

res = await client.send(request)
res = await fetch_with_thready_safety(
logger,
client,
client_sync,
request,
)
if (
res.status_code
in (http.HTTPStatus.FORBIDDEN, http.HTTPStatus.UNAUTHORIZED)
Expand All @@ -110,7 +139,13 @@ async def fetch_with_auth_fallback(
request.headers[
const.HeaderKey.AUTHORIZATION.value
] = f"Bearer {transforms.hash_signing_key(signing_key_fallback)}"
res = await client.send(request)

res = await fetch_with_thready_safety(
logger,
client,
client_sync,
request,
)

return res

Expand Down Expand Up @@ -173,6 +208,29 @@ def parse_url(url: str) -> str:
return parsed.geturl()


async def fetch_with_thready_safety(
logger: types.Logger,
client: ThreadAwareAsyncHTTPClient,
client_sync: httpx.Client,
request: httpx.Request,
) -> httpx.Response:
"""
Safely handles the situation where the async HTTP client is called in a
different thread.
"""

if client.is_same_thread() is True:
return await client.send(request)

# Python freaks out if you call httpx.AsyncClient's async methods in a
# multiple threads. To solve this, we'll use the synchronous client
# instead
logger.warning(
"called an async client method in a different thread; falling back to synchronous HTTP client"
)
return client_sync.send(request)


class RequestSignature:
_signature: typing.Optional[str] = None
_timestamp: typing.Optional[int] = None
Expand Down
25 changes: 21 additions & 4 deletions inngest/_internal/net_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,7 @@ def test_fails_for_both_signing_keys(self) -> None:
class Test_fetch_with_auth_fallback(unittest.IsolatedAsyncioTestCase):
def setUp(self) -> None:
super().setUp()
self._logger = unittest.mock.Mock()
self._req = httpx.Request("GET", "http://localhost")

def _create_async_transport(
Expand Down Expand Up @@ -241,7 +242,11 @@ def handler(request: httpx.Request) -> httpx.Response:
return httpx.Response(200, content=b"", request=request)

res = await net.fetch_with_auth_fallback(
httpx.AsyncClient(transport=self._create_async_transport(handler)),
self._logger,
net.ThreadAwareAsyncHTTPClient(
transport=self._create_async_transport(handler)
).initialize(),
httpx.Client(transport=self._create_transport(handler)),
self._req,
signing_key=_signing_key,
signing_key_fallback=_signing_key_fallback,
Expand Down Expand Up @@ -283,7 +288,11 @@ def handler(request: httpx.Request) -> httpx.Response:
return httpx.Response(200, content=b"", request=request)

res = await net.fetch_with_auth_fallback(
httpx.AsyncClient(transport=self._create_async_transport(handler)),
self._logger,
net.ThreadAwareAsyncHTTPClient(
transport=self._create_async_transport(handler)
).initialize(),
httpx.Client(transport=self._create_transport(handler)),
self._req,
signing_key=_signing_key,
signing_key_fallback=_signing_key_fallback,
Expand Down Expand Up @@ -325,7 +334,11 @@ def handler(request: httpx.Request) -> httpx.Response:
return httpx.Response(200, content=b"", request=request)

res = await net.fetch_with_auth_fallback(
httpx.AsyncClient(transport=self._create_async_transport(handler)),
self._logger,
net.ThreadAwareAsyncHTTPClient(
transport=self._create_async_transport(handler)
).initialize(),
httpx.Client(transport=self._create_transport(handler)),
self._req,
signing_key="signkey-prod-aaaaaa",
signing_key_fallback="signkey-prod-bbbbbb",
Expand Down Expand Up @@ -362,7 +375,11 @@ def handler(request: httpx.Request) -> httpx.Response:
return httpx.Response(200, content=b"", request=request)

res = await net.fetch_with_auth_fallback(
httpx.AsyncClient(transport=self._create_async_transport(handler)),
self._logger,
net.ThreadAwareAsyncHTTPClient(
transport=self._create_async_transport(handler)
).initialize(),
httpx.Client(transport=self._create_transport(handler)),
self._req,
signing_key=None,
signing_key_fallback=None,
Expand Down
13 changes: 4 additions & 9 deletions tests/cases/multiple_triggers.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import asyncio
from dataclasses import dataclass

import inngest
Expand Down Expand Up @@ -60,19 +59,15 @@ async def fn_async(ctx: inngest.Context, step: inngest.Step) -> None:
break

def run_test(self: base.TestClass) -> None:
async def trigger_event_and_wait(state_event: StateAndEvent) -> None:
await self.client.send(inngest.Event(name=state_event.event_name))
def trigger_event_and_wait(state_event: StateAndEvent) -> None:
self.client.send_sync(inngest.Event(name=state_event.event_name))
run_id = state_event.state.wait_for_run_id()
tests.helper.client.wait_for_run_status(
run_id, tests.helper.RunStatus.COMPLETED
)

async def run_all() -> None:
await asyncio.gather(
*(trigger_event_and_wait(se) for se in states_events)
)

asyncio.run(run_all())
for se in states_events:
trigger_event_and_wait(se)

assert all(se.state.run_id for se in states_events)
assert len(set(se.state.run_id for se in states_events)) == len(
Expand Down
Loading

0 comments on commit 17efc13

Please sign in to comment.