/
test_connection_pool.py
155 lines (119 loc) · 5.22 KB
/
test_connection_pool.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
from typing import AsyncIterator, Tuple
import pytest
import httpcore
from httpcore._async.base import ConnectionState
from httpcore._types import URL, Headers
class MockConnection(object):
def __init__(self, http_version):
self.origin = (b"http", b"example.org", 80)
self.state = ConnectionState.PENDING
self.is_http11 = http_version == "HTTP/1.1"
self.is_http2 = http_version == "HTTP/2"
self.stream_count = 0
async def arequest(
self,
method: bytes,
url: URL,
headers: Headers = None,
stream: httpcore.AsyncByteStream = None,
ext: dict = None,
) -> Tuple[int, Headers, httpcore.AsyncByteStream, dict]:
self.state = ConnectionState.ACTIVE
self.stream_count += 1
async def on_close():
self.stream_count -= 1
if self.stream_count == 0:
self.state = ConnectionState.IDLE
async def aiterator() -> AsyncIterator[bytes]:
yield b""
stream = httpcore.AsyncIteratorByteStream(
aiterator=aiterator(), aclose_func=on_close
)
return 200, [], stream, {}
async def aclose(self):
pass
def info(self) -> str:
return str(self.state)
def mark_as_ready(self) -> None:
self.state = ConnectionState.READY
def is_socket_readable(self) -> bool:
return False
class ConnectionPool(httpcore.AsyncConnectionPool):
def __init__(self, http_version: str):
super().__init__()
self.http_version = http_version
assert http_version in ("HTTP/1.1", "HTTP/2")
def _create_connection(self, **kwargs):
return MockConnection(self.http_version)
async def read_body(stream: httpcore.AsyncByteStream) -> bytes:
try:
body = []
async for chunk in stream:
body.append(chunk)
return b"".join(body)
finally:
await stream.aclose()
@pytest.mark.trio
@pytest.mark.parametrize("http_version", ["HTTP/1.1", "HTTP/2"])
async def test_sequential_requests(http_version) -> None:
async with ConnectionPool(http_version=http_version) as http:
info = await http.get_connection_info()
assert info == {}
response = await http.arequest(b"GET", (b"http", b"example.org", None, b"/"))
status_code, headers, stream, ext = response
info = await http.get_connection_info()
assert info == {"http://example.org": ["ConnectionState.ACTIVE"]}
await read_body(stream)
info = await http.get_connection_info()
assert info == {"http://example.org": ["ConnectionState.IDLE"]}
response = await http.arequest(b"GET", (b"http", b"example.org", None, b"/"))
status_code, headers, stream, ext = response
info = await http.get_connection_info()
assert info == {"http://example.org": ["ConnectionState.ACTIVE"]}
await read_body(stream)
info = await http.get_connection_info()
assert info == {"http://example.org": ["ConnectionState.IDLE"]}
@pytest.mark.trio
async def test_concurrent_requests_h11() -> None:
async with ConnectionPool(http_version="HTTP/1.1") as http:
info = await http.get_connection_info()
assert info == {}
response_1 = await http.arequest(b"GET", (b"http", b"example.org", None, b"/"))
status_code_1, headers_1, stream_1, ext_1 = response_1
info = await http.get_connection_info()
assert info == {"http://example.org": ["ConnectionState.ACTIVE"]}
response_2 = await http.arequest(b"GET", (b"http", b"example.org", None, b"/"))
status_code_2, headers_2, stream_2, ext_2 = response_2
info = await http.get_connection_info()
assert info == {
"http://example.org": ["ConnectionState.ACTIVE", "ConnectionState.ACTIVE"]
}
await read_body(stream_1)
info = await http.get_connection_info()
assert info == {
"http://example.org": ["ConnectionState.ACTIVE", "ConnectionState.IDLE"]
}
await read_body(stream_2)
info = await http.get_connection_info()
assert info == {
"http://example.org": ["ConnectionState.IDLE", "ConnectionState.IDLE"]
}
@pytest.mark.trio
async def test_concurrent_requests_h2() -> None:
async with ConnectionPool(http_version="HTTP/2") as http:
info = await http.get_connection_info()
assert info == {}
response_1 = await http.arequest(b"GET", (b"http", b"example.org", None, b"/"))
status_code_1, headers_1, stream_1, ext_1 = response_1
info = await http.get_connection_info()
assert info == {"http://example.org": ["ConnectionState.ACTIVE"]}
response_2 = await http.arequest(b"GET", (b"http", b"example.org", None, b"/"))
status_code_2, headers_2, stream_2, ext_2 = response_2
info = await http.get_connection_info()
assert info == {"http://example.org": ["ConnectionState.ACTIVE"]}
await read_body(stream_1)
info = await http.get_connection_info()
assert info == {"http://example.org": ["ConnectionState.ACTIVE"]}
await read_body(stream_2)
info = await http.get_connection_info()
assert info == {"http://example.org": ["ConnectionState.IDLE"]}