Skip to content

Commit

Permalink
added LayerArgs.as_default
Browse files Browse the repository at this point in the history
  • Loading branch information
haowen-xu committed Dec 17, 2020
1 parent 80649eb commit 63b95c3
Show file tree
Hide file tree
Showing 4 changed files with 99 additions and 80 deletions.
10 changes: 5 additions & 5 deletions tensorkit/backend/pytorch_/optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,14 +84,14 @@ class BackendOptimizer(Optimizer):
def __init__(self,
params: Iterable[Variable],
lr: float,
torch_optimizer: TorchOptimizer):
torch_optimizer_factory: Callable[[], TorchOptimizer]):
self.params = []
for p in params:
if any(id(p) == id(pp) for pp in self.params):
raise ValueError(f'Duplicated parameter: {p!r}')
self.params.append(p)

self.torch_optimizer = torch_optimizer
self.torch_optimizer = torch_optimizer_factory()
self.set_lr(lr)

@property
Expand Down Expand Up @@ -185,7 +185,7 @@ def __init__(self,
super().__init__(
params=params,
lr=lr,
torch_optimizer=torch.optim.SGD(
torch_optimizer_factory=lambda: torch.optim.SGD(
params=params,
lr=lr,
momentum=momentum,
Expand All @@ -211,7 +211,7 @@ def __init__(self,
super().__init__(
params=params,
lr=lr,
torch_optimizer=torch.optim.Adam(
torch_optimizer_factory=lambda: torch.optim.Adam(
params=params,
lr=lr,
betas=(beta_1, beta_2),
Expand All @@ -237,7 +237,7 @@ def __init__(self,
super().__init__(
params=params,
lr=lr,
torch_optimizer=torch.optim.Adamax(
torch_optimizer_factory=lambda: torch.optim.Adamax(
params=params,
lr=lr,
betas=(beta_1, beta_2),
Expand Down
94 changes: 28 additions & 66 deletions tensorkit/gnn/adj/gcn_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,7 @@ def _forward(self, input: Tensor, adj: List[Tensor]) -> Tensor:
# compute the outputs of modules
merge_mode = self.merge_mode
outputs: List[Tensor] = []
output = input
output_shape = (
[-1] +
input_shape[i_rank - self.feature_matrix_ndims + 1:]
Expand All @@ -243,7 +244,15 @@ def _forward(self, input: Tensor, adj: List[Tensor]) -> Tensor:

# apply the `f_i()` transformation
m_output = m(m_output)
outputs.append(m_output)

# merge if "concat", or sum if "add"
if merge_mode == 0:
if i == 0:
output = m_output
else:
output = output + m_output
else:
outputs.append(m_output)

# move to next module
i += 1
Expand All @@ -253,35 +262,41 @@ def _forward(self, input: Tensor, adj: List[Tensor]) -> Tensor:
if not two_dimensional_case:
input = reshape(input, output_shape)
m_output = self.self_module(input)
outputs.append(m_output)

# merge if "concat", or sum if "add"
if merge_mode == 0:
input = add_n(outputs)
else:
input = concat(outputs, axis=self.feature_axis)
# merge if "concat", or sum if "add"
if merge_mode == 0:
if i == 0:
output = m_output
else:
output = output + m_output
else:
outputs.append(m_output)

# do final merge
if merge_mode != 0 and len(outputs) > 0:
output = concat(outputs, axis=self.feature_axis)

# de-reference intermediate results to free the memory immediately
outputs = []
m_output = input
input = m_output = output

# add bias
if self.use_bias:
m_output = input = input + self.bias_store()
output = input = m_output = output + self.bias_store()

# apply post-linear
if self.use_post_linear:
m_output = input = self.post_linear(input)
output = input = m_output = self.post_linear(output)

# reshape to the final output shape: `(N, B1, B2, ..., K1, K2, ...)`
if not two_dimensional_case:
output_shape = (
input_shape[:i_rank - self.feature_matrix_ndims + 1] +
shape(input)[1:]
)
m_output = input = reshape(input, output_shape)
output = input = m_output = reshape(output, output_shape)

return input
return output

def forward(self, input: Tensor, adj: List[Tensor]) -> Tensor:
return self._forward(input, adj)
Expand Down Expand Up @@ -492,60 +507,7 @@ def __init__(self,
if use_bias:
out_dup = 1 + int(use_self_loop and merge_mode == 'concat')
bias_shape = [out_features * out_dup]
bias_store = SimpleParamStore(
bias_shape, initializer=bias_init, device=device)
else:
bias_store = None

if normalizer is not None:
normalizer = get_layer_from_layer_or_factory(
'normalizer', normalizer, args=(out_features,))

if activation is not None:
activation = get_layer_from_layer_or_factory('activation', activation)

super().__init__(
module=module, self_module=self_module, self_weight=self_weight,
bias_store=bias_store, normalizer=normalizer, activation=activation,
merge_mode=merge_mode,
)
class GCNDense(GCNLayer):
"""A standard dense GCN layer."""

def __init__(self,
in_features: int,
out_features: int,
use_self_loop: bool = True,
self_weight: float = 1.,
merge_mode: Union[str, GCNMergeMode] = 'add',
use_bias: Optional[bool] = None,
normalizer: Optional[NormalizerOrNormalizerFactory] = None,
activation: Optional[LayerOrLayerFactory] = None,
weight_norm: WeightNormArgType = False,
weight_init: TensorInitArgType = DEFAULT_WEIGHT_INIT,
bias_init: TensorInitArgType = DEFAULT_BIAS_INIT,
data_init: Optional[DataInitArgType] = None,
device: Optional[str] = None,
):
if use_bias is None:
use_bias = normalizer is None
linear_kwargs = dict(
use_bias=False,
weight_norm=weight_norm,
weight_init=weight_init,
data_init=data_init,
device=device,
)
module = Linear(in_features, out_features, **linear_kwargs)
self_module = (Linear(in_features, out_features, **linear_kwargs)
if use_self_loop else None)

if use_bias:
out_dup = (1 + int(use_self_loop)
if merge_mode == 'concat' else 1)
bias_shape = [out_features * out_dup]
bias_store = SimpleParamStore(
bias_shape, initializer=bias_init, device=device)
bias_store = SimpleParamStore(bias_shape, initializer=bias_init, device=device)
else:
bias_store = None

Expand Down
45 changes: 38 additions & 7 deletions tensorkit/layers/builder.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from contextlib import contextmanager
from typing import *

from mltk.utils import NOT_SET
from mltk.utils import NOT_SET, ContextStack

from .activation import *
from .composed import *
Expand All @@ -13,7 +13,7 @@
from ..arg_check import *
from ..typing_ import *

__all__ = ['LayerArgs', 'SequentialBuilder']
__all__ = ['LayerArgs', 'get_default_layer_args', 'SequentialBuilder']


def _get_layer_class(name: str) -> type:
Expand Down Expand Up @@ -90,18 +90,39 @@ class LayerArgs(object):

args: Dict[type, Dict[str, Any]]

def __init__(self, layer_args: Optional['LayerArgs'] = None):
def __init__(self, layer_args: Optional['LayerArgs'] = NOT_SET):
"""
Construct a new :class:`LayerArgs` instance.
Args:
layer_args: Clone from this :class:`LayerArgs` instance.
"""
if layer_args is not None:
if layer_args is NOT_SET:
layer_args = get_default_layer_args()

if layer_args is None:
self.args = {}
else:
self.args = {type_: {key: val for key, val in type_args.items()}
for type_, type_args in layer_args.args.items()}
else:
self.args = {}

@contextmanager
def as_default(self) -> ContextManager['LayerArgs']:
"""Push this `LayerArgs` instance as the default."""
try:
_layer_args_stack.push(self)
yield self
finally:
_layer_args_stack.pop()

def copy(self) -> 'LayerArgs':
"""
Copy a new `LayerArgs` instance.
Returns:
A new :class:`LayerArgs` instance.
"""
return LayerArgs(self)

def set_args(self,
type_or_types_: Union[
Expand All @@ -113,6 +134,9 @@ def set_args(self,
Args:
type_or_types_: The layer type or types.
**kwargs: The default arguments to be set.
Returns:
This :class:`LayerArgs` instance.
"""
if isinstance(type_or_types_, (str, type)):
type_or_types_ = [type_or_types_]
Expand Down Expand Up @@ -163,6 +187,14 @@ def build(self, type_: Union[str, type], *args, **kwargs):
return type_(*args, **self.get_kwargs(type_, **kwargs))


def get_default_layer_args() -> LayerArgs:
"""Get the global default `LayerArgs` instance."""
return _layer_args_stack.top()


_layer_args_stack = ContextStack[LayerArgs](lambda: LayerArgs(None))


class SequentialBuilder(object):
"""A class that helps to build a sequence layers."""

Expand Down Expand Up @@ -201,7 +233,6 @@ def __init__(self,
to the new sequential builder. This will also override
the layer args of `in_builder`.
"""

# parse the argument
if int(in_spec is not NOT_SET) + int(in_shape is not NOT_SET) + \
int(in_channels is not NOT_SET) + int(in_builder is not NOT_SET) != 1:
Expand Down
30 changes: 28 additions & 2 deletions tests/layers/test_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def __eq__(self, other):

class LayerArgsTestCase(TestCase):

def test_set_args(self):
def test_copy_and_set_args(self):
# empty default args
args = tk.layers.LayerArgs()
self.assertEqual(args.get_kwargs(_RecordInitArgsLayer), {})
Expand All @@ -43,7 +43,6 @@ def test_set_args(self):
self.assertIs(args.set_args(_RecordInitArgsLayer, d=4), args)
self.assertEqual(args.get_kwargs(_RecordInitArgsLayer), {'d': 4})
self.assertEqual(args.get_kwargs(_RecordInitArgsLayer, c=3, d=5), {'c': 3, 'd': 5})

o = args.build(_RecordInitArgsLayer)
self.assertIsInstance(o, _RecordInitArgsLayer)
self.assertEqual(o, ((), {'d': 4}))
Expand All @@ -53,11 +52,38 @@ def test_set_args(self):
self.assertEqual(o, ((1, 2), {'c': 3, 'd': 5}))

# inherit default args from previous instance
args2 = args.copy()
args2.set_args(_RecordInitArgsLayer, c=5)
self.assertEqual(args2.get_kwargs(_RecordInitArgsLayer), {'c': 5, 'd': 4})
self.assertEqual(args.get_kwargs(_RecordInitArgsLayer), {'d': 4}) # should not change

args2 = tk.layers.LayerArgs(args)
args2.set_args([_RecordInitArgsLayer], c=5)
self.assertEqual(args2.get_kwargs(_RecordInitArgsLayer), {'c': 5, 'd': 4})
self.assertEqual(args.get_kwargs(_RecordInitArgsLayer), {'d': 4}) # should not change

def test_as_default(self):
# the default default
def_args = get_default_layer_args()
args = LayerArgs()
self.assertIsNot(args, def_args)
with args.as_default() as args2:
self.assertIs(args2, args)
self.assertIs(get_default_layer_args(), args)
self.assertEqual(args2.get_kwargs(_RecordInitArgsLayer), {})
self.assertIs(get_default_layer_args(), def_args)

# test inherit from default and inherit from none
with LayerArgs().as_default() as def_args:
def_args.set_args([_RecordInitArgsLayer], c=5)
self.assertEqual(def_args.get_kwargs(_RecordInitArgsLayer), {'c': 5})
with LayerArgs().as_default() as args2:
self.assertIsNot(args2, def_args)
self.assertEqual(args2.get_kwargs(_RecordInitArgsLayer), {'c': 5})
with LayerArgs(None).as_default() as args2:
self.assertIsNot(args2, def_args)
self.assertEqual(args2.get_kwargs(_RecordInitArgsLayer), {})

def test_layer_names_as_types(self):
args = tk.layers.LayerArgs()
args.set_args(['dense', 'conv2d'], activation=tk.layers.LeakyReLU)
Expand Down

0 comments on commit 63b95c3

Please sign in to comment.