Skip to content

Commit

Permalink
[SYSTEMML-540] [SYSTEMML-515] Allow an expression for sparsity
Browse files Browse the repository at this point in the history
- This PR also improves the performance of dropout.

Closes apache#351.
  • Loading branch information
Niketan Pansare committed Jan 19, 2017
1 parent 59aa9f1 commit 1b8d44d
Show file tree
Hide file tree
Showing 6 changed files with 32 additions and 20 deletions.
10 changes: 8 additions & 2 deletions scripts/staging/SystemML-NN/nn/layers/dropout.dml
Expand Up @@ -42,10 +42,16 @@ forward = function(matrix[double] X, double p, int seed)
* - out: Ouptuts, of same shape as X.
* - mask: Dropout mask used to compute the output.
*/
# Normally, we might use something like
# `mask = rand(rows=nrow(X), cols=ncol(X), min=0, max=1, seed=seed) <= p`
# to create a dropout mask. Fortunately, SystemML has a `sparsity` parameter on
# the `rand` function that allows use to create a mask directly.
if (seed == -1) {
seed = as.integer(floor(as.scalar(rand(rows=1, cols=1, min=1, max=100000))))
mask = rand(rows=nrow(X), cols=ncol(X), min=1, max=1, sparsity=p)
}
else {
mask = rand(rows=nrow(X), cols=ncol(X), min=1, max=1, sparsity=p, seed=seed)
}
mask = rand(rows=nrow(X), cols=ncol(X), min=0, max=1, seed=seed) <= p
out = X * mask / p
}

Expand Down
12 changes: 8 additions & 4 deletions src/main/java/org/apache/sysml/hops/DataGenOp.java
Expand Up @@ -100,8 +100,10 @@ public DataGenOp(DataGenMethod mthd, DataIdentifier id, HashMap<String, Hop> inp
_paramIndexMap.put(s, index);
index++;
}
if ( mthd == DataGenMethod.RAND )
_sparsity = Double.valueOf(((LiteralOp)inputParameters.get(DataExpression.RAND_SPARSITY)).getName());

Hop sparsityOp = inputParameters.get(DataExpression.RAND_SPARSITY);
if ( mthd == DataGenMethod.RAND && sparsityOp instanceof LiteralOp)
_sparsity = Double.valueOf(((LiteralOp)sparsityOp).getName());

