Permalink
Browse files

tests should be reproducible

  • Loading branch information...
1 parent e1d70ac commit 71972511e85be8abc2c60455932f31c21f19385a @agibsonccc agibsonccc committed Feb 11, 2016
@@ -37,7 +37,7 @@
*
* @author Adam Gibson
*/
-public class DefaultOpExecutioner implements OpExecutioner {
+public class DefaultOpExecutioner implements OpExecutioner {
protected ExecutionMode executionMode = ExecutionMode.JAVA;
@@ -358,7 +358,7 @@ public INDArray execAndReturn(TransformOp op, int... dimension) {
@Override
public INDArray execAndReturn(ScalarOp op, int... dimension) {
- return exec(op, dimension).z();
+ return exec(op, dimension);
}
@Override
@@ -565,4 +565,9 @@ protected void doBroadcastOp(BroadcastOp op) {
}
}
}
+
+ @Override
+ public INDArray exec(BroadcastOp broadcast, int... dimension) {
+ throw new UnsupportedOperationException();
+ }
}
@@ -109,10 +109,16 @@
* Execute an accumulation along one or more dimensions
* @param accumulation the accumulation
* @param dimension the dimension
- * @return the accmulation op
+ * @return the accumulation op
*/
INDArray exec(Accumulation accumulation, int...dimension);
-
+ /**
+ * Execute an broadcast along one or more dimensions
+ * @param broadcast the accumulation
+ * @param dimension the dimension
+ * @return the broadcast op
+ */
+ INDArray exec(BroadcastOp broadcast, int...dimension);
/**
* Execute an accumulation along one or more dimensions
@@ -7,6 +7,7 @@
import org.nd4j.linalg.api.ops.*;
import org.nd4j.linalg.api.ops.executioner.DefaultOpExecutioner;
import org.nd4j.linalg.api.ops.impl.accum.Variance;
+import org.nd4j.linalg.api.ops.impl.transforms.Floor;
import org.nd4j.linalg.api.shape.Shape;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.util.ArrayUtil;
@@ -123,62 +124,126 @@ public INDArray exec(Accumulation op, int... dimension) {
java.nio.IntBuffer dimensionBuffer = Shape.toBuffer(dimension);
if(op.x().data().dataType() == DataBuffer.Type.DOUBLE) {
if(op instanceof Variance) {
- loop.execSummaryStats(op.opNum(),
- op.x().data().asNioDouble(),
- op.x().shapeInfo(), (DoubleBuffer) op.extraArgsBuff(),
- op.z().data().asNioDouble(),
- op.z().shapeInfo(),
- dimensionBuffer, dimension.length);
+ if(ret.isScalar()) {
+ ret.putScalar(0,loop.execSummaryStatsScalar(
+ op.opNum()
+ ,op.x().data().asNioDouble(),
+ op.x().shapeInfo(),
+ (DoubleBuffer) op.extraArgsBuff()));
+ }
+ else {
+ loop.execSummaryStats(op.opNum(),
+ op.x().data().asNioDouble(),
+ op.x().shapeInfo(), (DoubleBuffer) op.extraArgsBuff(),
+ op.z().data().asNioDouble(),
+ op.z().shapeInfo(),
+ dimensionBuffer, dimension.length);
+ }
+
}
else if(op.y() != null) {
- loop.execReduce3(op.opNum(),
- op.x().data().asNioDouble(),
- op.x().shapeInfo(), (DoubleBuffer) op.extraArgsBuff(),
- op.y().data().asNioDouble(),
- op.y().shapeInfo(),
- op.z().data().asNioDouble(),
- op.z().shapeInfo(),
- dimensionBuffer,dimension.length);
+ if(ret.isScalar()) {
+ ret.putScalar(0,loop.execReduce3Scalar(op.opNum(),
+ op.x().data().asNioDouble(),
+ op.x().shapeInfo(),
+ (DoubleBuffer)op.extraArgsBuff(),
+ op.y().data().asNioDouble(),
+ op.y().shapeInfo()));
+ }
+ else {
+ loop.execReduce3(op.opNum(),
+ op.x().data().asNioDouble(),
+ op.x().shapeInfo(), (DoubleBuffer) op.extraArgsBuff(),
+ op.y().data().asNioDouble(),
+ op.y().shapeInfo(),
+ op.z().data().asNioDouble(),
+ op.z().shapeInfo(),
+ dimensionBuffer,dimension.length);
+ }
+
}
else {
- loop.execReduce(op.opNum(),
- op.x().data().asNioDouble(),
- op.x().shapeInfo(), (DoubleBuffer) op.extraArgsBuff(),
- op.z().data().asNioDouble(),
- op.z().shapeInfo(),
- dimensionBuffer, dimension.length);
+ if(ret.isScalar()) {
+ ret.putScalar(0,loop.execReduceScalar(
+ op.opNum(),
+ op.x().data().asNioDouble(),
+ op.x().shapeInfo(),
+ (DoubleBuffer) op.extraArgsBuff()));
+ }
+ else {
+ loop.execReduce(op.opNum(),
+ op.x().data().asNioDouble(),
+ op.x().shapeInfo(), (DoubleBuffer) op.extraArgsBuff(),
+ op.z().data().asNioDouble(),
+ op.z().shapeInfo(),
+ dimensionBuffer, dimension.length);
+ }
+
}
}
else {
if(op instanceof Variance) {
- loop.execSummaryStats(op.opNum(),
- op.x().data().asNioFloat(),
- op.x().shapeInfo(), (FloatBuffer) op.extraArgsBuff(),
- op.z().data().asNioFloat(),
- op.z().shapeInfo(),
- dimensionBuffer, dimension.length);
+ if(ret.isScalar()) {
+ ret.putScalar(0,loop.execSummaryStatsScalar(
+ op.opNum()
+ ,op.x().data().asNioFloat(),
+ op.x().shapeInfo(),
+ (FloatBuffer) op.extraArgsBuff()));
+ }
+ else {
+ loop.execSummaryStats(op.opNum(),
+ op.x().data().asNioFloat(),
+ op.x().shapeInfo(), (FloatBuffer) op.extraArgsBuff(),
+ op.z().data().asNioFloat(),
+ op.z().shapeInfo(),
+ dimensionBuffer, dimension.length);
+ }
+
}
+
else if(op.y() != null) {
- loop.execReduce3(op.opNum(),
- op.x().data().asNioFloat(),
- op.x().shapeInfo(), (FloatBuffer) op.extraArgsBuff(),
- op.y().data().asNioFloat(),
- op.y().shapeInfo(),
- op.z().data().asNioFloat(),
- op.z().shapeInfo(),
- dimensionBuffer,dimension.length);
+ if(ret.isScalar()) {
+ ret.putScalar(0,loop.execReduce3Scalar(op.opNum(),
+ op.x().data().asNioFloat(),
+ op.x().shapeInfo(),
+ (FloatBuffer)op.extraArgsBuff(),
+ op.y().data().asNioFloat(),
+ op.y().shapeInfo()));
+ }
+ else {
+ loop.execReduce3(op.opNum(),
+ op.x().data().asNioFloat(),
+ op.x().shapeInfo(), (FloatBuffer) op.extraArgsBuff(),
+ op.y().data().asNioFloat(),
+ op.y().shapeInfo(),
+ op.z().data().asNioFloat(),
+ op.z().shapeInfo(),
+ dimensionBuffer,dimension.length);
+ }
+
}
else {
- loop.execReduce(op.opNum(),
- op.x().data().asNioFloat(),
- op.x().shapeInfo(), (FloatBuffer) op.extraArgsBuff(),
- op.z().data().asNioFloat(),
- op.z().shapeInfo(),
- dimensionBuffer, dimension.length);
+ if(ret.isScalar()) {
+ ret.putScalar(0,loop.execReduceScalar(
+ op.opNum(),
+ op.x().data().asNioFloat(),
+ op.x().shapeInfo(),
+ (FloatBuffer) op.extraArgsBuff()));
+ }
+ else {
+ loop.execReduce(op.opNum(),
+ op.x().data().asNioFloat(),
+ op.x().shapeInfo(), (FloatBuffer) op.extraArgsBuff(),
+ op.z().data().asNioFloat(),
+ op.z().shapeInfo(),
+ dimensionBuffer, dimension.length);
+ }
+
}
}
- return op.z();
+
+ return ret;
}
private void exec(ScalarOp op) {
@@ -220,50 +285,107 @@ private void exec(TransformOp op) {
if(op.x().data().dataType() == DataBuffer.Type.DOUBLE) {
if(op.y() != null) {
- loop.execPairwiseTransform(
- op.opNum(),
- op.x().data().asNioDouble(),
- op.x().shapeInfo(),op.y().data().asNioDouble(),
- op.y().shapeInfo(),op.z().data().asNioDouble(),
- op.z().shapeInfo(),
- (DoubleBuffer) op.extraArgsBuff(),
- op.n());
+ if(op.x().elementWiseStride() >=1 && op.y().elementWiseStride() >= 1) {
+ loop.execPairwiseTransform
+ (op.opNum(),
+ op.x().data().asNioDouble(),
+ op.x().elementWiseStride(),
+ op.y().data().asNioDouble(),
+ op.y().elementWiseStride(),
+ op.z().data().asNioDouble(),
+ op.z().elementWiseStride(),
+ (DoubleBuffer) op.extraArgsBuff(),
+ op.n());
+
+ }
+ else {
+ loop.execPairwiseTransform
+ (op.opNum(),
+ op.x().data().asNioDouble(),
+ op.x().elementWiseStride(),
+ op.y().data().asNioDouble(),
+ op.y().elementWiseStride(),
+ op.z().data().asNioDouble(),
+ op.z().elementWiseStride(),
+ (DoubleBuffer) op.extraArgsBuff(),
+ op.n());
+ }
+
}
else {
- loop.execTransform(op.opNum(),
- op.x().data().asNioDouble(),
- op.x().shapeInfo(),
- op.z().data().asNioDouble(),
- op.z().shapeInfo(),
- (DoubleBuffer) op.extraArgsBuff(), op.n());
+ if(op.x().elementWiseStride() >= 1) {
+ loop.execTransform(op.opNum(),
+ op.x().data().asNioDouble(),
+ op.x().elementWiseStride(),
+ op.z().data().asNioDouble(),
+ op.z().elementWiseStride(),
+ (DoubleBuffer) op.extraArgsBuff(), op.n());
+ }
+ else {
+ loop.execTransform(op.opNum(),
+ op.x().data().asNioDouble(),
+ op.x().shapeInfo(),
+ op.z().data().asNioDouble(),
+ op.z().shapeInfo(),
+ (DoubleBuffer) op.extraArgsBuff(), op.n());
+ }
+
}
}
else {
if(op.y() != null) {
- int yEleStride = op.y().elementWiseStride();
- loop.execPairwiseTransform(op.opNum(),
- op.x().data().asNioFloat(),
- op.x().shapeInfo(),
- op.y().data().asNioFloat(),
- op.y().shapeInfo(),
- op.z().data().asNioFloat(),
- op.z().shapeInfo(),
- (FloatBuffer) op.extraArgsBuff(),op.n());
+ if(op.x().elementWiseStride() >=1 && op.y().elementWiseStride() >= 1) {
+ loop.execPairwiseTransform
+ (op.opNum(),
+ op.x().data().asNioFloat(),
+ op.x().elementWiseStride(),
+ op.y().data().asNioFloat(),
+ op.y().elementWiseStride(),
+ op.z().data().asNioFloat(),
+ op.z().elementWiseStride(),
+ (FloatBuffer) op.extraArgsBuff(),
+ op.n());
+
+ }
+ else {
+ loop.execPairwiseTransform
+ (op.opNum(),
+ op.x().data().asNioFloat(),
+ op.x().elementWiseStride(),
+ op.y().data().asNioFloat(),
+ op.y().elementWiseStride(),
+ op.z().data().asNioFloat(),
+ op.z().elementWiseStride(),
+ (FloatBuffer) op.extraArgsBuff(),
+ op.n());
+ }
}
else {
- loop.execTransform(op.opNum(),
- op.x().data().asNioFloat(),
- op.x().shapeInfo(),
- op.z().data().asNioFloat(),
- op.z().shapeInfo(),
- (FloatBuffer) op.extraArgsBuff(),op.n());
+ if(op.x().elementWiseStride() >= 1) {
+ loop.execTransform(op.opNum(),
+ op.x().data().asNioFloat(),
+ op.x().elementWiseStride(),
+ op.z().data().asNioFloat(),
+ op.z().elementWiseStride(),
+ (FloatBuffer) op.extraArgsBuff(), op.n());
+ }
+ else {
+ loop.execTransform(op.opNum(),
+ op.x().data().asNioFloat(),
+ op.x().shapeInfo(),
+ op.z().data().asNioFloat(),
+ op.z().shapeInfo(),
+ (FloatBuffer) op.extraArgsBuff(), op.n());
+ }
+
}
}
}
}
- private void exec(BroadcastOp op,int...dimension) {
+ @Override
+ public INDArray exec(BroadcastOp op,int...dimension) {
java.nio.IntBuffer dimensionBuffer = Shape.toBuffer(dimension);
if(op.x().data().dataType() == DataBuffer.Type.DOUBLE) {
loop.execBroadcast(op.opNum(),
@@ -281,6 +403,8 @@ private void exec(BroadcastOp op,int...dimension) {
,op.z().data().asNioFloat(),op.z().shapeInfo(),
dimensionBuffer,dimension.length);
}
+
+ return op.z();
}
private void exec(IndexAccumulation op) {
@@ -326,7 +326,8 @@ public native double execSummaryStatsScalar(int opNum,DoubleBuffer x,
* @param result
* @param resultShapeInfo
*/
- public native void execSummaryStats(int opNum,DoubleBuffer x,
+ public native void execSummaryStats(int opNum,
+ DoubleBuffer x,
IntBuffer xShapeInfo,
DoubleBuffer extraParams,
DoubleBuffer result,
Oops, something went wrong.

0 comments on commit 7197251

Please sign in to comment.