Skip to content

Commit

Permalink
lstmBlockCell forward pass - first pass complete
Browse files Browse the repository at this point in the history
  • Loading branch information
AlexDBlack committed Feb 20, 2019
1 parent 825e800 commit dc9d906
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 26 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -73,11 +73,11 @@ CUSTOM_OP_IMPL(lstmBlockCell, 8, 7, false, 2, 1) {
return Status::OK();
}

DECLARE_TYPES(lstmCell) {
getOpDescriptor()
->setAllowedInputTypes(nd4j::DataType::ANY)
->setAllowedOutputTypes({ALL_FLOATS});
}
DECLARE_TYPES(lstmBlockCell) {
getOpDescriptor()
->setAllowedInputTypes(nd4j::DataType::ANY)
->setAllowedOutputTypes({ALL_FLOATS});
}


DECLARE_SHAPE_FN(lstmBlockCell) {
Expand All @@ -104,10 +104,12 @@ DECLARE_SHAPE_FN(lstmBlockCell) {
//TODO: shape validation

// evaluate output shapeInfos
const int bS = xt[1];
const int numUnits = cLast[2];
Nd4jLong *s(nullptr);
ALLOCATE(s, block.getWorkspace(), shape::shapeInfoLength(rank), Nd4jLong); // [bS, numUnits]
ALLOCATE(s, block.getWorkspace(), shape::shapeInfoLength(2), Nd4jLong); // [bS, numUnits]

s[0] = rank;
s[0] = 2;
s[1] = bS;
s[2] = numUnits;

Expand Down
46 changes: 27 additions & 19 deletions libnd4j/include/ops/declarable/helpers/cpu/lstm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -161,43 +161,51 @@ void lstmBlockCell(const NDArray* xt, const NDArray* cLast, const NDArray* yLast
* 6: Current cell output [bS, numProj], time t
*/
const bool peephole = (bool)params[0]; // if true, provide peephole connections
/*
const double forgetBias = params[1];
double clippingCellValue = params[2]; // clipping value for ct, if it is not equal to zero, then cell state is clipped
double clippingProjValue = params[3]; // clipping value for projected ht, if it is not equal to zero, then projected cell output is clipped
const double forgetBias = params[4];


const int bS = xt->sizeAt(0);
const int inSize = xt->sizeAt(1);
const int numProj = ht_1->sizeAt(1);
const int numUnits = ct_1->sizeAt(1);

auto z = mmul(*xt, *Wx) + mmul(*ht_1, *Wh) + *b; // [bS x 4*numUnits] + [bS x 4*numUnits] + [1 x 4*numUnits] = [bS x 4*numUnits]
//Concat inputs: [xt, yt-1]: concat([bs,nIn],[bs,nOut]) -> [bs, (nIn+nOut)]
auto concat = new nd4j::ops::concat();
auto variableSpace = new VariableSpace();
variableSpace->putVariable(-1, xt);
variableSpace->putVariable(-2, yLast);
Context block(1, variableSpace);
block.getIArguments()->push_back(1); //Dim 1
auto concatInputs = concat.execute(block);

auto zit = z({0,0, 0, numUnits}); // z for input gate, = mmul(Wxi,xt) + mmul(Whi,ht_1) + bi = [bS x numUnits]
auto zft = z({0,0, numUnits, 2*numUnits}); // z for forget gate, = mmul(Wxf,xt) + mmul(Whf,ht_1) + bf = [bS x numUnits]
auto zct = z({0,0, 2*numUnits, 3*numUnits}); // z for cell state, = mmul(Wxc,xt) + mmul(Whc,ht_1) + bc = [bS x numUnits]
auto zot = z({0,0, 3*numUnits, 4*numUnits}); // z for output gate, = mmul(Wxo,xt) + mmul(Who,ht_1) + bo = [bS x numUnits]

auto mmul = mmul(*concatInputs, *W); //mmul: [bs, (nIn+numUnits)]* [(inSize+numUnits), 4*numUnits] = [bs, 4*numUnits]

auto zz = z({0,0, 0, numUnits}); // z for input gate, [bS x numUnits]
auto zf = z({0,0, numUnits, 2*numUnits}); // z for forget gate, [bS x numUnits]
auto zi = z({0,0, 2*numUnits, 3*numUnits}); // z for input modulation gate, [bS x numUnits]
auto zo = z({0,0, 3*numUnits, 4*numUnits}); // z for output gate, [bS x numUnits]

if(peephole) { // add peephole connections: z + ct_1*Wc
zit += (*ct_1) * (*Wc)({0, numUnits}); // add peephole connections to input gate
zft += (*ct_1) * (*Wc)({numUnits, 2*numUnits}); // add peephole connections to forget gate
zi += (*ct_1) * (*Wci); // add peephole connections to input gate
zf += (*ct_1) * (*Wcf); // add peephole connections to forget gate
}

// current sell state = ft*ct_1 + it*activation(mmul(Wxc,xt) + mmul(Whc,ht_1) + bc
ct->assign( sigmoid(zft + forgetBias) * (*ct_1) + sigmoid(zit) * activation(zct) );
if(forgetBias > 0.0){
zft += forgetBias;
}
c->assign( sigmoid(zft) * (*cLast) + sigmoid(zit) * activation(zct) );

// if clipping value is provided then cell state is clipped by this value prior to the cell output activation
if(clippingCellValue != 0.)
clipping(ct, clippingCellValue);
if(clippingCellValue != 0.0)
clipping(c, clippingCellValue);

if(peephole)
zot += (*ct) * (*Wc)({{2*numUnits, 3*numUnits}}); // add peephole connections to output gate zot + ct*Wc
zot += (*c) * (*Wcf); // add peephole connections to output gate zot + ct*Wc

// current cell output = ot*activation(ct)
auto htNoPeepHole = sigmoid(zot) * activation(*ct); // = [bS x numUnits]
ht->assign(&htNoPeepHole);
*/
ht->assign(sigmoid(zo) * activation(*ct));
}


Expand Down

0 comments on commit dc9d906

Please sign in to comment.