Skip to content

Commit

Permalink
fix MLP (+complete type annotations for transform module) (#527)
Browse files Browse the repository at this point in the history
  • Loading branch information
msperber committed Aug 24, 2018
1 parent 8af53ea commit 82e6633
Showing 1 changed file with 65 additions and 40 deletions.
105 changes: 65 additions & 40 deletions xnmt/modelparts/transforms.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
import numbers
from typing import Optional, Sequence

import dynet as dy

from xnmt import param_collections, param_initializers
Expand All @@ -17,10 +20,6 @@ class Identity(Transform, Serializable):
"""
yaml_tag = "!Identity"

@serializable_init
def __init__(self):
pass

def transform(self, input_expr: dy.Expression) -> dy.Expression:
return input_expr

Expand All @@ -29,11 +28,11 @@ class Linear(Transform, Serializable):
Linear projection with optional bias.
Args:
input_dim (int): input dimension
output_dim (int): hidden dimension
bias (bool): whether to add a bias
param_init (ParamInitializer): how to initialize weight matrices
bias_init (ParamInitializer): how to initialize bias vectors
input_dim: input dimension
output_dim: hidden dimension
bias: whether to add a bias
param_init: how to initialize weight matrices
bias_init: how to initialize bias vectors
"""

yaml_tag = "!Linear"
Expand All @@ -42,9 +41,9 @@ class Linear(Transform, Serializable):
def __init__(self,
input_dim: int = Ref("exp_global.default_layer_dim"),
output_dim: int = Ref("exp_global.default_layer_dim"),
bias=True,
param_init=Ref("exp_global.param_init", default=bare(param_initializers.GlorotInitializer)),
bias_init=Ref("exp_global.bias_init", default=bare(param_initializers.ZeroInitializer))):
bias: bool=True,
param_init: param_initializers.ParamInitializer = Ref("exp_global.param_init", default=bare(param_initializers.GlorotInitializer)),
bias_init: param_initializers.ParamInitializer = Ref("exp_global.bias_init", default=bare(param_initializers.ZeroInitializer))) -> None:
self.bias = bias
self.input_dim = input_dim
self.output_dim = output_dim
Expand All @@ -67,12 +66,12 @@ class NonLinear(Transform, Serializable):
Linear projection with optional bias and non-linearity.
Args:
input_dim (int): input dimension
output_dim (int): hidden dimension
bias (bool): whether to add a bias
input_dim: input dimension
output_dim: hidden dimension
bias: whether to add a bias
activation: One of ``tanh``, ``relu``, ``sigmoid``, ``elu``, ``selu``, ``asinh`` or ``identity``.
param_init (ParamInitializer): how to initialize weight matrices
bias_init (ParamInitializer): how to initialize bias vectors
param_init: how to initialize weight matrices
bias_init: how to initialize bias vectors
"""

yaml_tag = "!NonLinear"
Expand All @@ -83,8 +82,8 @@ def __init__(self,
output_dim: int = Ref("exp_global.default_layer_dim"),
bias: bool = True,
activation: str = 'tanh',
param_init=Ref("exp_global.param_init", default=bare(param_initializers.GlorotInitializer)),
bias_init=Ref("exp_global.bias_init", default=bare(param_initializers.ZeroInitializer))):
param_init: param_initializers.ParamInitializer = Ref("exp_global.param_init", default=bare(param_initializers.GlorotInitializer)),
bias_init: param_initializers.ParamInitializer = Ref("exp_global.bias_init", default=bare(param_initializers.ZeroInitializer))) -> None:
self.bias = bias
self.output_dim = output_dim
self.input_dim = input_dim
Expand Down Expand Up @@ -127,15 +126,15 @@ class AuxNonLinear(NonLinear, Serializable):
NonLinear with an additional auxiliary input.
Args:
input_dim (int): input dimension
output_dim (int): hidden dimension
aux_input_dim (int): auxiliary input dimension.
input_dim: input dimension
output_dim: hidden dimension
aux_input_dim: auxiliary input dimension.
The actual input dimension is aux_input_dim + input_dim. This is useful
for when you want to do something like input feeding.
bias (bool): whether to add a bias
bias: whether to add a bias
activation: One of ``tanh``, ``relu``, ``sigmoid``, ``elu``, ``selu``, ``asinh`` or ``identity``.
param_init (ParamInitializer): how to initialize weight matrices
bias_init (ParamInitializer): how to initialize bias vectors
param_init: how to initialize weight matrices
bias_init: how to initialize bias vectors
"""