//generate base dir
String scratch = ConfigurationManager.getScratchSpace();
Expand Down Expand Up @@ -199,7 +201,7 @@ protected double computeOutputMemEstimate( long dim1, long dim2, long nnz )
{
double ret = 0;

if ( _op == DataGenMethod.RAND ) {
if ( _op == DataGenMethod.RAND && _sparsity != -1 ) {
if( hasConstantValue(0.0) ) { //if empty block
ret = OptimizerUtils.estimateSizeEmptyBlock(dim1, dim2);
}
Expand Down Expand Up @@ -237,7 +239,7 @@ protected long[] inferOutputCharacteristics( MemoTable memo )
{
long dim1 = computeDimParameterInformation(getInput().get(_paramIndexMap.get(DataExpression.RAND_ROWS)), memo);
long dim2 = computeDimParameterInformation(getInput().get(_paramIndexMap.get(DataExpression.RAND_COLS)), memo);
long nnz = (long)(_sparsity * dim1 * dim2);
long nnz = _sparsity >= 0 ? (long)(_sparsity * dim1 * dim2) : -1;
if( dim1>0 && dim2>0 )
return new long[]{ dim1, dim2, nnz };
}
Expand Down Expand Up @@ -355,6 +357,8 @@ else if (_op == DataGenMethod.SEQ )
_nnz = 0;
else if ( dimsKnown() && _sparsity>=0 ) //general case
_nnz = (long) (_sparsity * _dim1 * _dim2);
else
_nnz = -1;
}


Expand Down
15 changes: 7 additions & 8 deletions src/main/java/org/apache/sysml/lops/DataGen.java
Expand Up @@ -199,12 +199,11 @@ private String getCPInstruction_Rand(String output)
sb.append(iLop.prepScalarLabel());
sb.append(OPERAND_DELIMITOR);

iLop = _inputParams.get(DataExpression.RAND_SPARSITY.toString()); //no variable support
iLop = _inputParams.get(DataExpression.RAND_SPARSITY.toString());
if (iLop.isVariable())
throw new LopsException(printErrorLocation()
+ "Parameter " + DataExpression.RAND_SPARSITY
+ " must be a literal for a Rand operation.");
sb.append(iLop.getOutputParameters().getLabel());
sb.append(iLop.prepScalarLabel());
else
sb.append(iLop.getOutputParameters().getLabel());
sb.append(OPERAND_DELIMITOR);

iLop = _inputParams.get(DataExpression.RAND_SEED.toString());
Expand Down Expand Up @@ -442,9 +441,9 @@ private String getMRInstruction_Rand(int inputIndex, int outputIndex)

iLop = _inputParams.get(DataExpression.RAND_SPARSITY.toString()); //no variable support
if (iLop.isVariable())
throw new LopsException(this.printErrorLocation() + "Parameter "
+ DataExpression.RAND_SPARSITY + " must be a literal for a Rand operation.");
sb.append( iLop.getOutputParameters().getLabel() );
sb.append(iLop.prepScalarLabel());
else
sb.append( iLop.getOutputParameters().getLabel() );
sb.append( OPERAND_DELIMITOR );

iLop = _inputParams.get(DataExpression.RAND_SEED.toString());
Expand Down
5 changes: 1 addition & 4 deletions src/main/java/org/apache/sysml/parser/DataExpression.java
Expand Up @@ -1162,10 +1162,7 @@ else if (dataParam instanceof StringIdentifier) {
raiseValidateError("for Rand statement " + RAND_MIN + " has incorrect value type", conditional);
}

//parameters w/o support for variable inputs (requires double/int or string constants)
if (!(getVarParam(RAND_SPARSITY) instanceof DoubleIdentifier || getVarParam(RAND_SPARSITY) instanceof IntIdentifier)) {
raiseValidateError("for Rand statement " + RAND_SPARSITY + " has incorrect value type", conditional);
}
// Since sparsity can be arbitrary expression (SYSTEMML-515), no validation check for DoubleIdentifier/IntIdentifier required.

if (!(getVarParam(RAND_PDF) instanceof StringIdentifier)) {
raiseValidateError("for Rand statement " + RAND_PDF + " has incorrect value type", conditional);
Expand Down
Expand Up @@ -212,7 +212,10 @@ else if ( opcode.equalsIgnoreCase(DataGen.SAMPLE_OPCODE) ) {
maxValue = Double.valueOf(s[6]).doubleValue();
}

double sparsity = Double.parseDouble(s[7]);
double sparsity = -1;
if (!s[7].contains( Lop.VARIABLE_NAME_PLACEHOLDER)) {
sparsity = Double.valueOf(s[7]);
}

long seed = DataGenOp.UNSPECIFIED_SEED;
if( !s[8].contains( Lop.VARIABLE_NAME_PLACEHOLDER)){
Expand Down
Expand Up @@ -245,7 +245,10 @@ else if ( opcode.equalsIgnoreCase(DataGen.SAMPLE_OPCODE) ) {
maxValue = Double.valueOf(s[6]).doubleValue();
}

double sparsity = Double.parseDouble(s[7]);
double sparsity = -1;
if (!s[7].contains( Lop.VARIABLE_NAME_PLACEHOLDER)) {
sparsity = Double.valueOf(s[7]);
}

long seed = DataGenOp.UNSPECIFIED_SEED;
if (!s[8].contains( Lop.VARIABLE_NAME_PLACEHOLDER)) {
Expand Down

0 comments on commit 1b8d44d

Please sign in to comment.