diff --git a/tensorkit/layers/builder.py b/tensorkit/layers/builder.py index 09d7436..a5bf949 100644 --- a/tensorkit/layers/builder.py +++ b/tensorkit/layers/builder.py @@ -16,7 +16,10 @@ __all__ = ['LayerArgs', 'get_default_layer_args', 'SequentialBuilder'] -def _get_layer_class(name: str) -> type: +def _get_layer_class(name: str, support_wildcard: bool = False) -> Optional[type]: + if support_wildcard and name == '*': + return None + if not _cached_layer_class_names_map: # map the standard names of the layers to the layer classes import tensorkit as tk @@ -88,7 +91,9 @@ def _unsplit_channel_spatial(channel, spatial): class LayerArgs(object): """A class that manages the default arguments for constructing layers.""" - args: Dict[type, Dict[str, Any]] + # type? => {arg_name: arg_val}. + # None type indicates arguments for all types. + args: Dict[Optional[type], Dict[str, Any]] def __init__(self, layer_args: Optional['LayerArgs'] = NOT_SET): """ @@ -145,21 +150,24 @@ def set_args(self, for type_ in type_or_types_: if isinstance(type_, str): - type_ = _get_layer_class(type_) + type_ = _get_layer_class(type_, support_wildcard=True) if type_ not in self.args: self.args[type_] = {} - if layer_args_ is NOT_SET: - layer_args_ = getattr(type_, '__layer_args__', None) - layer_has_kwargs = getattr(type_, '__layer_has_kwargs__', False) - if layer_args_ is not None and not layer_has_kwargs: - for k in kwargs: - if k not in layer_args_: - raise ValueError( - f'The constructor of {type_!r} does not have ' - f'the specified keyword argument: {k}' - ) - + # validate the arguments + if type_ is not None: + if layer_args_ is NOT_SET: + layer_args_ = getattr(type_, '__layer_args__', None) + layer_has_kwargs = getattr(type_, '__layer_has_kwargs__', False) + if layer_args_ is not None and not layer_has_kwargs: + for k in kwargs: + if k not in layer_args_: + raise ValueError( + f'The constructor of {type_!r} does not have ' + f'the specified keyword argument: {k}' + ) + + # update the arguments self.args[type_].update(kwargs) return self @@ -177,10 +185,22 @@ def get_kwargs(self, type_: Union[str, type], **kwargs) -> Dict[str, Any]: """ if isinstance(type_, str): type_ = _get_layer_class(type_) - layer_args = self.args.get(type_) - if layer_args: - for key, val in layer_args.items(): + + # get the arguments for this type + args = self.args.get(type_) + if args: + for key, val in args.items(): kwargs.setdefault(key, val) + + # get the arguments for all types + layer_args = getattr(type_, '__layer_args__', None) + if layer_args: # only use known args + args = self.args.get(None) + if args: + for key, val in args.items(): + if key in layer_args: + kwargs.setdefault(key, val) + return kwargs def build(self, type_: Union[str, type], *args, **kwargs): @@ -349,19 +369,22 @@ def set_args(self, type_or_types_ = [type_or_types_] for type_ in type_or_types_: - if isinstance(type_, str): - type_ = _get_layer_class(type_) - - layer_args_ = getattr(type_, '__layer_args__', None) - if layer_args_ and 'output_padding' in layer_args_: - # suggest it's a deconv layer, add 'output_size' to the valid args list - layer_args_ = list(layer_args_) + ['output_size'] + if type_ == '*': + self.layer_args.set_args(type_, **kwargs) + else: + if isinstance(type_, str): + type_ = _get_layer_class(type_) - self.layer_args.set_args( - type_or_types_=type_, - layer_args_=layer_args_, - **kwargs - ) + layer_args_ = getattr(type_, '__layer_args__', None) + if layer_args_ and 'output_padding' in layer_args_: + # suggest it's a deconv layer, add 'output_size' to the valid args list + layer_args_ = list(layer_args_) + ['output_size'] + + self.layer_args.set_args( + type_or_types_=type_, + layer_args_=layer_args_, + **kwargs + ) return self diff --git a/tests/layers/test_builder.py b/tests/layers/test_builder.py index abe5390..4949285 100644 --- a/tests/layers/test_builder.py +++ b/tests/layers/test_builder.py @@ -168,6 +168,24 @@ def __init__(self, a, b, c): 'argument: c'): LayerArgs().set_args(MyLayer4, c=3) + def test_wildcard_args(self): + args = LayerArgs() + + args.set_args('*', out_channels=3, out_features=4, activation=LeakyReLU) + self.assertEqual(args.get_kwargs(Linear), {'out_features': 4}) + self.assertEqual(args.get_kwargs(Dense), {'out_features': 4, 'activation': LeakyReLU}) + self.assertEqual(args.get_kwargs(LinearConv2d), {'out_channels': 3}) + self.assertEqual(args.get_kwargs(Conv2d), {'out_channels': 3, 'activation': LeakyReLU}) + self.assertEqual(args.get_kwargs(_RecordInitArgsLayer), {}) + + args.set_args(_RecordInitArgsLayer, out_channels=3, out_features=4, activation=LeakyReLU) + self.assertEqual(args.get_kwargs(_RecordInitArgsLayer), {'out_channels': 3, 'out_features': 4, 'activation': LeakyReLU}) + + l = args.build(Dense, in_features=2) + self.assertEqual(l[0].in_features, 2) + self.assertEqual(l[0].out_features, 4) + self.assertIsInstance(l[1], LeakyReLU) + def ensure_all_layers_and_flows_have_layer_args_decorated(self): from tensorkit import layers as L, flows as F for pkg in [L, F]: @@ -420,6 +438,14 @@ def assert_in_shape(b, s): 'is None or an integer'): _ = SequentialBuilder(in_size=[8, 9], **{arg: arg_values[arg]}) + def test_wildcard_type(self): + builder = SequentialBuilder(5) + builder.set_args('*', activation=LeakyReLU) + self.assertEqual(builder.layer_args.get_kwargs(Dense), + {'activation': LeakyReLU}) + self.assertEqual(builder.layer_args.get_kwargs(Conv2d), + {'activation': LeakyReLU}) + def test_arg_scope(self): builder = SequentialBuilder(5) self.assertEqual(builder.layer_args.get_kwargs(Dense), {})