Skip to content

Commit

Permalink
Fixes and SameDiff functionality (#7807)
Browse files Browse the repository at this point in the history
* #6992 SameDiff mixed precision training support

* Placeholder shape validation

* Checkpoint listener

* SameDiff checkpoint listener

* SameDiff: Remove no longer required trainable params config from TrainingConfig

* SameDiff: add name scopes

* SameDiff name scopes - javadoc and tests

* #7802 Evaluation class - report single class not macro avg in stats() for binary case

* 7804 Arbiter - update score functions to use ND4J evaluation metric enums

* SameDiff flatbuffers export: don't export arrays for array type variables (not required)
  • Loading branch information
AlexDBlack committed May 29, 2019
1 parent ce51666 commit 1e2bcc1
Show file tree
Hide file tree
Showing 20 changed files with 1,492 additions and 236 deletions.
Expand Up @@ -18,9 +18,9 @@

import lombok.*;
import org.deeplearning4j.datasets.iterator.MultiDataSetWrapperIterator;
import org.deeplearning4j.eval.Evaluation;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.nd4j.evaluation.classification.Evaluation;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator;

Expand All @@ -37,6 +37,13 @@ public class EvaluationScoreFunction extends BaseNetScoreFunction {

protected Evaluation.Metric metric;

/**
* @param metric Evaluation metric to calculate
*/
public EvaluationScoreFunction(@NonNull org.deeplearning4j.eval.Evaluation.Metric metric) {
this(metric.toNd4j());
}

/**
* @param metric Evaluation metric to calculate
*/
Expand Down
Expand Up @@ -19,12 +19,11 @@
import lombok.*;
import org.deeplearning4j.datasets.iterator.MultiDataSetWrapperIterator;
import org.deeplearning4j.datasets.iterator.impl.MultiDataSetIteratorAdapter;
import org.deeplearning4j.eval.Evaluation;
import org.deeplearning4j.eval.ROC;
import org.deeplearning4j.eval.ROCBinary;
import org.deeplearning4j.eval.ROCMultiClass;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.nd4j.evaluation.classification.ROC;
import org.nd4j.evaluation.classification.ROCBinary;
import org.nd4j.evaluation.classification.ROCMultiClass;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator;

Expand Down
Expand Up @@ -17,18 +17,13 @@
package org.deeplearning4j.arbiter.scoring.impl;

import lombok.*;
import org.deeplearning4j.arbiter.optimize.api.data.DataSource;
import org.deeplearning4j.arbiter.scoring.RegressionValue;
import org.deeplearning4j.arbiter.scoring.util.ScoreUtil;
import org.deeplearning4j.datasets.iterator.MultiDataSetWrapperIterator;
import org.deeplearning4j.eval.RegressionEvaluation;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.nd4j.evaluation.regression.RegressionEvaluation;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator;

import java.util.Properties;

/**
* Score function for regression (including multi-label regression) for a MultiLayerNetwork or ComputationGraph
* on a test set. Supports all regression metrics: {@link RegressionEvaluation.Metric}
Expand All @@ -42,6 +37,10 @@ public class RegressionScoreFunction extends BaseNetScoreFunction {

protected RegressionEvaluation.Metric metric;

public RegressionScoreFunction(@NonNull org.deeplearning4j.eval.RegressionEvaluation.Metric metric) {
this(metric.toNd4j());
}

public RegressionScoreFunction(@NonNull RegressionEvaluation.Metric metric) {
this.metric = metric;
}
Expand Down
Expand Up @@ -644,9 +644,14 @@ protected void setInstanceId() {
this.ownName = UUID.randomUUID().toString();
else {
int argIndex = 0;
String varName = sameDiff.generateNewVarName(opName(),argIndex);
String scope = sameDiff.currentNameScope();
if(scope == null)
scope = "";
else
scope = scope + "/";
String varName = scope + sameDiff.generateNewVarName(opName(),argIndex);
while(sameDiff.functionExists(varName)) {
varName = sameDiff.generateNewVarName(opName(), argIndex);
varName = scope + sameDiff.generateNewVarName(opName(), argIndex);
argIndex++;
}

Expand Down
@@ -0,0 +1,61 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/

package org.nd4j.autodiff.listeners.checkpoint;

import lombok.AllArgsConstructor;
import lombok.Data;

import java.io.Serializable;
import java.util.Arrays;

/**
* A model checkpoint, used with {@link CheckpointListener}
*
* @author Alex Black
*/
@AllArgsConstructor
@Data
public class Checkpoint implements Serializable {

private int checkpointNum;
private long timestamp;
private int iteration;
private int epoch;
private String filename;

public static String getFileHeader(){
return "checkpointNum,timestamp,iteration,epoch,filename";
}

public static Checkpoint fromFileString(String str){
String[] split = str.split(",");
if(split.length != 5){
throw new IllegalStateException("Cannot parse checkpoint entry: expected 5 entries, got " + split.length
+ " - values = " + Arrays.toString(split));
}
return new Checkpoint(
Integer.parseInt(split[0]),
Long.parseLong(split[1]),
Integer.parseInt(split[2]),
Integer.parseInt(split[3]),
split[4]);
}

public String toFileString(){
return checkpointNum + "," + timestamp + "," + iteration + "," + epoch + "," + filename;
}
}

0 comments on commit 1e2bcc1

Please sign in to comment.