Skip to content

Commit

Permalink
#7292 INDArray.reshape overload - enforce no view
Browse files Browse the repository at this point in the history
  • Loading branch information
AlexDBlack committed Mar 16, 2019
1 parent 21da1c6 commit 8d2bfbb
Show file tree
Hide file tree
Showing 5 changed files with 62 additions and 0 deletions.
Expand Up @@ -4475,6 +4475,11 @@ public INDArray reshape(char order, int... newShape) {

@Override
public INDArray reshape(char order, long... newShape) {
return reshape(order, false, newShape);
}

@Override
public INDArray reshape(char order, boolean enforceView, long... newShape){
Nd4j.getCompressor().autoDecompress(this);

// special case for empty reshape
Expand Down Expand Up @@ -4539,6 +4544,11 @@ public INDArray reshape(char order, long... newShape) {
return reshapeAttempt;
}

if(enforceView){
throw new ND4JIllegalStateException("Unable to reshape array as view, called with enforceView=true. " +
"Use enforceView=false to return a copy instead, or call reshape on a non-strided array. Array shape info: " + this.shapeInfoToString().replaceAll("\n",""));
}


if (order != ordering()) {
INDArray ret = Nd4j.createUninitialized(this.dataType(), shape, order);
Expand Down
Expand Up @@ -803,6 +803,11 @@ public INDArray reshape(char order, int... newShape) {
return null;
}

@Override
public INDArray reshape(char order, boolean enforceView, long... newShape) {
return null;
}

@Override
public INDArray reshape(int[] shape) {
return null;
Expand Down
Expand Up @@ -2031,8 +2031,27 @@ public interface INDArray extends Serializable, AutoCloseable {
*/
INDArray reshape(char order, long... newShape);

/**
* Reshapes the ndarray (can't change the length of the ndarray). Typically this will be a view, unless reshaping
* without copying is impossible.
*
* @param newShape the new shape of the ndarray
* @return the reshaped ndarray
*/
INDArray reshape(char order, int... newShape);

/**
* Reshapes the ndarray (note: it's not possible to change the length of the ndarray).
* Typically this will be a view, unless reshaping without copying (i.e., returning a view) is impossible.<br>
* In that case, the behaviour will depend on the enforceView argument:
* enforceView == true: throw an exception<br>
* enforceView == false: return a copy<br>
*
* @param newShape the new shape of the ndarray
* @return the reshaped ndarray
*/
INDArray reshape(char order, boolean enforceView, long... newShape);


/**
* Reshapes the ndarray (can't change the length of the ndarray). Typically this will be a view, unless reshaping
Expand Down
Expand Up @@ -107,6 +107,11 @@ public INDArray reshape(char order, int... newShape) {
return null;
}

@Override
public INDArray reshape(char order, boolean enforceView, long... newShape) {
return null;
}

@Override
public INDArray reshape(int[] shape) {
return null;
Expand Down
Expand Up @@ -7364,6 +7364,29 @@ public void testRollingMean() {
log.info("Average time: {} ms", (timeEnd - timeStart) / (double) iterations / (double) 1000 / (double) 1000);
}

@Test
public void testZerosRank1() {
Nd4j.zeros(new int[] { 2 }, DataType.DOUBLE);
}

@Test
public void testReshapeEnforce(){

INDArray arr = Nd4j.create(new long[]{2,2}, 'c');
INDArray arr2 = arr.reshape('c', true, 4, 1);

INDArray arr1a = Nd4j.create(new long[]{2,3}, 'c').get(NDArrayIndex.all(), NDArrayIndex.interval(0,2));
INDArray arr3 = arr1a.reshape('c', false, 4,1);
assertFalse(arr3.isView()); //Should be copy

try{
INDArray arr4 = arr1a.reshape('c', true, 4,1);
fail("Expected exception");
} catch (ND4JIllegalStateException e){
assertTrue(e.getMessage(), e.getMessage().contains("Unable to reshape array as view"));
}
}

///////////////////////////////////////////////////////
protected static void fillJvmArray3D(float[][][] arr) {
int cnt = 1;
Expand Down

0 comments on commit 8d2bfbb

Please sign in to comment.