Skip to content

Commit

Permalink
Snapshot update (#8194)
Browse files Browse the repository at this point in the history
* fix double consumption of rng on cpu

Signed-off-by: raver119 <raver119@gmail.com>

* Shyrma docs (#222)

* - documenting and profiling matrix_set_diag cuda kernel

Signed-off-by: Yurii <yurii@skymind.io>

* - correct formula of pnorm pooling in cuda 2d/3d kernels
- remove helper matrix_diag which duplicates work of helper matrix_set_diag

Signed-off-by: Yurii <yurii@skymind.io>

* cublasHandle sharing + lock

Signed-off-by: raver119 <raver119@gmail.com>

* cublasHandle sharing + lock

Signed-off-by: raver119 <raver119@gmail.com>

* Documentation from serialization/deserialization in NLP (#221)

* refactoring

Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com>

* Javadocs

Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com>

* Javadoc fixed

Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com>

* Cleanup

Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com>

* dedicated lock for getCudaCublasHandle

Signed-off-by: raver119 <raver119@gmail.com>

* Small fixes (#223)

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* ELU DL4J fixes (#224)

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* javadoc (#225)

Signed-off-by: Robert Altena <Rob@Ra-ai.com>

* Small test compilation fix (#226)

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* #8182 remove spark version suffix (#227)

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* [WIP] Thread safety (#229)

* sync after cublas*gemm

Signed-off-by: raver119 <raver119@gmail.com>

* mutex for CublasHelper

Signed-off-by: raver119 <raver119@gmail.com>

* don't store cublasHandle in LaunchContext, it's per-device anyway

Signed-off-by: raver119 <raver119@gmail.com>

* some printout

Signed-off-by: raver119 <raver119@gmail.com>

* check for field instead

Signed-off-by: raver119 <raver119@gmail.com>

* pew-pew

Signed-off-by: raver119 <raver119@gmail.com>

* don't release ContextBuffers until device changed

Signed-off-by: raver119 <raver119@gmail.com>

* small tweak

Signed-off-by: raver119 <raver119@gmail.com>

* some logging in sgemm

Signed-off-by: raver119 <raver119@gmail.com>

* stream sync

Signed-off-by: raver119 <raver119@gmail.com>

* some more logging

Signed-off-by: raver119 <raver119@gmail.com>

* some more error checks

Signed-off-by: raver119 <raver119@gmail.com>

* one fancy test

Signed-off-by: raver119 <raver119@gmail.com>

* one fancy test

Signed-off-by: raver119 <raver119@gmail.com>

* minor AffinityManager fix

Signed-off-by: raver119 <raver119@gmail.com>

* cudaEvent error logging improvement

Signed-off-by: raver119 <raver119@gmail.com>

* ConstantHelper thread safety

Signed-off-by: raver119 <raver119@gmail.com>

* - minor corrections in ConstantTadHelper

Signed-off-by: Yurii <yurii@skymind.io>

* ConstantShapeHelper thread safety

Signed-off-by: raver119 <raver119@gmail.com>

* ConstantTadHelper.cu updated

Signed-off-by: raver119 <raver119@gmail.com>

* logging off

Signed-off-by: raver119 <raver119@gmail.com>

* logging off

Signed-off-by: raver119 <raver119@gmail.com>
  • Loading branch information
raver119 committed Sep 3, 2019
1 parent 9d03bb9 commit 7abc574
Show file tree
Hide file tree
Showing 58 changed files with 833 additions and 837 deletions.
Expand Up @@ -38,7 +38,7 @@
<dependency>
<groupId>org.datavec</groupId>
<artifactId>datavec-spark-inference-server_2.11</artifactId>
<version>1.0.0_spark_2-SNAPSHOT</version>
<version>1.0.0-SNAPSHOT</version>
<scope>test</scope>
</dependency>
<dependency>
Expand Down
Expand Up @@ -25,7 +25,7 @@

<artifactId>datavec-spark-inference-server_2.11</artifactId>
<packaging>jar</packaging>
<version>1.0.0_spark_2-SNAPSHOT</version>
<version>1.0.0-SNAPSHOT</version>
<name>datavec-spark-inference-server</name>

<properties>
Expand Down
2 changes: 1 addition & 1 deletion datavec/datavec-spark/pom.xml
Expand Up @@ -24,7 +24,7 @@
</parent>

<modelVersion>4.0.0</modelVersion>
<version>1.0.0_spark_2-SNAPSHOT</version>
<version>1.0.0-SNAPSHOT</version>
<artifactId>datavec-spark_2.11</artifactId>

<properties>
Expand Down
@@ -0,0 +1,63 @@
package org.deeplearning4j;

import org.deeplearning4j.datasets.iterator.EarlyTerminationDataSetIterator;
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.junit.Ignore;
import org.junit.Test;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.learning.config.RmsProp;
import org.nd4j.linalg.lossfunctions.LossFunctions;

import java.util.concurrent.CountDownLatch;

@Ignore
public class RandomTests {

@Test
public void testReproduce() throws Exception {

final MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().updater(new RmsProp())
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).list()
.layer(0, new org.deeplearning4j.nn.conf.layers.DenseLayer.Builder().nIn(28 * 28).nOut(10)
.activation(Activation.TANH).build())
.layer(1, new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder(
LossFunctions.LossFunction.MCXENT).nIn(10).nOut(10)
.activation(Activation.SOFTMAX).build())
.build();

for (int e = 0; e < 3; e++) {

int nThreads = 10;
final CountDownLatch l = new CountDownLatch(nThreads);
for (int i = 0; i < nThreads; i++) {
final int j = i;
Thread t = new Thread(new Runnable() {
@Override
public void run() {
try {
MultiLayerNetwork net = new MultiLayerNetwork(conf.clone());
net.init();
DataSetIterator iter = new EarlyTerminationDataSetIterator(new MnistDataSetIterator(10, false, 12345), 100);
net.fit(iter);
} catch (Throwable t) {
System.out.println("Thread failed: " + j);
t.printStackTrace();
} finally {
l.countDown();
}
}
});
t.start();
}

l.await();
System.out.println("DONE " + e + "\n");
}
}
}
Expand Up @@ -833,14 +833,14 @@ public void testMalformedLabels1() throws Exception {
public void testB64_1() throws Exception {
String wordA = "night";
String wordB = "night day";
String encA = WordVectorSerializer.encodeB64(wordA);
String encB = WordVectorSerializer.encodeB64(wordB);
String encA = WordVectorSerializer.ReadHelper.encodeB64(wordA);
String encB = WordVectorSerializer.ReadHelper.encodeB64(wordB);

assertEquals(wordA, WordVectorSerializer.decodeB64(encA));
assertEquals(wordB, WordVectorSerializer.decodeB64(encB));
assertEquals(wordA, WordVectorSerializer.ReadHelper.decodeB64(encA));
assertEquals(wordB, WordVectorSerializer.ReadHelper.decodeB64(encB));

assertEquals(wordA, WordVectorSerializer.decodeB64(wordA));
assertEquals(wordB, WordVectorSerializer.decodeB64(wordB));
assertEquals(wordA, WordVectorSerializer.ReadHelper.decodeB64(wordA));
assertEquals(wordB, WordVectorSerializer.ReadHelper.decodeB64(wordB));

}

Expand Down

0 comments on commit 7abc574

Please sign in to comment.