Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Advanced Reshape Layer #36

Closed
wants to merge 7 commits into from
Closed

Advanced Reshape Layer #36

wants to merge 7 commits into from

Conversation

patyork
Copy link
Contributor

@patyork patyork commented Apr 5, 2015

I've added a bit of an architecture change, as well as a new layer to allow Advanced Reshaping.

Basically, to allow reshapes involving the first dimension, the number of samples in the current batch (current_batch_size) must be available to the AdvancedReshape layer. The easiest way that I could see to allow that is to throw a new parameter (current_batch_size) to the Layer.output function. This recursively passes the current number of samples in the batch to each layer, so that it has it available if it is necessary. This required touching every layer.

The AdvancedReshape layer takes a lambda expression for the initialization parameter. This lambda should take 2 arguments: current_batch_size and current_shape (can be seen below). From these parameters, it is possible to reshape between 1D, 2D, 3D (and possibly to ND, although I am unsure on that). This lambda should return a tuple.

Below is an example in which the current_batch_size changes on the pass over the last batch. It is simplistic and not overly useful, but it shows that AdvancedReshape can correctly reshape from 2D -> 3D and then from 3D -> 2D without breaking over the toughest example:

model = Sequential()
model.add(Dense(2,1, activation='sigmoid'))

# Reshape to 3D: (number of sample in current batch, elements in each sample, values)
model.add(AdvancedReshape(new_shape_fn=lambda current_batch_size, current_shape: (current_batch_size, current_shape[0]/current_batch_size, current_shape[1])))
# Reshape back to 2D ((number of sample in current batch * elements in each sample, values))
model.add(AdvancedReshape(new_shape_fn=lambda current_batch_size, current_shape: (current_batch_size * current_shape[1], current_shape[2])))

sgd = SGD(lr=0.1, decay=1e-6, momentum=0.9, nesterov=True)
model.compile(loss='mse', optimizer='sgd')

X = np.zeros((3,2))
Y = np.zeros((3,1))
model.fit(X, Y, batch_size=2, nb_epoch=1)
Epoch 0
�2/3 [===================>..........] - ETA: 0s - loss: 0.2500��������������������������������������������������������������
3/3 [==============================] - 0s - loss: 0.2497

@patyork
Copy link
Contributor Author

patyork commented Apr 5, 2015

A less trivial example (and, in fact, one that is quite useful). The below is a deep net, with 3 Dense layers, a recurrent layer, and a Dense layer:

