Skip to content

Commit

Permalink
let all layers and flows support LayerArgs in tensorkit.layers and te…
Browse files Browse the repository at this point in the history
…nsorkit.flows
  • Loading branch information
haowen-xu committed Dec 17, 2020
1 parent 63b95c3 commit f5b3cd2
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 10 deletions.
54 changes: 44 additions & 10 deletions tensorkit/backend/pytorch_/layers.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import math
import types
from functools import partial
from functools import partial, wraps
from logging import getLogger
from typing import *

import mltk
Expand Down Expand Up @@ -227,6 +228,32 @@ def set_eval_mode(layer: Module):
return layer


def with_layer_args(cls):
def make_wrapper(type_, method):
@wraps(method)
def wrapper(*args, **kwargs):
from tensorkit.layers import get_default_layer_args
layer_args = get_default_layer_args()
default_kwargs = layer_args.get_kwargs(type_)
if default_kwargs:
getLogger(__name__).debug(
'Build instance of layer %r with default kwargs %r',
type_, default_kwargs
)
for k, v in default_kwargs.items():
kwargs.setdefault(k, v)
return method(*args, **kwargs)

type_.__with_layer_args_decorated__ = True
return wrapper

if not isinstance(cls, type) or not issubclass(cls, Module):
raise TypeError(f'`with_layer_args` can only be applied on a Module class.')

cls.__init__ = make_wrapper(cls, cls.__init__)
return cls


# ---- weight wrapper: a simple weight, or a normed weight ----
class NullParamStore(Module):
# This module is actually not used in any context.
Expand Down Expand Up @@ -487,7 +514,9 @@ def __new__(cls, name, parents, dct):
if annotations[attr] in (Module, ModuleList):
annotations.pop(attr)

return super().__new__(cls, name, parents, dct)
kclass = super().__new__(cls, name, parents, dct)
kclass = with_layer_args(kclass)
return kclass


class BaseLayer(Module, metaclass=BaseLayerMeta):
Expand Down Expand Up @@ -522,7 +551,7 @@ def extra_repr(self) -> str:
return ', '.join(buf)


class Sequential(torch_nn.Sequential):
class Sequential(torch_nn.Sequential, metaclass=BaseLayerMeta):

def __init__(self, *layers: Union[Module, Sequence[Module]]):
from tensorkit.layers import flatten_nested_layers
Expand Down Expand Up @@ -924,7 +953,7 @@ def _deconv_transform(self,


# ---- normalizer layers ----
class BatchNorm(torch_nn.BatchNorm1d):
class BatchNorm(torch_nn.BatchNorm1d, metaclass=BaseLayerMeta):
"""Batch normalization for dense layers."""

def __init__(self,
Expand All @@ -943,7 +972,7 @@ def _check_input_dim(self, input: Tensor):
'but the input shape is {}'.format(shape(input)))


class BatchNorm1d(torch_nn.BatchNorm1d):
class BatchNorm1d(torch_nn.BatchNorm1d, metaclass=BaseLayerMeta):
"""Batch normalization for 1D convolutional layers."""

def __init__(self,
Expand All @@ -962,7 +991,7 @@ def _check_input_dim(self, input: Tensor):
'but the input shape is {}'.format(shape(input)))


class BatchNorm2d(torch_nn.BatchNorm2d):
class BatchNorm2d(torch_nn.BatchNorm2d, metaclass=BaseLayerMeta):
"""Batch normalization for 2D convolutional layers."""

def __init__(self,
Expand All @@ -981,7 +1010,7 @@ def _check_input_dim(self, input: Tensor):
'but the input shape is {}'.format(shape(input)))


class BatchNorm3d(torch_nn.BatchNorm3d):
class BatchNorm3d(torch_nn.BatchNorm3d, metaclass=BaseLayerMeta):
"""Batch normalization for 3D convolutional layers."""

def __init__(self,
Expand Down Expand Up @@ -1130,7 +1159,8 @@ def step(*batch_data):


# ---- dropout layers ----
Dropout = torch_nn.Dropout
class Dropout(torch_nn.Dropout, metaclass=BaseLayerMeta):
pass


class Dropout1d(BaseLayer):
Expand Down Expand Up @@ -1162,8 +1192,12 @@ def forward(self, input: Tensor) -> Tensor:
return input


Dropout2d = torch_nn.Dropout2d
Dropout3d = torch_nn.Dropout3d
class Dropout2d(torch_nn.Dropout2d, metaclass=BaseLayerMeta):
pass


class Dropout3d(torch_nn.Dropout3d, metaclass=BaseLayerMeta):
pass


# ---- embedding layers ----
Expand Down
27 changes: 27 additions & 0 deletions tests/layers/test_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,22 @@ def test_as_default(self):
self.assertIsNot(args2, def_args)
self.assertEqual(args2.get_kwargs(_RecordInitArgsLayer), {})

# as default should affect a subclass of BaseLayer
with LayerArgs().as_default() as def_args:
# RecordInitArgsLayer
def_args.set_args([_RecordInitArgsLayer], c=5)
o = args.build(_RecordInitArgsLayer)
self.assertIsInstance(o, _RecordInitArgsLayer)
self.assertEqual(o, ((), {'c': 5}))

# A simple dense layer
l1 = tk.layers.Dense(2, 2)
self.assertEqual(len(l1), 1)
def_args.set_args([tk.layers.Dense], activation=tk.layers.ReLU)
l2 = tk.layers.Dense(2, 2)
self.assertEqual(len(l2), 2)
self.assertIsInstance(l2[-1], tk.layers.ReLU)

def test_layer_names_as_types(self):
args = tk.layers.LayerArgs()
args.set_args(['dense', 'conv2d'], activation=tk.layers.LeakyReLU)
Expand All @@ -101,6 +117,17 @@ def test_layer_names_as_types(self):
self.assertIsInstance(l2[1], tk.layers.LeakyReLU)
self.assertEqual(T.shape(l2[0].weight_store()), [4, 4, 3, 3])

def ensure_all_layers_and_flows_have_layer_args_decorated(self):
from tensorkit import layers as L, flows as F
for pkg in [L, F]:
for name in dir(pkg):
val = getattr(pkg, name)
if isinstance(val, T.Module):
self.assertTrue(
val.__with_layer_args_decorated__,
msg=f'{val!r}.__with_layer_args_decorated__ == False'
)


def sequential_builder_standard_check(ctx,
fn_name,
Expand Down

0 comments on commit f5b3cd2

Please sign in to comment.