diff --git a/tensorkit/backend/pytorch_/optim.py b/tensorkit/backend/pytorch_/optim.py index d61afdb..0fbc7a8 100644 --- a/tensorkit/backend/pytorch_/optim.py +++ b/tensorkit/backend/pytorch_/optim.py @@ -84,14 +84,14 @@ class BackendOptimizer(Optimizer): def __init__(self, params: Iterable[Variable], lr: float, - torch_optimizer: TorchOptimizer): + torch_optimizer_factory: Callable[[], TorchOptimizer]): self.params = [] for p in params: if any(id(p) == id(pp) for pp in self.params): raise ValueError(f'Duplicated parameter: {p!r}') self.params.append(p) - self.torch_optimizer = torch_optimizer + self.torch_optimizer = torch_optimizer_factory() self.set_lr(lr) @property @@ -185,7 +185,7 @@ def __init__(self, super().__init__( params=params, lr=lr, - torch_optimizer=torch.optim.SGD( + torch_optimizer_factory=lambda: torch.optim.SGD( params=params, lr=lr, momentum=momentum, @@ -211,7 +211,7 @@ def __init__(self, super().__init__( params=params, lr=lr, - torch_optimizer=torch.optim.Adam( + torch_optimizer_factory=lambda: torch.optim.Adam( params=params, lr=lr, betas=(beta_1, beta_2), @@ -237,7 +237,7 @@ def __init__(self, super().__init__( params=params, lr=lr, - torch_optimizer=torch.optim.Adamax( + torch_optimizer_factory=lambda: torch.optim.Adamax( params=params, lr=lr, betas=(beta_1, beta_2), diff --git a/tensorkit/gnn/adj/gcn_layers.py b/tensorkit/gnn/adj/gcn_layers.py index e78f189..b3a34ca 100644 --- a/tensorkit/gnn/adj/gcn_layers.py +++ b/tensorkit/gnn/adj/gcn_layers.py @@ -227,6 +227,7 @@ def _forward(self, input: Tensor, adj: List[Tensor]) -> Tensor: # compute the outputs of modules merge_mode = self.merge_mode outputs: List[Tensor] = [] + output = input output_shape = ( [-1] + input_shape[i_rank - self.feature_matrix_ndims + 1:] @@ -243,7 +244,15 @@ def _forward(self, input: Tensor, adj: List[Tensor]) -> Tensor: # apply the `f_i()` transformation m_output = m(m_output) - outputs.append(m_output) + + # merge if "concat", or sum if "add" + if merge_mode == 0: + if i == 0: + output = m_output + else: + output = output + m_output + else: + outputs.append(m_output) # move to next module i += 1 @@ -253,25 +262,31 @@ def _forward(self, input: Tensor, adj: List[Tensor]) -> Tensor: if not two_dimensional_case: input = reshape(input, output_shape) m_output = self.self_module(input) - outputs.append(m_output) - # merge if "concat", or sum if "add" - if merge_mode == 0: - input = add_n(outputs) - else: - input = concat(outputs, axis=self.feature_axis) + # merge if "concat", or sum if "add" + if merge_mode == 0: + if i == 0: + output = m_output + else: + output = output + m_output + else: + outputs.append(m_output) + + # do final merge + if merge_mode != 0 and len(outputs) > 0: + output = concat(outputs, axis=self.feature_axis) # de-reference intermediate results to free the memory immediately outputs = [] - m_output = input + input = m_output = output # add bias if self.use_bias: - m_output = input = input + self.bias_store() + output = input = m_output = output + self.bias_store() # apply post-linear if self.use_post_linear: - m_output = input = self.post_linear(input) + output = input = m_output = self.post_linear(output) # reshape to the final output shape: `(N, B1, B2, ..., K1, K2, ...)` if not two_dimensional_case: @@ -279,9 +294,9 @@ def _forward(self, input: Tensor, adj: List[Tensor]) -> Tensor: input_shape[:i_rank - self.feature_matrix_ndims + 1] + shape(input)[1:] ) - m_output = input = reshape(input, output_shape) + output = input = m_output = reshape(output, output_shape) - return input + return output def forward(self, input: Tensor, adj: List[Tensor]) -> Tensor: return self._forward(input, adj) @@ -492,60 +507,7 @@ def __init__(self, if use_bias: out_dup = 1 + int(use_self_loop and merge_mode == 'concat') bias_shape = [out_features * out_dup] - bias_store = SimpleParamStore( - bias_shape, initializer=bias_init, device=device) - else: - bias_store = None - - if normalizer is not None: - normalizer = get_layer_from_layer_or_factory( - 'normalizer', normalizer, args=(out_features,)) - - if activation is not None: - activation = get_layer_from_layer_or_factory('activation', activation) - - super().__init__( - module=module, self_module=self_module, self_weight=self_weight, - bias_store=bias_store, normalizer=normalizer, activation=activation, - merge_mode=merge_mode, - ) -class GCNDense(GCNLayer): - """A standard dense GCN layer.""" - - def __init__(self, - in_features: int, - out_features: int, - use_self_loop: bool = True, - self_weight: float = 1., - merge_mode: Union[str, GCNMergeMode] = 'add', - use_bias: Optional[bool] = None, - normalizer: Optional[NormalizerOrNormalizerFactory] = None, - activation: Optional[LayerOrLayerFactory] = None, - weight_norm: WeightNormArgType = False, - weight_init: TensorInitArgType = DEFAULT_WEIGHT_INIT, - bias_init: TensorInitArgType = DEFAULT_BIAS_INIT, - data_init: Optional[DataInitArgType] = None, - device: Optional[str] = None, - ): - if use_bias is None: - use_bias = normalizer is None - linear_kwargs = dict( - use_bias=False, - weight_norm=weight_norm, - weight_init=weight_init, - data_init=data_init, - device=device, - ) - module = Linear(in_features, out_features, **linear_kwargs) - self_module = (Linear(in_features, out_features, **linear_kwargs) - if use_self_loop else None) - - if use_bias: - out_dup = (1 + int(use_self_loop) - if merge_mode == 'concat' else 1) - bias_shape = [out_features * out_dup] - bias_store = SimpleParamStore( - bias_shape, initializer=bias_init, device=device) + bias_store = SimpleParamStore(bias_shape, initializer=bias_init, device=device) else: bias_store = None diff --git a/tensorkit/layers/builder.py b/tensorkit/layers/builder.py index 2c002a8..59415cc 100644 --- a/tensorkit/layers/builder.py +++ b/tensorkit/layers/builder.py @@ -1,7 +1,7 @@ from contextlib import contextmanager from typing import * -from mltk.utils import NOT_SET +from mltk.utils import NOT_SET, ContextStack from .activation import * from .composed import * @@ -13,7 +13,7 @@ from ..arg_check import * from ..typing_ import * -__all__ = ['LayerArgs', 'SequentialBuilder'] +__all__ = ['LayerArgs', 'get_default_layer_args', 'SequentialBuilder'] def _get_layer_class(name: str) -> type: @@ -90,18 +90,39 @@ class LayerArgs(object): args: Dict[type, Dict[str, Any]] - def __init__(self, layer_args: Optional['LayerArgs'] = None): + def __init__(self, layer_args: Optional['LayerArgs'] = NOT_SET): """ Construct a new :class:`LayerArgs` instance. Args: layer_args: Clone from this :class:`LayerArgs` instance. """ - if layer_args is not None: + if layer_args is NOT_SET: + layer_args = get_default_layer_args() + + if layer_args is None: + self.args = {} + else: self.args = {type_: {key: val for key, val in type_args.items()} for type_, type_args in layer_args.args.items()} - else: - self.args = {} + + @contextmanager + def as_default(self) -> ContextManager['LayerArgs']: + """Push this `LayerArgs` instance as the default.""" + try: + _layer_args_stack.push(self) + yield self + finally: + _layer_args_stack.pop() + + def copy(self) -> 'LayerArgs': + """ + Copy a new `LayerArgs` instance. + + Returns: + A new :class:`LayerArgs` instance. + """ + return LayerArgs(self) def set_args(self, type_or_types_: Union[ @@ -113,6 +134,9 @@ def set_args(self, Args: type_or_types_: The layer type or types. **kwargs: The default arguments to be set. + + Returns: + This :class:`LayerArgs` instance. """ if isinstance(type_or_types_, (str, type)): type_or_types_ = [type_or_types_] @@ -163,6 +187,14 @@ def build(self, type_: Union[str, type], *args, **kwargs): return type_(*args, **self.get_kwargs(type_, **kwargs)) +def get_default_layer_args() -> LayerArgs: + """Get the global default `LayerArgs` instance.""" + return _layer_args_stack.top() + + +_layer_args_stack = ContextStack[LayerArgs](lambda: LayerArgs(None)) + + class SequentialBuilder(object): """A class that helps to build a sequence layers.""" @@ -201,7 +233,6 @@ def __init__(self, to the new sequential builder. This will also override the layer args of `in_builder`. """ - # parse the argument if int(in_spec is not NOT_SET) + int(in_shape is not NOT_SET) + \ int(in_channels is not NOT_SET) + int(in_builder is not NOT_SET) != 1: diff --git a/tests/layers/test_builder.py b/tests/layers/test_builder.py index efb5b4a..efd3921 100644 --- a/tests/layers/test_builder.py +++ b/tests/layers/test_builder.py @@ -30,7 +30,7 @@ def __eq__(self, other): class LayerArgsTestCase(TestCase): - def test_set_args(self): + def test_copy_and_set_args(self): # empty default args args = tk.layers.LayerArgs() self.assertEqual(args.get_kwargs(_RecordInitArgsLayer), {}) @@ -43,7 +43,6 @@ def test_set_args(self): self.assertIs(args.set_args(_RecordInitArgsLayer, d=4), args) self.assertEqual(args.get_kwargs(_RecordInitArgsLayer), {'d': 4}) self.assertEqual(args.get_kwargs(_RecordInitArgsLayer, c=3, d=5), {'c': 3, 'd': 5}) - o = args.build(_RecordInitArgsLayer) self.assertIsInstance(o, _RecordInitArgsLayer) self.assertEqual(o, ((), {'d': 4})) @@ -53,11 +52,38 @@ def test_set_args(self): self.assertEqual(o, ((1, 2), {'c': 3, 'd': 5})) # inherit default args from previous instance + args2 = args.copy() + args2.set_args(_RecordInitArgsLayer, c=5) + self.assertEqual(args2.get_kwargs(_RecordInitArgsLayer), {'c': 5, 'd': 4}) + self.assertEqual(args.get_kwargs(_RecordInitArgsLayer), {'d': 4}) # should not change + args2 = tk.layers.LayerArgs(args) args2.set_args([_RecordInitArgsLayer], c=5) self.assertEqual(args2.get_kwargs(_RecordInitArgsLayer), {'c': 5, 'd': 4}) self.assertEqual(args.get_kwargs(_RecordInitArgsLayer), {'d': 4}) # should not change + def test_as_default(self): + # the default default + def_args = get_default_layer_args() + args = LayerArgs() + self.assertIsNot(args, def_args) + with args.as_default() as args2: + self.assertIs(args2, args) + self.assertIs(get_default_layer_args(), args) + self.assertEqual(args2.get_kwargs(_RecordInitArgsLayer), {}) + self.assertIs(get_default_layer_args(), def_args) + + # test inherit from default and inherit from none + with LayerArgs().as_default() as def_args: + def_args.set_args([_RecordInitArgsLayer], c=5) + self.assertEqual(def_args.get_kwargs(_RecordInitArgsLayer), {'c': 5}) + with LayerArgs().as_default() as args2: + self.assertIsNot(args2, def_args) + self.assertEqual(args2.get_kwargs(_RecordInitArgsLayer), {'c': 5}) + with LayerArgs(None).as_default() as args2: + self.assertIsNot(args2, def_args) + self.assertEqual(args2.get_kwargs(_RecordInitArgsLayer), {}) + def test_layer_names_as_types(self): args = tk.layers.LayerArgs() args.set_args(['dense', 'conv2d'], activation=tk.layers.LeakyReLU)