/
connection_pool.py
354 lines (305 loc) · 13.6 KB
/
connection_pool.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
import ssl
import sys
from types import TracebackType
from typing import AsyncIterable, AsyncIterator, List, Optional, Type
from .._exceptions import ConnectionNotAvailable, UnsupportedProtocol
from .._models import Origin, Request, Response
from .._synchronization import AsyncEvent, AsyncLock
from ..backends.auto import AutoBackend
from ..backends.base import AsyncNetworkBackend
from .connection import AsyncHTTPConnection
from .interfaces import AsyncConnectionInterface, AsyncRequestInterface
class RequestStatus:
def __init__(self, request: Request):
self.request = request
self.connection: Optional[AsyncConnectionInterface] = None
self._connection_acquired = AsyncEvent()
def set_connection(self, connection: AsyncConnectionInterface) -> None:
assert self.connection is None
self.connection = connection
self._connection_acquired.set()
def unset_connection(self) -> None:
assert self.connection is not None
self.connection = None
self._connection_acquired = AsyncEvent()
async def wait_for_connection(
self, timeout: Optional[float] = None
) -> AsyncConnectionInterface:
await self._connection_acquired.wait(timeout=timeout)
assert self.connection is not None
return self.connection
class AsyncConnectionPool(AsyncRequestInterface):
"""
A connection pool for making HTTP requests.
"""
def __init__(
self,
ssl_context: Optional[ssl.SSLContext] = None,
max_connections: Optional[int] = 10,
max_keepalive_connections: Optional[int] = None,
keepalive_expiry: Optional[float] = None,
http1: bool = True,
http2: bool = False,
retries: int = 0,
local_address: Optional[str] = None,
uds: Optional[str] = None,
network_backend: Optional[AsyncNetworkBackend] = None,
) -> None:
"""
A connection pool for making HTTP requests.
Parameters:
ssl_context: An SSL context to use for verifying connections.
If not specified, the default `httpcore.default_ssl_context()`
will be used.
max_connections: The maximum number of concurrent HTTP connections that
the pool should allow. Any attempt to send a request on a pool that
would exceed this amount will block until a connection is available.
max_keepalive_connections: The maximum number of idle HTTP connections
that will be maintained in the pool.
keepalive_expiry: The duration in seconds that an idle HTTP connection
may be maintained for before being expired from the pool.
http1: A boolean indicating if HTTP/1.1 requests should be supported
by the connection pool. Defaults to True.
http2: A boolean indicating if HTTP/2 requests should be supported by
the connection pool. Defaults to False.
retries: The maximum number of retries when trying to establish a
connection.
local_address: Local address to connect from. Can also be used to connect
using a particular address family. Using `local_address="0.0.0.0"`
will connect using an `AF_INET` address (IPv4), while using
`local_address="::"` will connect using an `AF_INET6` address (IPv6).
uds: Path to a Unix Domain Socket to use instead of TCP sockets.
network_backend: A backend instance to use for handling network I/O.
"""
self._ssl_context = ssl_context
self._max_connections = (
sys.maxsize if max_connections is None else max_connections
)
self._max_keepalive_connections = (
sys.maxsize
if max_keepalive_connections is None
else max_keepalive_connections
)
self._max_keepalive_connections = min(
self._max_connections, self._max_keepalive_connections
)
self._keepalive_expiry = keepalive_expiry
self._http1 = http1
self._http2 = http2
self._retries = retries
self._local_address = local_address
self._uds = uds
self._pool: List[AsyncConnectionInterface] = []
self._requests: List[RequestStatus] = []
self._pool_lock = AsyncLock()
self._network_backend = (
AutoBackend() if network_backend is None else network_backend
)
def create_connection(self, origin: Origin) -> AsyncConnectionInterface:
return AsyncHTTPConnection(
origin=origin,
ssl_context=self._ssl_context,
keepalive_expiry=self._keepalive_expiry,
http1=self._http1,
http2=self._http2,
retries=self._retries,
local_address=self._local_address,
uds=self._uds,
network_backend=self._network_backend,
)
@property
def connections(self) -> List[AsyncConnectionInterface]:
"""
Return a list of the connections currently in the pool.
For example:
```python
>>> pool.connections
[
<AsyncHTTPConnection ['https://example.com:443', HTTP/1.1, ACTIVE, Request Count: 6]>,
<AsyncHTTPConnection ['https://example.com:443', HTTP/1.1, IDLE, Request Count: 9]> ,
<AsyncHTTPConnection ['http://example.com:80', HTTP/1.1, IDLE, Request Count: 1]>,
]
```
"""
return list(self._pool)
async def _attempt_to_acquire_connection(self, status: RequestStatus) -> bool:
"""
Attempt to provide a connection that can handle the given origin.
"""
origin = status.request.url.origin
# If there are queued requests in front of us, then don't acquire a
# connection. We handle requests strictly in order.
waiting = [s for s in self._requests if s.connection is None]
if waiting and waiting[0] is not status:
return False
# Reuse an existing connection if one is currently available.
for idx, connection in enumerate(self._pool):
if connection.can_handle_request(origin) and connection.is_available():
self._pool.pop(idx)
self._pool.insert(0, connection)
status.set_connection(connection)
return True
# If the pool is currently full, attempt to close one idle connection.
if len(self._pool) >= self._max_connections:
for idx, connection in reversed(list(enumerate(self._pool))):
if connection.is_idle():
await connection.aclose()
self._pool.pop(idx)
break
# If the pool is still full, then we cannot acquire a connection.
if len(self._pool) >= self._max_connections:
return False
# Otherwise create a new connection.
connection = self.create_connection(origin)
self._pool.insert(0, connection)
status.set_connection(connection)
return True
async def _close_expired_connections(self) -> None:
"""
Clean up the connection pool by closing off any connections that have expired.
"""
# Close any connections that have expired their keep-alive time.
for idx, connection in reversed(list(enumerate(self._pool))):
if connection.has_expired():
await connection.aclose()
self._pool.pop(idx)
# If the pool size exceeds the maximum number of allowed keep-alive connections,
# then close off idle connections as required.
pool_size = len(self._pool)
for idx, connection in reversed(list(enumerate(self._pool))):
if connection.is_idle() and pool_size > self._max_keepalive_connections:
await connection.aclose()
self._pool.pop(idx)
pool_size -= 1
async def handle_async_request(self, request: Request) -> Response:
"""
Send an HTTP request, and return an HTTP response.
This is the core implementation that is called into by `.request()` or `.stream()`.
"""
scheme = request.url.scheme.decode()
if scheme == "":
raise UnsupportedProtocol(
"Request URL is missing an 'http://' or 'https://' protocol."
)
if scheme not in ("http", "https"):
raise UnsupportedProtocol(
f"Request URL has an unsupported protocol '{scheme}://'."
)
status = RequestStatus(request)
self._requests.append(status)
async with self._pool_lock:
await self._close_expired_connections()
await self._attempt_to_acquire_connection(status)
while True:
timeouts = request.extensions.get("timeout", {})
timeout = timeouts.get("pool", None)
try:
connection = await status.wait_for_connection(timeout=timeout)
except BaseException as exc:
# If we timeout here, or if the task is cancelled, then make
# sure to remove the request from the queue before bubbling
# up the exception.
self._requests.remove(status)
raise exc
try:
response = await connection.handle_async_request(request)
except ConnectionNotAvailable:
# The ConnectionNotAvailable exception is a special case, that
# indicates we need to retry the request on a new connection.
#
# The most common case where this can occur is when multiple
# requests are queued waiting for a single connection, which
# might end up as an HTTP/2 connection, but which actually ends
# up as HTTP/1.1.
async with self._pool_lock:
# Maintain our position in the request queue, but reset the
# status so that the request becomes queued again.
status.unset_connection()
await self._attempt_to_acquire_connection(status)
except BaseException as exc:
await self.response_closed(status)
raise exc
else:
break
# When we return the response, we wrap the stream in a special class
# that handles notifying the connection pool once the response
# has been released.
assert isinstance(response.stream, AsyncIterable)
return Response(
status=response.status,
headers=response.headers,
content=ConnectionPoolByteStream(response.stream, self, status),
extensions=response.extensions,
)
async def response_closed(self, status: RequestStatus) -> None:
"""
This method acts as a callback once the request/response cycle is complete.
It is called into from the `ConnectionPoolByteStream.aclose()` method.
"""
assert status.connection is not None
connection = status.connection
if status in self._requests:
self._requests.remove(status)
async with self._pool_lock:
# Update the state of the connection pool.
if connection.is_closed() and connection in self._pool:
self._pool.remove(connection)
# Since we've had a response closed, it's possible we'll now be able
# to service one or more requests that are currently pending.
for status in self._requests:
if status.connection is None:
acquired = await self._attempt_to_acquire_connection(status)
# If we could not acquire a connection for a queued request
# then we don't need to check anymore requests that are
# queued later behind it.
if not acquired:
break
# Housekeeping.
await self._close_expired_connections()
async def aclose(self) -> None:
"""
Close any connections in the pool.
"""
async with self._pool_lock:
requests_still_in_flight = len(self._requests)
for connection in self._pool:
await connection.aclose()
self._pool = []
self._requests = []
if requests_still_in_flight:
raise RuntimeError(
f"The connection pool was closed while {requests_still_in_flight} "
f"HTTP requests/responses were still in-flight."
)
async def __aenter__(self) -> "AsyncConnectionPool":
return self
async def __aexit__(
self,
exc_type: Optional[Type[BaseException]] = None,
exc_value: Optional[BaseException] = None,
traceback: Optional[TracebackType] = None,
) -> None:
await self.aclose()
class ConnectionPoolByteStream:
"""
A wrapper around the response byte stream, that additionally handles
notifying the connection pool when the response has been closed.
"""
def __init__(
self,
stream: AsyncIterable[bytes],
pool: AsyncConnectionPool,
status: RequestStatus,
) -> None:
self._stream = stream
self._pool = pool
self._status = status
async def __aiter__(self) -> AsyncIterator[bytes]:
async for part in self._stream:
yield part
async def aclose(self) -> None:
try:
if hasattr(self._stream, "aclose"):
await self._stream.aclose() # type: ignore
finally:
await self._pool.response_closed(self._status)