Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactoring of ConvLSTM2D. Added ConvRNN2D and ConvLSTM2DCell. #9094

Merged
merged 13 commits into from
Feb 25, 2018
Merged

Refactoring of ConvLSTM2D. Added ConvRNN2D and ConvLSTM2DCell. #9094

merged 13 commits into from
Feb 25, 2018

Conversation

gabrieldemarmiesse
Copy link
Contributor

@gabrieldemarmiesse gabrieldemarmiesse commented Jan 16, 2018

ConvRecurrent2D has been moved to legacy/layers.

The classes have been written:

  • ConvLSTM2DCell
  • ConvRNN2D
  • ConvLSTM2D

This PR has been mainly copy pasting and adapting code. Below are the places where there were the most changes are or the possible hotspots are (changes I'm not so sure are correct).

ConvLSTM2DCell

  • In call, I use K.ones_like to generate the dropout mask. Like this:
K.ones_like(inputs)

but in LSTMCell, this code is used:

def _generate_dropout_ones(inputs, dims):
    # Currently, CTNK can't instantiate `ones` with symbolic shapes.
    # Will update workaround once CTNK supports it.
    if K.backend() == 'cntk':
        ones = K.ones_like(K.reshape(inputs[:, 0], (-1, 1)))
        return K.tile(ones, (1, dims))
    else:
        return K.ones((K.shape(inputs)[0], dims))

I'm not sure why K.ones_like isn't used in LSTMCell

ConvRNN2D

  • In the test function, this part kept failing:
layer = convolutional_recurrent.ConvLSTM2D(**kwargs)
layer.build(inputs.shape)
assert len(layer.losses) == 3
assert layer.activity_regularizer
output = layer(K.variable(np.ones(inputs.shape)))
assert len(layer.losses) == 4

layer.losses was 6 instead of 4. I solved the problem by adding self.built = True at the end of the build function and by changing the way losses are collected. Because the activation loss was associated with the layer and not the cell. So this part:

@property
def losses(self):
    if isinstance(self.cell, Layer):
        return self.cell.losses
    return []

became this:

@property
def losses(self):
    layer_losses = super(ConvRNN2D, self).losses
    if isinstance(self.cell, Layer):
        return self.cell.losses + layer_losses
    return layer_losses
  • Also, this is get_initial_states. I assume here that the cell does a convolution on the input. Which may be not true (but for the moment it works since ConvLSTM2DCell is the only cell class that exist for 5D tensors).
def get_initial_state(self, inputs):
    # (samples, timesteps, rows, cols, filters)
    initial_state = K.zeros_like(inputs)
    # (samples, rows, cols, filters)
    initial_state = K.sum(initial_state, axis=1)
    shape = list(self.cell.kernel_shape)
    shape[-1] = self.cell.filters
    initial_state = self.cell.input_conv(initial_state,
                                          K.zeros(tuple(shape)),
                                          padding=self.cell.padding)
    if hasattr(self.cell.state_size, '__len__'):
        return [initial_state for _ in self.cell.state_size]
    else:
        return [initial_state]

ConvLSTM2D

  • Nothing magor. Copy-pasting the LSTM class and adding arguments like kernel_size or padding.

@gabrieldemarmiesse gabrieldemarmiesse changed the title [WIP] Refactoring of ConvLSTM2D. Added ConvRNN2D and ConvLSTM2DCell. Refactoring of ConvLSTM2D. Added ConvRNN2D and ConvLSTM2DCell. Jan 16, 2018
@gabrieldemarmiesse
Copy link
Contributor Author

gabrieldemarmiesse commented Jan 17, 2018

All tests are passing except for Theano. I kind of have a hard time this bug. theano.gof.fg.MissingInputError. It appears only when using dropout. It seems related to #2417 #2430 #1217 .

This bug appears only when calling predict, with Theano and with dropout or recurrent_dropout.

It seems Theano needs the learning phase, but I don't know how I can provide it.

It seems strange that it works for LSTM but not for my refactoring of ConvLSTM2DCell (this is essentially the same code). It's even stranger that it works with other backends but not Theano in particular.

I've done a bit of work on this PR, and I'm close to pass the tests, so if someone can help me understand how I can fix that, I would be very thankful.

                # check dropout
                layer_test(convolutional_recurrent.ConvLSTM2D,
                           kwargs={'data_format': data_format,
                                   'return_sequences': return_sequences,
                                   'filters': filters,
                                   'kernel_size': (num_row, num_col),
                                   'padding': 'same',
                                   'dropout': 0.1,
                                   'recurrent_dropout': 0.},
>                          input_shape=inputs.shape)

..\..\..\keras\utils\test_utils.py:91: in layer_test
    actual_output = model.predict(input_data)
..\..\..\keras\engine\training.py:1786: in predict
    self._make_predict_function()
..\..\..\keras\engine\training.py:1029: in _make_predict_function
    **kwargs)
