Skip to content

Commit

Permalink
[WIP] No more axis (#6902)
Browse files Browse the repository at this point in the history
Axis overhaul
  • Loading branch information
raver119 authored and sshepel committed Jan 8, 2019
1 parent fa169cc commit dd52fea
Show file tree
Hide file tree
Showing 254 changed files with 1,613 additions and 23,627 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ public INDArray applyDropout(INDArray inputActivations, INDArray output, int ite
noise = workspaceMgr.createUninitialized(ArrayType.INPUT, inputActivations.shape(), inputActivations.ordering());
Nd4j.getExecutioner().exec(new GaussianDistribution(noise, 1.0, stdev));

return Nd4j.getExecutioner().execAndReturn(new OldMulOp(inputActivations, noise, output));
return Nd4j.getExecutioner().exec(new OldMulOp(inputActivations, noise, output));
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -96,8 +96,7 @@ public INDArray doForward(boolean training, LayerWorkspaceMgr workspaceMgr) {
case Subtract:
if (inputs.length != 2)
throw new IllegalArgumentException("ElementWise subtraction only supports 2 inputs");
return Nd4j.getExecutioner().execAndReturn(
new OldSubOp(inputs[0], inputs[1], workspaceMgr.createUninitialized(ArrayType.ACTIVATIONS, inputs[0].shape())));
return Nd4j.getExecutioner().exec(new OldSubOp(inputs[0], inputs[1], workspaceMgr.createUninitialized(ArrayType.ACTIVATIONS, inputs[0].shape())));
case Product:
INDArray product = workspaceMgr.dup(ArrayType.ACTIVATIONS, inputs[0]);
for (int i = 1; i < inputs.length; i++) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ public INDArray doForward(boolean training, LayerWorkspaceMgr workspaceMgr) {
return x.divColumnVector(xNorm2);
} else {
INDArray out = Nd4j.createUninitialized(x.shape(), x.ordering());
return Nd4j.getExecutioner().execAndReturn(new BroadcastDivOp(x, xNorm2, out, 0));
return Nd4j.getExecutioner().exec(new BroadcastDivOp(x, xNorm2, out, 0));
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ public INDArray doForward(boolean training, LayerWorkspaceMgr workspaceMgr) {
}

try(MemoryWorkspace ws = workspaceMgr.notifyScopeBorrowed(ArrayType.ACTIVATIONS)) {
return Nd4j.getExecutioner().exec(new EuclideanDistance(a, b), dimensions);
return Nd4j.getExecutioner().exec(new EuclideanDistance(a, b, dimensions));
}
}

Expand Down Expand Up @@ -112,7 +112,7 @@ public Pair<Gradient, INDArray[]> doBackward(boolean tbptt, LayerWorkspaceMgr wo
}
} else {
//RNN and CNN case - Broadcast along dimension 0
dLda = Nd4j.getExecutioner().execAndReturn(new BroadcastMulOp(diff, first, diff, 0));
dLda = Nd4j.getExecutioner().exec(new BroadcastMulOp(diff, first, diff, 0));
try(MemoryWorkspace ws = workspaceMgr.notifyScopeBorrowed(ArrayType.ACTIVATION_GRAD)) {
dLdb = dLda.neg();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,7 @@ public Pair<Gradient, INDArray> backpropGradient(INDArray epsilon, LayerWorkspac
.build();
Nd4j.getExecutioner().exec(op);

INDArray isMax = Nd4j.getExecutioner().execAndReturn(new IsMax(col2d, col2d, 1));
INDArray isMax = Nd4j.getExecutioner().exec(new IsMax(col2d, col2d, 1));
isMax.muliColumnVector(epsilon1d);
break;
case AVG:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -279,7 +279,7 @@ These make zero difference for local training (other than perhaps when using FP1
dxhat = epsilon.mul(layerConf.getGamma());
} else {
//Standard case
dxhat = Nd4j.getExecutioner().execAndReturn(new BroadcastMulOp(epsilon, gamma,
dxhat = Nd4j.getExecutioner().exec(new BroadcastMulOp(epsilon, gamma,
Nd4j.createUninitialized(epsilon.shape(), epsilon.ordering()), 1));
}

Expand All @@ -292,9 +292,8 @@ These make zero difference for local training (other than perhaps when using FP1
INDArray dxmu2 = xMu.sum(0, 2, 3).muli(-2.0 / effectiveBatchSize).muli(dLdVar);
INDArray dLdmu = dxmu1.addi(dxmu2);

INDArray dLdx = Nd4j.getExecutioner().execAndReturn(new BroadcastDivOp(dxhat, std, dxhat, 1))
.addi(Nd4j.getExecutioner().execAndReturn(
new BroadcastMulOp(xMu, dLdVar.muli(2.0 / effectiveBatchSize), xMu, 1)));
INDArray dLdx = Nd4j.getExecutioner().exec(new BroadcastDivOp(dxhat, std, dxhat, 1))
.addi(Nd4j.getExecutioner().exec(new BroadcastMulOp(xMu, dLdVar.muli(2.0 / effectiveBatchSize), xMu, 1)));
Nd4j.getExecutioner()
.execAndReturn(new BroadcastAddOp(dLdx, dLdmu.muli(1.0 / effectiveBatchSize), dLdx, 1));

Expand Down Expand Up @@ -513,9 +512,9 @@ public INDArray preOutput(INDArray x, TrainingMode training, LayerWorkspaceMgr w
if (!Shape.strideDescendingCAscendingF(x))
x = x.dup(); //TODO: temp Workaround for broadcast bug. To be removed when fixed
xMu = workspaceMgr.createUninitialized(ArrayType.INPUT, x.shape(), x.ordering());
xMu = Nd4j.getExecutioner().execAndReturn(new BroadcastSubOp(x, mean,xMu, 1));
xMu = Nd4j.getExecutioner().exec(new BroadcastSubOp(x, mean,xMu, 1));
xHat = workspaceMgr.createUninitialized(ArrayType.INPUT, x.shape(), x.ordering());
xHat = Nd4j.getExecutioner().execAndReturn(new BroadcastDivOp(xMu, std,xHat, 1));
xHat = Nd4j.getExecutioner().exec(new BroadcastDivOp(xMu, std,xHat, 1));

if (layerConf.isLockGammaBeta()) {
//Special case: gamma/beta have fixed values for all outputs
Expand All @@ -531,10 +530,8 @@ public INDArray preOutput(INDArray x, TrainingMode training, LayerWorkspaceMgr w
} else {
//Standard case: gamma and beta are learned per parameter
activations = workspaceMgr.createUninitialized(ArrayType.ACTIVATIONS, x.shape(), x.ordering());
activations = Nd4j.getExecutioner().execAndReturn(
new BroadcastMulOp(xHat, gamma, activations, 1));
activations = Nd4j.getExecutioner()
.execAndReturn(new BroadcastAddOp(activations, beta, activations, 1));
activations = Nd4j.getExecutioner().exec(new BroadcastMulOp(xHat, gamma, activations, 1));
activations = Nd4j.getExecutioner().exec(new BroadcastAddOp(activations, beta, activations, 1));
}
} else {
// TODO setup BatchNorm for RNN http://arxiv.org/pdf/1510.01378v1.pdf
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -303,8 +303,8 @@ private INDArray epsilonHelperFullArray(INDArray inputArray, INDArray epsilon, i

switch (poolingType) {
case MAX:
INDArray isMax = Nd4j.getExecutioner().execAndReturn(new IsMax(inputArray.dup(), poolDim));
return Nd4j.getExecutioner().execAndReturn(new BroadcastMulOp(isMax, epsilon, isMax, broadcastDims));
INDArray isMax = Nd4j.getExecutioner().exec(new IsMax(inputArray.dup(), poolDim));
return Nd4j.getExecutioner().exec(new BroadcastMulOp(isMax, epsilon, isMax, broadcastDims));
case AVG:
//if out = avg(in,dims) then dL/dIn = 1/N * dL/dOut
int n = 1;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -558,7 +558,7 @@ static public Pair<Gradient, INDArray> backpropGradientHelper(final NeuralNetCon
INDArray deltao = deltaoNext;
Nd4j.getExecutioner().exec(new OldMulOp(nablaOut, sigmahOfS, deltao));
if (sigmoidGates) {
INDArray sigmaoPrimeOfZo = Nd4j.getExecutioner().execAndReturn(new TimesOneMinus(ao.dup('f'))); //Equivalent to sigmoid deriv on zo
INDArray sigmaoPrimeOfZo = Nd4j.getExecutioner().exec(new TimesOneMinus(ao.dup('f'))); //Equivalent to sigmoid deriv on zo
deltao.muli(sigmaoPrimeOfZo);
} else {
deltao.assign(gateActivationFn.backprop(fwdPass.oz[time], deltao).getFirst()); //Deltao needs to be modified in-place
Expand Down Expand Up @@ -610,8 +610,7 @@ static public Pair<Gradient, INDArray> backpropGradientHelper(final NeuralNetCon
deltag.muli(ai);
deltag.muli(nablaCellState);
} else {
INDArray temp2 = Nd4j.getExecutioner().execAndReturn(
new OldMulOp(ai, nablaCellState, Nd4j.createUninitialized(ai.shape(), 'f')));
INDArray temp2 = Nd4j.getExecutioner().exec(new OldMulOp(ai, nablaCellState, Nd4j.createUninitialized(ai.shape(), 'f')));
deltag.assign(gateActivationFn.backprop(fwdPass.gz[time], temp2).getFirst());
//TODO activation functions with params; optimize (no assign)
}
Expand All @@ -620,8 +619,7 @@ static public Pair<Gradient, INDArray> backpropGradientHelper(final NeuralNetCon
//Network input delta:
INDArray zi = fwdPass.iz[time];
INDArray deltai = deltaiNext;
temp = Nd4j.getExecutioner().execAndReturn(
new OldMulOp(ag, nablaCellState, Nd4j.createUninitialized(deltai.shape(), 'f')));
temp = Nd4j.getExecutioner().exec(new OldMulOp(ag, nablaCellState, Nd4j.createUninitialized(deltai.shape(), 'f')));
deltai.assign(afn.backprop(zi, temp).getFirst());
//TODO activation functions with params; also: optimize this (no assign)
//Shape: [m,n^L]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -221,8 +221,7 @@ public double optimize(INDArray parameters, INDArray gradients, INDArray searchD

// check for convergence on delta x
if ((step < stepMin) || Nd4j.getExecutioner()
.execAndReturn(new Eps(parameters, candidateParameters,
Nd4j.createUninitialized(DataType.BOOL, candidateParameters.shape(), candidateParameters.ordering()),
.exec(new Eps(parameters, candidateParameters,Nd4j.createUninitialized(DataType.BOOL, candidateParameters.shape(), candidateParameters.ordering()),
candidateParameters.length())).castTo(DataType.FLOAT).sumNumber().longValue() == candidateParameters.length()) {
score = setScoreFor(parameters, workspaceMgr);
log.debug("EXITING BACKTRACK: Jump too small (stepMin = {}). Exiting and using original params. Score = {}",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -134,9 +134,9 @@ public static INDArray maskedPoolingEpsilonTimeSeries(PoolingType poolingType, I
Nd4j.getExecutioner().exec(new BroadcastAddOp(input, negInfMask, withInf, 0, 2));
//At this point: all the masked out steps have value -inf, hence can't be the output of the MAX op

INDArray isMax = Nd4j.getExecutioner().execAndReturn(new IsMax(withInf, 2));
INDArray isMax = Nd4j.getExecutioner().exec(new IsMax(withInf, 2));

return Nd4j.getExecutioner().execAndReturn(new BroadcastMulOp(isMax, epsilon2d, isMax, 0, 1));
return Nd4j.getExecutioner().exec(new BroadcastMulOp(isMax, epsilon2d, isMax, 0, 1));
case AVG:
case SUM:
//if out = sum(in,dims) then dL/dIn = dL/dOut -> duplicate to each step and mask
Expand Down Expand Up @@ -290,9 +290,9 @@ public static INDArray maskedPoolingEpsilonCnn(PoolingType poolingType, INDArray
Nd4j.getExecutioner().exec(new BroadcastAddOp(input, negInfMask, withInf, dimensions));
//At this point: all the masked out steps have value -inf, hence can't be the output of the MAX op

INDArray isMax = Nd4j.getExecutioner().execAndReturn(new IsMax(withInf, 2, 3));
INDArray isMax = Nd4j.getExecutioner().exec(new IsMax(withInf, 2, 3));

return Nd4j.getExecutioner().execAndReturn(new BroadcastMulOp(isMax, epsilon2d, isMax, 0, 1));
return Nd4j.getExecutioner().exec(new BroadcastMulOp(isMax, epsilon2d, isMax, 0, 1));
case AVG:
case SUM:
//if out = sum(in,dims) then dL/dIn = dL/dOut -> duplicate to each step and mask
Expand Down
40 changes: 24 additions & 16 deletions libnd4j/blas/NativeOps.h
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,8 @@ class ND4J_EXPORT NativeOps {
void *extraParams,
void *hZ, Nd4jLong *hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo,
int *dimension, int dimensionLength);
void *hDimension, Nd4jLong *hDimensionShape,
void *dDimension, Nd4jLong *dDimensionShape);

/**
*
Expand All @@ -142,7 +143,8 @@ class ND4J_EXPORT NativeOps {
void *dY, Nd4jLong *dYShapeInfo,
void *hZ, Nd4jLong *hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo,
int *dimension,int dimensionLength);
void *hDimension, Nd4jLong *hDimensionShape,
void *dDimension, Nd4jLong *dDimensionShape);


void execBroadcastBool(
Expand All @@ -154,7 +156,8 @@ class ND4J_EXPORT NativeOps {
void *dY, Nd4jLong *dYShapeInfo,
void *hZ, Nd4jLong *hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo,
int *dimension,int dimensionLength);
void *hDimension, Nd4jLong *hDimensionShape,
void *dDimension, Nd4jLong *dDimensionShape);

/**
*
Expand Down Expand Up @@ -248,7 +251,8 @@ class ND4J_EXPORT NativeOps {
void *extraParams,
void *hZ, Nd4jLong *hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo,
int *dimension,int dimensionLength);
void *hDimension, Nd4jLong *hDimensionShape,
void *dDimension, Nd4jLong *dDimensionShape);


void execReduceSame(Nd4jPointer *extraPointers,
Expand All @@ -258,7 +262,8 @@ class ND4J_EXPORT NativeOps {
void *extraParams,
void *hZ, Nd4jLong *hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo,
int *dimension, int dimensionLength);
void *hDimension, Nd4jLong *hDimensionShape,
void *dDimension, Nd4jLong *dDimensionShape);


void execReduceBool(Nd4jPointer *extraPointers,
Expand All @@ -268,7 +273,8 @@ class ND4J_EXPORT NativeOps {
void *extraParams,
void *hZ, Nd4jLong *hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo,
int *dimension, int dimensionLength);
void *hDimension, Nd4jLong *hDimensionShape,
void *dDimension, Nd4jLong *dDimensionShape);


void execReduceLong(Nd4jPointer *extraPointers,
Expand All @@ -278,7 +284,8 @@ class ND4J_EXPORT NativeOps {
void *extraParams,
void *hZ, Nd4jLong *hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo,
int *dimension, int dimensionLength);
void *hDimension, Nd4jLong *hDimensionShape,
void *dDimension, Nd4jLong *dDimensionShape);

/**
*
Expand Down Expand Up @@ -343,8 +350,8 @@ class ND4J_EXPORT NativeOps {
void *dY, Nd4jLong *dYShapeInfo,
void *hZ, Nd4jLong *hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo,
int *dimension,
int dimensionLength,
void *hDimension, Nd4jLong *hDimensionShape,
void *dDimension, Nd4jLong *dDimensionShape,
Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets,
Nd4jLong *yTadOnlyShapeInfo, Nd4jLong *yTadOffsets);

Expand All @@ -358,7 +365,8 @@ class ND4J_EXPORT NativeOps {
void *dY, Nd4jLong *dYShapeInfo,
void *hZ, Nd4jLong *hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo,
int *dimension, int dimensionLength,
void *hDimension, Nd4jLong *hDimensionShape,
void *dDimension, Nd4jLong *dDimensionShape,
Nd4jLong *xTadShapeInfo, Nd4jLong *xOffsets,
Nd4jLong *yTadShapeInfo, Nd4jLong *yOffsets);

Expand Down Expand Up @@ -443,8 +451,8 @@ class ND4J_EXPORT NativeOps {
void *extraParams,
void *hZ, Nd4jLong *hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo,
int *dimension,
int dimensionLength,
void *hDimension, Nd4jLong *hDimensionShape,
void *dDimension, Nd4jLong *dDimensionShape,
bool biasCorrected,
Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets);

Expand Down Expand Up @@ -520,8 +528,8 @@ class ND4J_EXPORT NativeOps {
void *hScalars, Nd4jLong *hScalarShapeInfo,
void *dScalars, Nd4jLong *dScalarShapeInfo,
void *extraParams,
int *dimension,
int dimensionLength,
void *hDimension, Nd4jLong *hDimensionShape,
void *dDimension, Nd4jLong *dDimensionShape,
Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets,
Nd4jLong *tadShapeInfoZ, Nd4jLong *tadOffsetsZ);

Expand All @@ -534,8 +542,8 @@ class ND4J_EXPORT NativeOps {
void *hScalars, Nd4jLong *hScalarShapeInfo,
void *dScalars, Nd4jLong *dScalarShapeInfo,
void *extraParams,
int *dimension,
int dimensionLength,
void *hDimension, Nd4jLong *hDimensionShape,
void *dDimension, Nd4jLong *dDimensionShape,
Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets,
Nd4jLong *tadShapeInfoZ, Nd4jLong *tadOffsetsZ);

Expand Down
Loading

0 comments on commit dd52fea

Please sign in to comment.