diff --git a/gin/local/__init__.py b/gin/local/__init__.py new file mode 100644 index 0000000..60ccad5 --- /dev/null +++ b/gin/local/__init__.py @@ -0,0 +1,19 @@ +# coding=utf-8 +# Copyright 2020 The Gin-Config Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Init file for the `gin.local` package.""" + +from gin.local.config import bind +from gin.local.config import Config diff --git a/gin/local/config.py b/gin/local/config.py new file mode 100644 index 0000000..f427434 --- /dev/null +++ b/gin/local/config.py @@ -0,0 +1,191 @@ +# coding=utf-8 +# Copyright 2020 The Gin-Config Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Defines the `gin.Config` class and associated functions.""" + +import copy +import inspect +from typing import Any, Callable, Dict, Optional + +from gin.local import partial + +import tree + + +class ConfigState: + """Encapsulates state associated with a `Config`. + + This is separated out into its own class to avoid any possibility of name + collisions when assigning parameters to a `Config` instance. + """ + + def __init__(self, fn_or_cls: Callable[..., Any]): + self.fn_or_cls = fn_or_cls + self.signature = inspect.signature(fn_or_cls) + self.has_kwargs = any( # Used to disable param name validation. + param.kind == inspect.Parameter.VAR_KEYWORD + for param in self.signature.parameters.values()) + self.call = False + + def validate(self, arg_name, unused_value): + """Validates `arg_name` to ensure it is a parameter of `self.fn_or_cls`.""" + if not self.has_kwargs and arg_name not in self.signature.parameters: + raise TypeError(f"No argument named '{arg_name}' in {self.fn_or_cls}.") + + +class Config: + """Captures configuration for a specific function or class. + + This class represents the configuration for a given function or class, + exposing configured parameters as mutable attributes. For example, for a class + + TestClass: + + def __init__(self, arg, kwarg=None): + self.arg = arg + self.kwarg = kwarg + + a configuration may (for instance) be accomplished via + + class_config = Config(TestClass, kwarg='kwarg') + class_config.arg = 1 + + This `Config` instance may then be passed to the `bind` function to obtain + a "partial class" with values bound for the `arg` and `kwarg` parameters of + the test class constructor: + + partial_class = bind(class_config) + instance = partial_class() + assert instance.arg == 'arg' + assert instance.kwarg == 'kwarg' + + A given `Config` instance may be "called" to make it an "instance + configuration". This will have the effect that when `bind` is called, the + result of calling the corresponding partial will be provided instead of the + partial itself: + + instance_config = class_config() + instance = bind(instance_config) + assert instance.arg == 'arg' + assert instance.kwarg == 'kwarg' + + The instance config becomes separated from the class config, so any further + changes to `class_config` are not reflected by `instance_config` (and vice + versa). + """ + + __state__: ConfigState # Lets pytype know about the __state__ attribute. + + def __init__(self, fn_or_cls: Callable[..., Any], *args, **kwargs): + """Initialize for `fn_or_cls`, optionally specifying parameters. + + Args: + fn_or_cls: The function or class to configure. + *args: Any positional arguments to configure for `fn_or_cls`. + **kwargs: Any keyword arguments to configure for `fn_or_cls`. + """ + super().__setattr__('__state__', ConfigState(fn_or_cls)) + signature = self.__state__.signature + bound_arguments = signature.bind_partial(*args, **kwargs) + for name, value in bound_arguments.arguments.items(): + if signature.parameters[name].kind == inspect.Parameter.POSITIONAL_ONLY: + raise ValueError('Positional only arguments not supported.') + if signature.parameters[name].kind == inspect.Parameter.VAR_POSITIONAL: + raise ValueError('Variable positional arguments not supported.') + setattr(self, name, value) + + # Providing this pass-through method prevents spurious pytype errors. + def __getattr__(self, name: str): + """Get parameter with given `name`.""" + super().__getattribute__(name) + + def __setattr__(self, name: str, value: Any): + """Sets parameter `name` to `value`.""" + self.__state__.validate(name, value) # Make sure it's a valid param name. + super().__setattr__(name, value) + + def __repr__(self): + formatted_fn_or_cls = self.__state__.fn_or_cls.__qualname__ + formatted_params = [f'{k}={v}' for k, v in params(self).items()] + return f"Config[{formatted_fn_or_cls}]({', '.join(formatted_params)})" + + def __copy__(self): + config_copy = object.__new__(type(self)) + new_dict = copy.copy(self.__dict__) + new_dict['__state__'] = copy.deepcopy(self.__state__) + config_copy.__dict__.update(new_dict) + return config_copy + + def __call__(self): + """Creates a "called" copy of this `Config` instance.""" + if self.__state__.call: + raise ValueError('The config has already been marked as called.') + new_config = copy.copy(self) + new_config.__state__.call = True + return new_config + + +def params(config: Config): + """Returns a dictionary of the parameters specified by `config`.""" + return { + name: value for name, value in vars(config).items() if name != '__state__' + } + + +def bind(config: Config, memo: Optional[Dict[Config, Any]] = None) -> Any: + """Binds `config`, returning a `partial` with bound parameters. + + This is the core function for turning a `Config` into a (partially) bound + object. It recursively walks through `config`'s parameters, binding any nested + `Config` instances. The returned result is a callable `partial` with all + config parameters set. + + If the same `Config` instance is seen multiple times during traversal of the + configuration tree, `bind` is called only once (for the first instance + encountered), and the result is reused for subsequent copies of the instance. + This is achieved via the `memo` dictionary (similar to `deepcopy`). This has + the effect that for configured class instances, each separate config instance + is in one-to-one correspondence with an actual instance of the configured + class after calling `bind` (shared config instances <=> shared class + instances). + + Args: + config: A `Config` instance to bind. + memo: An optional dictionary mapping `Config` instances to their "bound" + values. This is used to map shared instances of a "instantiated" `Config` + in the configuration tree to a single shared object instance/value after + binding. If an empty dictionary is supplied, it will be filled with a + mapping of all `Config` instances in the full tree reachable from `config` + to their corresponding partial or instance values. + + Returns: + The bound version of `config`. + """ + memo = {} if memo is None else memo + + def map_fn(leaf): + return bind(leaf, memo) if isinstance(leaf, Config) else leaf + + if config not in memo: + kwargs = {} + for name, value in params(config).items(): + value = tree.map_structure(map_fn, value) + kwargs[name] = value + state = config.__state__ + bindings = state.signature.bind_partial(**kwargs) + result = partial.partial(state.fn_or_cls, *bindings.args, **bindings.kwargs) + memo[config] = result() if state.call else result + + return memo[config] diff --git a/gin/local/partial.py b/gin/local/partial.py new file mode 100644 index 0000000..3808f1d --- /dev/null +++ b/gin/local/partial.py @@ -0,0 +1,117 @@ +# coding=utf-8 +# Copyright 2020 The Gin-Config Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Defines a generic `partial` that works for both classes and functions.""" + +import functools +import inspect + +from typing import Any, Callable, Type + + +def _make_meta_call_wrapper(cls: Type[object]): + """Creates a pickle-compatible wrapper for `type(cls).__call__`. + + This function works in tandem with `_decorate_fn_or_cls` below. It wraps + `type(cls).__call__`, which is in general responsible for creating a new + instance of `cls` or one of its subclasses. In cases where the to-be-created + class is Gin's dynamically-subclassed version of `cls`, the wrapper here + instead returns an instance of `cls`, which isn't a dynamic subclass and more + generally doesn't have any Gin-related magic applied. This means the instance + is compatible with pickling, and is totally transparent to any inspections by + user code (since it really is an instance of the original type). + + Args: + cls: The class whose metaclass's call method should be wrapped. + + Returns: + A wrapped version of the `type(cls).__call__`. + """ + cls_meta = type(cls) + + @functools.wraps(cls_meta.__call__) + def meta_call_wrapper(new_cls: Type[object], *args, **kwargs): + # If `new_cls` (the to-be-created class) is a direct subclass of `cls`, we + # can be sure that it's Gin's dynamically created subclass. In this case, + # we directly create an instance of `cls` instead. Otherwise, some further + # dynamic subclassing by user code has likely occurred, and we just create + # an instance of `new_cls` to avoid issues. This instance is likely not + # compatible with pickle, but that's generally true of dynamically created + # subclasses and would require some user workaround with or without Gin. + if new_cls.__bases__ == (cls,): + new_cls = cls + return cls_meta.__call__(new_cls, *args, **kwargs) + + return meta_call_wrapper + + +def partialclass(cls, *args, **kwargs): + """Creates a class with partially-specified parameters. + + This class should generally behave interchangeably with `cls` in most + settings. The method used here is to create a dynamic subclass of `cls`, with + a metaclass which is itself a dynamic subclass of `cls`'s metaclass. This + metaclass supplies partial parameters to `cls` during instance creation, with + that result that constructing a `partial_cls` yields actual instances of + `cls`. + + The returned `partial_cls` will have the following properties: + + - `issubclass(partial_cls, cls) == True` + - `issubclass(cls, partial_cls) == False` + - `isinstance(partial_cls, type(cls)) == True` + - `type(partial_cls(...)) == cls` + + Args: + cls: The class to partially specify parameters for. + *args: Positional parameters to provide when constructing `cls`. + **kwargs: Keyword arguments to provide when constructing `cls`. + + Returns: + A dynamic subclass of `cls`, with parameters partially specified. + """ + cls_meta = type(cls) + meta_call = _make_meta_call_wrapper(cls) # See this for more details. + # Construct a new metaclass, subclassing the one from `cls`, supplying our + # decorated `__call__`. Most often this is just subclassing Python's `type`, + # but when `cls` has a custom metaclass set, this ensures that it will + # continue to work properly. + decorating_meta = type(cls_meta)(cls_meta.__name__, (cls_meta,), { + '__call__': functools.partialmethod(meta_call, *args, **kwargs), + }) + # Now we construct our class. This is a subclass of `cls`, but only with + # wrapper-related overrides, since currying parameters is all handled via the + # metaclass's `__call__` method. Note that we let '__annotations__' simply get + # forwarded to the base class, since creating a new type doesn't set this + # attribute by default. + overrides = { + attr: getattr(cls, attr) + for attr in ('__module__', '__name__', '__qualname__', '__doc__') + } + # If `cls` won't have a `__dict__` attribute, disable `__dict__` creation on + # our subclass as well. This seems like generally correct behavior, and also + # prevents errors that can arise under some very specific circumstances due + # to a CPython bug in type creation. + if getattr(cls, '__dictoffset__', None) == 0: + overrides['__slots__'] = () + # Finally, create the partial class using the metaclass created above. + return decorating_meta(cls.__name__, (cls,), overrides) + + +def partial(fn_or_cls: Callable[..., Any], *args, **kwargs): + if inspect.isclass(fn_or_cls): + return partialclass(fn_or_cls, *args, **kwargs) + else: + return functools.partial(fn_or_cls, *args, **kwargs) diff --git a/tests/local/config_test.py b/tests/local/config_test.py new file mode 100644 index 0000000..30b4ed8 --- /dev/null +++ b/tests/local/config_test.py @@ -0,0 +1,157 @@ +# coding=utf-8 +# Copyright 2020 The Gin-Config Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for the `gin.local.config` module.""" + +from absl.testing import absltest +from gin.local import config + +import tree + + +class TestClass: + + def __init__(self, arg1, arg2, kwarg1=None, kwarg2=None): + self.arg1 = arg1 + self.arg2 = arg2 + self.kwarg1 = kwarg1 + self.kwarg2 = kwarg2 + + +def test_function(arg1, arg2, kwarg1=None, kwarg2=None): + return (arg1, arg2, kwarg1, kwarg2) + + +class ConfigTest(absltest.TestCase): + + def test_config_for_classes(self): + class_config = config.Config(TestClass, 1, kwarg2='kwarg2') + self.assertEqual(class_config.arg1, 1) + self.assertEqual(class_config.kwarg2, 'kwarg2') + class_config.arg1 = 'arg1' + self.assertEqual(class_config.arg1, 'arg1') + class_config.arg2 = 'arg2' + class_config.kwarg1 = 'kwarg1' + + partial_class = config.bind(class_config) + instance = partial_class() + self.assertEqual(instance.arg1, 'arg1') + self.assertEqual(instance.arg2, 'arg2') + self.assertEqual(instance.kwarg1, 'kwarg1') + self.assertEqual(instance.kwarg2, 'kwarg2') + + def test_config_for_functions(self): + function_config = config.Config(test_function, 1, kwarg2='kwarg2') + self.assertEqual(function_config.arg1, 1) + self.assertEqual(function_config.kwarg2, 'kwarg2') + function_config.arg1 = 'arg1' + self.assertEqual(function_config.arg1, 'arg1') + function_config.arg2 = 'arg2' + function_config.kwarg1 = 'kwarg1' + + partial_function = config.bind(function_config) + self.assertEqual(partial_function(), ('arg1', 'arg2', 'kwarg1', 'kwarg2')) + + def test_nested_configs(self): + function_config1_args = ('innermost1', 'innermost2', 'kw1', 'kw2') + function_config1 = config.Config(test_function, *function_config1_args) + + class_config = config.Config( + TestClass, arg1=function_config1, arg2=function_config1()) + function_config2 = config.Config( + test_function, arg1=class_config, arg2=class_config()) + + function_config2_args = config.bind(function_config2)() + + test_class = function_config2_args[0] + self.assertTrue(issubclass(test_class, TestClass)) + + test_class_instance = test_class() + self.assertEqual(type(test_class_instance), TestClass) + self.assertEqual(test_class_instance.arg1(), function_config1_args) + self.assertEqual(test_class_instance.arg2, function_config1_args) + + test_class_instance = function_config2_args[1] + self.assertEqual(type(test_class_instance), TestClass) + self.assertEqual(test_class_instance.arg1(), function_config1_args) + self.assertEqual(test_class_instance.arg2, function_config1_args) + + def test_instance_sharing(self): + class_config = config.Config( + TestClass, 'arg1', 'arg2', kwarg1='kwarg1', kwarg2='kwarg2') + instance_config = class_config() + instance_config.arg1 = 'shared_arg1' + + # Changing instance config parameters doesn't change the class config. + self.assertEqual(class_config.arg1, 'arg1') + + function_config = config.Config(test_function, class_config(), { + 'key1': [instance_config, instance_config], + 'key2': (instance_config,) + }) + + memo = {} + function_args = config.bind(function_config(), memo=memo) + separate_instance = function_args[0] + shared_instance = memo[instance_config] + structure = function_args[1] + + self.assertIsNot(shared_instance, separate_instance) + for leaf in tree.flatten(structure): + self.assertIs(leaf, shared_instance) + + self.assertEqual(separate_instance.arg1, 'arg1') + self.assertEqual(shared_instance.arg1, 'shared_arg1') + + def test_memo_override(self): + class_config = config.Config( + TestClass, 'arg1', 'arg2', kwarg1='kwarg1', kwarg2='kwarg2') + instance_config = class_config() + function_config = config.Config(test_function, instance_config, { + 'key1': [instance_config, instance_config], + 'key2': (instance_config,) + }) + + overridden_instance_value = object() + memo = {instance_config: overridden_instance_value} + function_args = config.bind(function_config(), memo=memo) + instance = function_args[0] + structure = function_args[1] + + self.assertIs(instance, overridden_instance_value) + for leaf in tree.flatten(structure): + self.assertIs(leaf, overridden_instance_value) + + def test_call_config_twice_error(self): + class_config = config.Config( + TestClass, 'arg1', 'arg2', kwarg1='kwarg1', kwarg2='kwarg2') + expected_err_msg = r'The config has already been marked as called\.' + with self.assertRaisesRegex(ValueError, expected_err_msg): + class_config()() + + def test_params(self): + class_config = config.Config( + TestClass, 'arg1', 'arg2', kwarg1='kwarg1', kwarg2='kwarg2') + params = config.params(class_config) + self.assertEqual(params, { + 'arg1': 'arg1', + 'arg2': 'arg2', + 'kwarg1': 'kwarg1', + 'kwarg2': 'kwarg2' + }) + + +if __name__ == '__main__': + absltest.main() diff --git a/tests/local/partial_test.py b/tests/local/partial_test.py new file mode 100644 index 0000000..971c1d4 --- /dev/null +++ b/tests/local/partial_test.py @@ -0,0 +1,122 @@ +# coding=utf-8 +# Copyright 2020 The Gin-Config Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for the `gin.local.partial` module.""" + +import typing + +from absl.testing import absltest +from gin.local import partial + + +class TestMetaClass(type): + + def __call__(cls, *args, **kwargs): + instance = super().__call__(*args, **kwargs) + instance.meta_was_run = True + return instance + + +class TestClass(metaclass=TestMetaClass): + """A test class for testing partial.partialclass.""" + + def __init__(self, arg1, arg2, kwarg1=None, kwarg2=None): + self.arg1 = arg1 + self.arg2 = arg2 + self.kwarg1 = kwarg1 + self.kwarg2 = kwarg2 + + def method(self, arg1, arg2, kwarg1=None, kwarg2=None): + return (self, arg1, arg2, kwarg1, kwarg2) + + +class TestNamedTuple(typing.NamedTuple): + arg1: int + arg2: int + kwarg1: str = 'default' + kwarg2: str = 'default' + + +def test_function(arg1, arg2, kwarg1=None, kwarg2=None): + return (arg1, arg2, kwarg1, kwarg2) + + +class PartialTest(absltest.TestCase): + + def test_partial_class(self): + partial_class = partial.partialclass(TestClass, 1, kwarg2='kwarg2') + self.assertTrue(issubclass(partial_class, TestClass)) + self.assertIsInstance(partial_class, TestMetaClass) + self.assertEqual(partial_class.__module__, TestClass.__module__) + self.assertEqual(partial_class.__name__, TestClass.__name__) + self.assertEqual(partial_class.__qualname__, TestClass.__qualname__) + self.assertEqual(partial_class.__doc__, TestClass.__doc__) + + instance = partial_class(2, kwarg1='kwarg1') + self.assertEqual(type(instance), TestClass) + self.assertEqual(instance.arg1, 1) + self.assertEqual(instance.arg2, 2) + self.assertEqual(instance.kwarg1, 'kwarg1') + self.assertEqual(instance.kwarg2, 'kwarg2') + self.assertTrue(instance.meta_was_run) + + def test_partial_class_dynamic_subclass(self): + partial_class = partial.partialclass(TestClass, 1, kwarg2='kwarg2') + + class DynamicSubclass(partial_class): + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.subclass_init_called = True + + instance = DynamicSubclass(2, kwarg1='kwarg1') + self.assertEqual(type(instance), DynamicSubclass) + self.assertEqual(instance.arg1, 1) + self.assertEqual(instance.arg2, 2) + self.assertEqual(instance.kwarg1, 'kwarg1') + self.assertEqual(instance.kwarg2, 'kwarg2') + self.assertTrue(instance.meta_was_run) + self.assertTrue(instance.subclass_init_called) + + def test_partial_namedtuple(self): + partial_class = partial.partialclass(TestNamedTuple, 1, kwarg2='kwarg2') + self.assertEqual(partial_class.__slots__, ()) + + instance = partial_class(2, kwarg1='kwarg1') + self.assertEqual(type(instance), TestNamedTuple) + self.assertEqual(instance.arg1, 1) + self.assertEqual(instance.arg2, 2) + self.assertEqual(instance.kwarg1, 'kwarg1') + self.assertEqual(instance.kwarg2, 'kwarg2') + + def test_partial(self): + partial_class = partial.partial(TestClass, 1, 2, 'kwarg1', 'kwarg2') + instance = partial_class() + self.assertEqual(type(instance), TestClass) + self.assertEqual(instance.arg1, 1) + self.assertEqual(instance.arg2, 2) + self.assertEqual(instance.kwarg1, 'kwarg1') + self.assertEqual(instance.kwarg2, 'kwarg2') + + partial_method = partial.partial( + instance.method, 'a', 'b', kwarg1='c', kwarg2='d') + self.assertEqual(partial_method(), (instance, 'a', 'b', 'c', 'd')) + + partial_fn = partial.partial(test_function, 'one', 'two', 'three', 'four') + self.assertEqual(partial_fn(), ('one', 'two', 'three', 'four')) + + +if __name__ == '__main__': + absltest.main()