Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve Arbiter data usability #5952

Merged
merged 3 commits into from Jul 24, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Expand Up @@ -17,11 +17,13 @@
package org.deeplearning4j.arbiter.optimize.api;

import org.deeplearning4j.arbiter.optimize.api.data.DataProvider;
import org.deeplearning4j.arbiter.optimize.api.data.DataSource;
import org.deeplearning4j.arbiter.optimize.api.score.ScoreFunction;
import org.deeplearning4j.arbiter.optimize.runner.IOptimizationRunner;
import org.deeplearning4j.arbiter.optimize.runner.listener.StatusListener;

import java.util.List;
import java.util.Properties;
import java.util.concurrent.Callable;

/**
Expand All @@ -41,6 +43,20 @@ public interface TaskCreator {
* @param statusListeners Status listeners, that can be used for callbacks (to UI, for example)
* @return A callable that returns an OptimizationResult, once optimization is complete
*/
@Deprecated
Callable<OptimizationResult> create(Candidate candidate, DataProvider dataProvider, ScoreFunction scoreFunction,
List<StatusListener> statusListeners, IOptimizationRunner runner);

/**
* Generate a callable that can be executed to conduct the training of this model (given the model configuration)
*
* @param candidate Candidate (model) configuration to be trained
* @param dataSource Data source
* @param dataSourceProperties Properties (may be null) for the data source
* @param scoreFunction Score function to be used to evaluate the model
* @param statusListeners Status listeners, that can be used for callbacks (to UI, for example)
* @return A callable that returns an OptimizationResult, once optimization is complete
*/
Callable<OptimizationResult> create(Candidate candidate, Class<? extends DataSource> dataSource, Properties dataSourceProperties,
ScoreFunction scoreFunction, List<StatusListener> statusListeners, IOptimizationRunner runner);
}
Expand Up @@ -24,9 +24,11 @@

