Skip to content

Commit

Permalink
#5330 Allow rank 1 arrays in column vector ops
Browse files Browse the repository at this point in the history
  • Loading branch information
AlexDBlack committed Jun 12, 2018
1 parent cc5b836 commit 00bdf3c
Show file tree
Hide file tree
Showing 2 changed files with 66 additions and 1 deletion.
Expand Up @@ -2801,7 +2801,8 @@ else if(isScalar()) {
}

//Input validation: require (a) columnVector to actually be a column vector, and (b) this.size(0) to match columnVector.size(0)
if (!columnVector.isColumnVector() || this.size(0) != columnVector.size(0) || columnVector.length() <= 1) {
//Or, simply require it to be a rank 1 vector
if ((!columnVector.isColumnVector() && columnVector.rank() > 1) || this.size(0) != columnVector.size(0) || columnVector.length() <= 1) {
throw new IllegalStateException("Mismatched shapes (shape = " + Arrays.toString(shape())
+ ", column vector shape =" + Arrays.toString(columnVector.shape()) + ")");
}
Expand Down
Expand Up @@ -6550,6 +6550,70 @@ public void testTranspose_Custom(){
assertEquals(exp, out);
}

@Test
public void testRowColumnOpsRank1(){

for( int i=0; i<6; i++ ) {
INDArray orig = Nd4j.linspace(1, 12, 12).reshape('c', 3, 4);
INDArray in1r = orig.dup();
INDArray in2r = orig.dup();
INDArray in1c = orig.dup();
INDArray in2c = orig.dup();

INDArray rv1 = Nd4j.create(new double[]{1, 2, 3, 4}, new long[]{1, 4});
INDArray rv2 = Nd4j.create(new double[]{1, 2, 3, 4}, new long[]{4});
INDArray cv1 = Nd4j.create(new double[]{1, 2, 3}, new long[]{3, 1});
INDArray cv2 = Nd4j.create(new double[]{1, 2, 3}, new long[]{3});

switch (i){
case 0:
in1r.addiRowVector(rv1);
in2r.addiRowVector(rv2);
in1c.addiColumnVector(cv1);
in2c.addiColumnVector(cv2);
break;
case 1:
in1r.subiRowVector(rv1);
in2r.subiRowVector(rv2);
in1c.subiColumnVector(cv1);
in2c.subiColumnVector(cv2);
break;
case 2:
in1r.muliRowVector(rv1);
in2r.muliRowVector(rv2);
in1c.muliColumnVector(cv1);
in2c.muliColumnVector(cv2);
break;
case 3:
in1r.diviRowVector(rv1);
in2r.diviRowVector(rv2);
in1c.diviColumnVector(cv1);
in2c.diviColumnVector(cv2);
break;
case 4:
in1r.rsubiRowVector(rv1);
in2r.rsubiRowVector(rv2);
in1c.rsubiColumnVector(cv1);
in2c.rsubiColumnVector(cv2);
break;
case 5:
in1r.rdiviRowVector(rv1);
in2r.rdiviRowVector(rv2);
in1c.rdiviColumnVector(cv1);
in2c.rdiviColumnVector(cv2);
break;
default:
throw new RuntimeException();
}


assertEquals(in1r, in2r);
assertEquals(in1c, in2c);

}

}


///////////////////////////////////////////////////////
protected static void fillJvmArray3D(float[][][] arr) {
Expand Down

0 comments on commit 00bdf3c

Please sign in to comment.