Skip to content

Commit

Permalink
Update (most) automated tests.
Browse files Browse the repository at this point in the history
  • Loading branch information
matsuyamax committed Oct 5, 2015
1 parent 2bd4c29 commit c60e2df
Show file tree
Hide file tree
Showing 24 changed files with 529 additions and 372 deletions.
2 changes: 1 addition & 1 deletion examples/imdb_cnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@

# we start off with an efficient embedding layer which maps
# our vocab indices into embedding_dims dimensions
model.add(Embedding(max_features, embedding_dims, max_lenght=maxlen))
model.add(Embedding(max_features, embedding_dims, max_length=maxlen))
model.add(Dropout(0.25))

# we add a Convolution1D, which will learn nb_filter
Expand Down
36 changes: 23 additions & 13 deletions keras/layers/advanced_activations.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,10 @@ def get_output(self, train):
return T.nnet.relu(X, self.alpha)

def get_config(self):
return {"name": self.__class__.__name__,
"alpha": self.alpha}
config = {"name": self.__class__.__name__,
"alpha": self.alpha}
base_config = super(LeakyReLU, self).get_config()
return dict(base_config.items() + config.items())


class PReLU(MaskedLayer):
Expand Down Expand Up @@ -46,8 +48,10 @@ def get_output(self, train):
return pos + neg

def get_config(self):
return {"name": self.__class__.__name__,
"init": self.init.__name__}
config = {"name": self.__class__.__name__,
"init": self.init.__name__}
base_config = super(PReLU, self).get_config()
return dict(base_config.items() + config.items())


class ParametricSoftplus(MaskedLayer):
Expand Down Expand Up @@ -81,9 +85,11 @@ def get_output(self, train):
return T.nnet.softplus(self.betas * X) * self.alphas

def get_config(self):
return {"name": self.__class__.__name__,
"alpha_init": self.alpha_init,
"beta_init": self.beta_init}
config = {"name": self.__class__.__name__,
"alpha_init": self.alpha_init,
"beta_init": self.beta_init}
base_config = super(ParametricSoftplus, self).get_config()
return dict(base_config.items() + config.items())


class ThresholdedLinear(MaskedLayer):
Expand All @@ -103,11 +109,13 @@ def get_output(self, train):
return T.switch(abs(X) < self.theta, 0, X)

def get_config(self):
return {"name": self.__class__.__name__,
"theta": self.theta}
config = {"name": self.__class__.__name__,
"theta": self.theta}
base_config = super(ThresholdedLinear, self).get_config()
return dict(base_config.items() + config.items())


class ThresholdedReLu(MaskedLayer):
class ThresholdedReLU(MaskedLayer):
'''
Thresholded Rectified Activation
Expand All @@ -116,13 +124,15 @@ class ThresholdedReLu(MaskedLayer):
http://arxiv.org/pdf/1402.3337.pdf
'''
def __init__(self, theta=1.0, **kwargs):
super(ThresholdedReLu, self).__init__(**kwargs)
super(ThresholdedReLU, self).__init__(**kwargs)
self.theta = theta

def get_output(self, train):
X = self.get_input(train)
return T.switch(X > self.theta, X, 0)

def get_config(self):
return {"name": self.__class__.__name__,
"theta": self.theta}
config = {"name": self.__class__.__name__,
"theta": self.theta}
base_config = super(ThresholdedReLU, self).get_config()
return dict(base_config.items() + config.items())
2 changes: 0 additions & 2 deletions keras/layers/containers.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@ def add(self, layer):
self.layers[-1].set_previous(self.layers[-2])
if not hasattr(self.layers[0], 'input'):
self.set_input()
layer.init_updates()

params, regularizers, constraints, updates = layer.get_params()
self.params += params
Expand Down Expand Up @@ -217,7 +216,6 @@ def add_node(self, layer, name, input=None, inputs=[],
'merge_mode': merge_mode,
'concat_axis': concat_axis,
'create_output': create_output})
layer.init_updates()
params, regularizers, constraints, updates = layer.get_params()
self.params += params
self.regularizers += regularizers
Expand Down
129 changes: 78 additions & 51 deletions keras/layers/convolutional.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def __init__(self, nb_filter, filter_length,
init='uniform', activation='linear', weights=None,
border_mode='valid', subsample_length=1,
W_regularizer=None, b_regularizer=None, activity_regularizer=None,
W_constraint=None, b_constraint=None, **kwargs):
W_constraint=None, b_constraint=None, input_dim=None, input_length=None, **kwargs):

if border_mode not in {'valid', 'full', 'same'}:
raise Exception('Invalid border mode for Convolution1D:', border_mode)
Expand All @@ -77,6 +77,11 @@ def __init__(self, nb_filter, filter_length,
self.constraints = [self.W_constraint, self.b_constraint]

self.initial_weights = weights

self.input_dim = input_dim
self.input_length = input_length
if self.input_dim:
kwargs['input_shape'] = (self.input_length, self.input_dim)
super(Convolution1D, self).__init__(**kwargs)

def build(self):
Expand Down Expand Up @@ -143,18 +148,22 @@ def get_output(self, train=False):
return output

def get_config(self):
return {"name": self.__class__.__name__,
"nb_filter": self.nb_filter,
"filter_length": self.filter_length,
"init": self.init.__name__,
"activation": self.activation.__name__,
"border_mode": self.border_mode,
"subsample_length": self.subsample_length,
"W_regularizer": self.W_regularizer.get_config() if self.W_regularizer else None,
"b_regularizer": self.b_regularizer.get_config() if self.b_regularizer else None,
"activity_regularizer": self.activity_regularizer.get_config() if self.activity_regularizer else None,
"W_constraint": self.W_constraint.get_config() if self.W_constraint else None,
"b_constraint": self.b_constraint.get_config() if self.b_constraint else None}
config = {"name": self.__class__.__name__,
"nb_filter": self.nb_filter,
"filter_length": self.filter_length,
"init": self.init.__name__,
"activation": self.activation.__name__,
"border_mode": self.border_mode,
"subsample_length": self.subsample_length,
"W_regularizer": self.W_regularizer.get_config() if self.W_regularizer else None,
"b_regularizer": self.b_regularizer.get_config() if self.b_regularizer else None,
"activity_regularizer": self.activity_regularizer.get_config() if self.activity_regularizer else None,
"W_constraint": self.W_constraint.get_config() if self.W_constraint else None,
"b_constraint": self.b_constraint.get_config() if self.b_constraint else None,
"input_dim": self.input_dim,
"input_length": self.input_length}
base_config = super(Convolution1D, self).get_config()
return dict(base_config.items() + config.items())


class Convolution2D(Layer):
Expand Down Expand Up @@ -253,19 +262,21 @@ def get_output(self, train=False):
return self.activation(conv_out + self.b.dimshuffle('x', 0, 'x', 'x'))

def get_config(self):
return {"name": self.__class__.__name__,
"nb_filter": self.nb_filter,
"nb_row": self.nb_row,
"nb_col": self.nb_col,
"init": self.init.__name__,
"activation": self.activation.__name__,
"border_mode": self.border_mode,
"subsample": self.subsample,
"W_regularizer": self.W_regularizer.get_config() if self.W_regularizer else None,
"b_regularizer": self.b_regularizer.get_config() if self.b_regularizer else None,
"activity_regularizer": self.activity_regularizer.get_config() if self.activity_regularizer else None,
"W_constraint": self.W_constraint.get_config() if self.W_constraint else None,
"b_constraint": self.b_constraint.get_config() if self.b_constraint else None}
config = {"name": self.__class__.__name__,
"nb_filter": self.nb_filter,
"nb_row": self.nb_row,
"nb_col": self.nb_col,
"init": self.init.__name__,
"activation": self.activation.__name__,
"border_mode": self.border_mode,
"subsample": self.subsample,
"W_regularizer": self.W_regularizer.get_config() if self.W_regularizer else None,
"b_regularizer": self.b_regularizer.get_config() if self.b_regularizer else None,
"activity_regularizer": self.activity_regularizer.get_config() if self.activity_regularizer else None,
"W_constraint": self.W_constraint.get_config() if self.W_constraint else None,
"b_constraint": self.b_constraint.get_config() if self.b_constraint else None}
base_config = super(Convolution2D, self).get_config()
return dict(base_config.items() + config.items())


class MaxPooling1D(Layer):
Expand Down Expand Up @@ -297,10 +308,12 @@ def get_output(self, train=False):
return T.reshape(output, (output.shape[0], output.shape[1], output.shape[2]))

def get_config(self):
return {"name": self.__class__.__name__,
"stride": self.stride,
"pool_length": self.pool_length,
"ignore_border": self.ignore_border}
config = {"name": self.__class__.__name__,
"stride": self.stride,
"pool_length": self.pool_length,
"ignore_border": self.ignore_border}
base_config = super(MaxPooling1D, self).get_config()
return dict(base_config.items() + config.items())


class MaxPooling2D(Layer):
Expand Down Expand Up @@ -328,10 +341,12 @@ def get_output(self, train=False):
return output

def get_config(self):
return {"name": self.__class__.__name__,
"pool_size": self.pool_size,
"ignore_border": self.ignore_border,
"stride": self.stride}
config = {"name": self.__class__.__name__,
"pool_size": self.pool_size,
"ignore_border": self.ignore_border,
"stride": self.stride}
base_config = super(MaxPooling2D, self).get_config()
return dict(base_config.items() + config.items())


class UpSample1D(Layer):
Expand All @@ -353,8 +368,10 @@ def get_output(self, train=False):
return output

def get_config(self):
return {"name": self.__class__.__name__,
"length": self.length}
config = {"name": self.__class__.__name__,
"length": self.length}
base_config = super(UpSample1D, self).get_config()
return dict(base_config.items() + config.items())


class UpSample2D(Layer):
Expand All @@ -377,8 +394,10 @@ def get_output(self, train=False):
return output

def get_config(self):
return {"name": self.__class__.__name__,
"size": self.size}
config = {"name": self.__class__.__name__,
"size": self.size}
base_config = super(UpSample2D, self).get_config()
return dict(base_config.items() + config.items())


class ZeroPadding1D(Layer):
Expand All @@ -403,7 +422,7 @@ class ZeroPadding1D(Layer):
def __init__(self, padding=1, **kwargs):
super(ZeroPadding1D, self).__init__(**kwargs)
self.padding = padding
self.input = T.tensor4()
self.input = T.tensor3()

@property
def output_shape(self):
Expand All @@ -412,12 +431,18 @@ def output_shape(self):

def get_output(self, train=False):
X = self.get_input(train)
output = T.zeros(self.output_shape)
return T.set_subtensor(output[:, self.padding:X.shape[1]+self.padding, :], X)
input_shape = X.shape
output_shape = (input_shape[0],
input_shape[1] + 2 * self.padding,
input_shape[2])
output = T.zeros(output_shape)
return T.set_subtensor(output[:, self.padding:X.shape[1] + self.padding, :], X)

def get_config(self):
return {"name": self.__class__.__name__,
"padding": self.padding}
config = {"name": self.__class__.__name__,
"padding": self.padding}
base_config = super(ZeroPadding1D, self).get_config()
return dict(base_config.items() + config.items())


class ZeroPadding2D(Layer):
Expand Down Expand Up @@ -455,17 +480,19 @@ def output_shape(self):
def get_output(self, train=False):
X = self.get_input(train)
input_shape = X.shape
out_shape = (input_shape[0],
input_shape[1],
input_shape[2] + 2 * self.padding[0],
input_shape[3] + 2 * self.padding[1])
out = T.zeros(out_shape)
output_shape = (input_shape[0],
input_shape[1],
input_shape[2] + 2 * self.padding[0],
input_shape[3] + 2 * self.padding[1])
output = T.zeros(output_shape)
indices = (slice(None),
slice(None),
slice(self.padding[0], input_shape[2] + self.padding[0]),
slice(self.padding[1], input_shape[3] + self.padding[1]))
return T.set_subtensor(out[indices], X)
return T.set_subtensor(output[indices], X)

def get_config(self):
return {"name": self.__class__.__name__,
"padding": self.padding}
config = {"name": self.__class__.__name__,
"padding": self.padding}
base_config = super(ZeroPadding2D, self).get_config()
return dict(base_config.items() + config.items())

0 comments on commit c60e2df

Please sign in to comment.