..\..\..\keras\backend\theano_backend.py:1233: in function
    return Function(inputs, outputs, updates=updates, **kwargs)
..\..\..\keras\backend\theano_backend.py:1219: in __init__
    **kwargs)
C:\Users\smith\Miniconda3\lib\site-packages\theano\compile\function.py:317: in function
    output_keys=output_keys)
C:\Users\smith\Miniconda3\lib\site-packages\theano\compile\pfunc.py:486: in pfunc
    output_keys=output_keys)
C:\Users\smith\Miniconda3\lib\site-packages\theano\compile\function_module.py:1839: in orig_function
    name=name)
C:\Users\smith\Miniconda3\lib\site-packages\theano\compile\function_module.py:1487: in __init__
    accept_inplace)
C:\Users\smith\Miniconda3\lib\site-packages\theano\compile\function_module.py:181: in std_fgraph
    update_mapping=update_mapping)
C:\Users\smith\Miniconda3\lib\site-packages\theano\gof\fg.py:175: in __init__
    self.__import_r__(output, reason="init")
C:\Users\smith\Miniconda3\lib\site-packages\theano\gof\fg.py:346: in __import_r__
    self.__import__(variable.owner, reason=reason)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _

self = [InplaceDimShuffle{1,0,2,3,4}(InplaceDimShuffle{0,1,2,3,4}(for{cpu,scan_fn}(*7 -> Subtensor{int64}(Shape(*16 -> Subten...int32, matrix)>, Shape(*15)), mrg_uniform{TensorType(float32, 4D),no_inplace}(<TensorType(int32, matrix)>, Shape(*15))]
apply_node = mrg_uniform{TensorType(float32, 4D),no_inplace}(<TensorType(int32, matrix)>, Shape.0)
check = True, reason = 'init'

    def __import__(self, apply_node, check=True, reason=None):
        """
            Given an apply_node, recursively search from this node to know graph,
            and then add all unknown variables and apply_nodes to this graph.
            """
        node = apply_node
    
        # We import the nodes in topological order. We only are interested
        # in new nodes, so we use all variables we know of as if they were the input set.
        # (the functions in the graph module only use the input set to
        # know where to stop going down)
        new_nodes = graph.io_toposort(self.variables, apply_node.outputs)
    
        if check:
            for node in new_nodes:
                if hasattr(node, 'fgraph') and node.fgraph is not self:
                    raise Exception("%s is already owned by another fgraph" % node)
                for r in node.inputs:
                    if hasattr(r, 'fgraph') and r.fgraph is not self:
                        raise Exception("%s is already owned by another fgraph" % r)
                    if (r.owner is None and
                            not isinstance(r, graph.Constant) and
                            r not in self.inputs):
                        # Standard error message
                        error_msg = ("Input %d of the graph (indices start "
                                     "from 0), used to compute %s, was not "
                                     "provided and not given a value. Use the "
                                     "Theano flag exception_verbosity='high', "
                                     "for more information on this error."
                                     % (node.inputs.index(r), str(node)))
>                       raise MissingInputError(error_msg, variable=r)
E                       theano.gof.fg.MissingInputError: Input 0 of the graph (indices start from 0), used to compute Elemwise{second,no_inplace}(<TensorType(float32, 4D)>, InplaceDimShuffle{x,x,x,x}.0), was not provided and not given a value. Use the Theano flag exception_verbosity='high', for more information on this error.
E                        
E                       Backtrace when that variable is created:
E                       
E                         File "C:\Users\smith\Miniconda3\lib\site-packages\_pytest\python.py", line 147, in pytest_pyfunc_call
E                           testfunction(**testargs)
E                         File "C:\Users\smith\Desktop\projects\keras_pr\keras\tests\keras\layers\convolutional_recurrent_test.py", line 140, in test_convolutional_recurrent
E                           input_shape=inputs.shape)
E                         File "C:\Users\smith\Desktop\projects\keras_pr\keras\keras\utils\test_utils.py", line 85, in layer_test
E                           y = layer(x)
E                         File "C:\Users\smith\Desktop\projects\keras_pr\keras\keras\layers\convolutional_recurrent.py", line 292, in __call__
E                           return super(ConvRNN2D, self).__call__(inputs, **kwargs)
E                         File "C:\Users\smith\Desktop\projects\keras_pr\keras\keras\engine\topology.py", line 605, in __call__
E                           output = self.call(inputs, **kwargs)
E                         File "C:\Users\smith\Desktop\projects\keras_pr\keras\keras\layers\convolutional_recurrent.py", line 1055, in call
E                           initial_state=initial_state)
E                         File "C:\Users\smith\Desktop\projects\keras_pr\keras\keras\layers\convolutional_recurrent.py", line 391, in call
E                           input_length=timesteps)
E                         File "C:\Users\smith\Desktop\projects\keras_pr\keras\keras\backend\theano_backend.py", line 1424, in rnn
E                           go_backwards=go_backwards)

