Skip to content

Commit

Permalink
#6985 Vstack rank 1 edge case fix
Browse files Browse the repository at this point in the history
  • Loading branch information
AlexDBlack committed Jan 14, 2019
1 parent e4f8620 commit bbb6cca
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 13 deletions.
Expand Up @@ -43,6 +43,8 @@
import org.nd4j.linalg.lossfunctions.LossFunctions;
import org.nd4j.weightinit.impl.XavierInitScheme;

import java.util.Collections;
import java.util.HashMap;
import java.util.Map;

import static org.junit.Assert.assertEquals;
Expand Down Expand Up @@ -152,9 +154,10 @@ public void testCompareMlpTrainingIris(){
b1.getArr().assign(net.getParam("1_b"));

//Check output (forward pass)
in.setArray(f);
label.setArray(l);
sd.exec();
Map<String,INDArray> placeholders = new HashMap<>();
placeholders.put("in", f);
placeholders.put("label", l);
sd.exec(placeholders, lossMse.getVarName());
INDArray outSd = a1.getArr();
INDArray outDl4j = net.output(f);

Expand Down Expand Up @@ -183,7 +186,7 @@ public void testCompareMlpTrainingIris(){

//Check gradients (before updater applied)
Map<String,INDArray> grads = net.gradient().gradientForVariable();
sd.execBackwards();
sd.execBackwards(placeholders);

//Note that the SameDiff gradients don't include the L1/L2 terms at present just from execBackwards()... these are added later
//We can check correctness though with training param checks later
Expand Down
Expand Up @@ -5411,29 +5411,32 @@ public static INDArray hstack(Collection<INDArray> arrs) {
}

/**
* Concatenates two matrices vertically. Matrices must have identical
* numbers of columns.
* Concatenates two matrices vertically. Matrices must have identical numbers of columns.<br>
* Note that for vstack on rank 1 arrays, this is equivalent to {@link Nd4j#pile(INDArray...)}. Example: vstack([3],[3]) -> [2,3]
*
* @param arrs
* @param arrs Arrays to vstack
*/
public static INDArray vstack(INDArray... arrs) {
Preconditions.checkState(arrs != null && arrs.length > 0, "No input specified to vstack (null or length 0)");
if(arrs[0].rank() == 1){
//Edge case: vstack rank 1 arrays - gives rank 2... vstack([3],[3]) -> [2,3]
return pile(arrs);
}
INDArray ret = INSTANCE.vstack(arrs);
logCreationIfNecessary(ret);
return ret;
}


/**
* Concatenates two matrices vertically. Matrices must have identical
* numbers of columns.
* Concatenates two matrices vertically. Matrices must have identical numbers of columns.<br>
* Note that for vstack on rank 1 arrays, this is equivalent to {@link Nd4j#pile(INDArray...)}. Example: vstack([3],[3]) -> [2,3]
*
* @param arrs
* @param arrs Arrays to vstack
*/
public static INDArray vstack(Collection<INDArray> arrs) {
INDArray[] arrays = arrs.toArray(new INDArray[0]);
INDArray ret = INSTANCE.vstack(arrays);
logCreationIfNecessary(ret);
return ret;
return vstack(arrays);
}

/**
Expand Down
Expand Up @@ -7192,6 +7192,21 @@ public void testEmptyCasting(){
}
}

@Test
public void testVStackRank1(){
List<INDArray> list = new ArrayList<>();
list.add(Nd4j.linspace(1,3,3, DataType.DOUBLE));
list.add(Nd4j.linspace(4,6,3, DataType.DOUBLE));
list.add(Nd4j.linspace(7,9,3, DataType.DOUBLE));

INDArray out = Nd4j.vstack(list);
INDArray exp = Nd4j.createFromArray(new double[][]{
{1,2,3},
{4,5,6},
{7,8,9}});
assertEquals(exp, out);
}

///////////////////////////////////////////////////////
protected static void fillJvmArray3D(float[][][] arr) {
int cnt = 1;
Expand Down

0 comments on commit bbb6cca

Please sign in to comment.