Skip to content

Commit

Permalink
#7352 MultiLayerNetwork.output(DataSetIterator) validation
Browse files Browse the repository at this point in the history
  • Loading branch information
AlexDBlack committed Mar 28, 2019
1 parent c01cefa commit 38e3413
Showing 1 changed file with 24 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2355,13 +2355,19 @@ public INDArray output(INDArray input) {
/**
* Generate the output for all examples/batches in the input iterator, and concatenate them into a single array.
* See {@link #output(INDArray)}<br>
* NOTE: The output array can require a considerable amount of memory for iterators with a large number of examples
* NOTE 1: The output array can require a considerable amount of memory for iterators with a large number of examples<br>
* NOTE 2: This method cannot be used for variable length time series outputs, as this would require padding arrays
* for some outputs, or returning a mask array (which cannot be done with this method). For variable length time
* series applications, use one of the other output methods. This method also cannot be used with fully convolutional
* networks with different output sizes (for example, segmentation on different input image sizes).
*
*
* @param iterator Data to pass through the network
* @return output for all examples in the iterator, concatenated into a
*/
public INDArray output(DataSetIterator iterator, boolean train) {
List<INDArray> outList = new ArrayList<>();
long[] firstOutputShape = null;
while (iterator.hasNext()) {
DataSet next = iterator.next();
INDArray features = next.getFeatures();
Expand All @@ -2371,7 +2377,23 @@ public INDArray output(DataSetIterator iterator, boolean train) {

INDArray fMask = next.getFeaturesMaskArray();
INDArray lMask = next.getLabelsMaskArray();
outList.add(this.output(features, train, fMask, lMask));
INDArray output = this.output(features, train, fMask, lMask);
outList.add(output);
if(firstOutputShape == null){
firstOutputShape = output.shape();
} else {
//Validate that shapes are the same (may not be, for some RNN variable length time series applications)
long[] currShape = output.shape();
Preconditions.checkState(firstOutputShape.length == currShape.length, "Error during forward pass:" +
"different minibatches have different output array ranks - first minibatch shape %s, last minibatch shape %s", firstOutputShape, currShape);
for( int i=1; i<currShape.length; i++ ){ //Skip checking minibatch dimension, fine if this varies
Preconditions.checkState(firstOutputShape[i] == currShape[i], "Current output shape does not match first" +
" output array shape at position %s: all dimensions must match other than the first dimension.\n" +
" For variable length output size/length use cases such as for RNNs with multiple sequence lengths," +
" use one of the other (non iterator) output methods. First batch output shape: %s, current batch output shape: %s",
i, firstOutputShape, currShape);
}
}
}
return Nd4j.concat(0, outList.toArray(new INDArray[outList.size()]));
}
Expand Down

0 comments on commit 38e3413

Please sign in to comment.