Skip to content

Commit

Permalink
added wildcard matching in LayerArgs
Browse files Browse the repository at this point in the history
  • Loading branch information
haowen-xu committed Dec 18, 2020
1 parent f8731a1 commit 278aeb7
Show file tree
Hide file tree
Showing 2 changed files with 78 additions and 29 deletions.
81 changes: 52 additions & 29 deletions tensorkit/layers/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,10 @@
__all__ = ['LayerArgs', 'get_default_layer_args', 'SequentialBuilder']


def _get_layer_class(name: str) -> type:
def _get_layer_class(name: str, support_wildcard: bool = False) -> Optional[type]:
if support_wildcard and name == '*':
return None

if not _cached_layer_class_names_map:
# map the standard names of the layers to the layer classes
import tensorkit as tk
Expand Down Expand Up @@ -88,7 +91,9 @@ def _unsplit_channel_spatial(channel, spatial):
class LayerArgs(object):
"""A class that manages the default arguments for constructing layers."""

args: Dict[type, Dict[str, Any]]
# type? => {arg_name: arg_val}.
# None type indicates arguments for all types.
args: Dict[Optional[type], Dict[str, Any]]

def __init__(self, layer_args: Optional['LayerArgs'] = NOT_SET):
"""
Expand Down Expand Up @@ -145,21 +150,24 @@ def set_args(self,

for type_ in type_or_types_:
if isinstance(type_, str):
type_ = _get_layer_class(type_)
type_ = _get_layer_class(type_, support_wildcard=True)
if type_ not in self.args:
self.args[type_] = {}

if layer_args_ is NOT_SET:
layer_args_ = getattr(type_, '__layer_args__', None)
layer_has_kwargs = getattr(type_, '__layer_has_kwargs__', False)
if layer_args_ is not None and not layer_has_kwargs:
for k in kwargs:
if k not in layer_args_:
raise ValueError(
f'The constructor of {type_!r} does not have '
f'the specified keyword argument: {k}'
)

# validate the arguments
if type_ is not None:
if layer_args_ is NOT_SET:
layer_args_ = getattr(type_, '__layer_args__', None)
layer_has_kwargs = getattr(type_, '__layer_has_kwargs__', False)
if layer_args_ is not None and not layer_has_kwargs:
for k in kwargs:
if k not in layer_args_:
raise ValueError(
f'The constructor of {type_!r} does not have '
f'the specified keyword argument: {k}'
)

# update the arguments
self.args[type_].update(kwargs)

return self
Expand All @@ -177,10 +185,22 @@ def get_kwargs(self, type_: Union[str, type], **kwargs) -> Dict[str, Any]:
"""
if isinstance(type_, str):
type_ = _get_layer_class(type_)
layer_args = self.args.get(type_)
if layer_args:
for key, val in layer_args.items():

# get the arguments for this type
args = self.args.get(type_)
if args:
for key, val in args.items():
kwargs.setdefault(key, val)

# get the arguments for all types
layer_args = getattr(type_, '__layer_args__', None)
if layer_args: # only use known args
args = self.args.get(None)
if args:
for key, val in args.items():
if key in layer_args:
kwargs.setdefault(key, val)

return kwargs

def build(self, type_: Union[str, type], *args, **kwargs):
Expand Down Expand Up @@ -349,19 +369,22 @@ def set_args(self,
type_or_types_ = [type_or_types_]

for type_ in type_or_types_:
if isinstance(type_, str):
type_ = _get_layer_class(type_)

layer_args_ = getattr(type_, '__layer_args__', None)
if layer_args_ and 'output_padding' in layer_args_:
# suggest it's a deconv layer, add 'output_size' to the valid args list
layer_args_ = list(layer_args_) + ['output_size']
if type_ == '*':
self.layer_args.set_args(type_, **kwargs)
else:
if isinstance(type_, str):
type_ = _get_layer_class(type_)

self.layer_args.set_args(
type_or_types_=type_,
layer_args_=layer_args_,
**kwargs
)
layer_args_ = getattr(type_, '__layer_args__', None)
if layer_args_ and 'output_padding' in layer_args_:
# suggest it's a deconv layer, add 'output_size' to the valid args list
layer_args_ = list(layer_args_) + ['output_size']

self.layer_args.set_args(
type_or_types_=type_,
layer_args_=layer_args_,
**kwargs
)

return self

Expand Down
26 changes: 26 additions & 0 deletions tests/layers/test_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,24 @@ def __init__(self, a, b, c):
'argument: c'):
LayerArgs().set_args(MyLayer4, c=3)

def test_wildcard_args(self):
args = LayerArgs()

args.set_args('*', out_channels=3, out_features=4, activation=LeakyReLU)
self.assertEqual(args.get_kwargs(Linear), {'out_features': 4})
self.assertEqual(args.get_kwargs(Dense), {'out_features': 4, 'activation': LeakyReLU})
self.assertEqual(args.get_kwargs(LinearConv2d), {'out_channels': 3})
self.assertEqual(args.get_kwargs(Conv2d), {'out_channels': 3, 'activation': LeakyReLU})
self.assertEqual(args.get_kwargs(_RecordInitArgsLayer), {})

args.set_args(_RecordInitArgsLayer, out_channels=3, out_features=4, activation=LeakyReLU)
self.assertEqual(args.get_kwargs(_RecordInitArgsLayer), {'out_channels': 3, 'out_features': 4, 'activation': LeakyReLU})

l = args.build(Dense, in_features=2)
self.assertEqual(l[0].in_features, 2)
self.assertEqual(l[0].out_features, 4)
self.assertIsInstance(l[1], LeakyReLU)

def ensure_all_layers_and_flows_have_layer_args_decorated(self):
from tensorkit import layers as L, flows as F
for pkg in [L, F]:
Expand Down Expand Up @@ -420,6 +438,14 @@ def assert_in_shape(b, s):
'is None or an integer'):
_ = SequentialBuilder(in_size=[8, 9], **{arg: arg_values[arg]})

def test_wildcard_type(self):
builder = SequentialBuilder(5)
builder.set_args('*', activation=LeakyReLU)
self.assertEqual(builder.layer_args.get_kwargs(Dense),
{'activation': LeakyReLU})
self.assertEqual(builder.layer_args.get_kwargs(Conv2d),
{'activation': LeakyReLU})

def test_arg_scope(self):
builder = SequentialBuilder(5)
self.assertEqual(builder.layer_args.get_kwargs(Dense), {})
Expand Down

0 comments on commit 278aeb7

Please sign in to comment.