Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MergeVertex with 3D convolution issue #7715

Closed
AbdelmajidB opened this issue May 12, 2019 · 4 comments

Comments

@AbdelmajidB
Copy link

commented May 12, 2019

Issue Description

hi, i'm working on 3 u-net model, i'm in beta4 which support 3d convolution, in the following the 3d u-net that i have created:
public ComputationGraphConfiguration.GraphBuilder unetBuilder() {

    ComputationGraphConfiguration.GraphBuilder graph = new NeuralNetConfiguration.Builder().seed(seed)
               .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
               .updater(updater)
               .weightInit(weightInit)
               .l2(5e-5)
               .miniBatch(true)
               .cacheMode(cacheMode)
               .trainingWorkspaceMode(workspaceMode)
               .inferenceWorkspaceMode(workspaceMode)
               .graphBuilder();
    graph  
            .addLayer("conv1-1", new Convolution3D.Builder(3,3,3).stride(1,1,1).nOut(64)
                    .convolutionMode(ConvolutionMode.Same).cudnnAlgoMode(cudnnAlgoMode)
                    .activation(Activation.RELU).build(), "input")
            .addLayer("conv1-2", new Convolution3D.Builder(3,3,3).stride(1,1,1).nOut(64)
                    .convolutionMode(ConvolutionMode.Same).cudnnAlgoMode(cudnnAlgoMode)
                    .activation(Activation.RELU).build(), "conv1-1")
            .addLayer("pool1", new Subsampling3DLayer.Builder(Subsampling3DLayer.PoolingType.MAX).kernelSize(2,2,2)
                    .build(), "conv1-2")

            .addLayer("conv2-1", new Convolution3D.Builder(3,3,3).stride(1,1,1).nOut(128)
                    .convolutionMode(ConvolutionMode.Same).cudnnAlgoMode(cudnnAlgoMode)
                    .activation(Activation.RELU).build(), "pool1")
            .addLayer("conv2-2", new Convolution3D.Builder(3,3,3).stride(1,1,1).nOut(128)
                    .convolutionMode(ConvolutionMode.Same).cudnnAlgoMode(cudnnAlgoMode)
                    .activation(Activation.RELU).build(), "conv2-1")
            .addLayer("pool2", new Subsampling3DLayer.Builder(Subsampling3DLayer.PoolingType.MAX).kernelSize(2,2,2)
                    .build(), "conv2-2")

            .addLayer("conv3-1", new Convolution3D.Builder(3,3,3).stride(1,1,1).nOut(256)
                    .convolutionMode(ConvolutionMode.Same).cudnnAlgoMode(cudnnAlgoMode)
                    .activation(Activation.RELU).build(), "pool2")
            .addLayer("conv3-2", new Convolution3D.Builder(3,3,3).stride(1,1,1).nOut(256)
                    .convolutionMode(ConvolutionMode.Same).cudnnAlgoMode(cudnnAlgoMode)
                    .activation(Activation.RELU).build(), "conv3-1")
            .addLayer("pool3", new Subsampling3DLayer.Builder(Subsampling3DLayer.PoolingType.MAX).kernelSize(2,2,2)
                    .build(), "conv3-2")

            .addLayer("conv4-1", new Convolution3D.Builder(3,3,3).stride(1,1,1).nOut(512)
                    .convolutionMode(ConvolutionMode.Same).cudnnAlgoMode(cudnnAlgoMode)
                    .activation(Activation.RELU).build(), "pool3")
            .addLayer("conv4-2", new Convolution3D.Builder(3,3,3).stride(1,1,1).nOut(512)
                    .convolutionMode(ConvolutionMode.Same).cudnnAlgoMode(cudnnAlgoMode)
                    .activation(Activation.RELU).build(), "conv4-1")
            .addLayer("drop4", new DropoutLayer.Builder(0.5).build(), "conv4-2")
            .addLayer("pool4", new Subsampling3DLayer.Builder(Subsampling3DLayer.PoolingType.MAX).kernelSize(2,2,2)
                    .build(), "drop4")

            .addLayer("conv5-1", new Convolution3D.Builder(3,3,3).stride(1,1,1).nOut(1024)
                    .convolutionMode(ConvolutionMode.Same).cudnnAlgoMode(cudnnAlgoMode)
                    .activation(Activation.RELU).build(), "pool4")
            .addLayer("conv5-2", new Convolution3D.Builder(3,3,3).stride(1,1,1).nOut(1024)
                    .convolutionMode(ConvolutionMode.Same).cudnnAlgoMode(cudnnAlgoMode)
                    .activation(Activation.RELU).build(), "conv5-1")
            .addLayer("drop5", new DropoutLayer.Builder(0.5).build(), "conv5-2")

            // up6
            .addLayer("up6-1", new Upsampling3D.Builder(2).build(), "drop5")
            .addLayer("up6-2", new Convolution3D.Builder(2,2,2).stride(1,1,1).nOut(512)
                    .convolutionMode(ConvolutionMode.Same).cudnnAlgoMode(cudnnAlgoMode)
                    .activation(Activation.RELU).build(), "up6-1")
            .addVertex("merge6", new MergeVertex(), "drop4", "up6-2")
            .addLayer("conv6-1", new Convolution3D.Builder(3,3,3).stride(1,1,1).nOut(512)
                    .convolutionMode(ConvolutionMode.Same).cudnnAlgoMode(cudnnAlgoMode)
                    .activation(Activation.RELU).build(), "merge6")
            .addLayer("conv6-2", new Convolution3D.Builder(3,3,3).stride(1,1,1).nOut(512)
                    .convolutionMode(ConvolutionMode.Same).cudnnAlgoMode(cudnnAlgoMode)
                    .activation(Activation.RELU).build(), "conv6-1")

            // up7
            .addLayer("up7-1", new Upsampling3D.Builder(2).build(), "conv6-2")
            .addLayer("up7-2", new Convolution3D.Builder(2,2,2).stride(1,1,1).nOut(256)
                    .convolutionMode(ConvolutionMode.Same).cudnnAlgoMode(cudnnAlgoMode)
                    .activation(Activation.RELU).build(), "up7-1")
            .addVertex("merge7", new MergeVertex(), "conv3-2", "up7-2")
            .addLayer("conv7-1", new Convolution3D.Builder(3,3,3).stride(1,1,1).nOut(256)
                    .convolutionMode(ConvolutionMode.Same).cudnnAlgoMode(cudnnAlgoMode)
                    .activation(Activation.RELU).build(), "merge7")
            .addLayer("conv7-2", new Convolution3D.Builder(3,3,3).stride(1,1,1).nOut(256)
                    .convolutionMode(ConvolutionMode.Same).cudnnAlgoMode(cudnnAlgoMode)
                    .activation(Activation.RELU).build(), "conv7-1")

            // up8
            .addLayer("up8-1", new Upsampling3D.Builder(2).build(), "conv7-2")
            .addLayer("up8-2", new Convolution3D.Builder(2,2,2).stride(1,1,1).nOut(128)
                    .convolutionMode(ConvolutionMode.Same).cudnnAlgoMode(cudnnAlgoMode)
                    .activation(Activation.RELU).build(), "up8-1")
            .addVertex("merge8", new MergeVertex(), "conv2-2", "up8-2")
            .addLayer("conv8-1", new Convolution3D.Builder(3,3,3).stride(1,1,1).nOut(128)
                    .convolutionMode(ConvolutionMode.Same).cudnnAlgoMode(cudnnAlgoMode)
                    .activation(Activation.RELU).build(), "merge8")
            .addLayer("conv8-2", new Convolution3D.Builder(3,3,3).stride(1,1,1).nOut(128)
                    .convolutionMode(ConvolutionMode.Same).cudnnAlgoMode(cudnnAlgoMode)
                    .activation(Activation.RELU).build(), "conv8-1")

            // up9
            .addLayer("up9-1", new Upsampling3D.Builder(2).build(), "conv8-2")
            .addLayer("up9-2", new Convolution3D.Builder(2,2,2).stride(1,1,1).nOut(64)
                    .convolutionMode(ConvolutionMode.Same).cudnnAlgoMode(cudnnAlgoMode)
                    .activation(Activation.RELU).build(), "up9-1")
            .addVertex("merge9", new MergeVertex(), "conv1-2", "up9-2")
            .addLayer("conv9-1", new Convolution3D.Builder(3,3,3).stride(1,1,1).nOut(64)
                    .convolutionMode(ConvolutionMode.Same).cudnnAlgoMode(cudnnAlgoMode)
                    .activation(Activation.RELU).build(), "merge9")
            .addLayer("conv9-2", new Convolution3D.Builder(3,3,3).stride(1,1,1).nOut(64)
                    .convolutionMode(ConvolutionMode.Same).cudnnAlgoMode(cudnnAlgoMode)
                    .activation(Activation.RELU).build(), "conv9-1")
            .addLayer("conv9-3", new Convolution3D.Builder(3,3,3).stride(1,1,1).nOut(2)
                    .convolutionMode(ConvolutionMode.Same).cudnnAlgoMode(cudnnAlgoMode)
                    .activation(Activation.RELU).build(), "conv9-2")

            .addLayer("conv10", new Convolution3D.Builder(3,3,3).stride(1,1,1).nOut(1)
                    .convolutionMode(ConvolutionMode.Same).cudnnAlgoMode(cudnnAlgoMode)
                    .activation(Activation.IDENTITY).build(), "conv9-3")
            .addLayer("output", new Cnn3DLossLayer.Builder(DataFormat.NCDHW).lossFunction(LossFunctions.LossFunction.XENT)
                    .activation(Activation.SIGMOID).build(), "conv10")

            .setOutputs("output");

    return graph;
}

