Skip to content

Commit

Permalink
Update backends with rnn support
Browse files Browse the repository at this point in the history
  • Loading branch information
fchollet committed Nov 26, 2015
1 parent 37ebbc3 commit 47ed18a
Show file tree
Hide file tree
Showing 6 changed files with 510 additions and 629 deletions.
89 changes: 86 additions & 3 deletions keras/backend/tensorflow_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,10 +331,78 @@ def gradients(loss, variables):

# CONTROL FLOW

def rnn(step_function, inputs, initial_states, go_backwards=False):
'''TODO
def rnn(step_function, inputs, initial_states,
go_backwards=False, masking=True):
'''Iterates over the time dimension of a tensor.
Parameters
----------
inputs: tensor of temporal data of shape (samples, time, ...)
(at least 3D).
step_function:
Parameters:
input: tensor with shape (samples, ...) (no time dimension),
representing input for the batch of samples at a certain
time step.
states: list of tensors.
Returns:
output: tensor with shape (samples, ...) (no time dimension),
new_states: list of tensors, same length and shapes
as 'states'.
initial_states: tensor with shape (samples, ...) (no time dimension),
containing the initial values for the states used in
the step function.
go_backwards: boolean. If True, do the iteration over
the time dimension in reverse order.
masking: boolean. If true, any input timestep inputs[s, i]
that is all-zeros will be skipped (states will be passed to
the next step unchanged) and the corresponding output will
be all zeros.
Returns
-------
A tuple (last_output, outputs, new_states).
last_output: the latest output of the rnn, of shape (samples, ...)
outputs: tensor with shape (samples, time, ...) where each
entry outputs[s, t] is the output of the step function
at time t for sample s.
new_states: list of tensors, latest states returned by
the step function, of shape (samples, ...).
'''
pass
inputs = tf.transpose(inputs, (1, 0, 2))
input_list = tf.unpack(inputs)

states = initial_states
successive_states = []
successive_outputs = []
if go_backwards:
input_list = input_list.reverse()
for input in input_list:
output, new_states = step_function(input, states)
if masking:
# if all-zero input timestep, return
# all-zero output and unchanged states
switch = tf.reduce_any(input)
output = tf.control_flow_ops.cond(switch,
lambda: output,
lambda: 0. * output)
return_states = []
for state, new_state in zip(states, new_states):
return_states.append(tf.control_flow_ops.cond(switch,
lambda: new_state,
lambda: state))
states = return_states
else:
states = new_states
successive_outputs.append(output)
successive_states.append(states)

last_output = successive_outputs[-1]
outputs = tf.pack(successive_outputs)
new_states = successive_states[-1]

outputs = tf.transpose(outputs, (1, 0, 2))
return last_output, outputs, states


def switch(condition, then_expression, else_expression):
Expand Down Expand Up @@ -442,6 +510,11 @@ def conv2d(x, kernel, strides=(1, 1), border_mode='valid', dim_ordering='th'):

strides = (1,) + strides + (1,)

if _FLOATX == 'float64':
# tf conv2d only supports float32
x = tf.cast(x, 'float32')
kernel = tf.cast(kernel, 'float32')

if dim_ordering == 'th':
# TF uses the last dimension as channel dimension,
# instead of the 2nd one.
Expand All @@ -457,6 +530,9 @@ def conv2d(x, kernel, strides=(1, 1), border_mode='valid', dim_ordering='th'):
x = tf.nn.conv2d(x, kernel, strides, padding=padding)
else:
raise Exception('Unknown dim_ordering: ' + str(dim_ordering))

if _FLOATX == 'float64':
x = tf.cast(x, 'float64')
return x


Expand All @@ -478,6 +554,10 @@ def maxpool2d(x, pool_size, strides=(1, 1),
strides = (1,) + strides + (1,)
pool_size = (1,) + pool_size + (1,)

if _FLOATX == 'float64':
# tf max_pool only supports float32
x = tf.cast(x, 'float32')

if dim_ordering == 'th':
# TF uses the last dimension as channel dimension,
# instead of the 2nd one.
Expand All @@ -492,6 +572,9 @@ def maxpool2d(x, pool_size, strides=(1, 1),
x = tf.nn.max_pool(x, pool_size, strides, padding=padding)
else:
raise Exception('Unknown dim_ordering: ' + str(dim_ordering))

if _FLOATX == 'float64':
x = tf.cast(x, 'float64')
return x


Expand Down
81 changes: 76 additions & 5 deletions keras/backend/theano_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,12 +353,83 @@ def gradients(loss, variables):

# CONTROL FLOW

def rnn(step_function, inputs, initial_states, go_backwards=False):
'''TODO
Wrapper for scan
def rnn(step_function, inputs, initial_states,
go_backwards=False, masking=True):
'''Iterates over the time dimension of a tensor.
Parameters
----------
inputs: tensor of temporal data of shape (samples, time, ...)
(at least 3D).
step_function:
Parameters:
input: tensor with shape (samples, ...) (no time dimension),
representing input for the batch of samples at a certain
time step.
states: list of tensors.
Returns:
output: tensor with shape (samples, ...) (no time dimension),
new_states: list of tensors, same length and shapes
as 'states'.
initial_states: tensor with shape (samples, ...) (no time dimension),
containing the initial values for the states used in
the step function.
go_backwards: boolean. If True, do the iteration over
the time dimension in reverse order.
masking: boolean. If true, any input timestep inputs[s, i]
that is all-zeros will be skipped (states will be passed to
the next step unchanged) and the corresponding output will
be all zeros.
Returns
-------
A tuple (last_output, outputs, new_states).
last_output: the latest output of the rnn, of shape (samples, ...)
outputs: tensor with shape (samples, time, ...) where each
entry outputs[s, t] is the output of the step function
at time t for sample s.
new_states: list of tensors, latest states returned by
the step function, of shape (samples, ...).
'''
pass
inputs = inputs.dimshuffle((1, 0, 2))

