Skip to content

Commit

Permalink
Add type annotations to most of the codebase
Browse files Browse the repository at this point in the history
  • Loading branch information
PeterJCLaw committed Jan 8, 2024
1 parent 0b15ad4 commit 5d53137
Show file tree
Hide file tree
Showing 9 changed files with 131 additions and 52 deletions.
40 changes: 32 additions & 8 deletions src/devdata/anonymisers.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,24 @@
import pathlib
import random
from typing import Any, TypeVar

import faker
from django.db import models

from .types import Anonymiser, GenericAnonymiser
from .utils import get_exported_pks_for_model

T = TypeVar("T")


def faker_anonymise(
generator, *args, preserve_nulls=False, unique=False, **kwargs
):
def anonymise(*, pii_value, fake, **_kwargs):
generator: str,
*args: Any,
preserve_nulls: bool = False,
unique: bool = False,
**kwargs: Any,
) -> Anonymiser:
def anonymise(*, pii_value: T, fake: faker.Faker, **_kwargs: object) -> T:
if preserve_nulls and pii_value is None:
return None

Expand All @@ -16,8 +28,15 @@ def anonymise(*, pii_value, fake, **_kwargs):
return anonymise


def preserve_internal(alternative):
def anonymise(obj, field, pii_value, **kwargs):
def preserve_internal(
alternative: GenericAnonymiser[T],
) -> GenericAnonymiser[T]:
def anonymise(
obj: models.Model,
field: str,
pii_value: T,
**kwargs: Any,
) -> T:
if getattr(obj, "is_superuser", False) or getattr(
obj, "is_staff", False
):
Expand All @@ -27,16 +46,21 @@ def anonymise(obj, field, pii_value, **kwargs):
return anonymise


def const(value, preserve_nulls=False):
def anonymise(*_, pii_value, **_kwargs):
def const(value: T, preserve_nulls: bool = False) -> GenericAnonymiser[T]:
def anonymise(*_: object, pii_value: T, **_kwargs: object) -> T:
if preserve_nulls and pii_value is None:
return None
return value

return anonymise


def random_foreign_key(obj, field, dest, **_kwargs):
def random_foreign_key(
obj: models.Model,
field: str,
dest: pathlib.Path,
**_kwargs: object,
) -> Any:
related_model = obj._meta.get_field(field).related_model
exported_pks = get_exported_pks_for_model(dest, related_model)
return random.choice(exported_pks)
31 changes: 22 additions & 9 deletions src/devdata/engine.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
from __future__ import annotations

import json
from collections.abc import Collection
from pathlib import Path

from django.core.management import call_command
from django.core.management.color import no_style
Expand All @@ -20,14 +24,14 @@
)


def validate_strategies(only=None):
def validate_strategies(only: Collection[str] = ()) -> None:
not_found = []

for model in get_all_models():
if model._meta.abstract:
continue

app_model_label = to_app_model_label(model)
app_model_label = to_app_model_label(model) # type: ignore[arg-type] # mypy can't see that models are hashable

if app_model_label not in settings.strategies:
if only and app_model_label not in only:
Expand All @@ -49,7 +53,7 @@ def validate_strategies(only=None):
)


def export_migration_state(django_dbname, dest):
def export_migration_state(django_dbname: str, dest: Path) -> None:
file_path = migrations_file_path(dest)
file_path.parent.mkdir(parents=True, exist_ok=True)

Expand All @@ -69,7 +73,12 @@ def export_migration_state(django_dbname, dest):
json.dump(migration_state, f, indent=4, cls=DjangoJSONEncoder)


def export_data(django_dbname, dest, only=None, no_update=False):
def export_data(
django_dbname: str,
dest: Path,
only: Collection[str] = (),
no_update: bool = False,
) -> None:
model_strategies = sort_model_strategies(settings.strategies)
bar = progress(model_strategies)
for app_model_label, strategy in bar:
Expand Down Expand Up @@ -100,7 +109,11 @@ def export_data(django_dbname, dest, only=None, no_update=False):
)


