Skip to content

Commit

Permalink
Add support for python typing of Fields
Browse files Browse the repository at this point in the history
Add support for specifying the type of Fields with python class
typing syntax. This will ensure the appropriate type is inferred
from a Fields __get__ method, type checking when setting with
__set__, as well as setting appropriate dtypes where applicable at
runtime.
  • Loading branch information
natelust committed Jun 23, 2022
1 parent cc71245 commit f3c582d
Show file tree
Hide file tree
Showing 8 changed files with 371 additions and 50 deletions.
150 changes: 139 additions & 11 deletions python/lsst/pex/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,14 @@
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.

__all__ = ("Config", "ConfigMeta", "Field", "FieldValidationError", "UnexpectedProxyUsageError")
__all__ = (
"Config",
"ConfigMeta",
"Field",
"FieldValidationError",
"UnexpectedProxyUsageError",
"FieldTypeVar"
)

import copy
import importlib
Expand All @@ -38,14 +45,15 @@
import tempfile
import warnings

from typing import Generic, TypeVar, ForwardRef, cast, Any, MutableMapping, overload
from types import GenericAlias

# if YAML is not available that's fine and we simply don't register
# the yaml representer since we know it won't be used.
try:
import yaml
except ImportError:
yaml = None
YamlLoaders = ()
doImport = None

from .callStack import getCallStack, getStackFrame
from .comparison import compareConfigs, compareScalars, getComparisonName
Expand All @@ -60,6 +68,18 @@
YamlLoaders += (CLoader,)
except ImportError:
pass
else:
YamlLoaders = ()
doImport = None


class PexGenericAlias(GenericAlias):
def __call__(self, *args: Any, **kwds: Any) -> Any:
origin_kwargs = self._parseTypingArgs(self.__args__, kwds)
return super().__call__(*args, **(kwds | origin_kwargs))


FieldTypeVar = TypeVar("FieldTypeVar")


class UnexpectedProxyUsageError(TypeError):
Expand Down Expand Up @@ -258,18 +278,19 @@ def __init__(self, field, config, msg):
super().__init__(error)


class Field:
class Field(Generic[FieldTypeVar]):
"""A field in a `~lsst.pex.config.Config` that supports `int`, `float`,
`complex`, `bool`, and `str` data types.
Parameters
----------
doc : `str`
A description of the field for users.
dtype : type
dtype : type, optional
The field's data type. ``Field`` only supports basic data types:
`int`, `float`, `complex`, `bool`, and `str`. See
`Field.supportedTypes`.
`Field.supportedTypes`. Optional if supplied as a typing argument to
the class.
default : object, optional
The field's default value.
check : callable, optional
Expand Down Expand Up @@ -319,6 +340,13 @@ class Field:
container type (like a `lsst.pex.config.List`) depending on the field's
type. See the example, below.
Fields can be annotated with a type similar to other python classes (see
`here <https://peps.python.org/pep-0484/#generics>`_ ). Unlike most other
uses in python, this has an effect at type checking *and* runtime. If the
type is specified with a class annotation, it will be used as the value of
the ``dtype`` in the ``Field`` and there is no need to specify it as an
argument during instantiation.
Examples
--------
Instances of ``Field`` should be used as class attributes of
Expand All @@ -334,12 +362,75 @@ class Field:
>>> print(config.myInt)
5
"""
name: str
"""Identifier (variable name) used to refer to a Field within a Config
Class.
"""

supportedTypes = set((str, bool, float, int, complex))
"""Supported data types for field values (`set` of types).
"""