yaml_tag = "!AuxNonLinear"
Expand All @@ -147,8 +146,8 @@ def __init__(self,
aux_input_dim: int = Ref("exp_global.default_layer_dim"),
bias: bool = True,
activation: str = 'tanh',
param_init=Ref("exp_global.param_init", default=bare(param_initializers.GlorotInitializer)),
bias_init=Ref("exp_global.bias_init", default=bare(param_initializers.ZeroInitializer))):
param_init: param_initializers.ParamInitializer = Ref("exp_global.param_init", default=bare(param_initializers.GlorotInitializer)),
bias_init: param_initializers.ParamInitializer = Ref("exp_global.bias_init", default=bare(param_initializers.ZeroInitializer))) -> None:
original_input_dim = input_dim
input_dim += aux_input_dim
super().__init__(
Expand All @@ -170,25 +169,50 @@ class MLP(Transform, Serializable):

@serializable_init
def __init__(self,
input_dim: int = Ref("exp_global.default_layer_dim"),
hidden_dim: int = Ref("exp_global.default_layer_dim"),
output_dim: int = Ref("exp_global.default_layer_dim"),
input_dim: numbers.Integral = Ref("exp_global.default_layer_dim"),
hidden_dim: numbers.Integral = Ref("exp_global.default_layer_dim"),
output_dim: numbers.Integral = Ref("exp_global.default_layer_dim"),
bias: bool = True,
activation: str = 'tanh',
hidden_layers: int = 1,
param_init=Ref("exp_global.param_init", default=bare(param_initializers.GlorotInitializer)),
bias_init=Ref("exp_global.bias_init", default=bare(param_initializers.ZeroInitializer))):
self.layers = []
if hidden_layers > 0:
self.layers = [NonLinear(input_dim=input_dim, output_dim=hidden_dim, bias=bias, activation=activation, param_init=param_init, bias_init=bias_init)]
self.layers += [NonLinear(input_dim=hidden_dim, output_dim=hidden_dim, bias=bias, activation=activation, param_init=param_init, bias_init=bias_init) for _ in range(1,hidden_layers)]
self.layers += [Linear(input_dim=hidden_dim, output_dim=output_dim, bias=bias, param_init=param_init, bias_init=bias_init)]
hidden_layers: numbers.Integral = 1,
param_init: param_initializers.ParamInitializer = Ref("exp_global.param_init", default=bare(param_initializers.GlorotInitializer)),
bias_init: param_initializers.ParamInitializer = Ref("exp_global.bias_init", default=bare(param_initializers.ZeroInitializer)),
layers: Optional[Sequence[Transform]] = None) -> None:
self.layers = self.add_serializable_component("layers",
layers,
lambda: MLP._create_layers(num_layers=hidden_layers,
input_dim=input_dim,
hidden_dim=hidden_dim,
output_dim=output_dim,
bias=bias,
activation=activation,
param_init=param_init,
bias_init=bias_init))

@staticmethod
def _create_layers(num_layers: numbers.Integral, input_dim: numbers.Integral, hidden_dim: numbers.Integral,
output_dim: numbers.Integral, bias: bool, activation: str,
param_init: param_initializers.ParamInitializer, bias_init: param_initializers.ParamInitializer) \
-> Sequence[Transform]:
layers = []
if num_layers > 0:
layers = [NonLinear(input_dim=input_dim, output_dim=hidden_dim, bias=bias, activation=activation,
param_init=param_init, bias_init=bias_init)]
layers += [NonLinear(input_dim=hidden_dim, output_dim=hidden_dim, bias=bias, activation=activation,
param_init=param_init, bias_init=bias_init) for _ in range(1, num_layers)]
layers += [Linear(input_dim=hidden_dim if num_layers>0 else input_dim,
output_dim=output_dim,
bias=bias,
param_init=param_init,
bias_init=bias_init)]
return layers

def transform(self, expr: dy.Expression) -> dy.Expression:
for layer in self.layers:
expr = layer.transform(expr)
return expr


class Cwise(Transform, Serializable):
"""
A component-wise transformation that can be an arbitrary unary DyNet operation.
Expand All @@ -198,9 +222,10 @@ class Cwise(Transform, Serializable):
"""
yaml_tag = "!Cwise"
@serializable_init
def __init__(self, op="rectify"):
def __init__(self, op: str = "rectify") -> None:
self.op = getattr(dy, op, None)
if not self.op:
raise ValueError(f"DyNet does not have an operation '{op}'.")
def transform(self, input_expr: dy.Expression):

def transform(self, input_expr: dy.Expression) -> dy.Expression:
return self.op(input_expr)

0 comments on commit 82e6633

Please sign in to comment.