/
async_stream.py
146 lines (119 loc) · 4.79 KB
/
async_stream.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
from abc import ABC, abstractmethod
from typing import Optional, TYPE_CHECKING, Awaitable
import os
import aiofiles
import aiofiles.os
import asyncio
if TYPE_CHECKING:
from _typeshed import FileDescriptorOrPath
_MSB = 0x80
_REST = 0x7f
class CARByteStreamException(Exception):
pass
class CARByteStream(ABC):
def __init__(self) -> None:
self._pos = 0
self._limit: Optional[int] = None
@abstractmethod
async def can_read_more(self) -> bool:
pass
@abstractmethod
async def read_slice(self, start: int, end_exclusive: int) -> bytes:
pass
@property
def pos(self)-> int:
return self._pos
def set_limit(self, position: int) -> None:
self._limit = position
def move(self, position: int) -> None:
self._pos = position
async def read_bytes(self, count: int, walk_forward: bool=True) -> bytes:
result = await self.read_slice(self._pos, count + self._pos)
assert len(result) == count, f'expected {count} bytes, got {len(result)}'
if walk_forward:
self._pos += count
return result
async def read_u8(self) -> int:
return (await self.read_bytes(1))[0]
# https://github.com/chrisdickinson/varint/blob/master/decode.js
async def read_var_int(self) -> int:
result = 0
shift = 0
while True:
if shift > 49:
raise CARByteStreamException(f'cannot decode varint (at pos {self._pos})')
num = await self.read_u8()
result += ((num & _REST) << shift) if (shift < 28) else ((num & _REST) * pow(2, shift))
shift += 7
if num < _MSB: break
return result
class ChunkedMemoryByteStream(CARByteStream):
def __init__(self) -> None:
super().__init__()
self._complete = False
self._bytes = bytearray()
self._added_bytes_cond = asyncio.Condition()
async def mark_complete(self) -> None:
self._complete = True
async with self._added_bytes_cond:
self._added_bytes_cond.notify_all()
async def append_bytes(self, b: bytes) -> None:
if self._complete:
raise CARByteStreamException('tried to append bytes but complete flag was set!')
assert isinstance(b, bytes)
self._bytes.extend(b)
async with self._added_bytes_cond:
self._added_bytes_cond.notify_all()
async def read_slice(self, start: int, end_exclusive: int) -> bytes:
if start >= end_exclusive:
raise CARByteStreamException('only positive slices are allowed')
if self._limit and end_exclusive > self._limit:
raise CARByteStreamException('limit will be breached')
while end_exclusive > len(self._bytes):
if self._complete:
raise CARByteStreamException('waiting for bytes, but complete flag was set!')
else:
async with self._added_bytes_cond:
await self._added_bytes_cond.wait()
result = bytes(self._bytes[start:end_exclusive])
return result
async def can_read_more(self) -> bool:
return self._pos < (self._limit or (len(self._bytes) - 1))
class FileByteStream(CARByteStream):
_fd: Optional[int]
def __init__(self, file: 'FileDescriptorOrPath', *, close_fd: bool=True) -> None:
super().__init__()
if isinstance(file, int):
self._fd = file
else:
self._fd = os.open(file, os.O_RDONLY)
self._size: Optional[int] = None
self._close_fd = close_fd
def close(self) -> None:
if self._fd is None:
raise CARByteStreamException('this stream was already closed')
if self._close_fd:
os.close(self._fd)
self._fd = None
async def __aenter__(self) -> 'FileByteStream':
return self
async def __aexit__(self, *args): # type: ignore
self.close()
async def can_read_more(self) -> bool:
if self._size is None:
if self._fd is None:
raise CARByteStreamException('this stream was already closed')
self._size = await aiofiles.os.path.getsize(self._fd)
return self._pos < (self._limit or (self._size - 1))
async def read_slice(self, start: int, end_exclusive: int) -> bytes:
assert start >= 0
if self._fd is None:
raise CARByteStreamException('this stream was already closed')
if start >= end_exclusive:
raise CARByteStreamException('only positive slices are allowed')
async with aiofiles.open(self._fd, 'rb', closefd=False) as f:
file_pos = await f.tell()
if file_pos != start:
await f.seek(start)
result = await f.read(end_exclusive - start)
return result