Skip to content

Commit

Permalink
Revert backend rnn.
Browse files Browse the repository at this point in the history
  • Loading branch information
fchollet committed Mar 6, 2019
1 parent 4d857be commit 8cc7a97
Showing 1 changed file with 193 additions and 6 deletions.
199 changes: 193 additions & 6 deletions keras/backend/tensorflow_backend.py
Expand Up @@ -2780,12 +2780,199 @@ def rnn(step_function, inputs, initial_states,
{{np_implementation}}
"""
return tf_keras_backend.rnn(step_function, inputs, initial_states,
go_backwards=go_backwards,
mask=mask,
constants=constants,
unroll=unroll,
input_length=input_length)
ndim = len(inputs.shape)
if ndim < 3:
raise ValueError('Input should be at least 3D.')

# Transpose to time-major, i.e.
# from (batch, time, ...) to (time, batch, ...)
axes = [1, 0] + list(range(2, ndim))
inputs = tf.transpose(inputs, (axes))

if mask is not None:
if mask.dtype != tf.bool:
mask = tf.cast(mask, tf.bool)
if len(mask.shape) != 2:
raise ValueError(
'mask should have `shape=(samples, time)`, '
'got {}'.format(mask.shape))
mask = tf.transpose(mask, [1, 0])

def get_matching_mask(mask_t, ref_tensor_t):
# tf.where needs its condition tensor
# to be the same shape as its two
# result tensors
ndim = len(ref_tensor_t.shape)
for _ in range(ndim - 1):
mask_t = expand_dims(mask_t)
add_shape = tf.shape(ref_tensor_t)[1:]
multiple = tf.concat([[1], add_shape], 0)
return tf.tile(mask_t, multiple)

if constants is None:
constants = []

uses_learning_phase = [False]

if unroll:
if not inputs.shape[0]:
raise ValueError('Unrolling requires a '
'fixed number of timesteps.')
states = initial_states
successive_states = []
successive_outputs = []

input_list = tf.unstack(inputs)
if go_backwards:
input_list.reverse()

if mask is not None:
mask_list = tf.unstack(mask)
if go_backwards:
mask_list.reverse()

for inp, mask_t in zip(input_list, mask_list):
output, new_states = step_function(inp, states + constants)
if getattr(output, '_uses_learning_phase', False):
uses_learning_phase[0] = True

if not successive_outputs:
prev_output = zeros_like(output)
else:
prev_output = successive_outputs[-1]

output_mask_t = get_matching_mask(mask_t, output)
output = tf.where(output_mask_t, output, prev_output)

return_states = []
for state, new_state in zip(states, new_states):
state_mask_t = get_matching_mask(mask_t, new_state)
return_states.append(tf.where(state_mask_t,
new_state,
state))
states = return_states
successive_outputs.append(output)
successive_states.append(states)
last_output = successive_outputs[-1]
new_states = successive_states[-1]
outputs = tf.stack(successive_outputs)
else:
for inp in input_list:
output, states = step_function(inp, states + constants)
if getattr(output, '_uses_learning_phase', False):
uses_learning_phase[0] = True
successive_outputs.append(output)
successive_states.append(states)
last_output = successive_outputs[-1]
new_states = successive_states[-1]
outputs = tf.stack(successive_outputs)

else:
if go_backwards:
inputs = reverse(inputs, 0)

states = tuple(initial_states)

time_steps = tf.shape(inputs)[0]
output, _ = step_function(inputs[0], initial_states + constants)
output_ta = tensor_array_ops.TensorArray(
dtype=output.dtype,
size=time_steps,
tensor_array_name='output_ta')
initial_output = zeros_like(output)
input_ta = tensor_array_ops.TensorArray(
dtype=inputs.dtype,
size=time_steps,
tensor_array_name='input_ta')
input_ta = input_ta.unstack(inputs)
time = tf.constant(0, dtype='int32', name='time')
while_loop_kwargs = {
'cond': lambda time, *_: time < time_steps,
'parallel_iterations': 32,
'swap_memory': True,
'maximum_iterations': input_length}

if mask is not None:
if go_backwards:
mask = reverse(mask, 0)

mask_ta = tensor_array_ops.TensorArray(
dtype=tf.bool,
size=time_steps,
tensor_array_name='mask_ta')
mask_ta = mask_ta.unstack(mask)

def _step(time, output_ta_t, output_tm1, *states):
"""RNN step function.
# Arguments
time: Current timestep value.
output_ta_t: TensorArray.
output_tm1: output Tensor from previous timestep
*states: List of states.
# Returns
Tuple: `(time + 1,output_ta_t) + tuple(new_states)`
"""
current_input = input_ta.read(time)
mask_t = mask_ta.read(time)
output, new_states = step_function(current_input,
tuple(states) +
tuple(constants))
if getattr(output, '_uses_learning_phase', False):
uses_learning_phase[0] = True
for state, new_state in zip(states, new_states):
new_state.set_shape(state.shape)

output_mask_t = get_matching_mask(mask_t, output)
output = tf.where(output_mask_t, output, output_tm1)

new_states = [tf.where(get_matching_mask(mask_t, new_states[i]),
new_states[i],
states[i]) for i in range(len(states))]

output_ta_t = output_ta_t.write(time, output)
return (time + 1, output_ta_t, output) + tuple(new_states)

final_outputs = control_flow_ops.while_loop(
body=_step,
loop_vars=(time, output_ta, initial_output) + states,
**while_loop_kwargs)
new_states = final_outputs[3:] # skip output_tm1
else:
def _step(time, output_ta_t, *states):
"""RNN step function.
# Arguments
time: Current timestep value.
output_ta_t: TensorArray.
*states: List of states.
# Returns
Tuple: `(time + 1,output_ta_t) + tuple(new_states)`
"""
current_input = input_ta.read(time)
output, new_states = step_function(current_input,
tuple(states) +
tuple(constants))
if getattr(output, '_uses_learning_phase', False):
uses_learning_phase[0] = True
for state, new_state in zip(states, new_states):
new_state.set_shape(state.shape)
output_ta_t = output_ta_t.write(time, output)
return (time + 1, output_ta_t) + tuple(new_states)

final_outputs = control_flow_ops.while_loop(
body=_step,
loop_vars=(time, output_ta) + states,
**while_loop_kwargs)
new_states = final_outputs[2:]

last_time = final_outputs[0]
output_ta = final_outputs[1]
outputs = output_ta.stack()
last_output = output_ta.read(last_time - 1)

axes = [1, 0] + list(range(2, len(outputs.shape)))
outputs = tf.transpose(outputs, axes)
last_output._uses_learning_phase = uses_learning_phase[0]
return last_output, outputs, new_states


@symbolic
Expand Down

0 comments on commit 8cc7a97

Please sign in to comment.