Skip to content

Commit

Permalink
Use generic msgspec structs to fix variance issues
Browse files Browse the repository at this point in the history
  • Loading branch information
gwax committed Jul 13, 2023
1 parent 5ae6182 commit 7b92ffd
Show file tree
Hide file tree
Showing 5 changed files with 39 additions and 24 deletions.
29 changes: 16 additions & 13 deletions mtg_ssm/scryfall/fetcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import gzip
import os
from typing import List, cast
from typing import List

import appdirs
import msgspec
Expand Down Expand Up @@ -54,24 +54,27 @@ def _fetch_endpoint(endpoint: str) -> bytes:
def scryfetch() -> ScryfallDataSet: # pylint: disable=too-many-locals
"""Retrieve and deserialize Scryfall object data."""
print("Reading data from scryfall")
scrylist_decoder = msgspec.json.Decoder(ScryList)
bulk_data = msgspec.json.decode(
_fetch_endpoint(BULK_DATA_ENDPOINT), type=ScryList[ScryBulkData]
).data

bulk_data_list = scrylist_decoder.decode(_fetch_endpoint(BULK_DATA_ENDPOINT))
bulk_data = cast(List[ScryBulkData], bulk_data_list.data)

sets_list = scrylist_decoder.decode(_fetch_endpoint(SETS_ENDPOINT))
sets_data = cast(List[ScrySet], sets_list.data)
scrylistset_decoder = msgspec.json.Decoder(ScryList[ScrySet])
sets_list = scrylistset_decoder.decode(_fetch_endpoint(SETS_ENDPOINT))
sets_data = sets_list.data
while sets_list.has_more and sets_list.next_page is not None:
sets_list = scrylist_decoder.decode(_fetch_endpoint(sets_list.next_page))
sets_data += cast(List[ScrySet], sets_list.data)
sets_list = scrylistset_decoder.decode(_fetch_endpoint(sets_list.next_page))
sets_data += sets_list.data

migrations_list = scrylist_decoder.decode(_fetch_endpoint(MIGRATIONS_ENDPOINT))
migrations_data = cast(List[ScryMigration], migrations_list.data)
scrylistmigration_decoder = msgspec.json.Decoder(ScryList[ScryMigration])
migrations_list = scrylistmigration_decoder.decode(
_fetch_endpoint(MIGRATIONS_ENDPOINT)
)
migrations_data = migrations_list.data
while migrations_list.has_more and migrations_list.next_page is not None:
migrations_list = scrylist_decoder.decode(
migrations_list = scrylistmigration_decoder.decode(
_fetch_endpoint(migrations_list.next_page)
)
migrations_data += cast(List[ScryMigration], migrations_list.data)
migrations_data += migrations_list.data

[cards_endpoint] = [bd.download_uri for bd in bulk_data if bd.type == BULK_TYPE]
cards_data = msgspec.json.decode(
Expand Down
16 changes: 14 additions & 2 deletions mtg_ssm/scryfall/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,11 @@
import datetime as dt
from decimal import Decimal
from enum import Enum
from typing import Dict, List, Literal, Optional, Union
from typing import Dict, Generic, List, Literal, Optional, TypeVar, Union
from uuid import UUID

from msgspec import Struct
from typing_extensions import TypeAlias


class ScryColor(str, Enum):
Expand Down Expand Up @@ -451,16 +452,27 @@ class ScryMigration(
note: Optional[str] = None


ScryListable: TypeAlias = Union[
ScryBulkData,
ScryCard,
ScryMigration,
ScrySet,
]

_ScryListableT = TypeVar("_ScryListableT", bound=ScryListable)


class ScryList(
Struct,
Generic[_ScryListableT],
tag_field="object",
tag="list",
kw_only=True,
omit_defaults=True,
):
"""Model for https://scryfall.com/docs/api/lists"""

data: List[Union[ScrySet, ScryCard, ScryBulkData, ScryMigration]]
data: List[_ScryListableT]
has_more: bool
next_page: Optional[str] = None
total_cards: Optional[int] = None
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ classifiers = [

dependencies = [
"appdirs~=1.4",
"msgspec~=0.13",
"msgspec~=0.15",
"openpyxl~=3.0",
"requests~=2.27",
'requests-cache~=0.9.8',
Expand Down
10 changes: 5 additions & 5 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# pylint: disable=redefined-outer-name

from pathlib import Path
from typing import Dict, Generator, List, cast
from typing import Dict, Generator, List
from uuid import UUID

import msgspec
Expand Down Expand Up @@ -44,18 +44,18 @@ def cards_data() -> List[ScryCard]:
def sets_data() -> List[ScrySet]:
"""Fixture containing all test set data."""
with SETS_DATA_FILE.open("rb") as sets_data_file:
sets_list = msgspec.json.decode(sets_data_file.read(), type=ScryList)
return cast(List[ScrySet], sets_list.data)
sets_list = msgspec.json.decode(sets_data_file.read(), type=ScryList[ScrySet])
return sets_list.data


@pytest.fixture(scope="session")
def migrations_data() -> List[ScryMigration]:
"""Fixture containing all test migrations data."""
with MIGRATIONS_DATA_FILE.open("rb") as migrations_data_file:
migrations_list = msgspec.json.decode(
migrations_data_file.read(), type=ScryList
migrations_data_file.read(), type=ScryList[ScryMigration]
)
return cast(List[ScryMigration], migrations_list.data)
return migrations_list.data


@pytest.fixture(scope="session")
Expand Down
6 changes: 3 additions & 3 deletions tests/gen_testdata.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

import copy
from pathlib import Path
from typing import List, cast

import msgspec

Expand Down Expand Up @@ -85,9 +84,10 @@ def main() -> None: # pylint: disable=too-many-locals,too-many-statements
print("Fetching scryfall data")
scrydata = fetcher.scryfetch()
bulk_data_list = msgspec.json.decode(
fetcher._fetch_endpoint(fetcher.BULK_DATA_ENDPOINT), type=models.ScryList
fetcher._fetch_endpoint(fetcher.BULK_DATA_ENDPOINT),
type=models.ScryList[models.ScryBulkData],
)
bulk_data = cast(List[models.ScryBulkData], bulk_data_list.data)
bulk_data = bulk_data_list.data

print("Selecting sets")
accepted_sets = sorted(
Expand Down

0 comments on commit 7b92ffd

Please sign in to comment.