From 1e2bcc1026cc9c6b8117a4eb78ad12a82367ef04 Mon Sep 17 00:00:00 2001 From: Alex Black Date: Wed, 29 May 2019 22:52:12 +1000 Subject: [PATCH] Fixes and SameDiff functionality (#7807) * #6992 SameDiff mixed precision training support * Placeholder shape validation * Checkpoint listener * SameDiff checkpoint listener * SameDiff: Remove no longer required trainable params config from TrainingConfig * SameDiff: add name scopes * SameDiff name scopes - javadoc and tests * #7802 Evaluation class - report single class not macro avg in stats() for binary case * 7804 Arbiter - update score functions to use ND4J evaluation metric enums * SameDiff flatbuffers export: don't export arrays for array type variables (not required) --- .../scoring/impl/EvaluationScoreFunction.java | 9 +- .../scoring/impl/ROCScoreFunction.java | 7 +- .../scoring/impl/RegressionScoreFunction.java | 11 +- .../functions/DifferentialFunction.java | 9 +- .../listeners/checkpoint/Checkpoint.java | 61 ++ .../checkpoint/CheckpointListener.java | 608 ++++++++++++++++++ .../org/nd4j/autodiff/samediff/NameScope.java | 31 + .../nd4j/autodiff/samediff/SDVariable.java | 5 + .../org/nd4j/autodiff/samediff/SameDiff.java | 352 +++++----- .../autodiff/samediff/TrainingConfig.java | 37 +- .../autodiff/samediff/internal/Variable.java | 1 - .../evaluation/classification/Evaluation.java | 12 +- .../java/org/nd4j/linalg/api/shape/Shape.java | 26 + .../memory/abstracts/Nd4jWorkspace.java | 2 +- .../autodiff/samediff/NameScopeTests.java | 132 ++++ .../nd4j/autodiff/samediff/SameDiffTests.java | 75 ++- .../samediff/SameDiffTrainingTest.java | 65 ++ .../listeners/CheckpointListenerTest.java | 232 +++++++ .../java/org/nd4j/evaluation/EvalTest.java | 44 ++ .../linalg/api/buffer/BaseDataBuffer.java | 9 + 20 files changed, 1492 insertions(+), 236 deletions(-) create mode 100644 nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/checkpoint/Checkpoint.java create mode 100644 nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/checkpoint/CheckpointListener.java create mode 100644 nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/NameScope.java create mode 100644 nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/NameScopeTests.java create mode 100644 nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/listeners/CheckpointListenerTest.java diff --git a/arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/scoring/impl/EvaluationScoreFunction.java b/arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/scoring/impl/EvaluationScoreFunction.java index e735a2bd9e9e..7e71425d5d23 100644 --- a/arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/scoring/impl/EvaluationScoreFunction.java +++ b/arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/scoring/impl/EvaluationScoreFunction.java @@ -18,9 +18,9 @@ import lombok.*; import org.deeplearning4j.datasets.iterator.MultiDataSetWrapperIterator; -import org.deeplearning4j.eval.Evaluation; import org.deeplearning4j.nn.graph.ComputationGraph; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; +import org.nd4j.evaluation.classification.Evaluation; import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator; @@ -37,6 +37,13 @@ public class EvaluationScoreFunction extends BaseNetScoreFunction { protected Evaluation.Metric metric; + /** + * @param metric Evaluation metric to calculate + */ + public EvaluationScoreFunction(@NonNull org.deeplearning4j.eval.Evaluation.Metric metric) { + this(metric.toNd4j()); + } + /** * @param metric Evaluation metric to calculate */ diff --git a/arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/scoring/impl/ROCScoreFunction.java b/arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/scoring/impl/ROCScoreFunction.java index 707531729ca7..b8c71d5d1f34 100644 --- a/arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/scoring/impl/ROCScoreFunction.java +++ b/arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/scoring/impl/ROCScoreFunction.java @@ -19,12 +19,11 @@ import lombok.*; import org.deeplearning4j.datasets.iterator.MultiDataSetWrapperIterator; import org.deeplearning4j.datasets.iterator.impl.MultiDataSetIteratorAdapter; -import org.deeplearning4j.eval.Evaluation; -import org.deeplearning4j.eval.ROC; -import org.deeplearning4j.eval.ROCBinary; -import org.deeplearning4j.eval.ROCMultiClass; import org.deeplearning4j.nn.graph.ComputationGraph; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; +import org.nd4j.evaluation.classification.ROC; +import org.nd4j.evaluation.classification.ROCBinary; +import org.nd4j.evaluation.classification.ROCMultiClass; import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator; diff --git a/arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/scoring/impl/RegressionScoreFunction.java b/arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/scoring/impl/RegressionScoreFunction.java index c9c550513393..fdcc6934e6c6 100644 --- a/arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/scoring/impl/RegressionScoreFunction.java +++ b/arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/scoring/impl/RegressionScoreFunction.java @@ -17,18 +17,13 @@ package org.deeplearning4j.arbiter.scoring.impl; import lombok.*; -import org.deeplearning4j.arbiter.optimize.api.data.DataSource; -import org.deeplearning4j.arbiter.scoring.RegressionValue; -import org.deeplearning4j.arbiter.scoring.util.ScoreUtil; import org.deeplearning4j.datasets.iterator.MultiDataSetWrapperIterator; -import org.deeplearning4j.eval.RegressionEvaluation; import org.deeplearning4j.nn.graph.ComputationGraph; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; +import org.nd4j.evaluation.regression.RegressionEvaluation; import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator; -import java.util.Properties; - /** * Score function for regression (including multi-label regression) for a MultiLayerNetwork or ComputationGraph * on a test set. Supports all regression metrics: {@link RegressionEvaluation.Metric} @@ -42,6 +37,10 @@ public class RegressionScoreFunction extends BaseNetScoreFunction { protected RegressionEvaluation.Metric metric; + public RegressionScoreFunction(@NonNull org.deeplearning4j.eval.RegressionEvaluation.Metric metric) { + this(metric.toNd4j()); + } + public RegressionScoreFunction(@NonNull RegressionEvaluation.Metric metric) { this.metric = metric; } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/functions/DifferentialFunction.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/functions/DifferentialFunction.java index a3f40797ef21..680e2772bef5 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/functions/DifferentialFunction.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/functions/DifferentialFunction.java @@ -644,9 +644,14 @@ protected void setInstanceId() { this.ownName = UUID.randomUUID().toString(); else { int argIndex = 0; - String varName = sameDiff.generateNewVarName(opName(),argIndex); + String scope = sameDiff.currentNameScope(); + if(scope == null) + scope = ""; + else + scope = scope + "/"; + String varName = scope + sameDiff.generateNewVarName(opName(),argIndex); while(sameDiff.functionExists(varName)) { - varName = sameDiff.generateNewVarName(opName(), argIndex); + varName = scope + sameDiff.generateNewVarName(opName(), argIndex); argIndex++; } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/checkpoint/Checkpoint.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/checkpoint/Checkpoint.java new file mode 100644 index 000000000000..0c0a1429d607 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/checkpoint/Checkpoint.java @@ -0,0 +1,61 @@ +/******************************************************************************* + * Copyright (c) 2015-2018 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +package org.nd4j.autodiff.listeners.checkpoint; + +import lombok.AllArgsConstructor; +import lombok.Data; + +import java.io.Serializable; +import java.util.Arrays; + +/** + * A model checkpoint, used with {@link CheckpointListener} + * + * @author Alex Black + */ +@AllArgsConstructor +@Data +public class Checkpoint implements Serializable { + + private int checkpointNum; + private long timestamp; + private int iteration; + private int epoch; + private String filename; + + public static String getFileHeader(){ + return "checkpointNum,timestamp,iteration,epoch,filename"; + } + + public static Checkpoint fromFileString(String str){ + String[] split = str.split(","); + if(split.length != 5){ + throw new IllegalStateException("Cannot parse checkpoint entry: expected 5 entries, got " + split.length + + " - values = " + Arrays.toString(split)); + } + return new Checkpoint( + Integer.parseInt(split[0]), + Long.parseLong(split[1]), + Integer.parseInt(split[2]), + Integer.parseInt(split[3]), + split[4]); + } + + public String toFileString(){ + return checkpointNum + "," + timestamp + "," + iteration + "," + epoch + "," + filename; + } +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/checkpoint/CheckpointListener.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/checkpoint/CheckpointListener.java new file mode 100644 index 000000000000..7cd12ba2e449 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/checkpoint/CheckpointListener.java @@ -0,0 +1,608 @@ +package org.nd4j.autodiff.listeners.checkpoint; + + +import com.google.common.io.Files; +import lombok.NonNull; +import lombok.extern.slf4j.Slf4j; +import org.apache.commons.io.IOUtils; +import org.nd4j.autodiff.listeners.At; +import org.nd4j.autodiff.listeners.BaseListener; +import org.nd4j.autodiff.listeners.Loss; +import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.base.Preconditions; +import org.nd4j.linalg.dataset.api.MultiDataSet; + +import java.io.*; +import java.nio.charset.Charset; +import java.nio.charset.StandardCharsets; +import java.text.SimpleDateFormat; +import java.util.*; +import java.util.concurrent.TimeUnit; + +/** + * + * CheckpointListener: The goal of this listener is to periodically save a copy of the model during training..
+ * Model saving may be done:
+ * 1. Every N epochs
+ * 2. Every N iterations
+ * 3. Every T time units (every 15 minutes, for example)
+ * Or some combination of the 3.
+ *
+ * Models can be restored using {@link #loadCheckpoint(File, int)}, {@link #loadLastCheckpoint(File)} and {@link #loadCheckpoint(int)}. + * Checkpoints can be obtained using {@link #lastCheckpoint()} and {@link #availableCheckpoints()} + *
+ * Example 1: Saving a checkpoint every 2 epochs, keep all model files + *
+ * {@code CheckpointListener l = new CheckpointListener.Builder("/save/directory")
+ *       .keepAll() //Don't delete any models
+ *       .saveEveryNEpochs(2)
+ *       .build()
+ * }
+ * 
+ *
+ * Example 2: Saving a checkpoint every 1000 iterations, but keeping only the last 3 models (all older model + * files will be automatically deleted) + *
+ * {@code CheckpointListener l = new CheckpointListener.Builder(new File("/save/directory"))
+ *          .keepLast(3)
+ *          .saveEveryNIterations(1000)
+ *          .build();
+ * }
+ * 
+ *
+ * Example 3: Saving a checkpoint every 15 minutes, keeping the most recent 3 and otherwise every 4th checkpoint + * file: + *
+ * {@code CheckpointListener l = new CheckpointListener.Builder(new File("/save/directory"))
+ *          .keepLastAndEvery(3, 4)
+ *          .saveEvery(15, TimeUnit.MINUTES)
+ *          .build();
+ * }
+ * 
+ *
+ * Note that you can mix these: for example, to save every epoch and every 15 minutes (independent of last save time):
+ * {@code .saveEveryEpoch().saveEvery(15, TimeUnit.MINUTES)}
+ * To save every epoch, and every 15 minutes, since the last model save use:
+ * {@code .saveEveryEpoch().saveEvery(15, TimeUnit.MINUTES, true)}
+ * Note that is this last example, the sinceLast parameter is true. This means the 15-minute counter will be + * reset any time a model is saved.
+ * + * @author Alex Black + */ +@Slf4j +public class CheckpointListener extends BaseListener implements Serializable { + + private enum KeepMode {ALL, LAST, LAST_AND_EVERY}; + + private File rootDir; + private String fileNamePrefix; + private KeepMode keepMode; + private int keepLast; + private int keepEvery; + private boolean logSaving; + private boolean deleteExisting; + + private Integer saveEveryNEpochs; + private Integer saveEveryNIterations; + private boolean saveEveryNIterSinceLast; + private Long saveEveryAmount; + private TimeUnit saveEveryUnit; + private Long saveEveryMs; + private boolean saveEverySinceLast; + + private int lastCheckpointNum = -1; + private File checkpointRecordFile; + + private Checkpoint lastCheckpoint; + private long startTime = -1; + private int startIter = -1; + private Long lastSaveEveryMsNoSinceLast; + + private CheckpointListener(Builder builder){ + this.rootDir = builder.rootDir; + this.fileNamePrefix = builder.fileNamePrefix; + this.keepMode = builder.keepMode; + this.keepLast = builder.keepLast; + this.keepEvery = builder.keepEvery; + this.logSaving = builder.logSaving; + this.deleteExisting = builder.deleteExisting; + + this.saveEveryNEpochs = builder.saveEveryNEpochs; + this.saveEveryNIterations = builder.saveEveryNIterations; + this.saveEveryNIterSinceLast = builder.saveEveryNIterSinceLast; + this.saveEveryAmount = builder.saveEveryAmount; + this.saveEveryUnit = builder.saveEveryUnit; + this.saveEverySinceLast = builder.saveEverySinceLast; + + if(saveEveryAmount != null){ + saveEveryMs = TimeUnit.MILLISECONDS.convert(saveEveryAmount, saveEveryUnit); + } + + this.checkpointRecordFile = new File(rootDir, "checkpointInfo.txt"); + if(this.checkpointRecordFile.exists() && this.checkpointRecordFile.length() > 0){ + + if(deleteExisting){ + //Delete any files matching: + //"checkpoint_" + checkpointNum + "_" + modelType + ".zip"; + this.checkpointRecordFile.delete(); + File[] files = rootDir.listFiles(); + if(files != null && files.length > 0){ + for(File f : files){ + String name = f.getName(); + if(name.startsWith("checkpoint_") && (name.endsWith("MultiLayerNetwork.zip") || name.endsWith("ComputationGraph.zip"))){ + f.delete(); + } + } + } + } else { + throw new IllegalStateException("Detected existing checkpoint files at directory " + rootDir.getAbsolutePath() + + ". Use deleteExisting(true) to delete existing checkpoint files when present."); + } + } + } + + @Override + public void epochEnd(SameDiff sameDiff, At at) { + if(saveEveryNEpochs != null && (at.epoch()+1) % saveEveryNEpochs == 0){ + //Save: + saveCheckpoint(sameDiff, at); + } + //General saving conditions: don't need to check here - will check in iterationDone + } + + @Override + public void iterationDone(SameDiff sd, At at, MultiDataSet dataSet, Loss loss) { + if (startTime < 0) { + startTime = System.currentTimeMillis(); + startIter = at.iteration(); + return; + } + + //Check iterations saving condition: + if(saveEveryNIterations != null){ + if(saveEveryNIterSinceLast){ + //Consider last saved model when deciding whether to save + long lastSaveIter = (lastCheckpoint != null ? lastCheckpoint.getIteration() : startIter); + if(at.iteration() - lastSaveIter >= saveEveryNIterations){ + saveCheckpoint(sd, at); + return; + } + } else { + //Same every N iterations, regardless of saving time + if((at.iteration()+1) % saveEveryNIterations == 0){ + saveCheckpoint(sd, at); + return; + } + } + } + + //Check time saving condition: + long time = System.currentTimeMillis(); + if(saveEveryUnit != null){ + if(saveEverySinceLast){ + //Consider last saved when deciding whether to save + long lastSaveTime = (lastCheckpoint != null ? lastCheckpoint.getTimestamp() : startTime); + if((time - lastSaveTime) >= saveEveryMs){ + saveCheckpoint(sd, at); + return; + } + } else { + //Save periodically, regardless of when last model was saved + long lastSave = (lastSaveEveryMsNoSinceLast != null ? lastSaveEveryMsNoSinceLast : startTime); + if((time - lastSave) > saveEveryMs){ + saveCheckpoint(sd, at); + lastSaveEveryMsNoSinceLast = time; + return; + } + } + } + } + + private void saveCheckpoint(SameDiff sd, At at) { + try{ + saveCheckpointHelper(sd, at); + } catch (Exception e){ + throw new RuntimeException("Error saving checkpoint", e); + } + } + + private void saveCheckpointHelper(SameDiff model, At at) throws Exception { + if(!checkpointRecordFile.exists()){ + checkpointRecordFile.createNewFile(); + writeCheckpointInfo(Checkpoint.getFileHeader() + "\n", checkpointRecordFile); + } + + Checkpoint c = new Checkpoint(++lastCheckpointNum, System.currentTimeMillis(), at.iteration(), at.epoch(),null); + String filename = getFileName(lastCheckpointNum, at, c.getTimestamp()); + c.setFilename(filename); + + File saveFile = new File(rootDir, c.getFilename()); + model.asFlatFile(saveFile); + + String s = c.toFileString(); + writeCheckpointInfo(s + "\n", checkpointRecordFile); + + if(logSaving){ + log.info("Model checkpoint saved: epoch {}, iteration {}, path: {}", c.getEpoch(), c.getIteration(), + new File(rootDir, c.getFilename()).getPath() ); + } + this.lastCheckpoint = c; + + + //Finally: determine if we should delete some old models... + if(keepMode == null || keepMode == KeepMode.ALL){ + return; + } else if(keepMode == KeepMode.LAST){ + List checkpoints = availableCheckpoints(); + Iterator iter = checkpoints.iterator(); + while(checkpoints.size() > keepLast){ + Checkpoint toRemove = iter.next(); + File f = getFileForCheckpoint(toRemove); + f.delete(); + iter.remove(); + } + } else { + //Keep mode: last N and every M + for(Checkpoint cp : availableCheckpoints()){ + if(cp.getCheckpointNum() > 0 && (cp.getCheckpointNum()+1) % keepEvery == 0){ + //One of the "every M to keep" models + continue; + } else if(cp.getCheckpointNum() > lastCheckpointNum - keepLast ){ //Example: latest is 5, keep last 2 -> keep checkpoints 4 and 5 + //One of last N to keep + continue; + } + //Otherwise: delete file + File f = getFileForCheckpoint(cp); + f.delete(); + } + } + } + + //Filename format: "_checkpoint-#_epoch-#_iter-#_YYYY-MM-dd_HH-MM-ss.bin" + private String getFileName(int checkpointNum, At at, long time){ + StringBuilder sb = new StringBuilder(); + if(fileNamePrefix != null){ + sb.append(fileNamePrefix); + if(!fileNamePrefix.endsWith("_")){ + sb.append("_"); + } + } + sb.append("checkpoint-") + .append(checkpointNum) + .append("_epoch-").append(at.epoch()) + .append("_iter-").append(at.iteration()); + + SimpleDateFormat sdf = new SimpleDateFormat("YYYY-MM-dd_HH-mm-ss"); + String date = sdf.format(new Date(time)); + sb.append("_").append(date) + .append(".bin"); + + return sb.toString(); + } + + private static String writeCheckpointInfo(String str, File f){ + try { + if(!f.exists()){ + f.createNewFile(); + } + Files.append(str, f, StandardCharsets.UTF_8); + } catch (IOException e){ + throw new RuntimeException(e); + } + return str; + } + + /** + * List all available checkpoints. A checkpoint is 'available' if the file can be loaded. Any checkpoint files that + * have been automatically deleted (given the configuration) will not be returned here. + * + * @return List of checkpoint files that can be loaded + */ + public List availableCheckpoints(){ + if(!checkpointRecordFile.exists()){ + return Collections.emptyList(); + } + + return availableCheckpoints(rootDir); + } + + /** + * List all available checkpoints. A checkpoint is 'available' if the file can be loaded. Any checkpoint files that + * have been automatically deleted (given the configuration) will not be returned here. + * Note that the checkpointInfo.txt file must exist, as this stores checkpoint information + * + * @return List of checkpoint files that can be loaded from the specified directory + */ + public static List availableCheckpoints(File directory){ + File checkpointRecordFile = new File(directory, "checkpointInfo.txt"); + Preconditions.checkState(checkpointRecordFile.exists(), "Could not find checkpoint record file at expected path %s", checkpointRecordFile.getAbsolutePath()); + + List lines; + try(InputStream is = new BufferedInputStream(new FileInputStream(checkpointRecordFile))){ + lines = IOUtils.readLines(is); + } catch (IOException e){ + throw new RuntimeException("Error loading checkpoint data from file: " + checkpointRecordFile.getAbsolutePath(), e); + } + + List out = new ArrayList<>(lines.size()-1); //Assume first line is header + for( int i=1; i all = availableCheckpoints(rootDir); + if(all.isEmpty()){ + return null; + } + return all.get(all.size()-1); + } + + /** + * Get the model file for the given checkpoint. Checkpoint model file must exist + * + * @param checkpoint Checkpoint to get the model file for + * @return Model file for the checkpoint + */ + public File getFileForCheckpoint(Checkpoint checkpoint){ + return getFileForCheckpoint(checkpoint.getCheckpointNum()); + } + + /** + * Get the model file for the given checkpoint number. Checkpoint model file must exist + * + * @param checkpointNum Checkpoint number to get the model file for + * @return Model file for the checkpoint + */ + public File getFileForCheckpoint(int checkpointNum) { + return getFileForCheckpoint(rootDir, checkpointNum); + } + + public static File getFileForCheckpoint(File rootDir, int checkpointNum){ + //Scan the root directory, for a file matching the checkpoint filename pattern: + //Filename format: "_checkpoint-#_epoch-#_iter-#_YYYY-MM-dd_HH-MM-ss.bin" + + if(checkpointNum < 0){ + throw new IllegalArgumentException("Invalid checkpoint number: " + checkpointNum); + } + + String contains = "_checkpoint-" + checkpointNum + "_epoch-"; + + File[] allFiles = rootDir.listFiles(); + if(allFiles != null){ + for(File f : allFiles){ + if(f.getAbsolutePath().contains(contains)){ + return f; + } + } + } + + throw new IllegalStateException("Model file for checkpoint " + checkpointNum + " does not exist"); + } + + /** + * Load a given checkpoint number + * + */ + public SameDiff loadCheckpoint(int checkpointNum){ + return loadCheckpoint(rootDir, checkpointNum); + } + + /** + * Load a SameDiff instance for the given checkpoint that resides in the specified root directory + * + * @param rootDir Directory that the checkpoint resides in + * @param checkpointNum Checkpoint model number to load + * @return The loaded model + */ + public static SameDiff loadCheckpoint(File rootDir, int checkpointNum){ + File f = getFileForCheckpoint(rootDir, checkpointNum); + try { + return SameDiff.fromFlatFile(f); + } catch (IOException e){ + throw new RuntimeException("Error loading checkpoint " + checkpointNum + " from root directory " + rootDir.getAbsolutePath(), e); + } + } + + /** + * Load the last (most recent) checkpoint from the specified root directory + * @param rootDir Root directory to load checpoint from + * @return ComputationGraph for last checkpoint + */ + public static SameDiff loadLastCheckpoint(File rootDir){ + Checkpoint last = lastCheckpoint(rootDir); + return loadCheckpoint(rootDir, last.getCheckpointNum()); + } + + public static Builder builder(@NonNull File rootDir){ + return new Builder(rootDir); + } + + public static class Builder { + + private File rootDir; + private String fileNamePrefix = "SameDiff"; + private KeepMode keepMode; + private int keepLast; + private int keepEvery; + private boolean logSaving = true; + private boolean deleteExisting = false; + + private Integer saveEveryNEpochs; + private Integer saveEveryNIterations; + private boolean saveEveryNIterSinceLast; + private Long saveEveryAmount; + private TimeUnit saveEveryUnit; + private boolean saveEverySinceLast; + + /** + * @param rootDir Root directory to save models to + */ + public Builder(@NonNull String rootDir){ + this(new File(rootDir)); + } + + /** + * @param rootDir Root directory to save models to + */ + public Builder(@NonNull File rootDir){ + this.rootDir = rootDir; + } + + public Builder fileNamePrefix(String fileNamePrefix){ + this.fileNamePrefix = fileNamePrefix; + return this; + } + + /** + * Save a model at the end of every epoch + */ + public Builder saveEveryEpoch(){ + return saveEveryNEpochs(1); + } + + /** + * Save a model at the end of every N epochs + */ + public Builder saveEveryNEpochs(int n){ + this.saveEveryNEpochs = n; + return this; + } + + /** + * Save a model every N iterations + */ + public Builder saveEveryNIterations(int n){ + return saveEveryNIterations(n, false); + } + + /** + * Save a model every N iterations (if sinceLast == false), or if N iterations have passed since + * the last model vas saved (if sinceLast == true) + */ + public Builder saveEveryNIterations(int n, boolean sinceLast){ + this.saveEveryNIterations = n; + this.saveEveryNIterSinceLast = sinceLast; + return this; + } + + /** + * Save a model periodically + * + * @param amount Quantity of the specified time unit + * @param timeUnit Time unit + */ + public Builder saveEvery(long amount, TimeUnit timeUnit){ + return saveEvery(amount, timeUnit, false); + } + + /** + * Save a model periodically (if sinceLast == false), or if the specified amount of time has elapsed since + * the last model was saved (if sinceLast == true) + * + * @param amount Quantity of the specified time unit + * @param timeUnit Time unit + */ + public Builder saveEvery(long amount, TimeUnit timeUnit, boolean sinceLast){ + this.saveEveryAmount = amount; + this.saveEveryUnit = timeUnit; + this.saveEverySinceLast = sinceLast; + return this; + } + + /** + * Keep all model checkpoints - i.e., don't delete any. Note that this is the default. + */ + public Builder keepAll(){ + this.keepMode = KeepMode.ALL; + return this; + } + + /** + * Keep only the last N most recent model checkpoint files. Older checkpoints will automatically be deleted. + * @param n Number of most recent checkpoints to keep + */ + public Builder keepLast(int n){ + if(n <= 0){ + throw new IllegalArgumentException("Number of model files to keep should be > 0 (got: " + n + ")"); + } + this.keepMode = KeepMode.LAST; + this.keepLast = n; + return this; + } + + /** + * Keep the last N most recent model checkpoint files, and every M checkpoint files.
+ * For example: suppose you save every 100 iterations, for 2050 iteration, and use keepLastAndEvery(3,5). + * This means after 2050 iterations you would have saved 20 checkpoints - some of which will be deleted. + * Those remaining in this example: iterations 500, 1000, 1500, 1800, 1900, 2000. + * @param nLast Most recent checkpoints to keep + * @param everyN Every N checkpoints to keep (regardless of age) + */ + public Builder keepLastAndEvery(int nLast, int everyN){ + if(nLast <= 0){ + throw new IllegalArgumentException("Most recent number of model files to keep should be > 0 (got: " + + nLast + ")"); + } + if(everyN <= 0){ + throw new IllegalArgumentException("Every n model files to keep should be > 0 (got: " + + everyN + ")"); + } + + this.keepMode = KeepMode.LAST_AND_EVERY; + this.keepLast = nLast; + this.keepEvery = everyN; + return this; + } + + /** + * If true (the default) log a message every time a model is saved + * + * @param logSaving Whether checkpoint saves should be logged or not + */ + public Builder logSaving(boolean logSaving){ + this.logSaving = logSaving; + return this; + } + + /** + * If the checkpoint listener is set to save to a non-empty directory, should the CheckpointListener-related + * content be deleted?
+ * This is disabled by default (and instead, an exception will be thrown if existing data is found)
+ * WARNING: Be careful when enabling this, as it deletes all saved checkpoint models in the specified directory! + */ + public Builder deleteExisting(boolean deleteExisting){ + this.deleteExisting = deleteExisting; + return this; + } + + public CheckpointListener build(){ + if(saveEveryNEpochs == null && saveEveryAmount == null && saveEveryNIterations == null){ + throw new IllegalStateException("Cannot construct listener: no models will be saved (must use at least" + + " one of: save every N epochs, every N iterations, or every T time periods)"); + } + + return new CheckpointListener(this); + } + } +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/NameScope.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/NameScope.java new file mode 100644 index 000000000000..aca2ce9cbcc8 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/NameScope.java @@ -0,0 +1,31 @@ +package org.nd4j.autodiff.samediff; + +import lombok.Data; + +import java.io.Closeable; + +/** + * Used with {@link SameDiff#withNameScope(String)} + * + * @author Alex Black + */ +@Data +public class NameScope implements Closeable { + private final SameDiff sameDiff; + private final String name; + + public NameScope(SameDiff sameDiff, String name){ + this.sameDiff = sameDiff; + this.name = name; + } + + @Override + public void close() { + sameDiff.closeNameScope(this); + } + + @Override + public String toString(){ + return "NameScope(" + name + ")"; + } +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/SDVariable.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/SDVariable.java index 4a5915afcf58..28aafe2d3e1f 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/SDVariable.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/SDVariable.java @@ -90,6 +90,11 @@ public SDVariable(@NonNull String varName, @NonNull VariableType varType, @NonNu " SDVariables - variable \"%s\" is of type %s but was provided a weight initialization scheme %s", varName, varType, weightInitScheme); Preconditions.checkState(dataType != DataType.UNKNOWN, "Unknown datatype is not allowed for SDVariables (variable name: %s)", varName); + String nameScope = sameDiff.currentNameScope(); + if(nameScope != null){ + varName = nameScope + "/" + varName; + } + this.varName = varName; this.variableType = varType; this.dataType = dataType; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/SameDiff.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/SameDiff.java index 4e9041bad774..b926839d427d 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/SameDiff.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/SameDiff.java @@ -72,7 +72,6 @@ import org.nd4j.linalg.exception.ND4JIllegalStateException; import org.nd4j.linalg.exception.ND4UnresolvedOutputVariables; import org.nd4j.linalg.factory.Nd4j; -import org.nd4j.linalg.indexing.NDArrayIndex; import org.nd4j.linalg.learning.GradientUpdater; import org.nd4j.linalg.learning.regularization.Regularization; import org.nd4j.linalg.primitives.AtomicBoolean; @@ -126,6 +125,8 @@ public class SameDiff extends SDBaseOps { private List listeners = new ArrayList<>(); + private final List nameScopes = new ArrayList<>(); //Used as a stack + /////////////////////////////////////// //Fields related to training @Getter @@ -133,9 +134,7 @@ public class SameDiff extends SDBaseOps { @Getter private boolean initializedTraining; //True if training setup has been done @Getter - private INDArray updaterState; //Updater state array (1d, length equal to number of trainable parameters) - @Getter - private Map updaterViews; //Views of updaterState array for each trainable parameter + private Map updaterStates; //Updater state array (as vector, before splitting/reshaping) for each trainable parameter @Getter private Map updaterMap; //GradientUpdater instance for each trainable parameter @@ -438,6 +437,94 @@ public void addListeners(Collection listeners){ this.listeners.addAll(listeners); } + /** + * @return The current name scope, if any (null otherwise). See {@link #withNameScope(String)} for more details. + */ + public String currentNameScope(){ + if(nameScopes.isEmpty()) + return null; + + //Would use String.join but that is Java 8+ + StringBuilder sb = new StringBuilder(); + boolean first = true; + for(NameScope ns : nameScopes){ + if(!first){ + sb.append("/"); + } + sb.append(ns.getName()); + first = false; + } + return sb.toString(); + } + + /** + * @return The name with the current name scope (if any) appended. See {@link #withNameScope(String)} + */ + protected String nameWithScope(String name){ + String scope = currentNameScope(); + if(scope == null){ + return name; + } + return scope + "/" + name; + } + + //Intentionally package private + void addNameScope(NameScope nameScope){ + nameScopes.add(nameScope); + } + + //Intentionally package private + void closeNameScope(NameScope nameScope){ + //Check that the name scope is closed correctly/in order + Preconditions.checkState(!nameScopes.isEmpty(), "Cannot close name scope: no name scopes are currently defined"); + Preconditions.checkState(nameScopes.get(nameScopes.size()-1).equals(nameScope), + "Cannot close name scope %s: Name scopes must be closed in order. Current name scopes: \"%s\"", nameScope, currentNameScope()); + + nameScopes.remove(nameScopes.size()-1); + } + + /** + * Create a name scope. Name scopes append a prefix to the names of any variables and ops created while they are open. + *
+     *  {@code
+     *  SameDiff sd = SameDiff.create();
+     *  SDVariable x = sd.var("x", DataType.FLOAT, 5);
+     *  SDVariable y;
+     *  try(NameScope ns = sd.withNameScope("myScope"){
+     *      y = sd.var("y", DataType.FLOAT, 5);
+     *  }
+     *  SDVariable z = sd.var("z", DataType.FLOAT, 5);
+     *
+     *  String xName = x.getVarName();      //RESULT: "x"
+     *  String yName = y.getVarName();      //RESULT: "myScope/y"
+     *  String zName = z.getVarName();      //RESULT: "z"
+     *  }
+     * 
+ * + * Note that name scopes can also be nested: + *
+     *  {@code
+     *  SameDiff sd = SameDiff.create();
+     *  SDVariable x;
+     *  try(NameScope ns = sd.withNameScope("first"){
+     *      try(NameScope ns2 = sd.withNameScope("second"){
+     *          x = sd.var("x", DataType.FLOAT, 5);
+     *      }
+     *  }
+     *  String xName = x.getVarName();      //RESULT: "first/second/x"
+     *  }
+     * 
+ * + * + * @param nameScope Name of the name scope to open/create + * @return The NameScope object + */ + public NameScope withNameScope(String nameScope){ + NameScope ns = new NameScope(this, nameScope); + addNameScope(ns); + return ns; + } + /** * @param sameDiff @@ -827,7 +914,6 @@ public void associateArrayWithVariable(INDArray arr, SDVariable variable) { Preconditions.checkState(variable.dataType() == arr.dataType(), "Variable \"%s\" has datatype %s: cannot associate array with type %s with this variable", variable.getVarName(), variable.dataType(), arr.dataType()); - // FIXME: remove this before release if (sessions.get(Thread.currentThread().getId()) == null) { sessions.put(Thread.currentThread().getId(), new InferenceSession(this)); } @@ -865,6 +951,13 @@ public void associateArrayWithVariable(INDArray arr, SDVariable variable) { session.getNodeOutputs().put(varId, arr); //throw new UnsupportedOperationException("Cannot associate array with SDVariable of type ARRAY"); case PLACEHOLDER: + //Validate placeholder shapes: + long[] phShape = variable.placeholderShape(); + Preconditions.checkState(Shape.shapeMatchesPlaceholder(phShape, arr.shape()), + "Invalid array shape: cannot associate an array with shape %ndShape with a placeholder of shape %s:" + + "shape is wrong rank or does not match on one or more dimensions", arr, phShape); + + long tid = Thread.currentThread().getId(); if(!placeholdersPerThread.containsKey(tid)){ placeholdersPerThread.put(tid, new HashMap()); @@ -1600,10 +1693,15 @@ protected synchronized void fitHelper(MultiDataSetIterator iter, int numEpochs, int iteration = trainingConfig.getIterationCount(); int e = trainingConfig.getEpochCount(); - for (String s : trainingConfig.getTrainableParams()) { - //TODO fix using inference session - INDArray param = variables.get(s).getVariable().getArr(); - SDVariable gradVar = variables.get(s).getVariable().getGradient(); + for(Variable v : variables.values()){ + //Only update trainable params - float type parameters (variable type vars) + SDVariable sdv = v.getVariable(); + if(sdv.getVariableType() != VariableType.VARIABLE || !sdv.dataType().isFPType()) + continue; + + + INDArray param = sdv.getArr(); + SDVariable gradVar = sdv.getGradient(); if(gradVar == null){ //Not all trainable parameters have gradients defined. //Consider graph: in1->loss1; in2->loss2, where we optimize only loss1. @@ -1629,13 +1727,13 @@ protected synchronized void fitHelper(MultiDataSetIterator iter, int numEpochs, //Apply updater. Note that we need to reshape to [1,length] for updater INDArray reshapedView = Shape.newShapeNoCopy(grad, new long[]{1, grad.length()}, grad.ordering() == 'f'); //TODO make sure we always reshape in same order! - Preconditions.checkState(reshapedView != null, "Error reshaping array for parameter \"%s\": array is a view?", s); - GradientUpdater u = updaterMap.get(s); + Preconditions.checkState(reshapedView != null, "Error reshaping array for parameter \"%s\": array is a view?", sdv); + GradientUpdater u = updaterMap.get(sdv.getVarName()); try { u.applyUpdater(reshapedView, iteration, e); } catch (Throwable t) { - throw new RuntimeException("Error applying updater " + u.getClass().getSimpleName() + " to parameter \"" + s - + "\": either parameter size is inconsistent between iterations, or \"" + s + "\" should not be a trainable parameter?", t); + throw new RuntimeException("Error applying updater " + u.getClass().getSimpleName() + " to parameter \"" + sdv.getVarName() + + "\": either parameter size is inconsistent between iterations, or \"" + sdv.getVarName() + "\" should not be a trainable parameter?", t); } //Post-apply regularization (weight decay) @@ -1656,7 +1754,7 @@ protected synchronized void fitHelper(MultiDataSetIterator iter, int numEpochs, if(hasListeners){ for(Listener l : listeners){ - l.preUpdate(this, at, variables.get(s), reshapedView); + l.preUpdate(this, at, v, reshapedView); } } @@ -1703,10 +1801,6 @@ protected synchronized void fitHelper(MultiDataSetIterator iter, int numEpochs, trainingConfig.incrementIterationCount(); } - if(i < numEpochs - 1) { - iter.reset(); - } - if(incrementEpochCount) { if(hasListeners){ for(Listener l : listeners){ @@ -1715,6 +1809,10 @@ protected synchronized void fitHelper(MultiDataSetIterator iter, int numEpochs, } trainingConfig.incrementEpochCount(); } + + if(i < numEpochs - 1) { + iter.reset(); + } } } @@ -1733,14 +1831,16 @@ public double calcRegularizationScore() { return 0.0; } - if(trainingConfig.getTrainableParams() == null || trainingConfig.getTrainableParams().isEmpty()) - initializeTraining(); - List l = trainingConfig.getRegularization(); double loss = 0.0; - for (String s : trainingConfig.getTrainableParams()) { + for(Variable v : variables.values()){ + SDVariable sdv = v.getVariable(); + if(sdv.getVariableType() != VariableType.VARIABLE || !sdv.dataType().isFPType()){ + //Only trainable parameters (FP and variable type vars) contribute to regularization score + continue; + } for(Regularization r : l){ - INDArray arr = getVariable(s).getArr(); + INDArray arr = sdv.getArr(); loss += r.score(arr, trainingConfig.getIterationCount(), trainingConfig.getEpochCount()); } } @@ -1757,62 +1857,24 @@ protected void initializeTraining(){ if(trainingConfig == null) { throw new ND4JIllegalStateException("Please specify a training config with setTrainingConfig"); } - //First: infer the variables to be optimized if required - if(trainingConfig.getTrainableParams() == null || trainingConfig.getTrainableParams().size() == 0) { - //Variable is trainable if it's not the output of some function - //TODO also - should be floating point type - List trainVarList = new ArrayList<>(); - for(Variable var : variables.values()){ - SDVariable v = var.getVariable(); - String n = v.getVarName(); - if(variables.get(n).getOutputOfOp() == null && //Is a leaf (not the output of a function) - !isPlaceHolder(n) && //and not a placeholder - !variables.get(n).getVariable().isConstant() && //and not a constant - (trainingConfig.getDataSetFeatureMapping() == null || !trainingConfig.getDataSetFeatureMapping().contains(n)) && //and not an input (this really should be a placeholder, but we can't guarantee that...) - (trainingConfig.getDataSetLabelMapping() == null || !trainingConfig.getDataSetLabelMapping().contains(n)) && //and not a label (this really should be a placeholder, but we can't guarantee that...) - (trainingConfig.getDataSetFeatureMaskMapping() == null || !trainingConfig.getDataSetFeatureMaskMapping().contains(n)) && //and not a feature mask (this really should be a placeholder, but we can't guarantee that...) - (trainingConfig.getDataSetLabelMaskMapping() == null || !trainingConfig.getDataSetLabelMaskMapping().contains(n))){ //and not a label input (this really should be a placeholder, but we can't guarantee that...) - trainVarList.add(n); - } + updaterStates = new HashMap<>(); + updaterMap = new HashMap<>(); + for(Variable v : variables.values()){ + if(v.getVariable().getVariableType() != VariableType.VARIABLE || !v.getVariable().dataType().isFPType()){ + //Skip non-trainable parameters + continue; } - trainingConfig.setTrainableParams(trainVarList); - log.info("Inferred trainable variables: {}", trainVarList); - } - - //Allocate updater state - long numTrainableParams = 0; - DataType dt = null; //TODO support mixed precision variables - https://github.com/deeplearning4j/deeplearning4j/issues/6992 - for(String s : trainingConfig.getTrainableParams()) { - SDVariable v = variables.get(s).getVariable(); - Preconditions.checkState(v != null, "No variable found for trainable parameter name \"%s\"", s); - - INDArray arr = v.getArr(); - Preconditions.checkState(arr != null, "No array found for trainable parameter \"%s\"", s); - numTrainableParams += arr.length(); - if(dt == null) - dt = arr.dataType(); - } - - long updaterStateSize = trainingConfig.getUpdater().stateSize(numTrainableParams); - - if(updaterStateSize > 0) { - try(MemoryWorkspace ws = Nd4j.getMemoryManager().scopeOutOfWorkspaces()) { - updaterState = Nd4j.createUninitialized(dt, 1, updaterStateSize); + INDArray arr = v.getVariable().getArr(); + long stateSize = trainingConfig.getUpdater().stateSize(arr.length()); + if(stateSize == 0){ + //Updater has no state array (such as SGD or No-Op updaters + continue; } - } + INDArray view = Nd4j.createUninitialized(arr.dataType(), 1, stateSize); - long viewSoFar = 0; - updaterViews = new HashMap<>(); - updaterMap = new HashMap<>(); - for(String s : trainingConfig.getTrainableParams()) { - long thisSize = trainingConfig.getUpdater().stateSize(variables.get(s).getVariable().getArr().length()); - INDArray view = (updaterStateSize == 0 || thisSize == 0 ? null : - updaterState.get(NDArrayIndex.interval(0, 1), NDArrayIndex.interval(viewSoFar, viewSoFar + thisSize))); - - updaterViews.put(s, view); - updaterMap.put(s, trainingConfig.getUpdater().instantiate(view, true)); - viewSoFar += thisSize; + updaterStates.put(v.getName(), view); + updaterMap.put(v.getName(), trainingConfig.getUpdater().instantiate(view, true)); } initializedTraining = true; @@ -2202,8 +2264,15 @@ public SDVariable var(@NonNull String name, @NonNull WeightInitScheme weightInit //TODO only allowing null datatype for TF import (it's fixed in a later step) - don't want this in the public API! public SDVariable var(@NonNull String name, @NonNull VariableType variableType, WeightInitScheme weightInitScheme, org.nd4j.linalg.api.buffer.DataType dataType, long... shape) { - if (variables.containsKey(name) && variables.get(name).getVariable().getArr() != null) - throw new IllegalArgumentException("Another variable with the name " + name + " already exists."); + String withScope = nameWithScope(name); + if (variables.containsKey(withScope)) { + if(nameScopes.isEmpty()){ + throw new IllegalArgumentException("Another variable with the name " + name + " already exists (current name scope: \"" + + currentNameScope() + "\""); + } else { + throw new IllegalArgumentException("Another variable with the name " + name + " already exists."); + } + } if (name == null || name.length() < 1) name = getNewVarName(); @@ -2511,52 +2580,14 @@ public void convertToConstants(List variables){ } - if(trainingConfig != null){ - Set toRemove = new HashSet<>(); - boolean anyTrainableParmsModified = false; - List origTrainableParams = trainingConfig.getTrainableParams(); - for(SDVariable v : variables){ - toRemove.add(v.getVarName()); - if(!anyTrainableParmsModified && origTrainableParams.contains(v.getVarName())){ - anyTrainableParmsModified = true; - } - } - - - //Remove updater state for this variable: updaterState, updaterViews, updaterMap - if(anyTrainableParmsModified) { - List newTrainableParams = new ArrayList<>(); - for (String s : origTrainableParams) { - if (!toRemove.contains(s)) { - newTrainableParams.add(s); - } - } - trainingConfig.setTrainableParams(newTrainableParams); - } - - if(initializedTraining){ - List newUpdaterState = new ArrayList<>(); - for (String s : origTrainableParams) { - INDArray stateArr = updaterViews.get(s); - if (!toRemove.contains(s)) { - newUpdaterState.add(stateArr); - } - } - - updaterState = newUpdaterState.isEmpty() ? null : Nd4j.concat(0, newUpdaterState.toArray(new INDArray[newUpdaterState.size()])); - //Now, update updaterViews map: - long viewSoFar = 0; - updaterViews = new HashMap<>(); - updaterMap = new HashMap<>(); - for(String s : trainingConfig.getTrainableParams()) { - long thisSize = trainingConfig.getUpdater().stateSize(this.variables.get(s).getVariable().getArr().length()); - INDArray view = (updaterState == null || thisSize == 0 ? null : - updaterState.get(NDArrayIndex.interval(0, 1), NDArrayIndex.interval(viewSoFar, viewSoFar + thisSize))); - - updaterViews.put(s, view); - updaterMap.put(s, trainingConfig.getUpdater().instantiate(view, false)); - viewSoFar += thisSize; + if (trainingConfig != null && initializedTraining) { + //Remove updater state for now constant variables + for (SDVariable v : variables) { + INDArray state = updaterStates.remove(v.getVarName()); + if (state != null) { //Already null for constants + state.close(); //Deallocate now, instead of waiting for GC } + updaterMap.remove(v.getVarName()); } } } @@ -2620,41 +2651,21 @@ public void convertToVariables(@NonNull List constants){ //For training: need to add new updater state - if(trainingConfig != null){ - List newTrainableParams = new ArrayList<>(trainingConfig.getTrainableParams()); - List convertedToVars = new ArrayList<>(); - for(SDVariable v : constants){ - newTrainableParams.add(v.getVarName()); - convertedToVars.add(v.getVarName()); - } - trainingConfig.setTrainableParams(newTrainableParams); - - + if (trainingConfig != null && initializedTraining) { //Add updater state for this variable: updaterState, updaterViews, updaterMap - if(initializedTraining){ - long extraStateSize = 0; - for (String s : convertedToVars) { - INDArray arr = getVariable(s).getArr(); - long stateSize = trainingConfig.getUpdater().stateSize(arr.length()); - extraStateSize += stateSize; - } - if(extraStateSize > 0) { - INDArray newState = Nd4j.createUninitialized(updaterState.dataType(), 1, extraStateSize); - - updaterState = (updaterState == null ? newState : Nd4j.concat(1, updaterState, newState)); - //Now, update updaterViews map: - long viewSoFar = 0; - updaterViews = new HashMap<>(); - updaterMap = new HashMap<>(); - for (String s : trainingConfig.getTrainableParams()) { - long thisSize = trainingConfig.getUpdater().stateSize(this.variables.get(s).getVariable().getArr().length()); - INDArray view = (updaterState == null || thisSize == 0 ? null : - updaterState.get(NDArrayIndex.interval(0, 1), NDArrayIndex.interval(viewSoFar, viewSoFar + thisSize))); - - updaterViews.put(s, view); - boolean init = convertedToVars.contains(s); //Only initialize/zero the states for the new variables - updaterMap.put(s, trainingConfig.getUpdater().instantiate(view, init)); - viewSoFar += thisSize; + for (SDVariable v : constants) { + if (!updaterStates.containsKey(v.getOwnName())) { + //Create new updater state + INDArray arr = v.getArr(); + long thisSize = trainingConfig.getUpdater().stateSize(arr.length()); + if (thisSize > 0) { + INDArray stateArr = Nd4j.create(arr.dataType(), 1, thisSize); + updaterStates.put(v.getVarName(), stateArr); + GradientUpdater u = trainingConfig.getUpdater().instantiate(stateArr, true); + updaterMap.put(v.getVarName(), u); + } else { + GradientUpdater u = trainingConfig.getUpdater().instantiate(null, true); + updaterMap.put(v.getVarName(), u); } } } @@ -2763,14 +2774,6 @@ public void renameVariable(String from, String to){ trainingConfig.setDataSetLabelMaskMapping(l); } - if(trainingConfig.getTrainableParams() != null && trainingConfig.getTrainableParams().contains(from)){ - List l = new ArrayList<>(trainingConfig.getTrainableParams()); - while(l.contains(from)){ - l.set(l.indexOf(from), to); - } - trainingConfig.setTrainableParams(l); - } - if(trainingConfig.getLossVariables() != null && trainingConfig.getLossVariables().contains(from)){ List l = new ArrayList<>(trainingConfig.getLossVariables()); while(l.contains(from)){ @@ -4024,7 +4027,8 @@ public void resolveVariablesWith(Map arrays) { for (Map.Entry e : arrays.entrySet()) { SDVariable varForName = getVariable(e.getKey()); if (varForName == null) { - throw new ND4JIllegalStateException("No variable name found for " + e.getKey()); + throw new ND4JIllegalStateException("A placeholder array was provided for variable with name \"" + e.getKey() + + "\" but no variable with this name exists"); } Variable v = variables.get(e.getKey()); @@ -4087,6 +4091,13 @@ public SDVariable updateVariableNameAndReference(SDVariable varToUpdate, String return varToUpdate; } + String nameScope = currentNameScope(); + if(nameScope != null){ + if(!newVarName.startsWith(nameScope)){ + newVarName = nameScope + "/" + newVarName; + } + } + val oldVarName = varToUpdate.getVarName(); varToUpdate.setVarName(newVarName); updateVariableName(oldVarName, newVarName); @@ -4504,13 +4515,17 @@ public ByteBuffer asFlatBuffers(long graphId, @NonNull ExecutorConfiguration con reverseMap.put(variable.getVarName(), varIdx); - log.trace("Adding [{}] as [{}]", variable.getVarName(), varIdx); + log.trace("Adding [{}] as [{}]", variable.getVarName(), varIdx); int shape = 0; int name = bufferBuilder.createString(variable.getVarName()); - int array = arr == null ? 0 : arr.toFlatArray(bufferBuilder); + int array = 0; int id = IntPair.createIntPair(bufferBuilder, varIdx, outputNum); - byte varType = (byte)variable.getVariableType().ordinal(); + byte varType = (byte) variable.getVariableType().ordinal(); + if(variable.isConstant() || variable.isPlaceHolder() || variable.getVariableType() == VariableType.VARIABLE) { + //Don't export array type (i.e., activations), these are always replaced/re-calculated on each step + array = arr == null ? 0 : arr.toFlatArray(bufferBuilder); + } if (variable.getVariableType() == VariableType.PLACEHOLDER) { val shp = variable.getShape(); @@ -4603,7 +4618,6 @@ public ByteBuffer asFlatBuffers(long graphId, @NonNull ExecutorConfiguration con this.variables.get(e.getKey()).setVariableIndex(e.getValue()); } } - return bufferBuilder.dataBuffer(); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/TrainingConfig.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/TrainingConfig.java index 05fe87b3449b..6c87fb550344 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/TrainingConfig.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/TrainingConfig.java @@ -39,9 +39,6 @@ *
  • The L1 and L2 regularization coefficients (set to 0.0 by default)
  • *
  • The DataSet feature and label mapping - which defines how the feature/label arrays from the DataSet/MultiDataSet * should be associated with SameDiff variables (usually placeholders)
  • - *
  • Optional: The names of the trainable parameters. The trainable parameters are inferred automatically if not set here, though - * can be overridden if some parameters should not be modified during training (or if the automatic inference of the trainable - * parameters is not suitable/correct)
  • * * The TrainingConfig instance also stores the iteration count and the epoch count - these values are updated during training * and are used for example in learning rate schedules. @@ -62,7 +59,6 @@ public class TrainingConfig { private List dataSetLabelMapping; private List dataSetFeatureMaskMapping; private List dataSetLabelMaskMapping; - private List trainableParams; //Will be inferred automatically if null private List lossVariables; private int iterationCount; private int epochCount; @@ -81,7 +77,7 @@ public class TrainingConfig { */ public TrainingConfig(IUpdater updater, List regularization, String dataSetFeatureMapping, String dataSetLabelMapping) { this(updater, regularization, true, Collections.singletonList(dataSetFeatureMapping), Collections.singletonList(dataSetLabelMapping), - Collections.emptyList(), Collections.emptyList(), null, null); + Collections.emptyList(), Collections.emptyList(), null); } /** @@ -98,11 +94,9 @@ public TrainingConfig(IUpdater updater, List regularization, Str * @param dataSetLabelMapping As per dataSetFeatureMapping, but for the DataSet/MultiDataSet labels * @param dataSetFeatureMaskMapping May be null. If non-null, the variables that the MultiDataSet feature mask arrays should be associated with. * @param dataSetLabelMaskMapping May be null. If non-null, the variables that the MultiDataSet label mask arrays should be associated with. - * @param trainableParams May be null. If null: the set of trainable parameters will automatically be inferred from the SameDiff structure. - * If non-null, this defines the set of parameters that should be modified during training */ public TrainingConfig(IUpdater updater, List regularization, boolean minimize, List dataSetFeatureMapping, List dataSetLabelMapping, - List dataSetFeatureMaskMapping, List dataSetLabelMaskMapping, List trainableParams, List lossVariables) { + List dataSetFeatureMaskMapping, List dataSetLabelMaskMapping, List lossVariables) { this.updater = updater; this.regularization = regularization; this.minimize = minimize; @@ -110,7 +104,6 @@ public TrainingConfig(IUpdater updater, List regularization, boo this.dataSetLabelMapping = dataSetLabelMapping; this.dataSetFeatureMaskMapping = dataSetFeatureMaskMapping; this.dataSetLabelMaskMapping = dataSetLabelMaskMapping; - this.trainableParams = trainableParams; this.lossVariables = lossVariables; } @@ -150,7 +143,6 @@ public static class Builder { private List dataSetLabelMapping; private List dataSetFeatureMaskMapping; private List dataSetLabelMaskMapping; - private List trainableParams; //Will be inferred automatically if null private List lossVariables; private boolean skipValidation = false; private boolean markLabelsUnused = false; @@ -365,29 +357,6 @@ public Builder dataSetLabelMaskMapping(List dataSetLabelMaskMapping){ return this; } - /** - * Define the set of trainable parameters for the network.
    - * The trainable parameters are not set by default, which means they will be inferred automatically.
    - * The set of trainable parameters (variables) can be set here - any excluded from being set here won't be - * modified during training - * @param trainableParams Set of parameters/variables to train - */ - public Builder trainableParams(String... trainableParams){ - return trainableParams(Arrays.asList(trainableParams)); - } - - /** - * Define the set of trainable parameters for the network.
    - * The trainable parameters are not set by default, which means they will be inferred automatically.
    - * The set of trainable parameters (variables) can be set here - any excluded from being set here won't be - * modified during training - * @param trainableParams Set of parameters/variables to train - */ - public Builder trainableParams(List trainableParams) { - this.trainableParams = trainableParams; - return this; - } - public Builder skipBuilderValidation(boolean skip){ this.skipValidation = skip; return this; @@ -409,7 +378,7 @@ public TrainingConfig build(){ } return new TrainingConfig(updater, regularization, minimize, dataSetFeatureMapping, dataSetLabelMapping, - dataSetFeatureMaskMapping, dataSetLabelMaskMapping, trainableParams, lossVariables); + dataSetFeatureMaskMapping, dataSetLabelMaskMapping, lossVariables); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/internal/Variable.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/internal/Variable.java index 0d5b2be9d590..670b21dda90b 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/internal/Variable.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/internal/Variable.java @@ -31,7 +31,6 @@ public class Variable { protected String name; protected SDVariable variable; - protected Object shapeInfo; //TODO decide type, or if even to include (Variable class should ideally be immutable) protected List inputsForOp; protected List controlDepsForOp; //if a op control dependency (x -> opY) exists, then "opY" will be in this list protected List controlDepsForVar; //if a variable control dependency (x -> varY) exists, then "varY" will be in this list diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/evaluation/classification/Evaluation.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/evaluation/classification/Evaluation.java index 594db7f92453..30ed1b3f75f7 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/evaluation/classification/Evaluation.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/evaluation/classification/Evaluation.java @@ -647,9 +647,9 @@ private String stats(boolean suppressWarnings, boolean includeConfusion, boolean int nClasses = confusion.getClasses().size(); DecimalFormat df = new DecimalFormat("0.0000"); double acc = accuracy(); - double precisionMacro = precision(EvaluationAveraging.Macro); - double recallMacro = recall(EvaluationAveraging.Macro); - double f1Macro = f1(EvaluationAveraging.Macro); + double precision = precision(); //Macro precision for N>2, or binary class 1 (only) precision by default + double recall = recall(); //Macro recall for N>2, or binary class 1 (only) precision by default + double f1 = f1(); //Macro F1 for N>2, or binary class 1 (only) precision by default builder.append("\n========================Evaluation Metrics========================"); builder.append("\n # of classes: ").append(nClasses); builder.append("\n Accuracy: ").append(format(df, acc)); @@ -657,7 +657,7 @@ private String stats(boolean suppressWarnings, boolean includeConfusion, boolean double topNAcc = topNAccuracy(); builder.append("\n Top ").append(topN).append(" Accuracy: ").append(format(df, topNAcc)); } - builder.append("\n Precision: ").append(format(df, precisionMacro)); + builder.append("\n Precision: ").append(format(df, precision)); if (nClasses > 2 && averagePrecisionNumClassesExcluded() > 0) { int ex = averagePrecisionNumClassesExcluded(); builder.append("\t(").append(ex).append(" class"); @@ -665,7 +665,7 @@ private String stats(boolean suppressWarnings, boolean includeConfusion, boolean builder.append("es"); builder.append(" excluded from average)"); } - builder.append("\n Recall: ").append(format(df, recallMacro)); + builder.append("\n Recall: ").append(format(df, recall)); if (nClasses > 2 && averageRecallNumClassesExcluded() > 0) { int ex = averageRecallNumClassesExcluded(); builder.append("\t(").append(ex).append(" class"); @@ -673,7 +673,7 @@ private String stats(boolean suppressWarnings, boolean includeConfusion, boolean builder.append("es"); builder.append(" excluded from average)"); } - builder.append("\n F1 Score: ").append(format(df, f1Macro)); + builder.append("\n F1 Score: ").append(format(df, f1)); if (nClasses > 2 && averageF1NumClassesExcluded() > 0) { int ex = averageF1NumClassesExcluded(); builder.append("\t(").append(ex).append(" class"); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/shape/Shape.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/shape/Shape.java index 24a4b7e51b80..8c2224df1430 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/shape/Shape.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/shape/Shape.java @@ -3762,6 +3762,32 @@ public static long[] reductionShape(INDArray x, int[] dimension, boolean newForm return retShape; } + /** + * Determine whether the placeholder shape and the specified shape are compatible.
    + * Shapes are compatible if:
    + * (a) They are both the same length (same array rank, or null)
    + * (b) At each position either phShape[i] == -1 or phShape[i] == arrShape[i] + * + * @param phShape Placeholder shape + * @param arrShape Array shape to check if it matches the placeholder shape + * @return True if the array shape is compatible with the placeholder shape + */ + public static boolean shapeMatchesPlaceholder(long[] phShape, long[] arrShape) { + if (phShape == null && arrShape == null) + return true; //Rank 0? + if (phShape == null || arrShape == null) + return false; + if (phShape.length != arrShape.length) + return false; + for (int i = 0; i < phShape.length; i++) { + if (phShape[i] > 0) {//for <0 case: Any value for this dimension is OK (i.e., -1s) + if (phShape[i] != arrShape[i]) { + return false; + } + } + } + return true; + } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/memory/abstracts/Nd4jWorkspace.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/memory/abstracts/Nd4jWorkspace.java index 78897b5f8090..be694234dc06 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/memory/abstracts/Nd4jWorkspace.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/memory/abstracts/Nd4jWorkspace.java @@ -157,7 +157,7 @@ public Nd4jWorkspace(@NonNull WorkspaceConfiguration configuration, @NonNull Str "For cyclic workspace overallocation should be positive integral value."); stepsNumber = (int) (workspaceConfiguration.getOverallocationLimit() + 1); - log.debug("Steps: {}", stepsNumber); + log.trace("Steps: {}", stepsNumber); } //if (workspaceConfiguration.getPolicyLearning() == LearningPolicy.OVER_TIME && workspaceConfiguration.getCyclesBeforeInitialization() < 1) diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/NameScopeTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/NameScopeTests.java new file mode 100644 index 000000000000..c5ca3aa13d34 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/NameScopeTests.java @@ -0,0 +1,132 @@ +package org.nd4j.autodiff.samediff; + +import org.junit.Test; +import org.nd4j.autodiff.samediff.internal.SameDiffOp; +import org.nd4j.autodiff.samediff.internal.Variable; +import org.nd4j.linalg.BaseNd4jTest; +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.factory.Nd4jBackend; + +import java.util.Arrays; +import java.util.HashSet; +import java.util.Map; +import java.util.Set; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; + +public class NameScopeTests extends BaseNd4jTest { + + public NameScopeTests(Nd4jBackend b){ + super(b); + } + + @Override + public char ordering(){ + return 'c'; + } + + @Test + public void testVariableNameScopesBasic(){ + + SameDiff sd = SameDiff.create(); + SDVariable v = sd.var("x"); + try(NameScope ns = sd.withNameScope("nameScope")){ + SDVariable v2 = sd.var("x2"); + assertEquals("nameScope/x2", v2.getVarName()); + assertTrue(sd.getVariables().containsKey("nameScope/x2")); + assertEquals("nameScope", sd.currentNameScope()); + + SDVariable v3 = sd.var("x"); + assertEquals("nameScope/x", v3.getVarName()); + assertTrue(sd.getVariables().containsKey("nameScope/x")); + + try(NameScope ns2 = sd.withNameScope("scope2")){ + assertEquals("nameScope/scope2", sd.currentNameScope()); + SDVariable v4 = sd.var("x"); + assertEquals("nameScope/scope2/x", v4.getVarName()); + assertTrue(sd.getVariables().containsKey("nameScope/scope2/x")); + } + + assertEquals("nameScope", sd.currentNameScope()); + } + } + + @Test + public void testOpFieldsAndNames(){ + + SameDiff sd = SameDiff.create(); + SDVariable x = sd.var("x", DataType.FLOAT, 1); + SDVariable y; + SDVariable z; + + SDVariable add; + SDVariable addWithName; + SDVariable merge; + SDVariable mergeWithName; + try(NameScope ns = sd.withNameScope("s1")){ + y = sd.var("y", DataType.FLOAT, 1); + add = x.add(y); + addWithName = x.add("addxy", y); + try(NameScope ns2 = sd.withNameScope("s2")){ + z = sd.var("z", DataType.FLOAT, 1); + merge = sd.math().mergeMax(y, z); + mergeWithName = sd.math.mergeMax("mmax", y, z); + } + } + SDVariable a = sd.var("a", DataType.FLOAT, 1); + + assertEquals("x", x.getVarName()); + assertEquals("s1/y", y.getVarName()); + assertEquals("s1/s2/z", z.getVarName()); + assertEquals("a", a.getVarName()); + + assertTrue(add.getVarName(), add.getVarName().startsWith("s1/")); + assertEquals("s1/addxy", addWithName.getVarName()); + + assertTrue(merge.getVarName(), merge.getVarName().startsWith("s1/s2/")); + assertEquals("s1/s2/mmax", mergeWithName.getVarName()); + + Set allowedVarNames = new HashSet<>(Arrays.asList("x", "s1/y", "s1/s2/z", "a", + add.getVarName(), addWithName.getVarName(), merge.getVarName(), mergeWithName.getVarName())); + Set allowedOpNames = new HashSet<>(); + + //Check op names: + Map ops = sd.getOps(); + System.out.println(ops.keySet()); + + for(String s : ops.keySet()){ + assertTrue(s, s.startsWith("s1") || s.startsWith("s1/s2")); + allowedOpNames.add(s); + } + + //Check fields - Variable, SDOp, etc + for(Variable v : sd.getVariables().values()){ + assertTrue(v.getVariable().getVarName(), allowedVarNames.contains(v.getVariable().getVarName())); + assertEquals(v.getName(), v.getVariable().getVarName()); + if(v.getInputsForOp() != null){ + for(String s : v.getInputsForOp()){ + assertTrue(s, allowedOpNames.contains(s)); + } + } + + if(v.getOutputOfOp() != null){ + assertTrue(allowedOpNames.contains(v.getOutputOfOp())); + } + } + + assertTrue(allowedOpNames.containsAll(sd.getOps().keySet())); + + for(SameDiffOp op : sd.getOps().values()){ + assertTrue(allowedOpNames.contains(op.getName())); + assertEquals(op.getName(), op.getOp().getOwnName()); + if(op.getInputsToOp() != null){ + assertTrue(allowedVarNames.containsAll(op.getInputsToOp())); + } + + if(op.getOutputsOfOp() != null){ + assertTrue(allowedVarNames.containsAll(op.getOutputsOfOp())); + } + } + } +} diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/SameDiffTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/SameDiffTests.java index 862d49e50a8c..1ac93a833de3 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/SameDiffTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/SameDiffTests.java @@ -47,6 +47,7 @@ import org.nd4j.linalg.api.shape.LongShapeDescriptor; import org.nd4j.linalg.checkutil.NDArrayCreationUtil; import org.nd4j.linalg.dataset.DataSet; +import org.nd4j.linalg.dataset.MultiDataSet; import org.nd4j.linalg.dataset.adapter.SingletonMultiDataSetIterator; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.indexing.NDArrayIndex; @@ -211,13 +212,13 @@ public void testSaveWriteWithTrainingConfig() throws Exception { SameDiff sameDiff1 = SameDiff.restoreFromTrainingConfigZip(newFile); assertEquals(sameDiff.getTrainingConfig().getUpdater(), sameDiff1.getTrainingConfig().getUpdater()); - assertEquals(sameDiff.getUpdaterState(), sameDiff1.getUpdaterState()); + assertEquals(sameDiff.getUpdaterStates(), sameDiff1.getUpdaterStates()); sameDiff.saveWithTrainingConfig(newFile); sameDiff1 = SameDiff.restoreFromTrainingConfigZip(newFile); assertEquals(sameDiff.getTrainingConfig().getUpdater(), sameDiff1.getTrainingConfig().getUpdater()); - assertEquals(sameDiff.getUpdaterState(), sameDiff1.getUpdaterState()); + assertEquals(sameDiff.getUpdaterStates(), sameDiff1.getUpdaterStates()); } @@ -2745,11 +2746,6 @@ public void testConvertToConstant() { INDArray out = tanh.eval(); - List tp = c.getTrainableParams(); - assertEquals(2, tp.size()); - assertTrue(tp.contains("w")); - assertTrue(tp.contains("b")); - w.convertToConstant(); INDArray out2 = tanh.eval(); @@ -2790,11 +2786,6 @@ public void testConvertToVariable() { INDArray out = tanh.eval(); sd.fit(new SingletonMultiDataSetIterator(new DataSet(inArr, null).toMultiDataSet()), 1); - List tp = c.getTrainableParams(); - assertEquals(1, tp.size()); - assertFalse(tp.contains("w")); - assertTrue(tp.contains("b")); - w.convertToVariable(); INDArray out2 = tanh.eval(); @@ -3155,4 +3146,64 @@ public void testVariableRenaming2(){ v3.rename("newName"); sd.fit(new DataSet(Nd4j.rand(DataType.FLOAT, 3, 4), null)); } + + @Test + public void testPlaceholderShapeValidation(){ + SameDiff sd = SameDiff.create(); + SDVariable ph1 = sd.placeHolder("ph1", DataType.FLOAT, 3, 4); + SDVariable ph2 = sd.placeHolder("ph2", DataType.FLOAT, -1, 4); + SDVariable ph3 = sd.placeHolder("ph3", DataType.FLOAT, 3, -1); + SDVariable ph4 = sd.placeHolder("ph4", DataType.FLOAT, -1, -1); + + INDArray correctShape = Nd4j.create(DataType.FLOAT, 3, 4); + INDArray wrongShape = Nd4j.create(DataType.FLOAT, 2, 3); + INDArray wrongRank1 = Nd4j.create(DataType.FLOAT, 1); + INDArray wrongRank2 = Nd4j.create(DataType.FLOAT, 3, 4, 5); + for(SDVariable v : new SDVariable[]{ph1, ph2, ph3, ph4}){ + v.setArray(correctShape); + + if(v != ph4) { + try { + v.setArray(wrongShape); + fail("Expected exception"); + } catch (Exception t) { + String msg = t.getMessage(); + assertTrue(msg, msg.contains("shape") && msg.contains("[2, 3]") && msg.contains(Arrays.toString(v.placeholderShape()))); + } + } + + try{ + v.setArray(wrongRank1); + fail("Expected exception"); + } catch (Exception t){ + String msg = t.getMessage(); + assertTrue(msg, msg.contains("shape") && msg.contains("[1]") && msg.contains(Arrays.toString(v.placeholderShape()))); + } + + try{ + v.setArray(wrongRank2); + fail("Expected exception"); + } catch (Exception t){ + String msg = t.getMessage(); + assertTrue(msg, msg.contains("shape") && msg.contains("[3, 4, 5]") && msg.contains(Arrays.toString(v.placeholderShape()))); + } + } + + //Also try training: + SDVariable sum = sd.math.mergeAdd(ph1, ph2, ph3, ph4); + SDVariable mean = sum.mean(); + MultiDataSet mds = new MultiDataSet(new INDArray[]{wrongShape, wrongShape, wrongShape, wrongShape}, null); + + sd.setTrainingConfig(TrainingConfig.builder() + .dataSetFeatureMapping("ph1", "ph2", "ph3", "ph4") + .markLabelsUnused() + .updater(new Adam(1e-3)).build()); + + try{ + sd.fit(mds); + } catch (Exception t){ + String msg = t.getMessage(); + assertTrue(msg, msg.contains("shape") && msg.contains("[2, 3]")); + } + } } diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/SameDiffTrainingTest.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/SameDiffTrainingTest.java index 1dbf3ad3db52..1526381f2697 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/SameDiffTrainingTest.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/SameDiffTrainingTest.java @@ -23,6 +23,7 @@ import org.nd4j.evaluation.classification.Evaluation; import org.nd4j.linalg.BaseNd4jTest; import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.dataset.DataSet; import org.nd4j.linalg.dataset.IrisDataSetIterator; import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; import org.nd4j.linalg.dataset.api.preprocessor.NormalizerStandardize; @@ -123,6 +124,70 @@ public void irisTrainingSanityCheck() { } } + + + @Test + public void testTrainingMixedDtypes(){ + + for (String u : new String[]{"adam", "nesterov", "adamax", "amsgrad"}) { + + SameDiff sd = SameDiff.create(); + SDVariable in = sd.placeHolder("in", DataType.FLOAT, -1, 4); + + SDVariable inHalf = in.castTo(DataType.HALF); + SDVariable inDouble = in.castTo(DataType.DOUBLE); + + SDVariable wFloat = sd.var("wFloat", Nd4j.rand(DataType.FLOAT, 4, 3)); + SDVariable wDouble = sd.var("wDouble", Nd4j.rand(DataType.DOUBLE, 4, 3)); + SDVariable wHalf = sd.var("wHalf", Nd4j.rand(DataType.HALF, 4, 3)); + + SDVariable outFloat = in.mmul(wFloat); + SDVariable outDouble = inDouble.mmul(wDouble); + SDVariable outHalf = inHalf.mmul(wHalf); + + SDVariable sum = outFloat.add(outDouble.castTo(DataType.FLOAT)).add(outHalf.castTo(DataType.FLOAT)); + + SDVariable loss = sum.std(true); + + IUpdater updater; + switch (u) { + case "sgd": + updater = new Sgd(1e-2); + break; + case "adam": + updater = new Adam(1e-2); + break; + case "nesterov": + updater = new Nesterovs(1e-2); + break; + case "adamax": + updater = new AdaMax(1e-2); + break; + case "amsgrad": + updater = new AMSGrad(1e-2); + break; + default: + throw new RuntimeException(); + } + + TrainingConfig conf = new TrainingConfig.Builder() + .l2(1e-4) + .updater(updater) + .dataSetFeatureMapping("in") + .markLabelsUnused() + .build(); + + sd.setTrainingConfig(conf); + + DataSet ds = new DataSet(Nd4j.rand(DataType.FLOAT, 3, 4), null); + + for( int i=0; i<10; i++ ){ + sd.fit(ds); + } + } + + } + @Override public char ordering() { return 'c'; diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/listeners/CheckpointListenerTest.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/listeners/CheckpointListenerTest.java new file mode 100644 index 000000000000..cf99ebbaa838 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/listeners/CheckpointListenerTest.java @@ -0,0 +1,232 @@ +package org.nd4j.autodiff.samediff.listeners; + +import org.junit.Assert; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TemporaryFolder; +import org.nd4j.autodiff.listeners.checkpoint.CheckpointListener; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.autodiff.samediff.TrainingConfig; +import org.nd4j.linalg.BaseNd4jTest; +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.dataset.IrisDataSetIterator; +import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; +import org.nd4j.linalg.factory.Nd4j; +import org.nd4j.linalg.factory.Nd4jBackend; +import org.nd4j.linalg.learning.config.Adam; +import org.nd4j.linalg.primitives.Pair; + +import java.io.File; +import java.util.Arrays; +import java.util.HashSet; +import java.util.List; +import java.util.Set; +import java.util.concurrent.TimeUnit; + +import static junit.framework.TestCase.assertTrue; +import static org.junit.Assert.assertEquals; + +public class CheckpointListenerTest extends BaseNd4jTest { + + public CheckpointListenerTest(Nd4jBackend backend){ + super(backend); + } + + @Override + public char ordering(){ + return 'c'; + } + + @Rule + public TemporaryFolder testDir = new TemporaryFolder(); + + public static SameDiff getModel(){ + Nd4j.getRandom().setSeed(12345); + SameDiff sd = SameDiff.create(); + SDVariable in = sd.placeHolder("in", DataType.FLOAT, -1, 4); + SDVariable label = sd.placeHolder("label", DataType.FLOAT, -1, 3); + SDVariable w = sd.var("W", Nd4j.rand(DataType.FLOAT, 4, 3)); + SDVariable b = sd.var("b", DataType.FLOAT, 3); + + SDVariable mmul = in.mmul(w).add(b); + SDVariable softmax = sd.nn().softmax(mmul); + SDVariable loss = sd.loss().logLoss("loss", label, softmax); + + sd.setTrainingConfig(TrainingConfig.builder() + .dataSetFeatureMapping("in") + .dataSetLabelMapping("label") + .updater(new Adam(1e-2)) + .weightDecay(1e-2, true) + .build()); + + return sd; + } + + public static DataSetIterator getIter(){ + return new IrisDataSetIterator(15, 150); + } + + + @Test + public void testCheckpointEveryEpoch() throws Exception { + File dir = testDir.newFolder(); + + SameDiff sd = getModel(); + CheckpointListener l = CheckpointListener.builder(dir) + .saveEveryNEpochs(1) + .build(); + + sd.setListeners(l); + + DataSetIterator iter = getIter(); + sd.fit(iter, 3); + + File[] files = dir.listFiles(); + String s1 = "checkpoint-0_epoch-0_iter-9"; //Note: epoch is 10 iterations, 0-9, 10-19, 20-29, etc + String s2 = "checkpoint-1_epoch-1_iter-19"; + String s3 = "checkpoint-2_epoch-2_iter-29"; + boolean found1 = false; + boolean found2 = false; + boolean found3 = false; + for(File f : files){ + String s = f.getAbsolutePath(); + if(s.contains(s1)) + found1 = true; + if(s.contains(s2)) + found2 = true; + if(s.contains(s3)) + found3 = true; + } + assertEquals(4, files.length); //3 checkpoints and 1 text file (metadata) + assertTrue(found1 && found2 && found3); + } + + @Test + public void testCheckpointEvery5Iter() throws Exception { + File dir = testDir.newFolder(); + + SameDiff sd = getModel(); + CheckpointListener l = CheckpointListener.builder(dir) + .saveEveryNIterations(5) + .build(); + + sd.setListeners(l); + + DataSetIterator iter = getIter(); + sd.fit(iter, 2); //2 epochs = 20 iter + + File[] files = dir.listFiles(); + List names = Arrays.asList( + "checkpoint-0_epoch-0_iter-4", + "checkpoint-1_epoch-0_iter-9", + "checkpoint-2_epoch-1_iter-14", + "checkpoint-3_epoch-1_iter-19"); + boolean[] found = new boolean[names.size()]; + for(File f : files){ + String s = f.getAbsolutePath(); + System.out.println(s); + for( int i=0; i names = Arrays.asList( + "checkpoint-2_epoch-3_iter-30", + "checkpoint-3_epoch-4_iter-40"); + boolean[] found = new boolean[names.size()]; + for(File f : files){ + String s = f.getAbsolutePath(); + System.out.println(s); + for( int i=0; i cpNums = new HashSet<>(); + Set epochNums = new HashSet<>(); + for(File f2 : files){ + if(!f2.getPath().endsWith(".bin")){ + continue; + } + count++; + int idx = f2.getName().indexOf("epoch-"); + int end = f2.getName().indexOf("_", idx); + int num = Integer.parseInt(f2.getName().substring(idx + "epoch-".length(), end)); + epochNums.add(num); + + int start = f2.getName().indexOf("checkpoint-"); + end = f2.getName().indexOf("_", start + "checkpoint-".length()); + int epochNum = Integer.parseInt(f2.getName().substring(start + "checkpoint-".length(), end)); + cpNums.add(epochNum); + } + + assertEquals(cpNums.toString(), 5, cpNums.size()); + Assert.assertTrue(cpNums.toString(), cpNums.containsAll(Arrays.asList(2, 5, 7, 8, 9))); + Assert.assertTrue(epochNums.toString(), epochNums.containsAll(Arrays.asList(5, 11, 15, 17, 19))); + + assertEquals(5, l.availableCheckpoints().size()); + } +} diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/evaluation/EvalTest.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/evaluation/EvalTest.java index 2ccef06bf881..2aac466d1696 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/evaluation/EvalTest.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/evaluation/EvalTest.java @@ -28,6 +28,7 @@ import org.nd4j.linalg.indexing.NDArrayIndex; import org.nd4j.linalg.util.FeatureUtil; +import java.text.DecimalFormat; import java.util.*; import static org.junit.Assert.*; @@ -1044,6 +1045,49 @@ public void testLabelReset(){ String s2 = e1.stats(); assertEquals(s1, s2); + } + + @Test + public void testEvalStatsBinaryCase(){ + //Make sure we report class 1 precision/recall/f1 not macro averaged, for binary case + + Evaluation e = new Evaluation(); + + INDArray l0 = Nd4j.createFromArray(new double[]{1,0}).reshape(1,2); + INDArray l1 = Nd4j.createFromArray(new double[]{0,1}).reshape(1,2); + + e.eval(l1, l1); + e.eval(l1, l1); + e.eval(l1, l1); + e.eval(l0, l0); + e.eval(l1, l0); + e.eval(l1, l0); + e.eval(l0, l1); + + double tp = 3; + double fp = 1; + double fn = 2; + + double prec = tp / (tp + fp); + double rec = tp / (tp + fn); + double f1 = 2 * prec * rec / (prec + rec); + + assertEquals(prec, e.precision(), 1e-6); + assertEquals(rec, e.recall(), 1e-6); + + DecimalFormat df = new DecimalFormat("0.0000"); + + String stats = e.stats(); + //System.out.println(stats); + + String stats2 = stats.replaceAll("( )+", " "); + + String recS = " Recall: " + df.format(rec); + String preS = " Precision: " + df.format(prec); + String f1S = "F1 Score: " + df.format(f1); + assertTrue(stats2, stats2.contains(recS)); + assertTrue(stats2, stats2.contains(preS)); + assertTrue(stats2, stats2.contains(f1S)); } } diff --git a/nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/buffer/BaseDataBuffer.java b/nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/buffer/BaseDataBuffer.java index 47af1e28a825..fbfbd1aa5be3 100644 --- a/nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/buffer/BaseDataBuffer.java +++ b/nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/buffer/BaseDataBuffer.java @@ -1349,6 +1349,15 @@ public byte[] asBytes() { throw new RuntimeException(e); } break; + case UTF8: + byte[] temp4 = new byte[(int)length]; + asNio().get(temp4); + try { + dos.write(temp4); + } catch (IOException e){ + throw new RuntimeException(e); + } + break; default: throw new UnsupportedOperationException("Unknown data type: [" + dataType + "]"); }