def __init__(self, doc, dtype, default=None, check=None, optional=False, deprecated=None):
@staticmethod
def _parseTypingArgs(
params: tuple[type, ...] | tuple[str, ...],
kwds: MutableMapping[str, Any]
) -> MutableMapping[str, Any]:
"""Parses type annotations into keyword constructor arguments.
This is a special private method that interprets type arguments (i.e.
Field[str]) into keyword arguments to be passed on to the constructor.
Subclasses of Field can implement this method to customize how they
handle turning type parameters into keyword arguments (see DictField
for an example)
Parameters
----------
params : `tuple` of `type` or `tuple` of str
Parameters passed to the type annotation. These will either be
types or strings. Strings are to interpreted as forward references
and will be treated as such.
kwds : `MutableMapping` with keys of `str` and values of `Any`
These are the user supplied keywords that are to be passed to the
Field constructor.
Returns
-------
kwds : `MutableMapping` with keys of `str` and values of `Any`
The mapping of keywords that will be passed onto the constructor
of the Field. Should be filled in with any information gleaned
from the input parameters.
Raises
------
ValueError :
Raised if params is of incorrect length.
Raised if a forward reference could not be resolved
Raised if there is a conflict between params and values in kwds
"""
if len(params) > 1:
raise ValueError("Only single type parameters are supported")
unpackedParams = params[0]
if isinstance(unpackedParams, str):
unpackedParams = ForwardRef(unpackedParams)
if (result := unpackedParams._evaluate(globals(), locals(), set())) is None: # type: ignore
raise ValueError("Could not deduce type from input")
unpackedParams = cast(type, result)
if 'dtype' in kwds and kwds['dtype'] != unpackedParams:
raise ValueError("Conflicting definition for dtype")
elif 'dtype' not in kwds:
kwds['dtype'] = unpackedParams
return kwds

def __class_getitem__(cls, params: tuple[type, ...] | type | ForwardRef):
return PexGenericAlias(cls, params)

def __init__(self, doc, dtype=None, default=None, check=None, optional=False, deprecated=None):
if dtype is None:
raise ValueError(
"dtype must either be supplied as an argument or as a type argument to the class"
)
if dtype not in self.supportedTypes:
raise ValueError("Unsupported Field dtype %s" % _typeStr(dtype))

Expand Down Expand Up @@ -570,7 +661,33 @@ def toDict(self, instance):
"""
return self.__get__(instance)

def __get__(self, instance, owner=None, at=None, label="default"):
@overload
def __get__(
self,
instance: None,
owner: Any = None,
at: Any = None,
label: str = "default"
) -> "Field":
...

@overload
def __get__(
self,
instance: "Config",
owner: Any = None,
at: Any = None,
label: str = "default"
) -> FieldTypeVar:
...

def __get__(
self,
instance,
owner=None,
at=None,
label="default"
):
"""Define how attribute access should occur on the Config instance
This is invoked by the owning config object and should not be called
directly
Expand Down Expand Up @@ -599,7 +716,13 @@ def __get__(self, instance, owner=None, at=None, label="default"):
" incorrectly initialized"
)