# This model is similar to the architecture utilized in DeepSpeech [http://arxiv.org/abs/1412.5567]
#   which attained state-of-the art performance in speech recognition in noisy environments
# The only changes would be:
#   -a Bidirectional RNN (BRNN) instead of a simple RNN layer,
#   -Clipped ReLU instead of PReLU (although PReLU may perform better)
#   -the loss would be NLL of Connectionist Temporal Classification (CTC) cost
model = Sequential()
model.add(Dense(1520,2048))
model.add(PReLU(2048))
model.add(Dropout(p=.15))
model.add(Dense(2048,2048))
model.add(PReLU(2048))
model.add(Dropout(p=.15))
model.add(Dense(2048,2048))
model.add(PReLU(2048))
model.add(Dropout(p=.15))
model.add(AdvancedReshape(new_shape_fn=lambda current_batch_size, current_shape: (current_batch_size, current_shape[0]/current_batch_size, current_shape[1])))
model.add(SimpleRNN(2048, 2048))
model.add(PReLU(2048))
model.add(AdvancedReshape(new_shape_fn=lambda current_batch_size, current_shape: (current_batch_size*current_shape[1], current_shape[2])))
model.add(Dense(2048,30, activation='softmax')

sgd = SGD(lr=0.1, decay=1e-6, momentum=0.9, nesterov=True)
model.compile(loss='categorical_crossentropy', optimizer='sgd')

@fchollet
Copy link
Member

fchollet commented Apr 6, 2015

model.add(AdvancedReshape(new_shape_fn=lambda current_batch_size, current_shape: (current_batch_size, current_shape[0]/current_batch_size, current_shape[1])))
model.add(SimpleRNN(2048, 2048))
model.add(PReLU(2048))
model.add(AdvancedReshape(new_shape_fn=lambda current_batch_size, current_shape: (current_batch_size*current_shape[1], current_shape[2])))

Could you provide more detail about what is going on here, in terms what quantities the dimensions in the successive shapes stand for? I'm having trouble following.

@patyork
Copy link
Contributor Author

patyork commented Apr 6, 2015

Sure.

At the point in the model below, the batches are stacked into a 2D matrix that is passed around.

However, an RNN, needs a tensor3 of size (nb_samples, time_steps, values). The lambda below creates this shape, since time_steps == current_shape[0] / current_batch_size.

model.add(AdvancedReshape(new_shape_fn=lambda current_batch_size, current_shape: (current_batch_size, current_shape[0]/current_batch_size, current_shape[1])))

Then, the data is sent through the RNN

model.add(SimpleRNN(2048, 2048))

Finally, we want to restack the tensor3 into a 2D matrix, which is given by the lambda expression:

model.add(AdvancedReshape(new_shape_fn=lambda current_batch_size, current_shape: (current_batch_size*current_shape[1], current_shape[2])))
# which is equivalent to
model.add(AdvancedReshape(new_shape_fn=lambda current_batch_size, current_shape: (current_shape[0]*current_shape[1], current_shape[2])))

Full example, with numbers.

# let batch size = 5
# let there be 6 time steps in each sample

# Incoming shape: (5 * 6, 2048) == (30, 2048)

model.add(AdvancedReshape(new_shape_fn=lambda current_batch_size, current_shape: (current_batch_size, current_shape[0]/current_batch_size, current_shape[1])))
# Out shape: (5, 30/5, 2048) == (5, 6, 2048)

model.add(SimpleRNN(2048, 2048))
# Output shape: (5, 6, 2048)

model.add(AdvancedReshape(new_shape_fn=lambda current_batch_size, current_shape: (current_batch_size*current_shape[1], current_shape[2])))
# Output shape: (5 * 6, 2048) == (30, 2048)

@fchollet
Copy link
Member

fchollet commented Apr 6, 2015

model.add(AdvancedReshape(new_shape_fn=lambda current_batch_size, current_shape: (current_batch_size, current_shape[0]/current_batch_size, current_shape[1])))
# Out shape: (5, 30/5, 2048) == (5, 6, 2048)

This seems to assume current_shape[0] != current_batch_size, ie. current_shape == X.shape[1:]. But looking at the code:

nshape = self.new_shape_fn(current_batch_size, X.shape) 

Am I missing something?

@patyork
Copy link
Contributor Author

patyork commented Apr 6, 2015

This is not an assumption, this is fact when the batch size passed into the model.fit is not 1( batch_size != 1 )

self.new_shape_fn(...) is a function that produces a new shape, from the lambda given in the AdvancedReshape layer init.

I will provide a full working example in a moment. But let me think on it for a while, before I give examples with actual numbers.

@patyork
Copy link
Contributor Author

patyork commented Apr 6, 2015

EDIT:
Accidentally closed.

@patyork patyork closed this Apr 6, 2015
@fchollet fchollet reopened this Apr 6, 2015
@fchollet
Copy link
Member

fchollet commented Apr 6, 2015

So you're applying your reshape lambda to X.shape, where X is the input to the AdvancedReshape layer. It seems to me like X.shape[0] would be the number of samples in the batch. How is this not the case?

As far as I can tell, current_batch_size is not a user-facing parameter but is the first element of the shape of the input hitting the layer at the current iteration.

I'm just trying to understand how this works...

Looking at your code:

# we'll note X = layer.get_input(train, current_batch_size) at each layer. nb_samples is arbitrary.
# we're calling this model on an input of shape [nb_samples, 1520]. According to the changes in models.py,
# that means the value of current_batch_size being propagated throughout the entire model is nb_samples.
model = Sequential() 
model.add(Dense(1520,2048)) # X.shape == [nb_samples, 1520]
model.add(PReLU(2048)) # X.shape == [nb_samples, 2048]
model.add(Dropout(p=.15))
model.add(Dense(2048,2048)) 
model.add(PReLU(2048)) # X.shape == [nb_samples, 2048]
model.add(Dropout(p=.15))
model.add(Dense(2048,2048))
model.add(PReLU(2048)) # X.shape == [nb_samples, 2048]
model.add(Dropout(p=.15))
model.add(AdvancedReshape(new_shape_fn=lambda current_batch_size, current_shape: \
(current_batch_size, current_shape[0]/current_batch_size, current_shape[1]))) # X.shape == [nb_samples, 2048]
# hence the lambda is called on (nb_samples, (nb_samples, 2048))
# and returns (nb_samples, nb_samples/nb_samples, 2048)

At least that's how I understand it. Where is this incorrect?

@patyork
Copy link
Contributor Author

patyork commented Apr 8, 2015

This made the assumption that Keras supports 3D input into Dense layers (required for networks that include at least one Recurrent layer and one non-recurrent layer).

@patyork patyork closed this Apr 8, 2015
howard0su pushed a commit to howard0su/keras that referenced this pull request Jan 28, 2017
* fix saving

* fix

* mxnet backend (keras-team#32)

* handle comp operators in KerasSymbol

* Fix some test cases

* Fix random_uniform

* More fixes

* Fix model load/save

* fix Compare
fix random_uniform_variable

* fix relu

* Context Setting change (keras-team#33)

* Context Setting change

1. support gpu(?) besides gpu?
2. Default context when compling is get from mx.default_context()

* Do not cast if type is same

* fix

* Implemented tile (keras-team#35)

* fix
@harish2704
Copy link

Hi, Is there any update on this PR ?
Is there any change that , this will get merged ?
I am also facing similar issue

hubingallin pushed a commit to hubingallin/keras that referenced this pull request Sep 22, 2023
* Adds adam optimizer

* Adds adam optimizer

* Add optimizer tests for Adam.

* Adds adam optimizer

* Added initial Adam tests

* Adds adam optimizer

* Applied nit fixes for Adam

* Indent fix in docstring
pnacht pushed a commit to pnacht/keras that referenced this pull request Nov 10, 2023
* Adds adam optimizer

* Adds adam optimizer

* Add optimizer tests for Adam.

* Adds adam optimizer

* Added initial Adam tests

* Adds adam optimizer

* Applied nit fixes for Adam

* Indent fix in docstring
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants