From 00bdf3c252871eaa1b946bc5de667dd447f1ea56 Mon Sep 17 00:00:00 2001 From: Alex Black Date: Tue, 12 Jun 2018 16:30:22 +1000 Subject: [PATCH] #5330 Allow rank 1 arrays in column vector ops --- .../nd4j/linalg/api/ndarray/BaseNDArray.java | 3 +- .../test/java/org/nd4j/linalg/Nd4jTestsC.java | 64 +++++++++++++++++++ 2 files changed, 66 insertions(+), 1 deletion(-) diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/BaseNDArray.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/BaseNDArray.java index 24f0cbb75df1..a0c3031ec2e3 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/BaseNDArray.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/BaseNDArray.java @@ -2801,7 +2801,8 @@ else if(isScalar()) { } //Input validation: require (a) columnVector to actually be a column vector, and (b) this.size(0) to match columnVector.size(0) - if (!columnVector.isColumnVector() || this.size(0) != columnVector.size(0) || columnVector.length() <= 1) { + //Or, simply require it to be a rank 1 vector + if ((!columnVector.isColumnVector() && columnVector.rank() > 1) || this.size(0) != columnVector.size(0) || columnVector.length() <= 1) { throw new IllegalStateException("Mismatched shapes (shape = " + Arrays.toString(shape()) + ", column vector shape =" + Arrays.toString(columnVector.shape()) + ")"); } diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/Nd4jTestsC.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/Nd4jTestsC.java index eb4b9cf0ac61..eec1e66c95af 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/Nd4jTestsC.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/Nd4jTestsC.java @@ -6550,6 +6550,70 @@ public void testTranspose_Custom(){ assertEquals(exp, out); } + @Test + public void testRowColumnOpsRank1(){ + + for( int i=0; i<6; i++ ) { + INDArray orig = Nd4j.linspace(1, 12, 12).reshape('c', 3, 4); + INDArray in1r = orig.dup(); + INDArray in2r = orig.dup(); + INDArray in1c = orig.dup(); + INDArray in2c = orig.dup(); + + INDArray rv1 = Nd4j.create(new double[]{1, 2, 3, 4}, new long[]{1, 4}); + INDArray rv2 = Nd4j.create(new double[]{1, 2, 3, 4}, new long[]{4}); + INDArray cv1 = Nd4j.create(new double[]{1, 2, 3}, new long[]{3, 1}); + INDArray cv2 = Nd4j.create(new double[]{1, 2, 3}, new long[]{3}); + + switch (i){ + case 0: + in1r.addiRowVector(rv1); + in2r.addiRowVector(rv2); + in1c.addiColumnVector(cv1); + in2c.addiColumnVector(cv2); + break; + case 1: + in1r.subiRowVector(rv1); + in2r.subiRowVector(rv2); + in1c.subiColumnVector(cv1); + in2c.subiColumnVector(cv2); + break; + case 2: + in1r.muliRowVector(rv1); + in2r.muliRowVector(rv2); + in1c.muliColumnVector(cv1); + in2c.muliColumnVector(cv2); + break; + case 3: + in1r.diviRowVector(rv1); + in2r.diviRowVector(rv2); + in1c.diviColumnVector(cv1); + in2c.diviColumnVector(cv2); + break; + case 4: + in1r.rsubiRowVector(rv1); + in2r.rsubiRowVector(rv2); + in1c.rsubiColumnVector(cv1); + in2c.rsubiColumnVector(cv2); + break; + case 5: + in1r.rdiviRowVector(rv1); + in2r.rdiviRowVector(rv2); + in1c.rdiviColumnVector(cv1); + in2c.rdiviColumnVector(cv2); + break; + default: + throw new RuntimeException(); + } + + + assertEquals(in1r, in2r); + assertEquals(in1c, in2c); + + } + + } + /////////////////////////////////////////////////////// protected static void fillJvmArray3D(float[][][] arr) {