when i run the code i get the following error from MergeVertex class:

Exception in thread "main" java.lang.IllegalStateException: Unknown input type: InputTypeConvolutional3D(format=NDHWC,d=23,h=32,w=32,c=512)
at org.deeplearning4j.nn.conf.graph.MergeVertex.getOutputType(MergeVertex.java:116)
at org.deeplearning4j.nn.conf.ComputationGraphConfiguration.getLayerActivationTypes(ComputationGraphConfiguration.java:514)
at org.deeplearning4j.nn.conf.ComputationGraphConfiguration.addPreProcessors(ComputationGraphConfiguration.java:427)
at org.deeplearning4j.nn.conf.ComputationGraphConfiguration$GraphBuilder.build(ComputationGraphConfiguration.java:1171)
at ma.enset.brain_tumor_segmentation.SemanticSegmentation3D.run(SemanticSegmentation3D.java:96)
at ma.enset.brain_tumor_segmentation.SemanticSegmentation3D.main(SemanticSegmentation3D.java:289)

when i debug in getOutputType in MergeVertex class i get the following values of variables:

Capture

Version Information

  • Deeplearning4j 1.0.0-beta4
  • windows 7
  • CUDA 9.2

@AbdelmajidB AbdelmajidB changed the title MergeVertex 3D convolution issue MergeVertex with 3D convolution issue May 12, 2019

