Skip to content

Commit

Permalink
refactor!: SQLAlchemy v2 and mypy strict typings / annotations (#311)
Browse files Browse the repository at this point in the history
  • Loading branch information
tony committed Jun 10, 2023
2 parents 3db49b3 + 8eaa5ea commit 4a9499b
Show file tree
Hide file tree
Showing 10 changed files with 428 additions and 268 deletions.
24 changes: 24 additions & 0 deletions CHANGES
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,31 @@ $ pip install --user --upgrade --pre unihan-db

<!-- Maintainers, insert changes / features for the next release here -->

## Development

- **Improved typings**

Move to strict mypy typings (#311)

This will make future refactoring simplifications easier and maintain code
quality in the long term, in addition to more intelligent completions.

[`mypy --strict`]: https://mypy.readthedocs.io/en/stable/command_line.html#cmdoption-mypy-strict

### Breaking

- SQLAlchemy: Upgraded to v2 (#311)

Downstream packages will require SQLAlchemy v2 at a minimum.

Benefits in include: Built-in types for mypy, being able to use SQLAlchemy
core API against ORM entities.

See also: [What's new in SQLAlchemy
2.0](https://docs.sqlalchemy.org/en/20/changelog/whatsnew_20.html),
[Migrating to SQLAlchemy
2.0](https://docs.sqlalchemy.org/en/20/changelog/migration_20.html)

- **Python 3.7 Dropped**

Python 3.7 support has been dropped (#309)
Expand All @@ -39,6 +62,7 @@ _Maintenance only, no bug fixes or features_
formatting can be done almost instantly.

This change replaces isort, flake8 and flake8 plugins.

- poetry: 1.4.0 -> 1.5.0

See also: https://github.com/python-poetry/poetry/releases/tag/1.5.0
Expand Down
7 changes: 5 additions & 2 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,9 @@
}


def linkcode_resolve(domain, info): # NOQA: C901
def linkcode_resolve(
domain: str, info: t.Dict[str, str]
) -> t.Union[None, str]: # NOQA: C901
"""
Determine the URL corresponding to Python object
Expand Down Expand Up @@ -204,7 +206,8 @@ def linkcode_resolve(domain, info): # NOQA: C901
except AttributeError:
pass
else:
obj = unwrap(obj)
if callable(obj):
obj = unwrap(obj)

try:
fn = inspect.getsourcefile(obj)
Expand Down
8 changes: 5 additions & 3 deletions examples/01_bootstrap.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,10 @@

bootstrap.bootstrap_unihan(session)

random_row = session.query(Unhn).order_by(func.random()).limit(1).first()
random_row_query = session.query(Unhn).order_by(func.random()).limit(1)

pp = pprint.PrettyPrinter(indent=0)
assert random_row_query is not None

pp.pprint(random_row.to_dict())
random_row = random_row_query.first()

pprint.pprint(bootstrap.to_dict(random_row))
124 changes: 56 additions & 68 deletions poetry.lock

Large diffs are not rendered by default.

8 changes: 2 additions & 6 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ Repository = "https://github.com/cihai/unihan-db"
[tool.poetry.dependencies]
python = "^3.8"
appdirs = "*"
SQLAlchemy = "<2"
SQLAlchemy = ">=2"
unihan-etl = "~=0.19.1"

[tool.poetry.group.dev.dependencies]
Expand Down Expand Up @@ -81,7 +81,6 @@ pytest-cov = "*"
black = "*"
ruff = "*"
mypy = "*"
sqlalchemy-stubs = "*"
types-appdirs = "*"

[tool.poetry.extras]
Expand All @@ -104,14 +103,11 @@ lint = [
"black",
"ruff",
"mypy",
"sqlalchemy-stubs",
"types-appdirs",
]

[tool.mypy]
plugins = [
"sqlmypy",
]
strict = true
files = [
"src/",
"tests/",
Expand Down
76 changes: 51 additions & 25 deletions src/unihan_db/bootstrap.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,33 @@
import logging
import sys
from datetime import datetime

import typing as t
from sqlalchemy import create_engine, event
from sqlalchemy.orm import class_mapper, mapper, scoped_session, sessionmaker

import sqlalchemy
from sqlalchemy.orm import Session, class_mapper, scoped_session, sessionmaker
from sqlalchemy.orm.decl_api import registry
from sqlalchemy.orm.scoping import ScopedSession
from unihan_etl import process as unihan
from unihan_etl.types import UntypedUnihanData
from unihan_etl.util import merge_dict

from . import dirs, importer
from .tables import Base, Unhn

log = logging.getLogger(__name__)

mapper_reg = registry()

if t.TYPE_CHECKING:
from unihan_etl.types import (
UntypedNormalizedData,
)

def setup_logger(logger=None, level="INFO"):

def setup_logger(
logger: t.Optional[logging.Logger] = None,
level: t.Literal["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"] = "INFO",
) -> None:
"""
Setup logging for CLI use.
Expand Down Expand Up @@ -119,7 +132,7 @@ def setup_logger(logger=None, level="INFO"):
DEFAULT_FIELDS = ["ucn", "char"]


def is_bootstrapped(metadata):
def is_bootstrapped(metadata: sqlalchemy.MetaData) -> bool:
"""Return True if cihai is correctly bootstrapped."""
fields = UNIHAN_FIELDS + DEFAULT_FIELDS
if TABLE_NAME in metadata.tables.keys():
Expand All @@ -133,31 +146,38 @@ def is_bootstrapped(metadata):
return False


def bootstrap_data(options=None):
def bootstrap_data(
options: t.Union[UntypedUnihanData, None] = None,
) -> t.Optional["UntypedNormalizedData"]:
if options is None:
options = {}
_options = options

options = merge_dict(UNIHAN_ETL_DEFAULT_OPTIONS.copy(), options)
_options = merge_dict(UNIHAN_ETL_DEFAULT_OPTIONS.copy(), _options)

p = unihan.Packager(options)
p = unihan.Packager(_options)
p.download()
return p.export()


def bootstrap_unihan(session, options=None):
if options is None:
options = {}
def bootstrap_unihan(
session: t.Union[Session, ScopedSession[t.Any]],
options: t.Optional[UntypedUnihanData] = None,
) -> None:
_options = options if options is not None else {}

"""Download, extract and import unihan to database."""
if session.query(Unhn).count() == 0:
data = bootstrap_data(options)
data = bootstrap_data(_options)
assert data is not None
log.info("bootstrap Unhn table")
log.info("bootstrap Unhn table finished")
count = 0
total_count = len(data)
items = []

for char in data:
assert isinstance(char, dict)
c = Unhn(char=char["char"], ucn=char["ucn"])
importer.import_char(c, char)
items.append(c)
Expand All @@ -177,7 +197,7 @@ def bootstrap_unihan(session, options=None):
log.info("Done adding rows.")


def to_dict(obj, found=None):
def to_dict(obj: t.Any, found: t.Optional[t.Set[t.Any]] = None) -> t.Dict[str, object]:
"""
Return dictionary of an SQLAlchemy Query result.
Expand All @@ -196,21 +216,26 @@ def to_dict(obj, found=None):
dictionary representation of a SQLAlchemy query
"""

def _get_key_value(c):
def _get_key_value(c: str) -> t.Any:
if isinstance(getattr(obj, c), datetime):
return (c, getattr(obj, c).isoformat())
else:
return (c, getattr(obj, c))

_found: t.Set[t.Any]

if found is None:
found = set()
mapper = class_mapper(obj.__class__)
columns = [column.key for column in mapper.columns]
_found = set()
else:
_found = found

_mapper = class_mapper(obj.__class__)
columns = [column.key for column in _mapper.columns]

result = dict(map(_get_key_value, columns))
for name, relation in mapper.relationships.items():
if relation not in found:
found.add(relation)
for name, relation in _mapper.relationships.items():
if relation not in _found:
_found.add(relation)
related_obj = getattr(obj, name)
if related_obj is not None:
if relation.uselist:
Expand All @@ -220,7 +245,7 @@ def _get_key_value(c):
return result


def add_to_dict(b):
def add_to_dict(b: t.Any) -> t.Any:
"""
Add :func:`.to_dict` method to SQLAlchemy Base object.
Expand All @@ -233,7 +258,9 @@ def add_to_dict(b):
return b


def get_session(engine_url="sqlite:///{user_data_dir}/unihan_db.db"):
def get_session(
engine_url: str = "sqlite:///{user_data_dir}/unihan_db.db",
) -> "ScopedSession[t.Any]":
"""
Return new SQLAlchemy session object from engine string.
Expand All @@ -251,9 +278,8 @@ def get_session(engine_url="sqlite:///{user_data_dir}/unihan_db.db"):
engine_url = engine_url.format(**{"user_data_dir": dirs.user_data_dir})
engine = create_engine(engine_url)

event.listen(mapper, "after_configured", add_to_dict(Base))
Base.metadata.bind = engine
Base.metadata.create_all()
event.listen(mapper_reg, "after_configured", add_to_dict(Base))
Base.metadata.create_all(bind=engine)
session_factory = sessionmaker(bind=engine)
session = scoped_session(session_factory)

Expand Down
Loading

0 comments on commit 4a9499b

Please sign in to comment.