Skip to content

Commit

Permalink
More typing (draft)
Browse files Browse the repository at this point in the history
  • Loading branch information
hgrecco committed May 2, 2023
1 parent 5643c32 commit ce185c2
Show file tree
Hide file tree
Showing 7 changed files with 159 additions and 59 deletions.
4 changes: 4 additions & 0 deletions pint/_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,3 +63,7 @@ def __setitem__(self, key: Any, value: Any) -> None:

FuncType = Callable[..., Any]
F = TypeVar("F", bound=FuncType)


# TODO: Improve or delete types
QuantityArgument = Any
24 changes: 17 additions & 7 deletions pint/facets/context/objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,20 @@

import weakref
from collections import ChainMap, defaultdict
from typing import Any
from typing import Any, Callable
from collections.abc import Iterable

from ...facets.plain import UnitDefinition
from ...util import UnitsContainer, to_units_container
from .definitions import ContextDefinition
from ..._typing import Magnitude

Transformation = Callable[
[
Magnitude,
],
Magnitude,
]


class Context:
Expand Down Expand Up @@ -75,14 +83,14 @@ def __init__(
aliases: tuple[str] = tuple(),
defaults: dict[str, Any] | None = None,
) -> None:
self.name = name
self.aliases = aliases
self.name: str | None = name
self.aliases: tuple[str] = aliases

#: Maps (src, dst) -> transformation function
self.funcs = {}
self.funcs: dict[tuple[UnitsContainer, UnitsContainer], Transformation] = {}

#: Maps defaults variable names to values
self.defaults = defaults or {}
self.defaults: dict[str, Any] = defaults or {}

# Store Definition objects that are context-specific
self.redefinitions = []
Expand Down Expand Up @@ -154,7 +162,9 @@ def from_definition(cls, cd: ContextDefinition, to_base_func=None) -> Context:

return ctx

def add_transformation(self, src, dst, func) -> None:
def add_transformation(
self, src: UnitsContainer, dst: UnitsContainer, func: Transformation
) -> None:
"""Add a transformation function to the context."""

_key = self.__keytransform__(src, dst)
Expand Down Expand Up @@ -202,7 +212,7 @@ def _redefine(self, definition: UnitDefinition):

def hashable(
self,
) -> tuple[str | None, tuple[str, ...], frozenset, frozenset, tuple]:
) -> tuple[str | None, tuple[str], frozenset, frozenset, tuple]:
"""Generate a unique hashable and comparable representation of self, which can
be used as a key in a dict. This class cannot define ``__hash__`` because it is
mutable, and the Python interpreter does cache the output of ``__hash__``.
Expand Down
28 changes: 25 additions & 3 deletions pint/facets/group/objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,28 @@

from __future__ import annotations

from typing import Callable, Any, TYPE_CHECKING

from collections.abc import Generator, Iterable
from ...util import SharedRegistryObject, getattr_maybe_raise
from .definitions import GroupDefinition

if TYPE_CHECKING:
from ..plain import UnitDefinition

DefineFunc = Callable[
[
Any,
],
None,
]
AddUnitFunc = Callable[
[
UnitDefinition,
],
None,
]


class Group(SharedRegistryObject):
"""A group is a set of units.
Expand Down Expand Up @@ -57,7 +75,7 @@ def __init__(self, name: str):
self._computed_members: frozenset[str] | None = None

