Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion bson/binary.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,7 +298,7 @@ class Binary(bytes):

def __new__(
cls: Type[Binary],
data: Union[memoryview, bytes, _mmap, _array[Any]],
data: Union[memoryview, bytes, bytearray, _mmap, _array[Any]],
subtype: int = BINARY_SUBTYPE,
) -> Binary:
if not isinstance(subtype, int):
Expand Down
14 changes: 9 additions & 5 deletions bson/raw_bson.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,9 @@


def _inflate_bson(
bson_bytes: bytes, codec_options: CodecOptions[RawBSONDocument], raw_array: bool = False
bson_bytes: bytes | memoryview,
codec_options: CodecOptions[RawBSONDocument],
raw_array: bool = False,
) -> dict[str, Any]:
"""Inflates the top level fields of a BSON document.

Expand All @@ -85,7 +87,9 @@ class RawBSONDocument(Mapping[str, Any]):
__codec_options: CodecOptions[RawBSONDocument]

def __init__(
self, bson_bytes: bytes, codec_options: Optional[CodecOptions[RawBSONDocument]] = None
self,
bson_bytes: bytes | memoryview,
codec_options: Optional[CodecOptions[RawBSONDocument]] = None,
) -> None:
"""Create a new :class:`RawBSONDocument`

Expand Down Expand Up @@ -135,7 +139,7 @@ class from the standard library so it can be used like a read-only
_get_object_size(bson_bytes, 0, len(bson_bytes))

@property
def raw(self) -> bytes:
def raw(self) -> bytes | memoryview:
"""The raw BSON bytes composing this document."""
return self.__raw

Expand All @@ -153,7 +157,7 @@ def __inflated(self) -> Mapping[str, Any]:

@staticmethod
def _inflate_bson(
bson_bytes: bytes, codec_options: CodecOptions[RawBSONDocument]
bson_bytes: bytes | memoryview, codec_options: CodecOptions[RawBSONDocument]
) -> Mapping[str, Any]:
return _inflate_bson(bson_bytes, codec_options)

Expand All @@ -180,7 +184,7 @@ class _RawArrayBSONDocument(RawBSONDocument):

@staticmethod
def _inflate_bson(
bson_bytes: bytes, codec_options: CodecOptions[RawBSONDocument]
bson_bytes: bytes | memoryview, codec_options: CodecOptions[RawBSONDocument]
) -> Mapping[str, Any]:
return _inflate_bson(bson_bytes, codec_options, raw_array=True)

Expand Down
2 changes: 1 addition & 1 deletion bson/son.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ def popitem(self) -> Tuple[_Key, _Value]:
del self[k]
return (k, v)

