diff --git a/keras/backend/tensorflow_backend.py b/keras/backend/tensorflow_backend.py index 3ab0e7453ee..56b5c8342f7 100644 --- a/keras/backend/tensorflow_backend.py +++ b/keras/backend/tensorflow_backend.py @@ -2780,12 +2780,199 @@ def rnn(step_function, inputs, initial_states, {{np_implementation}} """ - return tf_keras_backend.rnn(step_function, inputs, initial_states, - go_backwards=go_backwards, - mask=mask, - constants=constants, - unroll=unroll, - input_length=input_length) + ndim = len(inputs.shape) + if ndim < 3: + raise ValueError('Input should be at least 3D.') + + # Transpose to time-major, i.e. + # from (batch, time, ...) to (time, batch, ...) + axes = [1, 0] + list(range(2, ndim)) + inputs = tf.transpose(inputs, (axes)) + + if mask is not None: + if mask.dtype != tf.bool: + mask = tf.cast(mask, tf.bool) + if len(mask.shape) != 2: + raise ValueError( + 'mask should have `shape=(samples, time)`, ' + 'got {}'.format(mask.shape)) + mask = tf.transpose(mask, [1, 0]) + + def get_matching_mask(mask_t, ref_tensor_t): + # tf.where needs its condition tensor + # to be the same shape as its two + # result tensors + ndim = len(ref_tensor_t.shape) + for _ in range(ndim - 1): + mask_t = expand_dims(mask_t) + add_shape = tf.shape(ref_tensor_t)[1:] + multiple = tf.concat([[1], add_shape], 0) + return tf.tile(mask_t, multiple) + + if constants is None: + constants = [] + + uses_learning_phase = [False] + + if unroll: + if not inputs.shape[0]: + raise ValueError('Unrolling requires a ' + 'fixed number of timesteps.') + states = initial_states + successive_states = [] + successive_outputs = [] + + input_list = tf.unstack(inputs) + if go_backwards: + input_list.reverse() + + if mask is not None: + mask_list = tf.unstack(mask) + if go_backwards: + mask_list.reverse() + + for inp, mask_t in zip(input_list, mask_list): + output, new_states = step_function(inp, states + constants) + if getattr(output, '_uses_learning_phase', False): + uses_learning_phase[0] = True + + if not successive_outputs: + prev_output = zeros_like(output) + else: + prev_output = successive_outputs[-1] + + output_mask_t = get_matching_mask(mask_t, output) + output = tf.where(output_mask_t, output, prev_output) + + return_states = [] + for state, new_state in zip(states, new_states): + state_mask_t = get_matching_mask(mask_t, new_state) + return_states.append(tf.where(state_mask_t, + new_state, + state)) + states = return_states + successive_outputs.append(output) + successive_states.append(states) + last_output = successive_outputs[-1] + new_states = successive_states[-1] + outputs = tf.stack(successive_outputs) + else: + for inp in input_list: + output, states = step_function(inp, states + constants) + if getattr(output, '_uses_learning_phase', False): + uses_learning_phase[0] = True + successive_outputs.append(output) + successive_states.append(states) + last_output = successive_outputs[-1] + new_states = successive_states[-1] + outputs = tf.stack(successive_outputs) + + else: + if go_backwards: + inputs = reverse(inputs, 0) + + states = tuple(initial_states) + + time_steps = tf.shape(inputs)[0] + output, _ = step_function(inputs[0], initial_states + constants) + output_ta = tensor_array_ops.TensorArray( + dtype=output.dtype, + size=time_steps, + tensor_array_name='output_ta') + initial_output = zeros_like(output) + input_ta = tensor_array_ops.TensorArray( + dtype=inputs.dtype, + size=time_steps, + tensor_array_name='input_ta') + input_ta = input_ta.unstack(inputs) + time = tf.constant(0, dtype='int32', name='time') + while_loop_kwargs = { + 'cond': lambda time, *_: time < time_steps, + 'parallel_iterations': 32, + 'swap_memory': True, + 'maximum_iterations': input_length} + + if mask is not None: + if go_backwards: + mask = reverse(mask, 0) + + mask_ta = tensor_array_ops.TensorArray( + dtype=tf.bool, + size=time_steps, + tensor_array_name='mask_ta') + mask_ta = mask_ta.unstack(mask) + + def _step(time, output_ta_t, output_tm1, *states): + """RNN step function. + # Arguments + time: Current timestep value. + output_ta_t: TensorArray. + output_tm1: output Tensor from previous timestep + *states: List of states. + # Returns + Tuple: `(time + 1,output_ta_t) + tuple(new_states)` + """ + current_input = input_ta.read(time) + mask_t = mask_ta.read(time) + output, new_states = step_function(current_input, + tuple(states) + + tuple(constants)) + if getattr(output, '_uses_learning_phase', False): + uses_learning_phase[0] = True + for state, new_state in zip(states, new_states): + new_state.set_shape(state.shape) + + output_mask_t = get_matching_mask(mask_t, output) + output = tf.where(output_mask_t, output, output_tm1) + + new_states = [tf.where(get_matching_mask(mask_t, new_states[i]), + new_states[i], + states[i]) for i in range(len(states))] + + output_ta_t = output_ta_t.write(time, output) + return (time + 1, output_ta_t, output) + tuple(new_states) + + final_outputs = control_flow_ops.while_loop( + body=_step, + loop_vars=(time, output_ta, initial_output) + states, + **while_loop_kwargs) + new_states = final_outputs[3:] # skip output_tm1 + else: + def _step(time, output_ta_t, *states): + """RNN step function. + # Arguments + time: Current timestep value. + output_ta_t: TensorArray. + *states: List of states. + # Returns + Tuple: `(time + 1,output_ta_t) + tuple(new_states)` + """ + current_input = input_ta.read(time) + output, new_states = step_function(current_input, + tuple(states) + + tuple(constants)) + if getattr(output, '_uses_learning_phase', False): + uses_learning_phase[0] = True + for state, new_state in zip(states, new_states): + new_state.set_shape(state.shape) + output_ta_t = output_ta_t.write(time, output) + return (time + 1, output_ta_t) + tuple(new_states) + + final_outputs = control_flow_ops.while_loop( + body=_step, + loop_vars=(time, output_ta) + states, + **while_loop_kwargs) + new_states = final_outputs[2:] + + last_time = final_outputs[0] + output_ta = final_outputs[1] + outputs = output_ta.stack() + last_output = output_ta.read(last_time - 1) + + axes = [1, 0] + list(range(2, len(outputs.shape))) + outputs = tf.transpose(outputs, axes) + last_output._uses_learning_phase = uses_learning_phase[0] + return last_output, outputs, new_states @symbolic