Skip to content

Commit

Permalink
#7092 INDArray.get(point,x)/get(x,point) returns 1d array
Browse files Browse the repository at this point in the history
  • Loading branch information
AlexDBlack committed Apr 17, 2019
1 parent 69f079c commit 6a88228
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 5 deletions.
Expand Up @@ -5030,16 +5030,15 @@ else if (rank() == 2 && jvmShapeInfo.javaShapeInformation[2] == 1 && indexes.len
if (indexes.length < 1)
throw new IllegalStateException("Invalid index found of zero length");

// FIXME: LONG
int[] shape = LongUtils.toInts(resolution.getShapes());
long[] shape = resolution.getShapes();
int numSpecifiedIndex = 0;
for (int i = 0; i < indexes.length; i++)
if (indexes[i] instanceof SpecifiedIndex)
numSpecifiedIndex++;

if (shape != null && numSpecifiedIndex > 0) {
Generator<List<List<Long>>> gen = SpecifiedIndex.iterate(indexes);
INDArray ret = Nd4j.create(this.dataType(), ArrayUtil.toLongArray(shape), 'c');
INDArray ret = Nd4j.create(this.dataType(), shape, 'c');
int count = 0;
while (true) {
try {
Expand Down
Expand Up @@ -234,7 +234,7 @@ else if (indexes[i] instanceof NDArrayIndexAll)

//specific easy case
if (numSpecified < 1 && interval < 1 && newAxis < 1 && pointIndex > 0 && numAll > 0) {
int minDimensions = Math.max(arr.rank() - pointIndex, 2);
int minDimensions = arr.rank()-pointIndex;
long[] shape = new long[minDimensions];
Arrays.fill(shape, 1);
long[] stride = new long[minDimensions];
Expand Down Expand Up @@ -277,7 +277,6 @@ else if (indexes[i] instanceof NDArrayIndexAll)
this.offsets = offsets;
this.offset = offset;
return true;

}

//intervals and all
Expand Down
Expand Up @@ -214,6 +214,31 @@ public void testGetIndicesVector() {
assertEquals(test, result);
}

@Test
public void test2dGetPoint(){
INDArray arr = Nd4j.linspace(1,12,12, DataType.DOUBLE).reshape('c',3,4);
for( int i=0; i<3; i++ ){
INDArray exp = Nd4j.create(new double[]{i*4+1, i*4+2, i*4+3, i*4+4});
INDArray row = arr.getRow(i);
INDArray get = arr.get(NDArrayIndex.point(i), NDArrayIndex.all());

assertEquals(1, row.rank());
assertEquals(1, get.rank());
assertEquals(exp, row);
assertEquals(exp, get);
}

for( int i=0; i<4; i++ ){
INDArray exp = Nd4j.create(new double[]{1+i, 5+i, 9+i});
INDArray col = arr.getColumn(i);
INDArray get = arr.get(NDArrayIndex.all(), NDArrayIndex.point(i));

assertEquals(1, col.rank());
assertEquals(1, get.rank());
assertEquals(exp, col);
assertEquals(exp, get);
}
}


@Override
Expand Down

0 comments on commit 6a88228

Please sign in to comment.