@AlexDBlack

This comment has been minimized.

Copy link
Contributor

commented May 13, 2019

Thanks for reporting - fixed here: #7724
And yes, it was just an issue of bad order here...
As far as I could see, this should only happen when you use setInputType, so one (albeit ugly) workaround is just set your nIn values on all layers manually, instead of using setInputType.

@AbdelmajidB

This comment has been minimized.

Copy link
Author

commented May 13, 2019

Thank you so much alex, how i can make the new changes in MergeVertex class locally for my project?

@AlexDBlack

This comment has been minimized.

Copy link
Contributor

commented May 13, 2019

You can copy and paste the MergeVertex (configuration) class though and use that instead of the original in your nets.

FYI: The downside of doing that is that if you save and then try to load on a machine without your 'custom' MergeVertex, it'll fail (class not found). For future versions of DL4J though (if your net needs to be usable for that long), you could edit the configuration json in the model zip file to point to the 'real' merge vertex.

AlexDBlack added a commit that referenced this issue May 15, 2019
Various fixes (#7724)
* #7715 Fix MergeVertex for CNN3D activations

* #7680 - AnalyzeSpark fix

* First round of uint TF import fixes

* More dtype fixes

* BFLOAT16 import

* another bunch of small fixes

* one more nano fix

* int16 fix

* int8/uint8 fix
@AlexDBlack

This comment has been minimized.

Copy link
Contributor

commented Jun 3, 2019

@AlexDBlack AlexDBlack closed this Jun 3, 2019

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
2 participants
You can’t perform that action at this time.