@property
def members(self):
def members(self) -> frozenset[str]:
"""Names of the units that are members of the group.
Calculated to include to all units in all included _used_groups.
Expand Down Expand Up @@ -143,7 +161,7 @@ def remove_groups(self, *group_names: str) -> None:

@classmethod
def from_lines(
cls, lines: Iterable[str], define_func, non_int_type: type = float
cls, lines: Iterable[str], define_func: DefineFunc, non_int_type: type = float
) -> Group:
"""Return a Group object parsing an iterable of lines.
Expand All @@ -160,11 +178,15 @@ def from_lines(
"""
group_definition = GroupDefinition.from_lines(lines, non_int_type)

if group_definition is None:
raise ValueError(f"Could not define group from {lines}")

return cls.from_definition(group_definition, define_func)

@classmethod
def from_definition(
cls, group_definition: GroupDefinition, add_unit_func=None
cls, group_definition: GroupDefinition, add_unit_func: AddUnitFunc | None = None
) -> Group:
grp = cls(group_definition.name)

Expand Down
128 changes: 86 additions & 42 deletions pint/facets/plain/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
from fractions import Fraction
from numbers import Number
from token import NAME, NUMBER
from tokenize import TokenInfo

from typing import (
TYPE_CHECKING,
Any,
Expand All @@ -33,7 +35,7 @@
from ..context import Context
from ..._typing import Quantity, Unit

from ..._typing import QuantityOrUnitLike, UnitLike
from ..._typing import QuantityOrUnitLike, UnitLike, QuantityArgument
from ..._vendor import appdirs
from ...compat import HAS_BABEL, babel_parse, tokenizer
from ...errors import DimensionalityError, RedefinitionError, UndefinedUnitError
Expand Down Expand Up @@ -75,8 +77,10 @@


@functools.lru_cache
def pattern_to_regex(pattern):
if hasattr(pattern, "finditer"):
def pattern_to_regex(pattern: str | re.Pattern[str]) -> re.Pattern[str]:
# TODO: This has been changed during typing improvements.
# if hasattr(pattern, "finditer"):
if not isinstance(pattern, str):
pattern = pattern.pattern

# Replace "{unit_name}" match string with float regex with unit_name as group
Expand Down Expand Up @@ -197,7 +201,15 @@ def __init__(
mpl_formatter: str = "{:P}",
):
#: Map a definition class to a adder methods.
self._adders = {}
self._adders: dict[
type[T],
Callable[
[
T,
],
None,
],
] = {}
self._register_definition_adders()
self._init_dynamic_classes()

Expand Down Expand Up @@ -297,7 +309,16 @@ def _after_init(self) -> None:
self._build_cache(loaded_files)
self._initialized = True

def _register_adder(self, definition_class, adder_func):
def _register_adder(
self,
definition_class: type[T],
adder_func: Callable[
[
T,
],
None,
],
) -> None:
"""Register a block definition."""
self._adders[definition_class] = adder_func

Expand All @@ -316,18 +337,18 @@ def __deepcopy__(self, memo) -> PlainRegistry:
new._init_dynamic_classes()
return new

def __getattr__(self, item):
def __getattr__(self, item: str) -> Unit:
getattr_maybe_raise(self, item)
return self.Unit(item)

def __getitem__(self, item):
def __getitem__(self, item: str):
logger.warning(
"Calling the getitem method from a UnitRegistry is deprecated. "
"use `parse_expression` method or use the registry as a callable."
)
return self.parse_expression(item)

def __contains__(self, item) -> bool:
def __contains__(self, item: str) -> bool:
"""Support checking prefixed units with the `in` operator"""
try:
self.__getattr__(item)
Expand Down Expand Up @@ -390,7 +411,7 @@ def cache_folder(self) -> pathlib.Path | None:
def non_int_type(self):
return self._non_int_type

def define(self, definition):
def define(self, definition: str | type) -> None:
"""Add unit to the registry.
Parameters
Expand All @@ -413,7 +434,7 @@ def define(self, definition):
# - then we define specific adder for each definition class. :-D
############

def _helper_dispatch_adder(self, definition):
def _helper_dispatch_adder(self, definition: Any) -> None:
"""Helper function to add a single definition,
choosing the appropiate method by class.
"""
Expand Down Expand Up @@ -474,19 +495,19 @@ def _add_alias(self, definition: AliasDefinition):
for alias in definition.aliases:
self._helper_single_adder(alias, unit, self._units, self._units_casei)

def _add_dimension(self, definition: DimensionDefinition):
def _add_dimension(self, definition: DimensionDefinition) -> None:
self._helper_adder(definition, self._dimensions, None)

def _add_derived_dimension(self, definition: DerivedDimensionDefinition):
def _add_derived_dimension(self, definition: DerivedDimensionDefinition) -> None:
for dim_name in definition.reference.keys():
if dim_name not in self._dimensions:
self._add_dimension(DimensionDefinition(dim_name))
self._helper_adder(definition, self._dimensions, None)

def _add_prefix(self, definition: PrefixDefinition):
def _add_prefix(self, definition: PrefixDefinition) -> None:
self._helper_adder(definition, self._prefixes, None)

def _add_unit(self, definition: UnitDefinition):
def _add_unit(self, definition: UnitDefinition) -> None:
if definition.is_base:
self._base_units.append(definition.name)
for dim_name in definition.reference.keys():
Expand Down Expand Up @@ -673,7 +694,7 @@ def _get_dimensionality_recurse(self, ref, exp, accumulator):
if reg.reference is not None:
self._get_dimensionality_recurse(reg.reference, exp2, accumulator)

def _get_dimensionality_ratio(self, unit1, unit2):
def _get_dimensionality_ratio(self, unit1: UnitLike, unit2: UnitLike):
"""Get the exponential ratio between two units, i.e. solve unit2 = unit1**x for x.
Parameters
Expand Down Expand Up @@ -780,7 +801,9 @@ def _get_root_units(self, input_units, check_nonmult=True):
cache[input_units] = factor, units
return factor, units

def get_base_units(self, input_units, check_nonmult=True, system=None):
def get_base_units(
self, input_units: UnitsContainer | str, check_nonmult: bool = True, system=None
):
"""Convert unit or dict of units to the plain units.
If any unit is non multiplicative and check_converter is True,
Expand Down Expand Up @@ -1104,7 +1127,32 @@ def _parse_units(

return ret

def _eval_token(self, token, case_sensitive=None, **values):
def _eval_token(
self,
token: TokenInfo,
case_sensitive: bool | None = None,
**values: QuantityArgument,
):
"""Evaluate a single token using the following rules:
1. numerical values as strings are replaced by their numeric counterparts
- integers are parsed as integers
- other numeric values are parses of non_int_type
2. strings in (inf, infinity, nan, dimensionless) with their numerical value.
3. strings in values.keys() are replaced by Quantity(values[key])
4. in other cases, the values are parsed as units and replaced by their canonical name.
Parameters
----------
token
Token to evaluate.
case_sensitive, optional
If true, a case sensitive matching of the unit name will be done in the registry.
If false, a case INsensitive matching of the unit name will be done in the registry.
(Default value = None, which uses registry setting)
**values
Other string that will be parsed using the Quantity constructor on their corresponding value.
"""
token_type = token[0]
token_text = token[1]
if token_type == NAME:
Expand Down Expand Up @@ -1139,28 +1187,25 @@ def parse_pattern(
Parameters
----------
input_string :
input_string
pattern_string:
The regex parse string
case_sensitive :
(Default value = None, which uses registry setting)
many :
The regex parse string
case_sensitive, optional
If true, a case sensitive matching of the unit name will be done in the registry.
If false, a case INsensitive matching of the unit name will be done in the registry.
(Default value = None, which uses registry setting)
many, optional
Match many results
(Default value = False)
Returns
-------
"""

if not input_string:
return [] if many else None

# Parse string
pattern = pattern_to_regex(pattern)
matched = re.finditer(pattern, input_string)
regex = pattern_to_regex(pattern)
matched = re.finditer(regex, input_string)

# Extract result(s)
results = []
Expand Down Expand Up @@ -1196,16 +1241,14 @@ def parse_expression(
Parameters
----------
input_string :
case_sensitive :
(Default value = None, which uses registry setting)
**values :
Returns
-------
input_string
case_sensitive, optional
If true, a case sensitive matching of the unit name will be done in the registry.
If false, a case INsensitive matching of the unit name will be done in the registry.
(Default value = None, which uses registry setting)
**values
Other string that will be parsed using the Quantity constructor on their corresponding value.
"""
if not input_string:
return self.Quantity(1)
Expand All @@ -1215,8 +1258,9 @@ def parse_expression(
input_string = string_preprocessor(input_string)
gen = tokenizer(input_string)

return build_eval_tree(gen).evaluate(
lambda x: self._eval_token(x, case_sensitive=case_sensitive, **values)
)
def _define_op(s: str):
return self._eval_token(s, case_sensitive=case_sensitive, **values)

return build_eval_tree(gen).evaluate(_define_op)

__call__ = parse_expression
Loading

0 comments on commit ce185c2

Please sign in to comment.