-
Notifications
You must be signed in to change notification settings - Fork 89
/
trio.py
142 lines (127 loc) · 4.93 KB
/
trio.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
import ssl
import typing
import trio
from .._exceptions import (
ConnectError,
ConnectTimeout,
ReadError,
ReadTimeout,
WriteError,
WriteTimeout,
map_exceptions,
)
from .base import AsyncNetworkBackend, AsyncNetworkStream
class TrioStream(AsyncNetworkStream):
def __init__(self, stream: trio.abc.Stream) -> None:
self._stream = stream
async def read(
self, max_bytes: int, timeout: typing.Optional[float] = None
) -> bytes:
timeout_or_inf = float("inf") if timeout is None else timeout
exc_map = {trio.TooSlowError: ReadTimeout, trio.BrokenResourceError: ReadError}
with map_exceptions(exc_map):
with trio.fail_after(timeout_or_inf):
return await self._stream.receive_some(max_bytes=max_bytes)
async def write(
self, buffer: bytes, timeout: typing.Optional[float] = None
) -> None:
if not buffer:
return
timeout_or_inf = float("inf") if timeout is None else timeout
exc_map = {
trio.TooSlowError: WriteTimeout,
trio.BrokenResourceError: WriteError,
}
with map_exceptions(exc_map):
with trio.fail_after(timeout_or_inf):
await self._stream.send_all(data=buffer)
async def aclose(self) -> None:
await self._stream.aclose()
async def start_tls(
self,
ssl_context: ssl.SSLContext,
server_hostname: typing.Optional[str] = None,
timeout: typing.Optional[float] = None,
) -> AsyncNetworkStream:
timeout_or_inf = float("inf") if timeout is None else timeout
exc_map = {
trio.TooSlowError: ConnectTimeout,
trio.BrokenResourceError: ConnectError,
}
ssl_stream = trio.SSLStream(
self._stream,
ssl_context=ssl_context,
server_hostname=server_hostname,
https_compatible=True,
server_side=False,
)
with map_exceptions(exc_map):
try:
with trio.fail_after(timeout_or_inf):
await ssl_stream.do_handshake()
except Exception as exc: # pragma: nocover
await self.aclose()
raise exc
return TrioStream(ssl_stream)
def get_extra_info(self, info: str) -> typing.Any:
if info == "ssl_object" and isinstance(self._stream, trio.SSLStream):
return self._stream._ssl_object # type: ignore
if info == "client_addr":
return self._get_socket_stream().socket.getsockname()
if info == "server_addr":
return self._get_socket_stream().socket.getpeername()
if info == "socket":
stream = self._stream
while isinstance(stream, trio.SSLStream):
stream = stream.transport_stream
assert isinstance(stream, trio.SocketStream)
return stream.socket
if info == "is_readable":
socket = self.get_extra_info("socket")
return socket.is_readable()
return None
def _get_socket_stream(self) -> trio.SocketStream:
stream = self._stream
while isinstance(stream, trio.SSLStream):
stream = stream.transport_stream
assert isinstance(stream, trio.SocketStream)
return stream
class TrioBackend(AsyncNetworkBackend):
async def connect_tcp(
self,
host: str,
port: int,
timeout: typing.Optional[float] = None,
local_address: typing.Optional[str] = None,
) -> AsyncNetworkStream:
timeout_or_inf = float("inf") if timeout is None else timeout
exc_map = {
trio.TooSlowError: ConnectTimeout,
trio.BrokenResourceError: ConnectError,
OSError: ConnectError,
}
# Trio supports 'local_address' from 0.16.1 onwards.
# We only include the keyword argument if a local_address
# argument has been passed.
kwargs: dict = {} if local_address is None else {"local_address": local_address}
with map_exceptions(exc_map):
with trio.fail_after(timeout_or_inf):
stream: trio.abc.Stream = await trio.open_tcp_stream(
host=host, port=port, **kwargs
)
return TrioStream(stream)
async def connect_unix_socket(
self, path: str, timeout: typing.Optional[float] = None
) -> AsyncNetworkStream: # pragma: nocover
timeout_or_inf = float("inf") if timeout is None else timeout
exc_map = {
trio.TooSlowError: ConnectTimeout,
trio.BrokenResourceError: ConnectError,
OSError: ConnectError,
}
with map_exceptions(exc_map):
with trio.fail_after(timeout_or_inf):
stream: trio.abc.Stream = await trio.open_unix_socket(path)
return TrioStream(stream)
async def sleep(self, seconds: float) -> None:
await trio.sleep(seconds) # pragma: nocover