Skip to content

Commit

Permalink
#7305 Fix getColumn on row vector (returning scalar, not view)
Browse files Browse the repository at this point in the history
  • Loading branch information
AlexDBlack committed Mar 28, 2019
1 parent e7a7327 commit 1b3d149
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 3 deletions.
Expand Up @@ -4901,9 +4901,6 @@ public INDArray getColumn(long c) {
return this;
else if (isColumnVector() && c > 0)
throw new IllegalArgumentException("Illegal index for row");
else if(isRowVector()) {
return Nd4j.scalar(getDouble(c));
}
return get(NDArrayIndex.all(), NDArrayIndex.point(c));
}

Expand Down
Expand Up @@ -7436,6 +7436,15 @@ public void testMeshgridDtypes() {
Nd4j.meshgrid(Nd4j.createFromArray(1, 2, 3), Nd4j.createFromArray(4, 5, 6));
}

@Test
public void testGetColumnRowVector(){
INDArray arr = Nd4j.create(1,4);
INDArray col = arr.getColumn(0);
System.out.println(Arrays.toString(col.shape()));
assertArrayEquals(new long[]{1,1}, col.shape());
}


///////////////////////////////////////////////////////
protected static void fillJvmArray3D(float[][][] arr) {
int cnt = 1;
Expand Down

0 comments on commit 1b3d149

Please sign in to comment.