Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

#225: [WIP] have two options for fitting, one that is async and allows the end ... #226

Closed
wants to merge 1 commit into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -18,36 +18,47 @@

import static org.nd4j.linalg.indexing.NDArrayIndex.interval;

import akka.actor.ActorSystem;
import akka.dispatch.Futures;
import com.google.common.util.concurrent.AtomicDouble;
import org.deeplearning4j.berkeley.Pair;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.TreeMap;
import java.util.concurrent.Callable;
import java.util.concurrent.CopyOnWriteArrayList;
import java.util.concurrent.CountDownLatch;

import org.deeplearning4j.berkeley.Pair;
import org.deeplearning4j.models.embeddings.WeightLookupTable;
import org.deeplearning4j.models.featuredetectors.autoencoder.recursive.Tree;
import org.deeplearning4j.models.word2vec.wordstore.VocabCache;
import org.deeplearning4j.models.word2vec.Word2Vec;
import org.deeplearning4j.models.word2vec.wordstore.VocabCache;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.optimize.api.ConvexOptimizer;
import org.deeplearning4j.parallel.Parallelization;
import org.deeplearning4j.util.MultiDimensionalMap;
import org.deeplearning4j.util.MultiDimensionalSet;
import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.rng.Random;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.NDArrayIndex;
import org.nd4j.linalg.ops.transforms.Transforms;
import org.nd4j.linalg.learning.AdaGrad;
import org.deeplearning4j.parallel.Parallelization;
import org.deeplearning4j.util.MultiDimensionalMap;
import org.deeplearning4j.util.MultiDimensionalSet;

import org.nd4j.linalg.ops.transforms.Transforms;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.util.*;
import java.util.concurrent.Callable;
import java.util.concurrent.CopyOnWriteArrayList;
import org.nd4j.linalg.api.rng.Random;
import scala.concurrent.Future;
import akka.actor.ActorSystem;
import akka.dispatch.Futures;
import akka.dispatch.OnComplete;

import com.google.common.util.concurrent.AtomicDouble;
/**
* Recursive Neural Tensor Network by Socher et. al
*
Expand Down Expand Up @@ -339,21 +350,50 @@ INDArray randomClassificationMatrix() {
return Nd4j.getBlasWrapper().scal((float) scalingForInit, ret);

}






/**
* Trains the network on this mini batch and waits for the training set to complete
* @param trainingBatch the trees to iterate on
*/
public void fit(List<Tree> trainingBatch) {
final CountDownLatch c = new CountDownLatch(trainingBatch.size());

List<Future<Object>> futureBatch = fitAsync(trainingBatch);

for(Future<Object> f : futureBatch) {
f.onComplete(new OnComplete<Object>() {
@Override
public void onComplete(Throwable throwable, Object e) throws Throwable {
if(throwable != null)
log.warn("Error occurred training batch",throwable);

c.countDown();
}
},rnTnActorSystem.dispatcher());
}


try {
c.await();
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
}
}

/**
* Trains the network on this mini batch
* Trains the network on this mini batch and returns a list of futures for each training job
* @param trainingBatch the trees to iterate on
*/
public void fit(List<Tree> trainingBatch) {
public List<Future<Object>> fitAsync(List<Tree> trainingBatch) {
this.trainingTrees = trainingBatch;
int count = 0;

List<Future<Object>> futureBatch = new ArrayList<>();

for(final Tree t : trainingBatch) {
log.info("Working mini batch " + count++);
Futures.future(new Callable<Object>() {
futureBatch.add(Futures.future(new Callable<Object>() {
@Override
public Object call() throws Exception {
forwardPropagateTree(t);
Expand All @@ -369,10 +409,11 @@ public Object call() throws Exception {

return null;
}
},rnTnActorSystem.dispatcher());
},rnTnActorSystem.dispatcher()));


}
return futureBatch;
}


Expand Down Expand Up @@ -946,10 +987,6 @@ public void run(Tree currentItem, Object[] args) {
},rnTnActorSystem);






// TODO: we may find a big speedup by separating the derivatives and then summing
final AtomicDouble error = new AtomicDouble(0);
if(!forwardPropTrees.isEmpty())
Expand Down Expand Up @@ -1005,12 +1042,10 @@ public double getValue() {

@Override
public void fit() {

}

@Override
public void update(Gradient gradient) {

}

@Override
Expand Down