diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/BaseNDArray.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/BaseNDArray.java index f84d2237d464..2a2dc338645b 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/BaseNDArray.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/BaseNDArray.java @@ -5030,8 +5030,7 @@ else if (rank() == 2 && jvmShapeInfo.javaShapeInformation[2] == 1 && indexes.len if (indexes.length < 1) throw new IllegalStateException("Invalid index found of zero length"); - // FIXME: LONG - int[] shape = LongUtils.toInts(resolution.getShapes()); + long[] shape = resolution.getShapes(); int numSpecifiedIndex = 0; for (int i = 0; i < indexes.length; i++) if (indexes[i] instanceof SpecifiedIndex) @@ -5039,7 +5038,7 @@ else if (rank() == 2 && jvmShapeInfo.javaShapeInformation[2] == 1 && indexes.len if (shape != null && numSpecifiedIndex > 0) { Generator>> gen = SpecifiedIndex.iterate(indexes); - INDArray ret = Nd4j.create(this.dataType(), ArrayUtil.toLongArray(shape), 'c'); + INDArray ret = Nd4j.create(this.dataType(), shape, 'c'); int count = 0; while (true) { try { diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/indexing/ShapeOffsetResolution.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/indexing/ShapeOffsetResolution.java index 3106b8c2452f..8cc303509577 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/indexing/ShapeOffsetResolution.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/indexing/ShapeOffsetResolution.java @@ -234,7 +234,7 @@ else if (indexes[i] instanceof NDArrayIndexAll) //specific easy case if (numSpecified < 1 && interval < 1 && newAxis < 1 && pointIndex > 0 && numAll > 0) { - int minDimensions = Math.max(arr.rank() - pointIndex, 2); + int minDimensions = arr.rank()-pointIndex; long[] shape = new long[minDimensions]; Arrays.fill(shape, 1); long[] stride = new long[minDimensions]; @@ -277,7 +277,6 @@ else if (indexes[i] instanceof NDArrayIndexAll) this.offsets = offsets; this.offset = offset; return true; - } //intervals and all diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/indexing/IndexingTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/indexing/IndexingTests.java index 8a923c89cf28..7539ad49c6d7 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/indexing/IndexingTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/indexing/IndexingTests.java @@ -214,6 +214,31 @@ public void testGetIndicesVector() { assertEquals(test, result); } + @Test + public void test2dGetPoint(){ + INDArray arr = Nd4j.linspace(1,12,12, DataType.DOUBLE).reshape('c',3,4); + for( int i=0; i<3; i++ ){ + INDArray exp = Nd4j.create(new double[]{i*4+1, i*4+2, i*4+3, i*4+4}); + INDArray row = arr.getRow(i); + INDArray get = arr.get(NDArrayIndex.point(i), NDArrayIndex.all()); + + assertEquals(1, row.rank()); + assertEquals(1, get.rank()); + assertEquals(exp, row); + assertEquals(exp, get); + } + + for( int i=0; i<4; i++ ){ + INDArray exp = Nd4j.create(new double[]{1+i, 5+i, 9+i}); + INDArray col = arr.getColumn(i); + INDArray get = arr.get(NDArrayIndex.all(), NDArrayIndex.point(i)); + + assertEquals(1, col.rank()); + assertEquals(1, get.rank()); + assertEquals(exp, col); + assertEquals(exp, get); + } + } @Override