def export_extras(django_dbname, dest, no_update=False):
def export_extras(
django_dbname: str,
dest: Path,
no_update: bool = False,
) -> None:
bar = progress(settings.extra_strategies)
for strategy in bar:
bar.set_postfix({"extra": strategy.name})
Expand All @@ -114,7 +127,7 @@ def export_extras(django_dbname, dest, no_update=False):
)


def import_schema(src, django_dbname):
def import_schema(src: Path, django_dbname: str) -> None:
connection = connections[django_dbname]

with disable_migrations():
Expand Down Expand Up @@ -149,7 +162,7 @@ def import_schema(src, django_dbname):
)


def import_data(src, django_dbname):
def import_data(src: Path, django_dbname: str) -> None:
model_strategies = sort_model_strategies(settings.strategies)
bar = progress(model_strategies)
for app_model_label, strategy in bar:
Expand All @@ -160,14 +173,14 @@ def import_data(src, django_dbname):
strategy.import_data(django_dbname, src, model)


def import_extras(src, django_dbname):
def import_extras(src: Path, django_dbname: str) -> None:
bar = progress(settings.extra_strategies)
for strategy in bar:
bar.set_postfix({"extra": strategy.name})
strategy.import_data(django_dbname, src)


def import_cleanup(src, django_dbname):
def import_cleanup(src: Path, django_dbname: str) -> None:
conn = connections[django_dbname]
with conn.cursor() as cursor:
for reset_sql in conn.ops.sequence_reset_sql(
Expand Down
9 changes: 5 additions & 4 deletions src/devdata/extras.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,19 @@
import json
import textwrap
from pathlib import Path
from typing import Callable, Dict, Set, Tuple
from typing import Any, Callable, Dict, Set, Tuple

from django.db import connections

Logger = Callable[[object], None]
Logger = Callable[[str], None]


class ExtraImport:
"""
Base extra defining how to get data into a fresh database.
"""

name: str
depends_on = () # type: Tuple[str, ...]

def __init__(self) -> None:
Expand All @@ -28,9 +29,9 @@ class ExtraExport:
Base extra defining how to get data out of an existing database.
"""

seen_names = set() # type: Set[Tuple[str, str]]
seen_names = set() # type: Set[str]

def __init__(self, *args, name, **kwargs):
def __init__(self, *args: Any, name: str, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)

self.name = name
Expand Down
13 changes: 12 additions & 1 deletion src/devdata/management/commands/devdata_export.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

import argparse
from pathlib import Path

Expand All @@ -22,6 +24,7 @@ def add_arguments(self, parser: CommandParser) -> None:
nargs=argparse.OPTIONAL,
help="Export destination",
default="./devdata",
type=Path,
)
parser.add_argument(
"only",
Expand All @@ -40,7 +43,15 @@ def add_arguments(self, parser: CommandParser) -> None:
action="store_true",
)

def handle(self, *, dest, only=None, database, no_update, **options):
def handle(
self,
*,
dest: Path,
only: list[str],
database: str,
no_update: bool,
**options: object,
) -> None:
try:
for app_model_label in only:
apps.get_model(app_model_label, require_ready=False)
Expand Down
12 changes: 10 additions & 2 deletions src/devdata/management/commands/devdata_import.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import_extras,
validate_strategies,
)
from ...reset_modes import MODES, DropDatabaseReset
from ...reset_modes import MODES, DropDatabaseReset, Reset
from ...settings import settings


Expand All @@ -25,6 +25,7 @@ def add_arguments(self, parser):
nargs=argparse.OPTIONAL,
help="Import source",
default="./devdata",
type=Path,
)
parser.add_argument(
"--database",
Expand All @@ -44,7 +45,14 @@ def add_arguments(self, parser):
action="store_true",
)

def handle(self, src, database, reset_mode, no_input=False, **options):
def handle(
self,
src: Path,
database: str,
reset_mode: Reset,
no_input: bool = False,
**options: object
) -> None:
try:
validate_strategies()
except AssertionError as e:
Expand Down
26 changes: 16 additions & 10 deletions src/devdata/settings.py
Original file line number Diff line number Diff line change
@@ -1,39 +1,45 @@
from __future__ import annotations

from typing import Any
from typing import TYPE_CHECKING, Any, TypeVar

from django.conf import settings as django_settings
from django.utils.module_loading import import_string

from .extras import ExtraImport
from .types import Anonymiser
from .utils import get_all_models, to_app_model_label

if TYPE_CHECKING:
from .strategies import Strategy

T = TypeVar("T")

DEFAULT_FIELD_ANONYMISERS: dict[str, Anonymiser] = {}
DEFAULT_MODEL_ANONYMISERS: dict[str, Anonymiser] = {}
DEFAULT_FAKER_LOCALES = ["en_US"]


def import_strategy(strategy):
def import_strategy(strategy: tuple[str, dict[str, Any]] | T) -> T:
try:
klass_path, kwargs = strategy
klass_path, kwargs = strategy # type: ignore[misc]
klass = import_string(klass_path)
return klass(**kwargs)
except (ValueError, TypeError, IndexError):
return strategy
return strategy # type: ignore[return-value]


class Settings:
@property
def strategies(self):
def strategies(self) -> dict[str, list[Strategy]]:
model_strategies = django_settings.DEVDATA_STRATEGIES

ret = {}
ret: dict[str, list[Strategy]] = {}

for model in get_all_models():
if model._meta.abstract:
continue

app_model_label = to_app_model_label(model)
app_model_label = to_app_model_label(model) # type: ignore[arg-type] # mypy can't see that models are hashable

ret[app_model_label] = []
strategies = model_strategies.get(app_model_label)
Expand All @@ -55,22 +61,22 @@ def strategies(self):
return ret

@property
def extra_strategies(self):
def extra_strategies(self) -> list[ExtraImport]:
return [
import_strategy(x)
for x in getattr(django_settings, "DEVDATA_EXTRA_STRATEGIES", ())
]

@property
def field_anonymisers(self):
def field_anonymisers(self) -> dict[str, Anonymiser]:
return getattr(
django_settings,
"DEVDATA_FIELD_ANONYMISERS",
DEFAULT_FIELD_ANONYMISERS,
)

@property
def model_anonymisers(self):
def model_anonymisers(self) -> dict[str, Anonymiser]:
return getattr(
django_settings,
"DEVDATA_MODEL_ANONYMISERS",
Expand Down
5 changes: 3 additions & 2 deletions src/devdata/strategies.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Set, Tuple
from typing import Any, Set, Tuple

from django.core import serializers
from django.db import models
Expand All @@ -18,6 +18,7 @@ class Strategy:
database.
"""

name: str
depends_on = () # type: Tuple[str, ...]

def __init__(self):
Expand All @@ -36,7 +37,7 @@ class Exportable:

seen_names = set() # type: Set[Tuple[str, str]]

def __init__(self, *args, name, **kwargs):
def __init__(self, *args: Any, name: str, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)

self.name = name
Expand Down
9 changes: 7 additions & 2 deletions src/devdata/types.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,23 @@
import pathlib
from typing import Any, Protocol, TypeVar
from typing import Any, Generic, Protocol, TypeVar

import faker

T = TypeVar("T")


class Anonymiser(Protocol):
class GenericAnonymiser(Generic[T], Protocol):
def __call__(
self,
*,
obj: Any,
field: str,
pii_value: T,
fake: faker.Faker,
dest: pathlib.Path,
) -> T:
...


class Anonymiser(GenericAnonymiser[Any], Protocol):
pass
Loading

0 comments on commit 5d53137

Please sign in to comment.