Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TransferLearning.GraphBuilder#nOutReplace does not work on layer before SubsamplingLayer #6389

Closed
DrChainsaw opened this issue Sep 6, 2018 · 1 comment

Comments

@DrChainsaw
Copy link

commented Sep 6, 2018

SubsamplingLayer escapes the !vertex.hasLayer exception in transferlearning as it is a layer (but apparently not a Feedforward layer).

Trying to do something in line with what the recommendation of that assert says (use removeVertex followed by addVertex but do addLayer instead) does not seem to work either as the expected number of inputs of the next layer is not updated. Haven't verified if this is the case for non-layer vertexes though...

I also could not find a way to deliver the the expectation from removeVertexKeepConnections: "Note the expectation here is to then add back another vertex with the same name or else the graph will be left in an invalid state": There is no method to just add back a vertex/layer without any inputs and at least addLayer throws an exception if inputs is not set.

Testcase which reproduces the issue:

    @Test(expected = UnsupportedOperationException.class)
    public void changeNout() {
        final String inputName = "input";
        final String changeNoutName = "changeNout";
        final String poolName = "pool";
        final String afterPoolName = "afterPool";
        final String outputName = "output";
        final INDArray input = Nd4j.randn(new long[] {1, 6, 6, 6});
        final ComputationGraph graph = new ComputationGraph(new NeuralNetConfiguration.Builder()
                .weightInit(new ConstantDistribution(666))
                .graphBuilder()
                .addInputs(inputName)
                .setOutputs(outputName)
                .setInputTypes(InputType.inferInputTypes(input))
                .addLayer(changeNoutName, new Convolution2D.Builder(1, 1)
                        .nOut(10)
                        .build(), inputName)
                .addLayer(poolName, new SubsamplingLayer.Builder(1,1).build(), changeNoutName)
                .addLayer(afterPoolName, new Convolution2D.Builder(1, 1)
                        .nOut(7)
                        .build(), poolName)
                .addLayer(outputName, new OutputLayer.Builder()
                        .nOut(2)
                        .build(), afterPoolName)
                .build());
        graph.init();

        // Crash!!
        final ComputationGraph newGraph = doNoutReplace(graph, changeNoutName);

        // Does not crash here, but fails because inputs of afterPoolName are not changed
        //final ComputationGraph newGraph = doLayerReplace(graph, changeNoutName);
        newGraph.init();

        assertEquals("Incorrect number of outputs!", 5 , newGraph.layerSize(changeNoutName));
        assertEquals("Incorrect number of inputs!", 5, newGraph.layerInputSize(afterPoolName));

        // Crash!!
        newGraph.output(input);
    }

    private ComputationGraph doNoutReplace(ComputationGraph graph, String layerName) {
        return new TransferLearning.GraphBuilder(graph).nOutReplace(layerName, 5, WeightInit.ZERO).build();
    }

    private ComputationGraph doLayerReplace(ComputationGraph graph, String layerName) {
                final FeedForwardLayer newLayerConf = (FeedForwardLayer) graph.getLayer(layerName).conf().getLayer().clone();
        final List<String> inputs = graph.getConfiguration().getVertexInputs().get(layerName);
        newLayerConf.setNOut(5);
        return new TransferLearning.GraphBuilder(graph).removeVertexKeepConnections(layerName)
                .addLayer(layerName, newLayerConf, inputs.toArray(new String[0])).build();
    }

@AlexDBlack AlexDBlack self-assigned this Sep 7, 2018

AlexDBlack added a commit that referenced this issue Sep 17, 2018
AlexDBlack added a commit that referenced this issue Sep 18, 2018
Various fixes (#6450)
* #6401 transposei fix

* #6378 MmulTranspose fix + cleanup

* #6442 validate array order

* Cleanup and test fixes

* #6403 batch norm validation

* #6389 Fix TransferLearning nOutReplace issues; add nInReplace method

* Trigger CI
@lock

This comment has been minimized.

Copy link

commented Oct 18, 2018

This thread has been automatically locked since there has not been any recent activity after it was closed. Please open a new issue for related bugs.

@lock lock bot locked and limited conversation to collaborators Oct 18, 2018

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Projects
None yet
2 participants
You can’t perform that action at this time.