diff --git a/findmy/accessory.py b/findmy/accessory.py index 16f699d..f859568 100644 --- a/findmy/accessory.py +++ b/findmy/accessory.py @@ -18,6 +18,7 @@ from .util import crypto if TYPE_CHECKING: + import io from collections.abc import Generator from pathlib import Path @@ -269,7 +270,7 @@ def keys_at(self, ind: int) -> set[KeyPair]: @classmethod def from_plist( cls, - plist: str | Path | dict | bytes, + plist: str | Path | dict | bytes | io.BufferedIOBase, key_alignment_plist: str | Path | dict | bytes | None = None, *, name: str | None = None, @@ -322,7 +323,7 @@ def from_plist( ) @override - def to_json(self, path: str | Path | None = None, /) -> FindMyAccessoryMapping: + def to_json(self, path: str | Path | io.TextIOBase | None = None, /) -> FindMyAccessoryMapping: alignment_date = None if self._alignment_date is not None: alignment_date = self._alignment_date.isoformat() @@ -346,7 +347,7 @@ def to_json(self, path: str | Path | None = None, /) -> FindMyAccessoryMapping: @override def from_json( cls, - val: str | Path | FindMyAccessoryMapping, + val: str | Path | io.TextIOBase | io.BufferedIOBase | FindMyAccessoryMapping, /, ) -> FindMyAccessory: val = util.files.read_data_json(val) diff --git a/findmy/keys.py b/findmy/keys.py index af584da..d5418e2 100644 --- a/findmy/keys.py +++ b/findmy/keys.py @@ -18,6 +18,7 @@ from .util import crypto, parsers if TYPE_CHECKING: + import io from collections.abc import Generator from pathlib import Path @@ -197,7 +198,7 @@ def adv_key_bytes(self) -> bytes: return int.to_bytes(key_bytes, 28, "big") @override - def to_json(self, dst: str | Path | None = None, /) -> KeyPairMapping: + def to_json(self, dst: str | Path | io.TextIOBase | None = None, /) -> KeyPairMapping: return save_and_return_json( { "type": "keypair", @@ -210,7 +211,9 @@ def to_json(self, dst: str | Path | None = None, /) -> KeyPairMapping: @classmethod @override - def from_json(cls, val: str | Path | KeyPairMapping, /) -> KeyPair: + def from_json( + cls, val: str | Path | io.TextIOBase | io.BufferedIOBase | KeyPairMapping, / + ) -> KeyPair: val = read_data_json(val) assert val["type"] == "keypair" diff --git a/findmy/reports/account.py b/findmy/reports/account.py index 7cbd571..ce731e7 100644 --- a/findmy/reports/account.py +++ b/findmy/reports/account.py @@ -48,6 +48,7 @@ ) if TYPE_CHECKING: + import io from collections.abc import Sequence from pathlib import Path @@ -433,7 +434,7 @@ def last_name(self) -> str | None: return self._account_info["last_name"] if self._account_info else None @override - def to_json(self, path: str | Path | None = None, /) -> AccountStateMapping: + def to_json(self, path: str | Path | io.TextIOBase | None = None, /) -> AccountStateMapping: res: AccountStateMapping = { "type": "account", "ids": {"uid": self._uid, "devid": self._devid}, @@ -455,7 +456,7 @@ def to_json(self, path: str | Path | None = None, /) -> AccountStateMapping: @override def from_json( cls, - val: str | Path | AccountStateMapping, + val: str | Path | io.TextIOBase | io.BufferedIOBase | AccountStateMapping, /, *, anisette_libs_path: str | Path | None = None, @@ -1048,7 +1049,7 @@ def to_json(self, dst: str | Path | None = None, /) -> AccountStateMapping: @override def from_json( cls, - val: str | Path | AccountStateMapping, + val: str | Path | io.TextIOBase | io.BufferedIOBase | AccountStateMapping, /, *, anisette_libs_path: str | Path | None = None, diff --git a/findmy/reports/anisette.py b/findmy/reports/anisette.py index e7c7341..2f22dc6 100644 --- a/findmy/reports/anisette.py +++ b/findmy/reports/anisette.py @@ -4,6 +4,7 @@ import asyncio import base64 +import io import locale import logging import time @@ -205,7 +206,7 @@ def __init__(self, server_url: str) -> None: self._closed = False @override - def to_json(self, dst: str | Path | None = None, /) -> RemoteAnisetteMapping: + def to_json(self, dst: str | Path | io.TextIOBase | None = None, /) -> RemoteAnisetteMapping: """See :meth:`BaseAnisetteProvider.serialize`.""" return util.files.save_and_return_json( { @@ -217,7 +218,9 @@ def to_json(self, dst: str | Path | None = None, /) -> RemoteAnisetteMapping: @classmethod @override - def from_json(cls, val: str | Path | RemoteAnisetteMapping) -> RemoteAnisetteProvider: + def from_json( + cls, val: str | Path | io.TextIOBase | io.BufferedIOBase | RemoteAnisetteMapping + ) -> RemoteAnisetteProvider: """See :meth:`BaseAnisetteProvider.deserialize`.""" val = util.files.read_data_json(val) @@ -349,7 +352,7 @@ async def _get_ani(self) -> Anisette: return ani @override - def to_json(self, dst: str | Path | None = None, /) -> LocalAnisetteMapping: + def to_json(self, dst: str | Path | io.TextIOBase | None = None, /) -> LocalAnisetteMapping: """See :meth:`BaseAnisetteProvider.serialize`.""" if self._ani is None: # Anisette has not been called yet, so the future has not yet resolved. @@ -378,7 +381,7 @@ def to_json(self, dst: str | Path | None = None, /) -> LocalAnisetteMapping: @override def from_json( cls, - val: str | Path | LocalAnisetteMapping, + val: str | Path | io.TextIOBase | io.BufferedIOBase | LocalAnisetteMapping, *, libs_path: str | Path | None = None, ) -> LocalAnisetteProvider: diff --git a/findmy/reports/reports.py b/findmy/reports/reports.py index cec9eb1..3c9c162 100644 --- a/findmy/reports/reports.py +++ b/findmy/reports/reports.py @@ -21,6 +21,7 @@ from findmy.keys import HasHashedPublicKey, KeyPair, KeyPairMapping, KeyPairType if TYPE_CHECKING: + import io from collections.abc import Sequence from pathlib import Path @@ -199,7 +200,7 @@ def status(self) -> int: @overload def to_json( self, - dst: str | Path | None = None, + dst: str | Path | io.TextIOBase | None = None, /, *, include_key: Literal[True], @@ -209,7 +210,7 @@ def to_json( @overload def to_json( self, - dst: str | Path | None = None, + dst: str | Path | io.TextIOBase | None = None, /, *, include_key: Literal[False], @@ -219,7 +220,7 @@ def to_json( @overload def to_json( self, - dst: str | Path | None = None, + dst: str | Path | io.TextIOBase | None = None, /, *, include_key: None = None, @@ -229,7 +230,7 @@ def to_json( @override def to_json( self, - dst: str | Path | None = None, + dst: str | Path | io.TextIOBase | None = None, /, *, include_key: bool | None = None, @@ -258,7 +259,9 @@ def to_json( @classmethod @override - def from_json(cls, val: str | Path | LocationReportMapping, /) -> LocationReport: + def from_json( + cls, val: str | Path | io.TextIOBase | io.BufferedIOBase | LocationReportMapping, / + ) -> LocationReport: val = util.files.read_data_json(val) assert val["type"] == "locReportEncrypted" or val["type"] == "locReportDecrypted" diff --git a/findmy/util/files.py b/findmy/util/files.py index a366f5a..b34170f 100644 --- a/findmy/util/files.py +++ b/findmy/util/files.py @@ -2,6 +2,7 @@ from __future__ import annotations +import io import json import plistlib from collections.abc import Mapping @@ -11,7 +12,7 @@ _T = TypeVar("_T", bound=Mapping) -def save_and_return_json(data: _T, dst: str | Path | None) -> _T: +def save_and_return_json(data: _T, dst: str | Path | io.TextIOBase | None) -> _T: """Save and return a JSON-serializable data structure.""" if dst is None: return data @@ -19,12 +20,15 @@ def save_and_return_json(data: _T, dst: str | Path | None) -> _T: if isinstance(dst, str): dst = Path(dst) - dst.write_text(json.dumps(data, indent=4)) + if isinstance(dst, io.IOBase): + json.dump(data, dst, indent=4) + elif isinstance(dst, Path): + dst.write_text(json.dumps(data, indent=4)) return data -def read_data_json(val: str | Path | _T) -> _T: +def read_data_json(val: str | Path | io.TextIOBase | io.BufferedIOBase | _T) -> _T: """Read JSON data from a file if a path is passed, or return the argument itself.""" if isinstance(val, str): val = Path(val) @@ -32,10 +36,13 @@ def read_data_json(val: str | Path | _T) -> _T: if isinstance(val, Path): val = cast("_T", json.loads(val.read_text())) + if isinstance(val, io.IOBase): + val = cast("_T", json.load(val)) + return val -def save_and_return_plist(data: _T, dst: str | Path | None) -> _T: +def save_and_return_plist(data: _T, dst: str | Path | io.BufferedIOBase | None) -> _T: """Save and return a Plist file.""" if dst is None: return data @@ -43,12 +50,15 @@ def save_and_return_plist(data: _T, dst: str | Path | None) -> _T: if isinstance(dst, str): dst = Path(dst) - dst.write_bytes(plistlib.dumps(data)) + if isinstance(dst, io.IOBase): + dst.write(plistlib.dumps(data)) + elif isinstance(dst, Path): + dst.write_bytes(plistlib.dumps(data)) return data -def read_data_plist(val: str | Path | _T | bytes) -> _T: +def read_data_plist(val: str | Path | io.BufferedIOBase | _T | bytes) -> _T: """Read Plist data from a file if a path is passed, or return the argument itself.""" if isinstance(val, str): val = Path(val) @@ -59,4 +69,7 @@ def read_data_plist(val: str | Path | _T | bytes) -> _T: if isinstance(val, bytes): val = cast("_T", plistlib.loads(val)) + if isinstance(val, io.IOBase): + val = cast("_T", plistlib.loads(val.read())) + return val