def update(self, other: Optional[Any] = None, **kwargs: _Value) -> None: # type: ignore[override]
def update(self, other: Optional[Any] = None, **kwargs: _Value) -> None:
# Make progressively weaker assumptions about "other"
if other is None:
pass
Expand Down
2 changes: 1 addition & 1 deletion bson/typings.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,4 +28,4 @@
_DocumentOut = Union[MutableMapping[str, Any], "RawBSONDocument"]
_DocumentType = TypeVar("_DocumentType", bound=Mapping[str, Any])
_DocumentTypeArg = TypeVar("_DocumentTypeArg", bound=Mapping[str, Any])
_ReadableBuffer = Union[bytes, memoryview, "mmap", "array"] # type: ignore[type-arg]
_ReadableBuffer = Union[bytes, memoryview, bytearray, "mmap", "array"] # type: ignore[type-arg]
2 changes: 1 addition & 1 deletion pymongo/asynchronous/cursor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1009,7 +1009,7 @@ def _deepcopy(
else:
if not isinstance(key, RE_TYPE):
key = copy.deepcopy(key, memo) # noqa: PLW2901
y[key] = value
y[key] = value # type:ignore[index]
return y

def _prepare_to_die(self, already_killed: bool) -> tuple[int, Optional[_CursorAddress]]:
Expand Down
6 changes: 3 additions & 3 deletions pymongo/asynchronous/encryption.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,7 @@ def spawn(self) -> None:
args.extend(self.opts._mongocryptd_spawn_args)
_spawn_daemon(args)

async def mark_command(self, database: str, cmd: bytes) -> bytes:
async def mark_command(self, database: str, cmd: bytes) -> bytes | memoryview:
"""Mark a command for encryption.

:param database: The database on which to run this command.
Expand All @@ -291,7 +291,7 @@ async def mark_command(self, database: str, cmd: bytes) -> bytes:
)
return res.raw

async def fetch_keys(self, filter: bytes) -> AsyncGenerator[bytes, None]:
async def fetch_keys(self, filter: bytes) -> AsyncGenerator[bytes | memoryview, None]:
"""Yields one or more keys from the key vault.

:param filter: The filter to pass to find.
Expand Down Expand Up @@ -463,7 +463,7 @@ async def encrypt(
# TODO: PYTHON-1922 avoid decoding the encrypted_cmd.
return _inflate_bson(encrypted_cmd, DEFAULT_RAW_BSON_OPTIONS)

async def decrypt(self, response: bytes) -> Optional[bytes]:
async def decrypt(self, response: bytes | memoryview) -> Optional[bytes]:
"""Decrypt a MongoDB command response.

:param response: A MongoDB command response as BSON.
Expand Down
2 changes: 1 addition & 1 deletion pymongo/asynchronous/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ async def _getaddrinfo(
socket.SocketKind,
int,
str,
tuple[str, int] | tuple[str, int, int, int],
tuple[str, int] | tuple[str, int, int, int] | tuple[int, bytes],
]
]:
if not _IS_SYNC:
Expand Down
2 changes: 1 addition & 1 deletion pymongo/compression_support.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ def compress(data: bytes) -> bytes:
return zstandard.ZstdCompressor().compress(data)


def decompress(data: bytes, compressor_id: int) -> bytes:
def decompress(data: bytes | memoryview, compressor_id: int) -> bytes:
if compressor_id == SnappyContext.compressor_id:
# python-snappy doesn't support the buffer interface.
# https://github.com/andrix/python-snappy/issues/65
Expand Down
16 changes: 9 additions & 7 deletions pymongo/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -1352,15 +1352,17 @@ class _OpReply:
UNPACK_FROM = struct.Struct("<iqii").unpack_from
OP_CODE = 1

def __init__(self, flags: int, cursor_id: int, number_returned: int, documents: bytes):
def __init__(
self, flags: int, cursor_id: int, number_returned: int, documents: bytes | memoryview
):
self.flags = flags
self.cursor_id = Int64(cursor_id)
self.number_returned = number_returned
self.documents = documents

def raw_response(
self, cursor_id: Optional[int] = None, user_fields: Optional[Mapping[str, Any]] = None
) -> list[bytes]:
) -> list[bytes | memoryview]:
"""Check the response header from the database, without decoding BSON.

Check the response for errors and unpack.
Expand Down Expand Up @@ -1448,7 +1450,7 @@ def more_to_come(self) -> bool:
return False

@classmethod
def unpack(cls, msg: bytes) -> _OpReply:
def unpack(cls, msg: bytes | memoryview) -> _OpReply:
"""Construct an _OpReply from raw bytes."""
# PYTHON-945: ignore starting_from field.
flags, cursor_id, _, number_returned = cls.UNPACK_FROM(msg)
Expand All @@ -1470,7 +1472,7 @@ class _OpMsg:
MORE_TO_COME = 1 << 1
EXHAUST_ALLOWED = 1 << 16 # Only present on requests.

def __init__(self, flags: int, payload_document: bytes):
def __init__(self, flags: int, payload_document: bytes | memoryview):
self.flags = flags
self.payload_document = payload_document

Expand Down Expand Up @@ -1512,7 +1514,7 @@ def command_response(self, codec_options: CodecOptions[Any]) -> dict[str, Any]:
"""Unpack a command response."""
return self.unpack_response(codec_options=codec_options)[0]

def raw_command_response(self) -> bytes:
def raw_command_response(self) -> bytes | memoryview:
"""Return the bytes of the command response."""
return self.payload_document

Expand All @@ -1522,7 +1524,7 @@ def more_to_come(self) -> bool:
return bool(self.flags & self.MORE_TO_COME)

@classmethod
def unpack(cls, msg: bytes) -> _OpMsg:
def unpack(cls, msg: bytes | memoryview) -> _OpMsg:
"""Construct an _OpMsg from raw bytes."""
flags, first_payload_type, first_payload_size = cls.UNPACK_FROM(msg)
if flags != 0:
Expand All @@ -1541,7 +1543,7 @@ def unpack(cls, msg: bytes) -> _OpMsg:
return cls(flags, payload_document)


_UNPACK_REPLY: dict[int, Callable[[bytes], Union[_OpReply, _OpMsg]]] = {
_UNPACK_REPLY: dict[int, Callable[[bytes | memoryview], Union[_OpReply, _OpMsg]]] = {
_OpReply.OP_CODE: _OpReply.unpack,
_OpMsg.OP_CODE: _OpMsg.unpack,
}
Expand Down
9 changes: 5 additions & 4 deletions pymongo/network_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,7 @@ def sock(self) -> Union[socket.socket, _sslConn]:
def fileno(self) -> int:
return self.conn.fileno()

def recv_into(self, buffer: bytes) -> int:
def recv_into(self, buffer: bytes | memoryview) -> int:
return self.conn.recv_into(buffer)


Expand Down Expand Up @@ -533,14 +533,14 @@ def _resolve_pending(self, exc: Optional[Exception] = None) -> None:
fut = self._pending_listeners.popleft()
fut.set_result(b"")

def _read(self, bytes_needed: int) -> memoryview:
def _read(self, bytes_needed: int) -> bytes:
"""Read bytes."""
# Send the bytes to the listener.
if self._bytes_ready < bytes_needed:
bytes_needed = self._bytes_ready
self._bytes_ready -= bytes_needed

output_buf = bytearray(bytes_needed)
output_buf = memoryview(bytearray(bytes_needed))
n_remaining = bytes_needed
out_index = 0
while n_remaining > 0:
Expand All @@ -557,7 +557,7 @@ def _read(self, bytes_needed: int) -> memoryview:
output_buf[out_index : out_index + buf_size] = buffer[:]
out_index += buf_size
n_remaining -= buf_size
return memoryview(output_buf)
return bytes(output_buf)


async def async_sendall(conn: PyMongoBaseProtocol, buf: bytes) -> None:
Expand Down Expand Up @@ -670,6 +670,7 @@ def receive_message(
f"Message length ({length!r}) is larger than server max "
f"message size ({max_message_size!r})"
)
data: bytes | memoryview
if op_code == 2012:
op_code, _, compressor_id = _UNPACK_COMPRESSION_HEADER(receive_data(conn, 9, deadline))
data = decompress(receive_data(conn, length - 25, deadline), compressor_id)
Expand Down
2 changes: 1 addition & 1 deletion pymongo/synchronous/cursor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1007,7 +1007,7 @@ def _deepcopy(
else:
if not isinstance(key, RE_TYPE):
key = copy.deepcopy(key, memo) # noqa: PLW2901
y[key] = value
y[key] = value # type:ignore[index]
return y

def _prepare_to_die(self, already_killed: bool) -> tuple[int, Optional[_CursorAddress]]:
Expand Down
6 changes: 3 additions & 3 deletions pymongo/synchronous/encryption.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,7 @@ def spawn(self) -> None:
args.extend(self.opts._mongocryptd_spawn_args)
_spawn_daemon(args)

def mark_command(self, database: str, cmd: bytes) -> bytes:
def mark_command(self, database: str, cmd: bytes) -> bytes | memoryview:
"""Mark a command for encryption.

:param database: The database on which to run this command.
Expand All @@ -288,7 +288,7 @@ def mark_command(self, database: str, cmd: bytes) -> bytes:
)
return res.raw

def fetch_keys(self, filter: bytes) -> Generator[bytes, None]:
def fetch_keys(self, filter: bytes) -> Generator[bytes | memoryview, None]:
"""Yields one or more keys from the key vault.

:param filter: The filter to pass to find.
Expand Down Expand Up @@ -460,7 +460,7 @@ def encrypt(
# TODO: PYTHON-1922 avoid decoding the encrypted_cmd.
return _inflate_bson(encrypted_cmd, DEFAULT_RAW_BSON_OPTIONS)

def decrypt(self, response: bytes) -> Optional[bytes]:
def decrypt(self, response: bytes | memoryview) -> Optional[bytes]:
"""Decrypt a MongoDB command response.

:param response: A MongoDB command response as BSON.
Expand Down
2 changes: 1 addition & 1 deletion pymongo/synchronous/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def _getaddrinfo(
socket.SocketKind,
int,
str,
tuple[str, int] | tuple[str, int, int, int],
tuple[str, int] | tuple[str, int, int, int] | tuple[int, bytes],
]
]:
if not _IS_SYNC:
Expand Down
4 changes: 2 additions & 2 deletions test/test_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -475,7 +475,7 @@ def test_typeddict_not_required_document_type(self) -> None:
# This should fail because the output is a Movie.
assert out["foo"] # type:ignore[typeddict-item]
# pyright gives reportTypedDictNotRequiredAccess for the following:
assert out["_id"] # type:ignore
assert out["_id"] # type:ignore[unused-ignore]

@only_type_check
def test_typeddict_empty_document_type(self) -> None:
Expand All @@ -496,7 +496,7 @@ def test_typeddict_find_notrequired(self):
out = coll.find_one({})
assert out is not None
# pyright gives reportTypedDictNotRequiredAccess for the following:
assert out["_id"] # type:ignore
assert out["_id"] # type:ignore[unused-ignore]

@only_type_check
def test_raw_bson_document_type(self) -> None:
Expand Down
6 changes: 4 additions & 2 deletions test/unified_format_shared.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
LOCAL_MASTER_KEY,
)
from test.utils_shared import CMAPListener, camel_to_snake, parse_collection_options
from typing import Any, Union
from typing import Any, MutableMapping, Union

from bson import (
RE_TYPE,
Expand Down Expand Up @@ -162,7 +162,9 @@ def __new__(cls, name, this_bases, d):
return meta(name, resolved_bases, d)

@classmethod
def __prepare__(cls, name, this_bases):
def __prepare__(
cls, name: str, this_bases: tuple[type, ...], /, **kwds: Any
) -> MutableMapping[str, object]:
return meta.__prepare__(name, bases)

return type.__new__(metaclass, "temporary_class", (), {})
Expand Down
Loading
Loading