Skip to content

Commit

Permalink
Small fix
Browse files Browse the repository at this point in the history
  • Loading branch information
AlexDBlack committed Feb 9, 2019
1 parent 7a05134 commit 0f7e7d0
Show file tree
Hide file tree
Showing 6 changed files with 18 additions and 14 deletions.
Expand Up @@ -16,6 +16,7 @@

package org.deeplearning4j.spark.models.sequencevectors.learning.elements;

import org.deeplearning4j.models.embeddings.learning.impl.elements.BatchSequences;
import org.deeplearning4j.models.embeddings.learning.impl.elements.RandomUtils;
import org.deeplearning4j.models.sequencevectors.sequence.Sequence;
import org.deeplearning4j.models.sequencevectors.sequence.ShallowSequenceElement;
Expand Down Expand Up @@ -45,6 +46,11 @@ public String getCodeName() {
return "Spark-CBOW";
}

@Override
public double learnSequence(Sequence<ShallowSequenceElement> sequence, AtomicLong nextRandom, double learningRate, BatchSequences<ShallowSequenceElement> batchSequences) {
throw new UnsupportedOperationException();
}

@Override
public Frame<? extends TrainingMessage> frameSequence(Sequence<ShallowSequenceElement> sequence,
AtomicLong nextRandom, double learningRate) {
Expand Down
Expand Up @@ -17,6 +17,7 @@
package org.deeplearning4j.spark.models.sequencevectors.learning.elements;

import lombok.extern.slf4j.Slf4j;
import org.deeplearning4j.models.embeddings.learning.impl.elements.BatchSequences;
import org.deeplearning4j.models.embeddings.learning.impl.elements.RandomUtils;
import org.deeplearning4j.models.sequencevectors.sequence.Sequence;
import org.deeplearning4j.models.sequencevectors.sequence.ShallowSequenceElement;
Expand All @@ -39,6 +40,11 @@ public String getCodeName() {
return "Spark-SkipGram";
}

@Override
public double learnSequence(Sequence<ShallowSequenceElement> sequence, AtomicLong nextRandom, double learningRate, BatchSequences<ShallowSequenceElement> batchSequences) {
throw new UnsupportedOperationException();
}

protected transient AtomicLong counter;
protected transient ThreadLocal<Frame<SkipGramRequestMessage>> frame;

Expand Down
Expand Up @@ -60,16 +60,8 @@ public Axpy() {

}

public Axpy(INDArray x, INDArray z, double p) {
super(x,z);
this.p = p;
this.extraArgs = new Object[] {p};
}

public Axpy(INDArray x, INDArray z, double p, long n) {
super(x,z);
this.p = p;
this.extraArgs = new Object[] {p, (double) n};
public Axpy(INDArray x, INDArray y, INDArray z, double p) {
this(x,y,z,p,x.length());
}

public Axpy(INDArray x, INDArray y, INDArray z, double p, long n) {
Expand Down Expand Up @@ -101,6 +93,6 @@ public String tensorflowName() {

@Override
public List<SDVariable> doDiff(List<SDVariable> f1) {
return null;
throw new UnsupportedOperationException("Backprop: not yet implemented");
}
}
Expand Up @@ -49,7 +49,7 @@ public void apply(INDArray param, INDArray gradView, double lr, int iteration, i
//where sign(x[i]) is -1 or 1
double coeff = l1.valueAt(iteration, epoch);
INDArray sign = Transforms.sign(param, true);
Nd4j.exec(new Axpy(sign, gradView, coeff)); //Gradient += l1 * sign(param)
Nd4j.exec(new Axpy(gradView, sign, gradView, coeff)); //Gradient += l1 * sign(param)
}

@Override
Expand Down
Expand Up @@ -54,7 +54,7 @@ public void apply(INDArray param, INDArray gradView, double lr, int iteration, i
//L = loss + l2 * 0.5 * sum_i x[i]^2
//dL/dx[i] = dloss/dx[i] + l2 * x[i]
double coeff = l2.valueAt(iteration, epoch);
Nd4j.exec(new Axpy(param, gradView, coeff)); //Gradient += scale * param
Nd4j.exec(new Axpy(gradView, param, gradView, coeff)); //Gradient += scale * param
}

@Override
Expand Down
Expand Up @@ -75,7 +75,7 @@ public void apply(INDArray param, INDArray gradView, double lr, int iteration, i
if(applyLR){
scale *= lr;
}
Nd4j.exec(new Axpy(param, gradView, scale)); //update += scale * param
Nd4j.exec(new Axpy(gradView, param, gradView, scale)); //update += scale * param
}

@Override
Expand Down

0 comments on commit 0f7e7d0

Please sign in to comment.