Skip to content

Commit

Permalink
#6362 Add system properties for DL4J and ND4J temporary file locations
Browse files Browse the repository at this point in the history
  • Loading branch information
AlexDBlack committed Oct 4, 2018
1 parent e5ddb5c commit 41dfe0f
Show file tree
Hide file tree
Showing 17 changed files with 191 additions and 273 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,13 @@ public class DL4JSystemProperties {

private DL4JSystemProperties(){ }

/**
* Applicability: DL4J ModelSerializer, ModelGuesser, Keras model import<br>
* Description: Specify the local directory where temporary files will be written. If not specified, the default
* Java temporary directory (java.io.tmpdir system property) will generally be used.
*/
public static final String DL4J_TEMP_DIR_PROPERTY = "org.deeplearning4j.tempdir";

/**
* Applicability: Numerous modules, including deeplearning4j-datasets and deeplearning4j-zoo<br>
* Description: Used to set the local location for downloaded remote resources such as datasets (like MNIST) and
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/

package org.deeplearning4j.util;

import org.deeplearning4j.config.DL4JSystemProperties;

import java.io.File;
import java.io.IOException;

/**
* Utilities for working with temporary files
*
* @author Alex Black
*/
public class DL4JFileUtils {

private DL4JFileUtils(){ }

/**
* Create a temporary file in the location specified by {@link DL4JSystemProperties#DL4J_TEMP_DIR_PROPERTY} if set,
* or the default temporary directory (usually specified by java.io.tmpdir system property)
* @param prefix Prefix for generating file's name; must be at least 3 characeters
* @param suffix Suffix for generating file's name; may be null (".tmp" will be used if null)
* @return A temporary file
*/
public static File createTempFile(String prefix, String suffix) {
String p = System.getProperty(DL4JSystemProperties.DL4J_TEMP_DIR_PROPERTY);
try {
if (p == null || p.isEmpty()) {
return File.createTempFile(prefix, suffix);
} else {
return File.createTempFile(prefix, suffix, new File(p));
}
} catch (IOException e){
throw new RuntimeException("Error creating temporary file", e);
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.io.FileUtils;
import org.apache.commons.io.IOUtils;
import org.deeplearning4j.config.DL4JSystemProperties;
import org.deeplearning4j.nn.api.Model;
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
Expand Down Expand Up @@ -110,15 +111,18 @@ public static Object loadConfigGuess(String path) throws Exception {
* @throws Exception
*/
public static Object loadConfigGuess(InputStream stream) throws Exception {
File tmp = new File("model-" + UUID.randomUUID().toString());
String p = System.getProperty(DL4JSystemProperties.DL4J_TEMP_DIR_PROPERTY);
File tmp = DL4JFileUtils.createTempFile("model-" + UUID.randomUUID().toString(), "bin");
BufferedOutputStream bufferedOutputStream = new BufferedOutputStream(new FileOutputStream(tmp));
IOUtils.copy(stream, bufferedOutputStream);
bufferedOutputStream.flush();
bufferedOutputStream.close();
tmp.deleteOnExit();
Object load = loadConfigGuess(tmp.getAbsolutePath());
tmp.delete();
return load;
try {
return loadConfigGuess(tmp.getAbsolutePath());
} finally {
tmp.delete();
}
}

/**
Expand Down Expand Up @@ -170,21 +174,9 @@ public static Model loadModelGuess(String path) throws Exception {
* @throws Exception
*/
public static Model loadModelGuess(InputStream stream) throws Exception {
return loadModelGuess(stream, null);
}

/**
* Load the model from the given input stream
* @param stream the path of the file to "guess"
* @param tempFileDirectory May be null. The directory in which to create any temporary files
*
* @return the loaded model
* @throws Exception
*/
public static Model loadModelGuess(InputStream stream, File tempFileDirectory) throws Exception {
//Currently (Nov 2017): KerasModelImport doesn't support loading from input streams
//Simplest solution here: write to a temporary file
File f = File.createTempFile("loadModelGuess",".bin",tempFileDirectory);
File f = DL4JFileUtils.createTempFile("loadModelGuess",".bin");
f.deleteOnExit();

try (OutputStream os = new BufferedOutputStream(new FileOutputStream(f))) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,8 @@ public void testWeightNoiseConfigJson() {
assertEquals(wn, ((BaseLayer) graph.getLayer(2).conf().getLayer()).getWeightNoise());

TestUtils.testModelSerialization(graph);

graph.fit(new DataSet(Nd4j.create(1,10), Nd4j.create(1,10)));
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,12 @@

package org.deeplearning4j.nn.modelimport.keras.optimizers;

import org.deeplearning4j.config.DL4JSystemProperties;
import org.deeplearning4j.nn.modelimport.keras.KerasModel;
import org.deeplearning4j.nn.modelimport.keras.KerasSequentialModel;
import org.deeplearning4j.nn.modelimport.keras.e2e.KerasModelEndToEndTest;
import org.deeplearning4j.nn.modelimport.keras.utils.KerasModelBuilder;
import org.deeplearning4j.util.DL4JFileUtils;
import org.junit.Test;
import org.nd4j.linalg.io.ClassPathResource;

Expand Down Expand Up @@ -70,7 +72,7 @@ private void importSequential(String modelPath) throws Exception {
ClassPathResource modelResource =
new ClassPathResource(modelPath,
KerasModelEndToEndTest.class.getClassLoader());
File modelFile = createTempFile("tempModel", ".h5");
File modelFile = DL4JFileUtils.createTempFile("tempModel", ".h5");
Files.copy(modelResource.getInputStream(), modelFile.toPath(), StandardCopyOption.REPLACE_EXISTING);
KerasModelBuilder builder = new KerasModel().modelBuilder().modelHdf5Filename(modelFile.getAbsolutePath())
.enforceTrainingConfig(false);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import com.atilika.kuromoji.dict.DictionaryEntryBase;
import com.atilika.kuromoji.dict.GenericDictionaryEntry;
import com.atilika.kuromoji.dict.TokenInfoDictionary;
import org.deeplearning4j.util.DL4JFileUtils;

import java.io.*;
import java.util.ArrayList;
Expand Down Expand Up @@ -189,7 +190,7 @@ protected void writeWordIds(String filename) throws IOException {

@Deprecated
public WordIdMap getWordIdMap() throws IOException {
File file = File.createTempFile("kuromoji-wordid-", ".bin");
File file = DL4JFileUtils.createTempFile("kuromoji-wordid-", ".bin");
file.deleteOnExit();

OutputStream output = new BufferedOutputStream(new FileOutputStream(file));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
import org.deeplearning4j.text.sentenceiterator.BasicLineIterator;
import org.deeplearning4j.text.tokenization.tokenizer.TokenPreProcess;
import org.deeplearning4j.text.tokenization.tokenizerfactory.TokenizerFactory;
import org.deeplearning4j.util.DL4JFileUtils;
import org.nd4j.base.Preconditions;
import org.nd4j.compression.impl.NoOp;
import org.nd4j.linalg.api.ndarray.INDArray;
Expand Down Expand Up @@ -456,12 +457,12 @@ public static void writeWord2VecModel(Word2Vec vectors, OutputStream stream) thr
zipfile.putNextEntry(syn0);

// writing out syn0
File tempFileSyn0 = File.createTempFile("word2vec", "0");
File tempFileSyn1 = File.createTempFile("word2vec", "1");
File tempFileSyn1Neg = File.createTempFile("word2vec", "n");
File tempFileCodes = File.createTempFile("word2vec", "h");
File tempFileHuffman = File.createTempFile("word2vec", "h");
File tempFileFreqs = File.createTempFile("word2vec", "f");
File tempFileSyn0 = DL4JFileUtils.createTempFile("word2vec", "0");
File tempFileSyn1 = DL4JFileUtils.createTempFile("word2vec", "1");
File tempFileSyn1Neg = DL4JFileUtils.createTempFile("word2vec", "n");
File tempFileCodes = DL4JFileUtils.createTempFile("word2vec", "h");
File tempFileHuffman = DL4JFileUtils.createTempFile("word2vec", "h");
File tempFileFreqs = DL4JFileUtils.createTempFile("word2vec", "f");
tempFileSyn0.deleteOnExit();
tempFileSyn1.deleteOnExit();
tempFileSyn1Neg.deleteOnExit();
Expand Down Expand Up @@ -598,11 +599,11 @@ public static void writeParagraphVectors(ParagraphVectors vectors, OutputStream
zipfile.putNextEntry(syn0);

// writing out syn0
File tempFileSyn0 = File.createTempFile("paravec", "0");
File tempFileSyn1 = File.createTempFile("paravec", "1");
File tempFileCodes = File.createTempFile("paravec", "h");
File tempFileHuffman = File.createTempFile("paravec", "h");
File tempFileFreqs = File.createTempFile("paravec", "h");
File tempFileSyn0 = DL4JFileUtils.createTempFile("paravec", "0");
File tempFileSyn1 = DL4JFileUtils.createTempFile("paravec", "1");
File tempFileCodes = DL4JFileUtils.createTempFile("paravec", "h");
File tempFileHuffman = DL4JFileUtils.createTempFile("paravec", "h");
File tempFileFreqs = DL4JFileUtils.createTempFile("paravec", "h");
tempFileSyn0.deleteOnExit();
tempFileSyn1.deleteOnExit();
tempFileCodes.deleteOnExit();
Expand Down Expand Up @@ -773,11 +774,11 @@ public static ParagraphVectors readParagraphVectors(File file) throws IOExceptio
*/
@Deprecated
public static Word2Vec readWord2Vec(File file) throws IOException {
File tmpFileSyn0 = File.createTempFile("word2vec", "0");
File tmpFileSyn1 = File.createTempFile("word2vec", "1");
File tmpFileC = File.createTempFile("word2vec", "c");
File tmpFileH = File.createTempFile("word2vec", "h");
File tmpFileF = File.createTempFile("word2vec", "f");
File tmpFileSyn0 = DL4JFileUtils.createTempFile("word2vec", "0");
File tmpFileSyn1 = DL4JFileUtils.createTempFile("word2vec", "1");
File tmpFileC = DL4JFileUtils.createTempFile("word2vec", "c");
File tmpFileH = DL4JFileUtils.createTempFile("word2vec", "h");
File tmpFileF = DL4JFileUtils.createTempFile("word2vec", "f");

tmpFileSyn0.deleteOnExit();
tmpFileSyn1.deleteOnExit();
Expand Down Expand Up @@ -896,7 +897,7 @@ public static Word2Vec readWord2Vec(File file) throws IOException {
* @return
*/
public static ParagraphVectors readParagraphVectors(InputStream stream) throws IOException {
File tmpFile = File.createTempFile("restore", "paravec");
File tmpFile = DL4JFileUtils.createTempFile("restore", "paravec");
try {
FileUtils.copyInputStreamToFile(stream, tmpFile);
return readParagraphVectors(tmpFile);
Expand Down Expand Up @@ -2276,9 +2277,9 @@ public static Word2Vec readWord2VecModel(@NonNull File file, boolean extendedMod
} else {
log.debug("Trying simplified model restoration...");

tmpFileSyn0 = File.createTempFile("word2vec", "syn");
tmpFileSyn0 = DL4JFileUtils.createTempFile("word2vec", "syn");
tmpFileSyn0.deleteOnExit();
tmpFileConfig = File.createTempFile("word2vec", "config");
tmpFileConfig = DL4JFileUtils.createTempFile("word2vec", "config");
tmpFileConfig.deleteOnExit();
// we don't need full model, so we go directly to syn0 file

Expand Down Expand Up @@ -2504,7 +2505,7 @@ public static WordVectors loadStaticModel(File file) {
// if zip - that's dl4j format
try {
log.debug("Trying DL4j format...");
File tmpFileSyn0 = File.createTempFile("word2vec", "syn");
File tmpFileSyn0 = DL4JFileUtils.createTempFile("word2vec", "syn");
tmpFileSyn0.deleteOnExit();

ZipFile zipFile = new ZipFile(file);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import org.deeplearning4j.text.sentenceiterator.PrefetchingSentenceIterator;
import org.deeplearning4j.text.sentenceiterator.SentenceIterator;
import org.deeplearning4j.text.sentenceiterator.SynchronizedSentenceIterator;
import org.deeplearning4j.util.DL4JFileUtils;
import org.deeplearning4j.util.ThreadUtils;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.primitives.Pair;
Expand Down Expand Up @@ -289,7 +290,7 @@ public AbstractCoOccurrences<T> build() {
// use temp file, if no target file was specified
try {
if (this.target == null) {
this.target = File.createTempFile("cooccurrence", "map");
this.target = DL4JFileUtils.createTempFile("cooccurrence", "map");
}
this.target.deleteOnExit();
} catch (Exception e) {
Expand Down Expand Up @@ -417,8 +418,8 @@ public ShadowCopyThread() {
counter = new RoundCount(1);
tempFiles = new File[2];

tempFiles[0] = File.createTempFile("aco", "tmp");
tempFiles[1] = File.createTempFile("aco", "tmp");
tempFiles[0] = DL4JFileUtils.createTempFile("aco", "tmp");
tempFiles[1] = DL4JFileUtils.createTempFile("aco", "tmp");

tempFiles[0].deleteOnExit();
tempFiles[1].deleteOnExit();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.io.IOUtils;
import org.apache.commons.io.output.CloseShieldOutputStream;
import org.deeplearning4j.config.DL4JSystemProperties;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.api.Model;
import org.deeplearning4j.nn.api.Updater;
Expand Down Expand Up @@ -721,7 +722,7 @@ public static void addNormalizerToModel(File f, Normalizer<?> normalizer) {
File tempFile = null;
try {
// copy existing model to temporary file
tempFile = File.createTempFile("tempcopy", "temp");
tempFile = DL4JFileUtils.createTempFile("dl4jModelSerializerTemp", "bin");
tempFile.deleteOnExit();
Files.copy(f, tempFile);
try (ZipFile zipFile = new ZipFile(tempFile);
Expand Down Expand Up @@ -776,8 +777,7 @@ public static void addObjectToFile(@NonNull File f, @NonNull String key, @NonNul
File tempFile = null;
try {
// copy existing model to temporary file
tempFile = File.createTempFile("tempcopy", "temp");
tempFile.deleteOnExit();
tempFile = DL4JFileUtils.createTempFile("dl4jModelSerializerTemp", "bin");
Files.copy(f, tempFile);
f.delete();
try (ZipFile zipFile = new ZipFile(tempFile);
Expand Down Expand Up @@ -996,7 +996,8 @@ private static void checkTempFileFromInputStream(File f) throws IOException {

private static File tempFileFromStream(InputStream is) throws IOException{
checkInputStream(is);
File tmpFile = File.createTempFile("dl4jModelSerializer", "bin");
String p = System.getProperty(DL4JSystemProperties.DL4J_TEMP_DIR_PROPERTY);
File tmpFile = DL4JFileUtils.createTempFile("dl4jModelSerializer", "bin");
try {
tmpFile.deleteOnExit();
BufferedOutputStream bufferedOutputStream = new BufferedOutputStream(new FileOutputStream(tmpFile));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
import org.deeplearning4j.ui.storage.FileStatsStorage;
import org.deeplearning4j.ui.storage.InMemoryStatsStorage;
import org.deeplearning4j.ui.storage.impl.QueueStatsStorageListener;
import org.deeplearning4j.util.DL4JFileUtils;
import org.nd4j.linalg.primitives.Pair;
import play.Mode;
import play.api.routing.Router;
Expand Down Expand Up @@ -231,7 +232,7 @@ public void runMain(String[] args) {
}
else if(enableRemote) {
try {
File tempStatsFile = File.createTempFile("dl4j", "UIstats");
File tempStatsFile = DL4JFileUtils.createTempFile("dl4j", "UIstats");
tempStatsFile.delete();
tempStatsFile.deleteOnExit();
enableRemoteListener(new FileStatsStorage(tempStatsFile), true);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.NDArrayIndex;
import org.nd4j.linalg.util.ND4JFileUtils;

import java.io.*;
import java.util.ArrayList;
Expand Down Expand Up @@ -101,7 +102,7 @@ public void run() {
* @throws IOException
*/
public MiniBatchFileDataSetIterator(DataSet baseData, int batchSize, boolean delete) throws IOException {
this(baseData, batchSize, delete, new File(System.getProperty("java.io.tmpdir")));
this(baseData, batchSize, delete, ND4JFileUtils.getTempDir());
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
import org.nd4j.linalg.exception.ND4JIllegalStateException;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.memory.MemoryManager;
import org.nd4j.linalg.util.ND4JFileUtils;

import java.io.BufferedOutputStream;
import java.io.File;
Expand Down Expand Up @@ -182,7 +183,7 @@ public Nd4jWorkspace(@NonNull WorkspaceConfiguration configuration, @NonNull Str
}
} else if (configuration.getInitialSize() > 0) {
try {
tempFile = File.createTempFile("workspace", "tempMMAP");
tempFile = ND4JFileUtils.createTempFile("workspace", "tempMMAP");
tempFile.deleteOnExit();

// fill temp file with zeroes, up to initialSize bytes
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,14 @@ public class ND4JSystemProperties {
* @see #JAVACPP_MEMORY_MAX_BYTES
*/
public static final String JAVACPP_MEMORY_MAX_PHYSICAL_BYTES = "org.bytedeco.javacpp.maxphysicalbytes";

/**
* Applicability: ND4J Temporary file creation/extraction for ClassPathResource, memory mapped workspaces, and <br>
* Description: Specify the local directory where temporary files will be written. If not specified, the default
* Java temporary directory (java.io.tmpdir system property) will generally be used.
*/
public static final String ND4J_TEMP_DIR_PROPERTY = "org.nd4j.tempdir";

/**
* Applicability: always - but only if an ND4J backend cannot be found/loaded via standard ServiceLoader mechanisms<br>
* Description: Set this property to a set fully qualified JAR files to attempt to load before failing on
Expand Down
Loading

0 comments on commit 41dfe0f

Please sign in to comment.