Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 60 additions & 19 deletions findmy/reports/anisette.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from __future__ import annotations

import asyncio
import base64
import locale
import logging
Expand Down Expand Up @@ -31,7 +32,7 @@ class LocalAnisetteMapping(TypedDict):
"""JSON mapping representing state of a local Anisette provider."""

type: Literal["aniLocal"]
prov_data: str
prov_data: str | None


AnisetteMapping = Union[RemoteAnisetteMapping, LocalAnisetteMapping]
Expand Down Expand Up @@ -286,7 +287,7 @@ class LocalAnisetteProvider(BaseAnisetteProvider, util.abc.Serializable[LocalAni
def __init__(
self,
*,
state_blob: BinaryIO | None = None,
state_blob: BytesIO | None = None,
libs_path: str | Path | None = None,
) -> None:
"""Initialize the provider."""
Expand All @@ -295,41 +296,75 @@ def __init__(
if isinstance(libs_path, str):
libs_path = Path(libs_path)

if libs_path is None or not libs_path.is_file():
# we do not yet initialize Anisette in order to prevent blocking the event loop,
# since the anisette library will download the required libraries synchronously.
self._ani: Anisette | None = None

self._ani_data: AnisetteHeaders | None = None
self._libs_path: Path | None = libs_path
self._state_blob: BytesIO | None = state_blob

@property
def _is_new_session(self) -> bool:
return self._state_blob is None

async def _get_ani(self) -> Anisette:
if self._ani is not None:
return self._ani

if self._libs_path is None or not self._libs_path.is_file():
logger.info(
"The Anisette engine will download libraries required for operation, "
"this may take a few seconds...",
)
if libs_path is None:
if self._libs_path is None:
logger.info(
"To speed up future local Anisette initializations, "
"provide a filesystem path to load the libraries from.",
)

files: list[BinaryIO | Path] = []
if state_blob is not None:
files.append(state_blob)
if libs_path is not None and libs_path.exists():
files.append(libs_path)
if self._state_blob is not None:
files.append(self._state_blob)
if self._libs_path is not None and self._libs_path.exists():
files.append(self._libs_path)

self._ani = Anisette.load(*files)
self._ani_data: AnisetteHeaders | None = None
self._libs_path: Path | None = libs_path
loop = asyncio.get_running_loop()
ani = await loop.run_in_executor(None, Anisette.load, *files)
is_provisioned = await loop.run_in_executor(None, lambda: ani.is_provisioned)

if self._libs_path is not None:
ani.save_libs(self._libs_path)

if libs_path is not None:
self._ani.save_libs(libs_path)
if state_blob is not None and not self._ani.is_provisioned:
if not self._is_new_session and not is_provisioned:
logger.warning(
"The Anisette state that was loaded has not yet been provisioned. "
"Was the previous session saved properly?",
)

# pre-provision to ensure that the VM has initialized
await loop.run_in_executor(None, ani.provision)

self._ani = ani
return ani

@override
def to_json(self, dst: str | Path | None = None, /) -> LocalAnisetteMapping:
"""See :meth:`BaseAnisetteProvider.serialize`."""
with BytesIO() as buf:
self._ani.save_provisioning(buf)
prov_data = base64.b64encode(buf.getvalue()).decode("utf-8")
if self._ani is None:
# Anisette has not been called yet, so the future has not yet resolved.
# We don't want to wait here, so we just return the original state blob.
# If the state blob is None, this means we have a new session that has not
# been provisioned yet, so we will not save the provisioning data.
if self._state_blob is None:
prov_data = None
else:
prov_data = base64.b64encode(self._state_blob.getvalue()).decode("utf-8")
else:
# Anisette has been initialized, so we can save the provisioning data.
with BytesIO() as buf:
self._ani.save_provisioning(buf)
prov_data = base64.b64encode(buf.getvalue()).decode("utf-8")

return util.files.save_and_return_json(
{
Expand All @@ -352,7 +387,8 @@ def from_json(

assert val["type"] == "aniLocal"

state_blob = BytesIO(base64.b64decode(val["prov_data"]))
prov_data = val["prov_data"]
state_blob = None if prov_data is None else BytesIO(base64.b64decode(prov_data))

return cls(state_blob=state_blob, libs_path=libs_path)

Expand All @@ -365,7 +401,12 @@ async def get_headers(
with_client_info: bool = False,
) -> dict[str, str]:
"""See :meth:`BaseAnisetteProvider.get_headers`."""
self._ani_data = self._ani.get_data()
ani = await self._get_ani()

# run in executor to prevent blocking the event loop,
# since get_data may make blocking network requests.
loop = asyncio.get_running_loop()
self._ani_data = await loop.run_in_executor(None, ani.get_data)

return await super().get_headers(user_id, device_id, serial, with_client_info)

Expand Down