/
anyio.py
196 lines (163 loc) · 6.16 KB
/
anyio.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
from ssl import SSLContext
from typing import Optional
import anyio.abc
from anyio import BrokenResourceError, EndOfStream
from anyio.abc import ByteStream, SocketAttribute
from anyio.streams.tls import TLSAttribute, TLSStream
from .._exceptions import (
CloseError,
ConnectError,
ConnectTimeout,
ReadError,
ReadTimeout,
WriteError,
WriteTimeout,
)
from .._types import TimeoutDict
from .._utils import is_socket_readable
from .base import AsyncBackend, AsyncLock, AsyncSemaphore, AsyncSocketStream
class SocketStream(AsyncSocketStream):
def __init__(self, stream: ByteStream) -> None:
self.stream = stream
self.read_lock = anyio.create_lock()
self.write_lock = anyio.create_lock()
def get_http_version(self) -> str:
alpn_protocol = self.stream.extra(TLSAttribute.alpn_protocol, None)
return "HTTP/2" if alpn_protocol == "h2" else "HTTP/1.1"
async def start_tls(
self,
hostname: bytes,
ssl_context: SSLContext,
timeout: TimeoutDict,
) -> "SocketStream":
connect_timeout = timeout.get("connect")
try:
async with anyio.fail_after(connect_timeout):
ssl_stream = await TLSStream.wrap(
self.stream,
ssl_context=ssl_context,
hostname=hostname.decode("ascii"),
)
except TimeoutError:
raise ConnectTimeout from None
except BrokenResourceError as exc:
raise ConnectError from exc
return SocketStream(ssl_stream)
async def read(self, n: int, timeout: TimeoutDict) -> bytes:
read_timeout = timeout.get("read")
async with self.read_lock:
try:
async with anyio.fail_after(read_timeout):
return await self.stream.receive(n)
except TimeoutError:
raise ReadTimeout from None
except BrokenResourceError as exc:
raise ReadError from exc
except EndOfStream:
raise ReadError("Server disconnected while attempting read") from None
async def write(self, data: bytes, timeout: TimeoutDict) -> None:
if not data:
return
write_timeout = timeout.get("write")
async with self.write_lock:
try:
async with anyio.fail_after(write_timeout):
return await self.stream.send(data)
except TimeoutError:
raise WriteTimeout from None
except BrokenResourceError as exc:
raise WriteError from exc
async def aclose(self) -> None:
async with self.write_lock:
try:
await self.stream.aclose()
except BrokenResourceError as exc:
raise CloseError from exc
def is_readable(self) -> bool:
sock = self.stream.extra(SocketAttribute.raw_socket)
return is_socket_readable(sock.fileno())
class Lock(AsyncLock):
def __init__(self) -> None:
self._lock = anyio.create_lock()
async def release(self) -> None:
await self._lock.release()
async def acquire(self) -> None:
await self._lock.acquire()
class Semaphore(AsyncSemaphore):
def __init__(self, max_value: int, exc_class: type):
self.max_value = max_value
self.exc_class = exc_class
@property
def semaphore(self) -> anyio.abc.Semaphore:
if not hasattr(self, "_semaphore"):
self._semaphore = anyio.create_semaphore(self.max_value)
return self._semaphore
async def acquire(self, timeout: float = None) -> None:
async with anyio.move_on_after(timeout):
await self.semaphore.acquire()
return
raise self.exc_class()
async def release(self) -> None:
await self.semaphore.release()
class AnyIOBackend(AsyncBackend):
async def open_tcp_stream(
self,
hostname: bytes,
port: int,
ssl_context: Optional[SSLContext],
timeout: TimeoutDict,
*,
local_address: Optional[str],
) -> AsyncSocketStream:
connect_timeout = timeout.get("connect")
unicode_host = hostname.decode("utf-8")
try:
async with anyio.fail_after(connect_timeout):
stream: anyio.abc.ByteStream
stream = await anyio.connect_tcp(
unicode_host, port, local_host=local_address
)
if ssl_context:
stream = await TLSStream.wrap(
stream,
hostname=unicode_host,
ssl_context=ssl_context,
standard_compatible=False,
)
except TimeoutError:
raise ConnectTimeout from None
except BrokenResourceError as exc:
raise ConnectError from exc
return SocketStream(stream=stream)
async def open_uds_stream(
self,
path: str,
hostname: bytes,
ssl_context: Optional[SSLContext],
timeout: TimeoutDict,
) -> AsyncSocketStream:
connect_timeout = timeout.get("connect")
unicode_host = hostname.decode("utf-8")
try:
async with anyio.fail_after(connect_timeout):
stream: anyio.abc.ByteStream = await anyio.connect_unix(path)
if ssl_context:
stream = await TLSStream.wrap(
stream,
hostname=unicode_host,
ssl_context=ssl_context,
standard_compatible=False,
)
except TimeoutError:
raise ConnectTimeout from None
except BrokenResourceError as exc:
raise ConnectError from exc
return SocketStream(stream=stream)
def create_lock(self) -> AsyncLock:
return Lock()
def create_semaphore(self, max_value: int, exc_class: type) -> AsyncSemaphore:
return Semaphore(max_value, exc_class=exc_class)
async def time(self) -> float:
return await anyio.current_time()
async def sleep(self, seconds: float) -> None:
await anyio.sleep(seconds)