Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

LearnedSelfAttentionLayer requires fixed batchsize to output/fit or else fails #7777

HWBK opened this issue May 23, 2019 · 2 comments


Copy link

commented May 23, 2019

Issue Description

Trying to build a sequence classifier using the ComputationGraphConfiguration:

ComputationGraphConfiguration confSemantic = new NeuralNetConfiguration.Builder()
                .seed( System.currentTimeMillis() )
                .updater( new Nadam() )
                .addInputs("mode_input", "word_input" )
                .addLayer("L1", new Bidirectional(new LSTM.Builder()
                        .activation(Activation.TANH).nOut(lstmLayerSize1).build()), "word_input" )
                .addLayer( "intent_attention" , new LearnedSelfAttentionLayer.Builder().nQueries(50).nHeads(1).build() , "L1", "mode_input"  )
                .addLayer( "slot_attention" ,  new LearnedSelfAttentionLayer.Builder().nQueries(50).nHeads(1).build() , "L1", "mode_input" )
                //.addLayer( "intent_attention" , new SelfAttentionLayer.Builder().nHeads(1).build() , "L1", "mode_input"  )
                //.addLayer( "slot_attention" ,  new SelfAttentionLayer.Builder().nHeads(1).build() , "L1", "mode_input" )
                .addLayer("slot_out", new RnnOutputLayer.Builder(LossFunction.MCXENT).activation(Activation.SOFTMAX) //MCXENT + softmax for classification
                        .nOut( testData.uniqueSlotsList.size() ).build(), "slot_attention" )
                .addLayer("intent_out", new RnnOutputLayer.Builder(LossFunction.MCXENT).activation(Activation.SOFTMAX) //MCXENT + softmax for classification
                        .nOut( testData.uniqueSemanticTagList.size() ).build() , "intent_attention")
                .setOutputs("slot_out", "intent_out")
                .setInputTypes(InputType.recurrent(testData.MODE_VEC_SIZE), InputType.recurrent(testData.WORD_VEC_SIZE)  )
ComputationGraph netSemantic = new ComputationGraph(confSemantic);

word_input is an sequence of word embeddings (length 300). The maximum length of the sequence is 50.
lstmLayerSize1 is 110 and mode_input size is 3 so the input to the attention is 100*2 + 3 = 223
trainData and testData implement MultiDataSetIterator
Batchsize is 16

After fitting some data using:

while (trainData.hasNext()) {;

Using the SelfAttentionLayer works but when I use the LearnedSelfAttentionLayer and call an output function like:

while (testData.hasNext()) {
        MultiDataSet ds =;
        for ( MultiDataSet singlePrediction : ds.asList()) {
            INDArray [] out = netSemantic.output(singlePrediction.getFeatures());

it fails with the stack at:

It looks like output function is expecting a batch of 16 [223,50] inputs instead of just the one.
Also, during fitting, the batch must be exactly 16 or the fit function fails with a similar error
xShape = [14, 223, 50], yShape = [16, 223, 50]
since the batch only had 14 left in the training set instead of a full batch of 16

Version Information

  • Deeplearning4j version : 1.0.0-beta4
  • platform information :
    Apache Maven 3.6.0
    Maven home: /usr/share/maven
    Java version: 1.8.0_171, vendor: Oracle Corporation, runtime: /usr/lib/jvm/java-8-oracle/jre
    Default locale: en_US, platform encoding: UTF-8
    OS name: "linux", version: "4.15.0-50-generic", arch: "amd64", family: "unix"
  • CUDA version: NA (CPU)

This comment has been minimized.

Copy link

commented May 23, 2019

Thanks for flagging the issue.

That is an actual bug in the definition of the layer.


This comment has been minimized.

Copy link

commented May 23, 2019

The fix has been merged to the current dev branch and will be available on snapshots once the dev branch has been merged to master.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
None yet
2 participants
You can’t perform that action at this time.