def _step(*args):
global single_result
input = args[0]
states = args[1:]
output, new_states = step_function(input, states)
if masking:
# if all-zero input timestep, return
# all-zero output and unchanged states

This comment has been minimized.

Copy link
@EderSantana

EderSantana Nov 26, 2015

Contributor

This is not correct the most correct thing to do though. Outputs are dependent of the hidden states. If states are unchanged, so should outputs. I believe there won't be a lot of harm, but this is something to think about.

switch = T.any(input)
output = T.switch(switch, output, 0. * output)
return_states = []
for state, new_state in zip(states, new_states):
return_states.append(T.switch(switch, new_state, state))
return [output] + return_states
else:
return [output] + new_states

results, _ = theano.scan(
_step,
sequences=inputs,
outputs_info=[None] + initial_states,
go_backwards=go_backwards)

# deal with Theano API inconsistency
if type(results) is list:
outputs = results[0]
states = results[1:]
else:
outputs = results
states = []

outputs = T.squeeze(outputs)
last_output = outputs[-1]

outputs = outputs.dimshuffle((1, 0, 2))
states = [T.squeeze(state[-1]) for state in states]
return last_output, outputs, states


def switch(condition, then_expression, else_expression):
Expand Down
12 changes: 9 additions & 3 deletions keras/layers/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -782,9 +782,15 @@ def output_shape(self):

def get_output(self, train=False):
X = self.get_input(train)
output = self.activation(K.dot(K.permute_dimensions(X, (1, 0, 2)),
self.W) + self.b)
return K.permute_dimensions(output, (1, 0, 2))

def step(x, states):
output = K.dot(x, self.W) + self.b
return output, []

last_output, outputs, states = K.rnn(step, X, [], masking=False)

outputs = self.activation(outputs)
return outputs

def get_config(self):
config = {"name": self.__class__.__name__,
Expand Down

0 comments on commit 47ed18a

Please sign in to comment.