def __set__(self, instance, value, at=None, label="assignment"):
def __set__(
self,
instance: "Config",
value: FieldTypeVar | None,
at: Any = None,
label: str = "assignment"
) -> None:
"""Set an attribute on the config instance.
Parameters
Expand Down Expand Up @@ -736,7 +859,7 @@ def __init__(self):

def __enter__(self):
self.origMetaPath = sys.meta_path
sys.meta_path = [self] + sys.meta_path
sys.meta_path = [self] + sys.meta_path # type: ignore
return self

def __exit__(self, *args):
Expand Down Expand Up @@ -764,7 +887,8 @@ def getModules(self):
return self._modules


class Config(metaclass=ConfigMeta):
# type ignore because type checker thinks ConfigMeta is Generic when it is not
class Config(metaclass=ConfigMeta): # type: ignore
"""Base class for configuration (*config*) objects.
Notes
Expand Down Expand Up @@ -809,6 +933,10 @@ class behavior.
>>> print(config.listField)
['coffee', 'green tea', 'water', 'earl grey tea']
"""
_storage: MutableMapping[str, Any]
_fields: MutableMapping[str, Field]
_history: MutableMapping[str, list[Any]]
_imports: set[Any]

def __iter__(self):
"""Iterate over fields."""
Expand Down
9 changes: 7 additions & 2 deletions python/lsst/pex/config/configChoiceField.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@
import copy
import weakref

from typing import ForwardRef

from .callStack import getCallStack, getStackFrame
from .comparison import compareConfigs, compareScalars, getComparisonName
from .config import Config, Field, FieldValidationError, UnexpectedProxyUsageError, _joinNamePath, _typeStr
Expand Down Expand Up @@ -143,7 +145,7 @@ def __reduce__(self):
)


class ConfigInstanceDict(collections.abc.Mapping):
class ConfigInstanceDict(collections.abc.Mapping[str, Config]):
"""Dictionary of instantiated configs, used to populate a
`~lsst.pex.config.ConfigChoiceField`.
Expand Down Expand Up @@ -479,6 +481,9 @@ def __init__(self, doc, typemap, default=None, optional=False, multi=False, depr
self.typemap = typemap
self.multi = multi

def __class_getitem__(cls, params: tuple[type, ...] | type | ForwardRef):
raise ValueError("ConfigChoiceField does not support typing argument")

def _getOrMake(self, instance, label="default"):
instanceDict = instance._storage.get(self.name)
if instanceDict is None:
Expand All @@ -491,7 +496,7 @@ def _getOrMake(self, instance, label="default"):

return instanceDict

def __get__(self, instance, owner=None):
def __get__(self, instance, owner=None) -> instanceDictClass:
if instance is None or not isinstance(instance, Config):
return self
else:
Expand Down
16 changes: 12 additions & 4 deletions python/lsst/pex/config/configField.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,14 @@

__all__ = ["ConfigField"]

from typing import Any

from .callStack import getCallStack, getStackFrame
from .comparison import compareConfigs, getComparisonName
from .config import Config, Field, FieldValidationError, _joinNamePath, _typeStr
from .config import Config, Field, FieldValidationError, _joinNamePath, _typeStr, FieldTypeVar


class ConfigField(Field):
class ConfigField(Field[FieldTypeVar]):
"""A configuration field (`~lsst.pex.config.Field` subclass) that takes a
`~lsst.pex.config.Config`-type as a value.
Expand Down Expand Up @@ -94,7 +96,7 @@ def __init__(self, doc, dtype, default=None, check=None, deprecated=None):
deprecated=deprecated,
)

def __get__(self, instance, owner=None):
def __get__(self, instance, owner=None) -> FieldTypeVar:
if instance is None or not isinstance(instance, Config):
return self
else:
Expand All @@ -105,7 +107,13 @@ def __get__(self, instance, owner=None):
self.__set__(instance, self.default, at=at, label="default")
return value

def __set__(self, instance, value, at=None, label="assignment"):
def __set__(
self,
instance: Config,
value: FieldTypeVar,
at: Any = None,
label: str = "assignment"
) -> None:
if instance._frozen:
raise FieldValidationError(self, instance, "Cannot modify a frozen Config")
name = _joinNamePath(prefix=instance._name, name=self.name)
Expand Down
31 changes: 24 additions & 7 deletions python/lsst/pex/config/configurableField.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,22 @@
import copy
import weakref

from typing import Generic, MutableMapping, ForwardRef, Any

from .callStack import getCallStack, getStackFrame
from .comparison import compareConfigs, getComparisonName
from .config import Config, Field, FieldValidationError, UnexpectedProxyUsageError, _joinNamePath, _typeStr


class ConfigurableInstance:
from .config import (
Config,
Field,
FieldValidationError,
UnexpectedProxyUsageError,
_joinNamePath,
_typeStr,
FieldTypeVar
)


class ConfigurableInstance(Generic[FieldTypeVar]):
"""A retargetable configuration in a `ConfigurableField` that proxies
a `~lsst.pex.config.Config`.
Expand Down Expand Up @@ -182,7 +192,7 @@ def __reduce__(self):
)


class ConfigurableField(Field):
class ConfigurableField(Field[FieldTypeVar]):
"""A configuration field (`~lsst.pex.config.Field` subclass) that can be
can be retargeted towards a different configurable (often a
`lsst.pipe.base.Task` subclass).
Expand Down Expand Up @@ -299,6 +309,13 @@ def __init__(self, doc, target, ConfigClass=None, default=None, check=None, depr
self.target = target
self.ConfigClass = ConfigClass

@staticmethod
def _parseTypingArgs(
params: tuple[type, ...] | tuple[ForwardRef, ...],
kwds: MutableMapping[str, Any]
) -> MutableMapping[str, Any]:
return kwds

def __getOrMake(self, instance, at=None, label="default"):
value = instance._storage.get(self.name, None)
if value is None:
Expand All @@ -308,9 +325,9 @@ def __getOrMake(self, instance, at=None, label="default"):
instance._storage[self.name] = value
return value

def __get__(self, instance, owner=None, at=None, label="default"):
def __get__(self, instance, owner=None, at=None, label="default") -> ConfigurableInstance[FieldTypeVar]:
if instance is None or not isinstance(instance, Config):
return self
return self # type: ignore
else:
return self.__getOrMake(instance, at=at, label=label)

Expand Down

0 comments on commit f3c582d

Please sign in to comment.