Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add mypy & type annotations. #37

Merged
merged 8 commits into from
Jul 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,32 @@ jobs:
run: |
poetry run flake8 . --jobs=auto --format=github

type-check:
name: Type-check
runs-on: ubuntu-20.04

steps:
- name: Checkout
uses: actions/checkout@v3

- name: Set up Python 3.9
uses: actions/setup-python@v4
with:
python-version: 3.9

- name: Set up Poetry
uses: abatilo/actions-poetry@v2.1.5
with:
poetry-version: 1.7.1

- name: Install dependencies
run: |
poetry install

- name: Type-check
run: |
poetry run mypy src tests

validate-dependencies:
name: Check dependency locks
runs-on: ubuntu-20.04
Expand Down
60 changes: 59 additions & 1 deletion poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

18 changes: 18 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,9 @@ tox-gh-actions = "^2.4.0"
psycopg2-binary = "^2.8.6"
pytest-check = "^1.0.1"

mypy = "^1.8.0"
types-tqdm = "^4.66.0.20240106"

[tool.poetry.group.dev.dependencies]
tox-pyenv = "^1.1.0"

Expand All @@ -48,6 +51,21 @@ use_parentheses = true
ensure_newline_before_comments = true
line_length = 80

[tool.mypy]
warn_unused_configs = true

# Be fairly strict with our types
strict_optional = true
enable_error_code = "ignore-without-code"
disallow_incomplete_defs = true
disallow_any_generics = true
disallow_untyped_decorators = true
disallow_untyped_defs = true

[[tool.mypy.overrides]]
module = "django.*"
ignore_missing_imports = true

[tool.pytest.ini_options]
DJANGO_SETTINGS_MODULE = "testsite.settings"
django_find_project = false
Expand Down
40 changes: 32 additions & 8 deletions src/devdata/anonymisers.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,24 @@
import pathlib
import random
from typing import Any, TypeVar

import faker
from django.db import models

from .types import Anonymiser, GenericAnonymiser
from .utils import get_exported_pks_for_model

T = TypeVar("T")


def faker_anonymise(
generator, *args, preserve_nulls=False, unique=False, **kwargs
):
def anonymise(*, pii_value, fake, **_kwargs):
generator: str,
*args: Any,
preserve_nulls: bool = False,
unique: bool = False,
**kwargs: Any,
) -> Anonymiser:
def anonymise(*, pii_value: T, fake: faker.Faker, **_kwargs: object) -> T:
if preserve_nulls and pii_value is None:
return None

Expand All @@ -16,8 +28,15 @@ def anonymise(*, pii_value, fake, **_kwargs):
return anonymise


def preserve_internal(alternative):
def anonymise(obj, field, pii_value, **kwargs):
def preserve_internal(
alternative: GenericAnonymiser[T],
) -> GenericAnonymiser[T]:
def anonymise(
obj: models.Model,
field: str,
pii_value: T,
**kwargs: Any,
) -> T:
if getattr(obj, "is_superuser", False) or getattr(
obj, "is_staff", False
):
Expand All @@ -27,16 +46,21 @@ def anonymise(obj, field, pii_value, **kwargs):
return anonymise


def const(value, preserve_nulls=False):
def anonymise(*_, pii_value, **_kwargs):
def const(value: T, preserve_nulls: bool = False) -> GenericAnonymiser[T]:
def anonymise(*_: object, pii_value: T, **_kwargs: object) -> T:
if preserve_nulls and pii_value is None:
return None
return value

return anonymise


def random_foreign_key(obj, field, dest, **_kwargs):
def random_foreign_key(
obj: models.Model,
field: str,
dest: pathlib.Path,
**_kwargs: object,
) -> Any:
related_model = obj._meta.get_field(field).related_model
exported_pks = get_exported_pks_for_model(dest, related_model)
return random.choice(exported_pks)
31 changes: 22 additions & 9 deletions src/devdata/engine.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
from __future__ import annotations

import json
from collections.abc import Collection
from pathlib import Path

from django.core.management import call_command
from django.core.management.color import no_style
Expand All @@ -20,14 +24,14 @@
)


def validate_strategies(only=None):
def validate_strategies(only: Collection[str] = ()) -> None:
not_found = []

for model in get_all_models():
if model._meta.abstract:
continue

app_model_label = to_app_model_label(model)
app_model_label = to_app_model_label(model) # type: ignore[arg-type] # mypy can't see that models are hashable

