Skip to content

Commit

Permalink
Fix axpy args order for regularization
Browse files Browse the repository at this point in the history
  • Loading branch information
AlexDBlack committed Feb 11, 2019
1 parent b78775f commit 1d17030
Show file tree
Hide file tree
Showing 3 changed files with 3 additions and 3 deletions.
Expand Up @@ -49,7 +49,7 @@ public void apply(INDArray param, INDArray gradView, double lr, int iteration, i
//where sign(x[i]) is -1 or 1
double coeff = l1.valueAt(iteration, epoch);
INDArray sign = Transforms.sign(param, true);
Nd4j.exec(new Axpy(gradView, sign, gradView, coeff)); //Gradient += l1 * sign(param)
Nd4j.exec(new Axpy(sign, gradView, gradView, coeff)); //Gradient = l1 * sign(param) + gradient
}

@Override
Expand Down
Expand Up @@ -54,7 +54,7 @@ public void apply(INDArray param, INDArray gradView, double lr, int iteration, i
//L = loss + l2 * 0.5 * sum_i x[i]^2
//dL/dx[i] = dloss/dx[i] + l2 * x[i]
double coeff = l2.valueAt(iteration, epoch);
Nd4j.exec(new Axpy(gradView, param, gradView, coeff)); //Gradient += scale * param
Nd4j.exec(new Axpy(param, gradView, gradView, coeff)); //Gradient = scale * param + gradient
}

@Override
Expand Down
Expand Up @@ -75,7 +75,7 @@ public void apply(INDArray param, INDArray gradView, double lr, int iteration, i
if(applyLR){
scale *= lr;
}
Nd4j.exec(new Axpy(gradView, param, gradView, scale)); //update += scale * param
Nd4j.exec(new Axpy(param, gradView, gradView, scale)); //update = scale * param + update
}

@Override
Expand Down

0 comments on commit 1d17030

Please sign in to comment.