/**
* DataProvider interface abstracts out the providing of data
* @deprecated Use {@link DataSource}
*/
@JsonInclude(JsonInclude.Include.NON_NULL)
@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class")
@Deprecated
public interface DataProvider extends Serializable {

/**
Expand Down
@@ -0,0 +1,41 @@
package org.deeplearning4j.arbiter.optimize.api.data;

import java.io.Serializable;
import java.util.Properties;

/**
* DataSource: defines where the data should come from for training and testing.
* Note that implementations must have a no-argument contsructor
*
* @author Alex Black
*/
public interface DataSource extends Serializable {

/**
* Configure the current data source with the specified properties
* Note: These properties are fixed for the training instance, and are optionally provided by the user
* at the configuration stage.
* The properties could be anything - and are usually specific to each DataSource implementation.
* For example, values such as batch size could be set using these properties
* @param properties Properties to apply to the data source instance
*/
void configure(Properties properties);

/**
* Get test data to be used for the optimization. Usually a DataSetIterator or MultiDataSetIterator
*/
Object trainData();

/**
* Get test data to be used for the optimization. Usually a DataSetIterator or MultiDataSetIterator
*/
Object testData();

/**
* The type of data returned by {@link #trainData()} and {@link #testData()}.
* Usually DataSetIterator or MultiDataSetIterator
* @return Class of the objects returned by trainData and testData
*/
Class<?> getDataType();

}
Expand Up @@ -17,12 +17,14 @@
package org.deeplearning4j.arbiter.optimize.api.score;

import org.deeplearning4j.arbiter.optimize.api.data.DataProvider;
import org.deeplearning4j.arbiter.optimize.api.data.DataSource;
import org.nd4j.shade.jackson.annotation.JsonInclude;
import org.nd4j.shade.jackson.annotation.JsonTypeInfo;

import java.io.Serializable;
import java.util.List;
import java.util.Map;
import java.util.Properties;

/**
* ScoreFunction defines the objective of hyperparameter optimization.
Expand All @@ -44,6 +46,16 @@ public interface ScoreFunction extends Serializable {
*/
double score(Object model, DataProvider dataProvider, Map<String, Object> dataParameters);

/**
* Calculate and return the score, for the given model and data provider
*
* @param model Model to score
* @param dataSource Data source
* @param dataSourceProperties data source properties
* @return Calculated score
*/
double score(Object model, Class<? extends DataSource> dataSource, Properties dataSourceProperties);

/**
* Should this score function be minimized or maximized?
*
Expand Down
Expand Up @@ -19,6 +19,7 @@
import lombok.*;
import org.deeplearning4j.arbiter.optimize.api.CandidateGenerator;
import org.deeplearning4j.arbiter.optimize.api.data.DataProvider;
import org.deeplearning4j.arbiter.optimize.api.data.DataSource;
import org.deeplearning4j.arbiter.optimize.api.saving.ResultSaver;
import org.deeplearning4j.arbiter.optimize.api.score.ScoreFunction;
import org.deeplearning4j.arbiter.optimize.api.termination.TerminationCondition;
Expand All @@ -28,8 +29,10 @@
import org.nd4j.shade.jackson.databind.annotation.JsonSerialize;

import java.io.IOException;
import java.lang.reflect.Constructor;
import java.util.Arrays;
import java.util.List;
import java.util.Properties;

/**
* OptimizationConfiguration ties together all of the various
Expand All @@ -46,6 +49,10 @@ public class OptimizationConfiguration {
@JsonSerialize
private DataProvider dataProvider;
@JsonSerialize
private Class<? extends DataSource> dataSource;
@JsonSerialize
private Properties dataSourceProperties;
@JsonSerialize
private CandidateGenerator candidateGenerator;
@JsonSerialize
private ResultSaver resultSaver;
Expand All @@ -63,6 +70,8 @@ public class OptimizationConfiguration {

private OptimizationConfiguration(Builder builder) {
this.dataProvider = builder.dataProvider;
this.dataSource = builder.dataSource;
this.dataSourceProperties = builder.dataSourceProperties;
this.candidateGenerator = builder.candidateGenerator;
this.resultSaver = builder.resultSaver;
this.scoreFunction = builder.scoreFunction;
Expand All @@ -74,22 +83,49 @@ private OptimizationConfiguration(Builder builder) {

//Validate the configuration: data types, score types, etc
//TODO

//Validate that the dataSource has a no-arg constructor
if(dataSource != null){
try{
dataSource.getConstructor();
} catch (NoSuchMethodException e){
throw new IllegalStateException("Data source class " + dataSource.getName() + " does not have a public no-argument constructor");
}
}
}

public static class Builder {

private DataProvider dataProvider;
private Class<? extends DataSource> dataSource;
private Properties dataSourceProperties;
private CandidateGenerator candidateGenerator;
private ResultSaver resultSaver;
private ScoreFunction scoreFunction;
private List<TerminationCondition> terminationConditions;
private Long rngSeed;

/**
* @deprecated Use {@link #dataSource(Class, Properties)}
*/
@Deprecated
public Builder dataProvider(DataProvider dataProvider) {
this.dataProvider = dataProvider;
return this;
}

/**
* DataSource: defines where the data should come from for training and testing.
* Note that implementations must have a no-argument contsructor
* @param dataSource Class for the data source
* @param dataSourceProperties May be null. Properties for configuring the data source
*/
public Builder dataSource(Class<? extends DataSource> dataSource, Properties dataSourceProperties){
this.dataSource = dataSource;
this.dataSourceProperties = dataSourceProperties;
return this;
}

public Builder candidateGenerator(CandidateGenerator candidateGenerator) {
this.candidateGenerator = candidateGenerator;
return this;
Expand Down
Expand Up @@ -24,16 +24,14 @@
import org.deeplearning4j.arbiter.optimize.api.Candidate;
import org.deeplearning4j.arbiter.optimize.api.OptimizationResult;
import org.deeplearning4j.arbiter.optimize.api.data.DataProvider;
import org.deeplearning4j.arbiter.optimize.api.data.DataSource;
import org.deeplearning4j.arbiter.optimize.api.saving.ResultReference;
import org.deeplearning4j.arbiter.optimize.api.score.ScoreFunction;
import org.deeplearning4j.arbiter.optimize.api.termination.TerminationCondition;
import org.deeplearning4j.arbiter.optimize.config.OptimizationConfiguration;
import org.deeplearning4j.arbiter.optimize.runner.listener.StatusListener;

import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Queue;
import java.util.*;
import java.util.concurrent.*;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicLong;
Expand Down Expand Up @@ -149,8 +147,12 @@ public void execute() {
status = processFailedCandidates(candidate);
} else {
long created = System.currentTimeMillis();
ListenableFuture<OptimizationResult> f =
execute(candidate, config.getDataProvider(), config.getScoreFunction());
ListenableFuture<OptimizationResult> f;
if(config.getDataSource() != null){
f = execute(candidate, config.getDataSource(), config.getDataSourceProperties(), config.getScoreFunction());
} else {
f = execute(candidate, config.getDataProvider(), config.getScoreFunction());
}
f.addListener(new OnCompletionListener(f), futureListenerExecutor);
queuedFutures.add(f);
totalCandidateCount.getAndIncrement();
Expand Down Expand Up @@ -366,9 +368,16 @@ public void run() {

protected abstract int maxConcurrentTasks();

@Deprecated
protected abstract ListenableFuture<OptimizationResult> execute(Candidate candidate, DataProvider dataProvider,
ScoreFunction scoreFunction);

@Deprecated
protected abstract List<ListenableFuture<OptimizationResult>> execute(List<Candidate> candidates,
DataProvider dataProvider, ScoreFunction scoreFunction);

protected abstract ListenableFuture<OptimizationResult> execute(Candidate candidate, Class<? extends DataSource> dataSource,
Properties dataSourceProperties, ScoreFunction scoreFunction);

protected abstract List<ListenableFuture<OptimizationResult>> execute(List<Candidate> candidates, Class<? extends DataSource> dataSource,
Properties dataSourceProperties, ScoreFunction scoreFunction);
}
Expand Up @@ -22,12 +22,14 @@
import lombok.Setter;
import org.deeplearning4j.arbiter.optimize.api.*;
import org.deeplearning4j.arbiter.optimize.api.data.DataProvider;
import org.deeplearning4j.arbiter.optimize.api.data.DataSource;
import org.deeplearning4j.arbiter.optimize.api.score.ScoreFunction;
import org.deeplearning4j.arbiter.optimize.config.OptimizationConfiguration;

import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Properties;
import java.util.concurrent.*;
import java.util.concurrent.atomic.AtomicLong;

Expand Down Expand Up @@ -117,6 +119,21 @@ protected List<ListenableFuture<OptimizationResult>> execute(List<Candidate> can
return list;
}

@Override
protected ListenableFuture<OptimizationResult> execute(Candidate candidate, Class<? extends DataSource> dataSource, Properties dataSourceProperties, ScoreFunction scoreFunction) {
return execute(Collections.singletonList(candidate), dataSource, dataSourceProperties, scoreFunction).get(0);
}

@Override
protected List<ListenableFuture<OptimizationResult>> execute(List<Candidate> candidates, Class<? extends DataSource> dataSource, Properties dataSourceProperties, ScoreFunction scoreFunction) {
List<ListenableFuture<OptimizationResult>> list = new ArrayList<>(candidates.size());
for (Candidate candidate : candidates) {
Callable<OptimizationResult> task = taskCreator.create(candidate, dataSource, dataSourceProperties, scoreFunction, statusListeners, this);
list.add(executor.submit(task));
}
return list;
}

@Override
public void shutdown(boolean awaitTermination) {
if(awaitTermination){
Expand Down
Expand Up @@ -21,6 +21,7 @@
import org.deeplearning4j.arbiter.optimize.api.*;
import org.deeplearning4j.arbiter.optimize.api.data.DataProvider;
import org.deeplearning4j.arbiter.optimize.api.data.DataSetIteratorFactoryProvider;
import org.deeplearning4j.arbiter.optimize.api.data.DataSource;
import org.deeplearning4j.arbiter.optimize.api.score.ScoreFunction;
import org.deeplearning4j.arbiter.optimize.generator.GridSearchCandidateGenerator;
import org.deeplearning4j.arbiter.optimize.parameter.continuous.ContinuousParameterSpace;
Expand Down Expand Up @@ -172,6 +173,11 @@ public double score(Object m, DataProvider data, Map<String, Object> dataParamet
return a * Math.pow(x2 - b * x1 * x1 + c * x1 - r, 2.0) + s * (1 - t) * Math.cos(x1) + s;
}

@Override
public double score(Object model, Class<? extends DataSource> dataSource, Properties dataSourceProperties) {
throw new UnsupportedOperationException();
}

@Override
public boolean minimize() {
return true;
Expand Down Expand Up @@ -200,7 +206,7 @@ public OptimizationResult call() throws Exception {

BraninConfig candidate = (BraninConfig) c.getValue();

double score = scoreFunction.score(candidate, null, null);
double score = scoreFunction.score(candidate, null, (Map)null);
System.out.println(candidate.getX1() + "\t" + candidate.getX2() + "\t" + score);

Thread.sleep(20);
Expand All @@ -218,5 +224,10 @@ public OptimizationResult call() throws Exception {
}
};
}

@Override
public Callable<OptimizationResult> create(Candidate candidate, Class<? extends DataSource> dataSource, Properties dataSourceProperties, ScoreFunction scoreFunction, List<StatusListener> statusListeners, IOptimizationRunner runner) {
throw new UnsupportedOperationException();
}
}
}
Expand Up @@ -18,16 +18,19 @@

import lombok.EqualsAndHashCode;
import org.deeplearning4j.arbiter.optimize.api.data.DataProvider;
import org.deeplearning4j.arbiter.optimize.api.data.DataSource;
import org.deeplearning4j.arbiter.optimize.api.score.ScoreFunction;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.dataset.api.iterator.DataSetIteratorFactory;
import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator;

import java.io.IOException;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.Properties;

/**
* Created by Alex on 23/07/2017.
Expand All @@ -39,6 +42,24 @@ public abstract class BaseNetScoreFunction implements ScoreFunction {
@Override
public double score(Object model, DataProvider dataProvider, Map<String, Object> dataParameters) {
Object testData = dataProvider.testData(dataParameters);
return score(model, testData);
}

@Override
public double score(Object model, Class<? extends DataSource> dataSource, Properties dataSourceProperties) {
DataSource ds;
try{
ds = dataSource.newInstance();
if (dataSourceProperties != null) {
ds.configure(dataSourceProperties);
}
} catch (Exception e){
throw new RuntimeException(e);
}
return score(model, ds.testData());
}

protected double score(Object model, Object testData){
if (model instanceof MultiLayerNetwork) {
if (testData instanceof DataSetIterator) {
return score((MultiLayerNetwork) model, (DataSetIterator) testData);
Expand Down