Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions findmy/accessory.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from .util import crypto

if TYPE_CHECKING:
import io
from collections.abc import Generator
from pathlib import Path

Expand Down Expand Up @@ -269,7 +270,7 @@ def keys_at(self, ind: int) -> set[KeyPair]:
@classmethod
def from_plist(
cls,
plist: str | Path | dict | bytes,
plist: str | Path | dict | bytes | io.BufferedIOBase,
key_alignment_plist: str | Path | dict | bytes | None = None,
*,
name: str | None = None,
Expand Down Expand Up @@ -322,7 +323,7 @@ def from_plist(
)

@override
def to_json(self, path: str | Path | None = None, /) -> FindMyAccessoryMapping:
def to_json(self, path: str | Path | io.TextIOBase | None = None, /) -> FindMyAccessoryMapping:
alignment_date = None
if self._alignment_date is not None:
alignment_date = self._alignment_date.isoformat()
Expand All @@ -346,7 +347,7 @@ def to_json(self, path: str | Path | None = None, /) -> FindMyAccessoryMapping:
@override
def from_json(
cls,
val: str | Path | FindMyAccessoryMapping,
val: str | Path | io.TextIOBase | io.BufferedIOBase | FindMyAccessoryMapping,
/,
) -> FindMyAccessory:
val = util.files.read_data_json(val)
Expand Down
7 changes: 5 additions & 2 deletions findmy/keys.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from .util import crypto, parsers

if TYPE_CHECKING:
import io
from collections.abc import Generator
from pathlib import Path

Expand Down Expand Up @@ -197,7 +198,7 @@ def adv_key_bytes(self) -> bytes:
return int.to_bytes(key_bytes, 28, "big")

@override
def to_json(self, dst: str | Path | None = None, /) -> KeyPairMapping:
def to_json(self, dst: str | Path | io.TextIOBase | None = None, /) -> KeyPairMapping:
return save_and_return_json(
{
"type": "keypair",
Expand All @@ -210,7 +211,9 @@ def to_json(self, dst: str | Path | None = None, /) -> KeyPairMapping:

@classmethod
@override
def from_json(cls, val: str | Path | KeyPairMapping, /) -> KeyPair:
def from_json(
cls, val: str | Path | io.TextIOBase | io.BufferedIOBase | KeyPairMapping, /
) -> KeyPair:
val = read_data_json(val)
assert val["type"] == "keypair"

Expand Down
7 changes: 4 additions & 3 deletions findmy/reports/account.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
)

if TYPE_CHECKING:
import io
from collections.abc import Sequence
from pathlib import Path

Expand Down Expand Up @@ -433,7 +434,7 @@ def last_name(self) -> str | None:
return self._account_info["last_name"] if self._account_info else None

@override
def to_json(self, path: str | Path | None = None, /) -> AccountStateMapping:
def to_json(self, path: str | Path | io.TextIOBase | None = None, /) -> AccountStateMapping:
res: AccountStateMapping = {
"type": "account",
"ids": {"uid": self._uid, "devid": self._devid},
Expand All @@ -455,7 +456,7 @@ def to_json(self, path: str | Path | None = None, /) -> AccountStateMapping:
@override
def from_json(
cls,
val: str | Path | AccountStateMapping,
val: str | Path | io.TextIOBase | io.BufferedIOBase | AccountStateMapping,
/,
*,
anisette_libs_path: str | Path | None = None,
Expand Down Expand Up @@ -1048,7 +1049,7 @@ def to_json(self, dst: str | Path | None = None, /) -> AccountStateMapping:
@override
def from_json(
cls,
val: str | Path | AccountStateMapping,
val: str | Path | io.TextIOBase | io.BufferedIOBase | AccountStateMapping,
/,
*,
anisette_libs_path: str | Path | None = None,
Expand Down
11 changes: 7 additions & 4 deletions findmy/reports/anisette.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import asyncio
import base64
import io
import locale
import logging
import time
Expand Down Expand Up @@ -205,7 +206,7 @@ def __init__(self, server_url: str) -> None:
self._closed = False

@override
def to_json(self, dst: str | Path | None = None, /) -> RemoteAnisetteMapping:
def to_json(self, dst: str | Path | io.TextIOBase | None = None, /) -> RemoteAnisetteMapping:
"""See :meth:`BaseAnisetteProvider.serialize`."""
return util.files.save_and_return_json(
{
Expand All @@ -217,7 +218,9 @@ def to_json(self, dst: str | Path | None = None, /) -> RemoteAnisetteMapping:

@classmethod
@override
def from_json(cls, val: str | Path | RemoteAnisetteMapping) -> RemoteAnisetteProvider:
def from_json(
cls, val: str | Path | io.TextIOBase | io.BufferedIOBase | RemoteAnisetteMapping
) -> RemoteAnisetteProvider:
"""See :meth:`BaseAnisetteProvider.deserialize`."""
val = util.files.read_data_json(val)

Expand Down Expand Up @@ -349,7 +352,7 @@ async def _get_ani(self) -> Anisette:
return ani

@override
def to_json(self, dst: str | Path | None = None, /) -> LocalAnisetteMapping:
def to_json(self, dst: str | Path | io.TextIOBase | None = None, /) -> LocalAnisetteMapping:
"""See :meth:`BaseAnisetteProvider.serialize`."""
if self._ani is None:
# Anisette has not been called yet, so the future has not yet resolved.
Expand Down Expand Up @@ -378,7 +381,7 @@ def to_json(self, dst: str | Path | None = None, /) -> LocalAnisetteMapping:
@override
def from_json(
cls,
val: str | Path | LocalAnisetteMapping,
val: str | Path | io.TextIOBase | io.BufferedIOBase | LocalAnisetteMapping,
*,
libs_path: str | Path | None = None,
) -> LocalAnisetteProvider:
Expand Down
13 changes: 8 additions & 5 deletions findmy/reports/reports.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from findmy.keys import HasHashedPublicKey, KeyPair, KeyPairMapping, KeyPairType

if TYPE_CHECKING:
import io
from collections.abc import Sequence
from pathlib import Path

Expand Down Expand Up @@ -199,7 +200,7 @@ def status(self) -> int:
@overload
def to_json(
self,
dst: str | Path | None = None,
dst: str | Path | io.TextIOBase | None = None,
/,
*,
include_key: Literal[True],
Expand All @@ -209,7 +210,7 @@ def to_json(
@overload
def to_json(
self,
dst: str | Path | None = None,
dst: str | Path | io.TextIOBase | None = None,
/,
*,
include_key: Literal[False],
Expand All @@ -219,7 +220,7 @@ def to_json(
@overload
def to_json(
self,
dst: str | Path | None = None,
dst: str | Path | io.TextIOBase | None = None,
/,
*,
include_key: None = None,
Expand All @@ -229,7 +230,7 @@ def to_json(
@override
def to_json(
self,
dst: str | Path | None = None,
dst: str | Path | io.TextIOBase | None = None,
/,
*,
include_key: bool | None = None,
Expand Down Expand Up @@ -258,7 +259,9 @@ def to_json(

@classmethod
@override
def from_json(cls, val: str | Path | LocationReportMapping, /) -> LocationReport:
def from_json(
cls, val: str | Path | io.TextIOBase | io.BufferedIOBase | LocationReportMapping, /
) -> LocationReport:
val = util.files.read_data_json(val)
assert val["type"] == "locReportEncrypted" or val["type"] == "locReportDecrypted"

Expand Down
25 changes: 19 additions & 6 deletions findmy/util/files.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from __future__ import annotations

import io
import json
import plistlib
from collections.abc import Mapping
Expand All @@ -11,44 +12,53 @@
_T = TypeVar("_T", bound=Mapping)


def save_and_return_json(data: _T, dst: str | Path | None) -> _T:
def save_and_return_json(data: _T, dst: str | Path | io.TextIOBase | None) -> _T:
"""Save and return a JSON-serializable data structure."""
if dst is None:
return data

if isinstance(dst, str):
dst = Path(dst)

dst.write_text(json.dumps(data, indent=4))
if isinstance(dst, io.IOBase):
json.dump(data, dst, indent=4)
elif isinstance(dst, Path):
dst.write_text(json.dumps(data, indent=4))

return data


def read_data_json(val: str | Path | _T) -> _T:
def read_data_json(val: str | Path | io.TextIOBase | io.BufferedIOBase | _T) -> _T:
"""Read JSON data from a file if a path is passed, or return the argument itself."""
if isinstance(val, str):
val = Path(val)

if isinstance(val, Path):
val = cast("_T", json.loads(val.read_text()))

if isinstance(val, io.IOBase):
val = cast("_T", json.load(val))

return val


def save_and_return_plist(data: _T, dst: str | Path | None) -> _T:
def save_and_return_plist(data: _T, dst: str | Path | io.BufferedIOBase | None) -> _T:
"""Save and return a Plist file."""
if dst is None:
return data

if isinstance(dst, str):
dst = Path(dst)

dst.write_bytes(plistlib.dumps(data))
if isinstance(dst, io.IOBase):
dst.write(plistlib.dumps(data))
elif isinstance(dst, Path):
dst.write_bytes(plistlib.dumps(data))

return data


def read_data_plist(val: str | Path | _T | bytes) -> _T:
def read_data_plist(val: str | Path | io.BufferedIOBase | _T | bytes) -> _T:
"""Read Plist data from a file if a path is passed, or return the argument itself."""
if isinstance(val, str):
val = Path(val)
Expand All @@ -59,4 +69,7 @@ def read_data_plist(val: str | Path | _T | bytes) -> _T:
if isinstance(val, bytes):
val = cast("_T", plistlib.loads(val))

if isinstance(val, io.IOBase):
val = cast("_T", plistlib.loads(val.read()))

return val