if app_model_label not in settings.strategies:
if only and app_model_label not in only:
Expand All @@ -49,7 +53,7 @@ def validate_strategies(only=None):
)


def export_migration_state(django_dbname, dest):
def export_migration_state(django_dbname: str, dest: Path) -> None:
file_path = migrations_file_path(dest)
file_path.parent.mkdir(parents=True, exist_ok=True)

Expand All @@ -69,7 +73,12 @@ def export_migration_state(django_dbname, dest):
json.dump(migration_state, f, indent=4, cls=DjangoJSONEncoder)


def export_data(django_dbname, dest, only=None, no_update=False):
def export_data(
django_dbname: str,
dest: Path,
only: Collection[str] = (),
no_update: bool = False,
) -> None:
model_strategies = sort_model_strategies(settings.strategies)
bar = progress(model_strategies)
for app_model_label, strategy in bar:
Expand Down Expand Up @@ -100,7 +109,11 @@ def export_data(django_dbname, dest, only=None, no_update=False):
)


def export_extras(django_dbname, dest, no_update=False):
def export_extras(
django_dbname: str,
dest: Path,
no_update: bool = False,
) -> None:
bar = progress(settings.extra_strategies)
for strategy in bar:
bar.set_postfix({"extra": strategy.name})
Expand All @@ -114,7 +127,7 @@ def export_extras(django_dbname, dest, no_update=False):
)


def import_schema(src, django_dbname):
def import_schema(src: Path, django_dbname: str) -> None:
connection = connections[django_dbname]

with disable_migrations():
Expand Down Expand Up @@ -149,7 +162,7 @@ def import_schema(src, django_dbname):
)


def import_data(src, django_dbname):
def import_data(src: Path, django_dbname: str) -> None:
model_strategies = sort_model_strategies(settings.strategies)
bar = progress(model_strategies)
for app_model_label, strategy in bar:
Expand All @@ -160,14 +173,14 @@ def import_data(src, django_dbname):
strategy.import_data(django_dbname, src, model)


def import_extras(src, django_dbname):
def import_extras(src: Path, django_dbname: str) -> None:
bar = progress(settings.extra_strategies)
for strategy in bar:
bar.set_postfix({"extra": strategy.name})
strategy.import_data(django_dbname, src)


def import_cleanup(src, django_dbname):
def import_cleanup(src: Path, django_dbname: str) -> None:
conn = connections[django_dbname]
with conn.cursor() as cursor:
for reset_sql in conn.ops.sequence_reset_sql(
Expand Down
22 changes: 15 additions & 7 deletions src/devdata/extras.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,22 @@
from __future__ import annotations

import json
import textwrap
from pathlib import Path
from typing import Callable, Dict, Set, Tuple
from typing import Any, Callable

from django.db import connections

Logger = Callable[[object], None]
Logger = Callable[[str], None]


class ExtraImport:
"""
Base extra defining how to get data into a fresh database.
"""

depends_on = () # type: Tuple[str, ...]
name: str
depends_on: tuple[str, ...] = ()

def __init__(self) -> None:
pass
Expand All @@ -28,9 +31,9 @@ class ExtraExport:
Base extra defining how to get data out of an existing database.
"""

seen_names = set() # type: Set[Tuple[str, str]]
seen_names: set[str] = set()

def __init__(self, *args, name, **kwargs):
def __init__(self, *args: Any, name: str, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)

self.name = name
Expand Down Expand Up @@ -76,7 +79,12 @@ class PostgresSequences(ExtraExport, ExtraImport):
matching primary keys.
"""

def __init__(self, *args, name="postgres-sequences", **kwargs):
def __init__(
self,
*args: Any,
name: str = "postgres-sequences",
**kwargs: Any,
) -> None:
super().__init__(*args, name=name, **kwargs)

def export_data(
Expand Down Expand Up @@ -154,7 +162,7 @@ def import_data(self, django_dbname: str, src: Path) -> None:
with self.data_file(src).open() as f:
sequences = json.load(f)

def check_simple_value(mapping: Dict[str, str], *, key: str) -> str:
def check_simple_value(mapping: dict[str, str], *, key: str) -> str:
value = mapping[key]
if not value.replace("_", "").isalnum():
raise ValueError(f"{key} is not alphanumeric")
Expand Down
Loading
Loading