C:\Users\smith\Miniconda3\lib\site-packages\theano\gof\fg.py:391: MissingInputError

@fchollet
Copy link
Member

Thank you for the PR. Would a contributor be available to help review it?

return_sequences: Boolean. Whether to return the last output
in the output sequence, or the full sequence.
go_backwards: Boolean (default False).
If True, rocess the input sequence backwards.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

rocess -> process

return_sequences: Boolean. Whether to return the last output
in the output sequence, or the full sequence.
go_backwards: Boolean (default False).
If True, rocess the input sequence backwards.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

rocess -> process

Copy link
Member

@fchollet fchollet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR, and thank you for your patience (couldn't find anyone to review). My biggest comment is that you should subclass RNN is order not to reimplement common methods.


Do not use in a model -- it's not a functional layer!
class ConvRNN2D(Layer):
"""Base class for recurrent layers.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"for convolutional-recurrent layers"

It is at the moment not possible to stack ConvLSTM2DCells because
the stack class isn't implemented yet. It is also not recommended to
use ConvRNN2D with other cell types than ConvLSTM2DCell because
some things are assumed to be true to simplify the implementation.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Implementation-level comments should be left as code comments, and not included in the user-facing docs.


# Input shape
5D tensor with shape `(num_samples, timesteps, channels, rows, cols)`.
If 'data_format' is 'channels_first' then it's a
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please use the same formatting as we do in e.g. the Conv2D docstring.

This is the expected shape of your inputs *including the batch
size*.
It should be a tuple of integers, e.g. `(32, 10, 100)`.
if sequential model:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use bullet points

self.go_backwards = go_backwards
self.stateful = stateful

# Masking is only supported in Theano
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this still true? What's the underlying reason?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My bad, it was in the docstring of ConvRecurrent2D, I didn't check if it was still true or not.

if cell.data_format == 'channels_first':
output_shape = [output_shape] + [(input_shape[0], cell.filters, rows, cols) for _ in range(2)]
elif cell.data_format == 'channels_last':
output_shape = [output_shape] + [(input_shape[0], rows, cols, cell.filters) for _ in range(2)]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Break up long line

self._num_constants = len(constants)
additional_specs += self.constants_spec
# at this point additional_inputs cannot be empty
is_keras_tensor = hasattr(additional_inputs[0], '_keras_history')
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can use K.is_keras_tensor

# at this point additional_inputs cannot be empty
is_keras_tensor = hasattr(additional_inputs[0], '_keras_history')
for tensor in additional_inputs:
if hasattr(tensor, '_keras_history') != is_keras_tensor:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same

initial_state,
constants=constants,
go_backwards=self.go_backwards,
mask=mask,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pretty sure masking would work with TF here? Does it not?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My understanding of masking is, I would say, limited. I enabled masking for all backends in the last commit.

"""Abstract base class for convolutional recurrent layers.

Do not use in a model -- it's not a functional layer!
class ConvRNN2D(Layer):
Copy link
Member

@fchollet fchollet Jan 30, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe you should be able to remove a lot of redundant code by subclassing RNN. There are lots of shared methods.

@gabrieldemarmiesse
Copy link
Contributor Author

Thank you very much for the review! I'm going to look at it this week end.

@gabrieldemarmiesse
Copy link
Contributor Author

The corrections were done. Quite a lot of code could be removed indeed. But I'm still stuck on the Theano issue. Do you have any ideas or leads on this?

@fchollet
Copy link
Member

fchollet commented Feb 8, 2018

Sorry, no time to look into it right now. Could someone else help out?

@fchollet fchollet self-assigned this Feb 21, 2018
Copy link
Member

@fchollet fchollet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Regarding the Theano issue, this is hard limitation linked to variable creation in a symbolic loop.

Before the refactor, we would create the dropout mask outside of the loop, and pass it to the loop as a constants argument.

Now we're instead creating the mask at the first iteration of the loop. That doesn't work with Theano.

Thus we must disable dropout in Theano. The current architecture is cleaner and changing it for the sake of supporting this feature is not worth it.

Please put the following in the layer constructor:

    if K.backend() == 'theano' and (dropout or recurrent_dropout):
            warnings.warn(
                'RNN dropout is no longer supported with the Theano backend '
                'due to technical limitations. '
                'You can either set `dropout` and `recurrent_dropout` to 0, '
                'or use the TensorFlow backend.')
            dropout = 0.
            recurrent_dropout = 0.

class ConvLSTM2D(ConvRecurrent2D):
"""Convolutional LSTM.
@property
def losses(self):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe you don't need this method, the inherited one should work fine.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The issue is that removing this function causes the test to fail:

layer = convolutional_recurrent.ConvLSTM2D(**kwargs)
                    layer.build(inputs.shape)
                    assert len(layer.losses) == 3
                    assert layer.activity_regularizer
                    output = layer(K.variable(np.ones(inputs.shape)))
>                   assert len(layer.losses) == 4
E                   AssertionError: assert 3 == 4
E                    +  where 3 = len([<tf.Tensor 'add:0' shape=() dtype=float32>, <tf.Tensor 'add_1:0' shape=() dtype=float32>, <tf.Tensor 'add_2:0' shape=() dtype=float32>])
E                    +    where [<tf.Tensor 'add:0' shape=() dtype=float32>, <tf.Tensor 'add_1:0' shape=() dtype=float32>, <tf.Tensor 'add_2:0' shape=() dtype=float32>] = <keras.layers.convolutional_recurrent.ConvLSTM2D object at 0x00000262BA0BFBA8>.losses

I'm going to investigate and see what can be done.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So, after some research, I believe that we need either to keep it or change the losses property in the RNN class. Let me explain:

This test verifies that the activity regularizer loss can be obtained from the layer. If we fall back to the RNN loss, then this code get executed:

@property
def losses(self):
    if isinstance(self.cell, Layer):
        return self.cell.losses
    return []

which implies that the cell contains all the losses we are interested in. But this is a contradiction with the tests that I saw in recurrent_test.py notably, those lines:

layer = layer_class(units, return_sequences=False, weights=None,
                        input_shape=(timesteps, embedding_dim),
                        activity_regularizer='l2')
assert layer.activity_regularizer
x = K.variable(np.ones((num_samples, timesteps, embedding_dim)))
layer(x)
assert len(layer.cell.get_losses_for(x)) == 0
assert len(layer.get_losses_for(x)) == 1

In the end, I see no mention anywhere in LSTMCell of the property activity_regularizer leading me to believe that the cell does usually not hold this loss.
So this is strange. I made a last commit modifying the losses attribute of RNN because I think this is an expected behavior to get the loss relative to the activity regularizer when calling the attribute. Please tell me if this is what you think is right.

@gabrieldemarmiesse
Copy link
Contributor Author

Thank you very much for the help and expertise. I don't think I could have find this on my own (and the patch would have been discarded anyway since we want to keep the cell architecture). I'll update the PR this weekend.

Copy link
Member

@fchollet fchollet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM.
On behalf of the community, thank you for the good work!

@fchollet fchollet merged commit 125f423 into keras-team:master Feb 25, 2018
dschwertfeger added a commit to dschwertfeger/keras that referenced this pull request Feb 26, 2018
…ack-embeddings-from-layer-outputs

* upstream/master: (443 commits)
  Fix `pool_2d` of Theano for backend compatibility (keras-team#9479)
  Clean up conv backend tests (keras-team#9478)
  Refactoring of ConvLSTM2D. Added ConvRNN2D and ConvLSTM2DCell. (keras-team#9094)
  Fix different results over three backends for ResNet50 and MobileNet (keras-team#9473)
  Add depthwise conv2d for Theano and CNTK (keras-team#9457)
  Fix ImageDataGenerator preprocessing_function (keras-team#9273)
  Add separable conv2d for CNTK (keras-team#9442)
  Remove word “shuffled” from comments in examples (keras-team#9453)
  Add random brightness to Image Preprocessing (Code Cleanup) (keras-team#9390)
  Only print Theano RNN dropout warning when needed.
  Add train test split to DirectoryIterator (keras-team#6152)
  Misc: Slight optimisation (keras-team#9445)
  Force update of Sequences for Windows (keras-team#9436)
  Move tests for datasets (keras-team#9439)
  Style fixes (keras-team#9441)
  Increase test coverages by excluding several lines (keras-team#9428)
  Fixing minor bug in pretrained_word_embeddings example (keras-team#9438)
  Optimizer - set_weights : check weights length (keras-team#9435)
  Add `conv_utils_test` (keras-team#9429)
  Enable `variable_input_channels` test for `InceptionV3` (keras-team#9425)
  ...
@gabrieldemarmiesse gabrieldemarmiesse deleted the lstm_conv_pr branch September 18, 2018 06:32
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants