Skip to content

Commit

Permalink
AbstractTrainer no longer stores a local copy of dbName.
Browse files Browse the repository at this point in the history
  • Loading branch information
datumbox committed Dec 21, 2016
1 parent eb43202 commit 1e76739
Show file tree
Hide file tree
Showing 7 changed files with 35 additions and 33 deletions.
3 changes: 2 additions & 1 deletion CHANGELOG.md
@@ -1,7 +1,7 @@
CHANGELOG CHANGELOG
========= =========


Version 0.8.0-SNAPSHOT - Build 20161220 Version 0.8.0-SNAPSHOT - Build 20161221
--------------------------------------- ---------------------------------------


- Improved Validation: - Improved Validation:
Expand All @@ -20,6 +20,7 @@ Version 0.8.0-SNAPSHOT - Build 20161220
- Removed the AbstractWrapper and Modeler inherits directly from AbstractTrainer. - Removed the AbstractWrapper and Modeler inherits directly from AbstractTrainer.
- Created a TrainableBundle to keep track of the Trainables of Modeler, AbstractBoostingBagging and StepwiseRegression. - Created a TrainableBundle to keep track of the Trainables of Modeler, AbstractBoostingBagging and StepwiseRegression.
- Removed automatic save after fit, now save() must be called. - Removed automatic save after fit, now save() must be called.
- AbstractTrainer no longer stores a local copy of dbName.


Version 0.7.1-SNAPSHOT - Build 20161217 Version 0.7.1-SNAPSHOT - Build 20161217
--------------------------------------- ---------------------------------------
Expand Down
Expand Up @@ -194,6 +194,7 @@ public void predict(Dataframe newData) {
protected void _fit(Dataframe trainingData) { protected void _fit(Dataframe trainingData) {
TrainingParameters trainingParameters = knowledgeBase.getTrainingParameters(); TrainingParameters trainingParameters = knowledgeBase.getTrainingParameters();
Configuration conf = knowledgeBase.getConf(); Configuration conf = knowledgeBase.getConf();
String dbName = knowledgeBase.getDbc().getDatabaseName();


//reset previous entries on the bundle //reset previous entries on the bundle
resetBundle(); resetBundle();
Expand Down Expand Up @@ -264,6 +265,7 @@ private void resetBundle() {
private void initBundle() { private void initBundle() {
TrainingParameters trainingParameters = knowledgeBase.getTrainingParameters(); TrainingParameters trainingParameters = knowledgeBase.getTrainingParameters();
Configuration conf = knowledgeBase.getConf(); Configuration conf = knowledgeBase.getConf();
String dbName = knowledgeBase.getDbc().getDatabaseName();


if(!bundle.containsKey("dataTransformer")) { if(!bundle.containsKey("dataTransformer")) {
AbstractTransformer.AbstractTrainingParameters dtParams = trainingParameters.getDataTransformerTrainingParameters(); AbstractTransformer.AbstractTrainingParameters dtParams = trainingParameters.getDataTransformerTrainingParameters();
Expand Down
Expand Up @@ -42,19 +42,19 @@
*/ */
public class InMemoryConnector extends AbstractDatabaseConnector { public class InMemoryConnector extends AbstractDatabaseConnector {


private final String database; private String dbName;
private final InMemoryConfiguration dbConf; private final InMemoryConfiguration dbConf;


/** /**
* @param database * @param dbName
* @param dbConf * @param dbConf
* @see AbstractDatabaseConnector#AbstractDatabaseConnector() * @see AbstractDatabaseConnector#AbstractDatabaseConnector()
*/ */
protected InMemoryConnector(String database, InMemoryConfiguration dbConf) { protected InMemoryConnector(String dbName, InMemoryConfiguration dbConf) {
super(); super();
this.database = database; this.dbName = dbName;
this.dbConf = dbConf; this.dbConf = dbConf;
logger.trace("Opened db "+ database); logger.trace("Opened db {}", dbName);
} }


/** {@inheritDoc} */ /** {@inheritDoc} */
Expand Down Expand Up @@ -109,9 +109,9 @@ public void close() {
return; return;
} }
super.close(); super.close();
logger.trace("Closed db "+ database); logger.trace("Closed db {}", dbName);
} }

/** {@inheritDoc} */ /** {@inheritDoc} */
@Override @Override
public void clear() { public void clear() {
Expand Down Expand Up @@ -150,7 +150,7 @@ public <T extends Map> void dropBigMap(String name, T map) {
/** {@inheritDoc} */ /** {@inheritDoc} */
@Override @Override
public String getDatabaseName() { public String getDatabaseName() {
return database; return dbName;
} }


private Path getRootPath() { private Path getRootPath() {
Expand All @@ -161,6 +161,6 @@ private Path getRootPath() {
outputFolder = System.getProperty("java.io.tmpdir"); //write them to the tmp directory outputFolder = System.getProperty("java.io.tmpdir"); //write them to the tmp directory
} }


return Paths.get(outputFolder + File.separator + database); return Paths.get(outputFolder + File.separator + dbName);
} }
} }
Expand Up @@ -41,7 +41,7 @@
*/ */
public class MapDBConnector extends AbstractDatabaseConnector { public class MapDBConnector extends AbstractDatabaseConnector {


private final String database; private String dbName;
private final MapDBConfiguration dbConf; private final MapDBConfiguration dbConf;


/** /**
Expand Down Expand Up @@ -77,15 +77,15 @@ private enum DBType {
private final Map<DBType, DB> dbRegistry = new HashMap<>(); private final Map<DBType, DB> dbRegistry = new HashMap<>();


/** /**
* @param database * @param dbName
* @param dbConf * @param dbConf
* @see AbstractDatabaseConnector#AbstractDatabaseConnector() * @see AbstractDatabaseConnector#AbstractDatabaseConnector()
*/ */
protected MapDBConnector(String database, MapDBConfiguration dbConf) { protected MapDBConnector(String dbName, MapDBConfiguration dbConf) {
super(); super();
this.database = database; this.dbName = dbName;
this.dbConf = dbConf; this.dbConf = dbConf;
logger.trace("Opened db "+ database); logger.trace("Opened db {}", dbName);
} }


/** {@inheritDoc} */ /** {@inheritDoc} */
Expand Down Expand Up @@ -140,7 +140,7 @@ public void close() {
super.close(); super.close();


closeDBRegistry(); closeDBRegistry();
logger.trace("Closed db "+ database); logger.trace("Closed db {}", dbName);
} }


/** {@inheritDoc} */ /** {@inheritDoc} */
Expand Down Expand Up @@ -256,7 +256,7 @@ public <T extends Map> void dropBigMap(String name, T map) {
/** {@inheritDoc} */ /** {@inheritDoc} */
@Override @Override
public String getDatabaseName() { public String getDatabaseName() {
return database; return dbName;
} }


//private methods of connector class //private methods of connector class
Expand Down Expand Up @@ -406,6 +406,6 @@ private Path getRootPath() {
outputFolder = System.getProperty("java.io.tmpdir"); //write them to the tmp directory outputFolder = System.getProperty("java.io.tmpdir"); //write them to the tmp directory
} }


return Paths.get(outputFolder + File.separator + database); return Paths.get(outputFolder + File.separator + dbName);
} }
} }
Expand Up @@ -107,11 +107,6 @@ public static abstract class AbstractTrainingParameters implements TrainingParam
*/ */
protected final Logger logger = LoggerFactory.getLogger(getClass()); protected final Logger logger = LoggerFactory.getLogger(getClass());


/**
* The name of the Database where we persist our data.
*/
protected final String dbName; //FIXME: do we really need the dbName here? Perhaps a temp name is good enough

/** /**
* The KnowledgeBase instance of the algorithm. * The KnowledgeBase instance of the algorithm.
*/ */
Expand All @@ -120,23 +115,23 @@ public static abstract class AbstractTrainingParameters implements TrainingParam
/** /**
* Constructor which is called on model initialization before training. * Constructor which is called on model initialization before training.
* *
* @param baseDBname * @param dbName
* @param conf * @param conf
* @param trainingParameters * @param trainingParameters
*/ */
protected AbstractTrainer(String baseDBname, Configuration conf, TP trainingParameters) { protected AbstractTrainer(String dbName, Configuration conf, TP trainingParameters) {
dbName = baseDBname + conf.getDbConfig().getDBnameSeparator() + this.getClass().getSimpleName(); dbName += conf.getDbConfig().getDBnameSeparator() + this.getClass().getSimpleName();
knowledgeBase = new KnowledgeBase<>(dbName, conf, trainingParameters); knowledgeBase = new KnowledgeBase<>(dbName, conf, trainingParameters);
} }


/** /**
* Constructor which is called when we pre-trained load persisted models. * Constructor which is called when we pre-trained load persisted models.
* *
* @param baseDBname * @param dbName
* @param conf * @param conf
*/ */
protected AbstractTrainer(String baseDBname, Configuration conf) { protected AbstractTrainer(String dbName, Configuration conf) {
dbName = baseDBname + conf.getDbConfig().getDBnameSeparator() + this.getClass().getSimpleName(); dbName += conf.getDbConfig().getDBnameSeparator() + this.getClass().getSimpleName();;
knowledgeBase = new KnowledgeBase<>(dbName, conf); knowledgeBase = new KnowledgeBase<>(dbName, conf);
} }


Expand Down
Expand Up @@ -205,6 +205,7 @@ protected void _predict(Dataframe newData) {
@Override @Override
protected void _fit(Dataframe trainingData) { protected void _fit(Dataframe trainingData) {
Configuration conf = knowledgeBase.getConf(); Configuration conf = knowledgeBase.getConf();
DatabaseConnector dbc = knowledgeBase.getDbc();
TP trainingParameters = knowledgeBase.getTrainingParameters(); TP trainingParameters = knowledgeBase.getTrainingParameters();
MP modelParameters = knowledgeBase.getModelParameters(); MP modelParameters = knowledgeBase.getModelParameters();


Expand Down Expand Up @@ -236,7 +237,7 @@ protected void _fit(Dataframe trainingData) {
//training the weak classifiers //training the weak classifiers
int t=0; int t=0;
int retryCounter = 0; int retryCounter = 0;
String prefix = dbName+conf.getDbConfig().getDBnameSeparator()+DB_INDICATOR; String prefix = dbc.getDatabaseName()+conf.getDbConfig().getDBnameSeparator()+DB_INDICATOR;
while(t<totalWeakClassifiers) { while(t<totalWeakClassifiers) {
logger.debug("Training Weak learner {}", t); logger.debug("Training Weak learner {}", t);


Expand Down Expand Up @@ -343,12 +344,13 @@ private void resetBundle() {
} }


private void initBundle() { private void initBundle() {
Configuration conf = knowledgeBase.getConf();
DatabaseConnector dbc = knowledgeBase.getDbc();
MP modelParameters = knowledgeBase.getModelParameters(); MP modelParameters = knowledgeBase.getModelParameters();
TP trainingParameters = knowledgeBase.getTrainingParameters(); TP trainingParameters = knowledgeBase.getTrainingParameters();
Configuration conf = knowledgeBase.getConf();


//the number of weak classifiers is the minimum between the classifiers that were defined in training parameters AND the number of the weak classifiers that were kept //the number of weak classifiers is the minimum between the classifiers that were defined in training parameters AND the number of the weak classifiers that were kept
String prefix = dbName+knowledgeBase.getConf().getDbConfig().getDBnameSeparator()+DB_INDICATOR; String prefix = dbc.getDatabaseName()+knowledgeBase.getConf().getDbConfig().getDBnameSeparator()+DB_INDICATOR;
Class<AbstractClassifier> weakClassifierClass = trainingParameters.getWeakClassifierTrainingParameters().getTClass(); Class<AbstractClassifier> weakClassifierClass = trainingParameters.getWeakClassifierTrainingParameters().getTClass();
int totalWeakClassifiers = Math.min(modelParameters.getWeakClassifierWeights().size(), trainingParameters.getMaxWeakClassifiers()); int totalWeakClassifiers = Math.min(modelParameters.getWeakClassifierWeights().size(), trainingParameters.getMaxWeakClassifiers());
for(int t=0;t<totalWeakClassifiers;++t) { for(int t=0;t<totalWeakClassifiers;++t) {
Expand Down
Expand Up @@ -168,6 +168,7 @@ protected void _predict(Dataframe newData) {
protected void _fit(Dataframe trainingData) { protected void _fit(Dataframe trainingData) {
TrainingParameters trainingParameters = knowledgeBase.getTrainingParameters(); TrainingParameters trainingParameters = knowledgeBase.getTrainingParameters();
Configuration conf = knowledgeBase.getConf(); Configuration conf = knowledgeBase.getConf();
String dbName = knowledgeBase.getDbc().getDatabaseName();


//reset previous entries on the bundle //reset previous entries on the bundle
resetBundle(); resetBundle();
Expand Down Expand Up @@ -254,6 +255,7 @@ private void resetBundle() {
private void initBundle() { private void initBundle() {
TrainingParameters trainingParameters = knowledgeBase.getTrainingParameters(); TrainingParameters trainingParameters = knowledgeBase.getTrainingParameters();
Configuration conf = knowledgeBase.getConf(); Configuration conf = knowledgeBase.getConf();
String dbName = knowledgeBase.getDbc().getDatabaseName();


if(!bundle.containsKey("mlregressor")) { if(!bundle.containsKey("mlregressor")) {
AbstractTrainingParameters mlParams = trainingParameters.getRegressionTrainingParameters(); AbstractTrainingParameters mlParams = trainingParameters.getRegressionTrainingParameters();
Expand All @@ -265,7 +267,7 @@ private void initBundle() {
private Map<Object, Double> runRegression(Dataframe trainingData) { private Map<Object, Double> runRegression(Dataframe trainingData) {
AbstractRegressor mlregressor = MLBuilder.create( AbstractRegressor mlregressor = MLBuilder.create(
knowledgeBase.getTrainingParameters().getRegressionTrainingParameters(), knowledgeBase.getTrainingParameters().getRegressionTrainingParameters(),
dbName, knowledgeBase.getDbc().getDatabaseName(),
knowledgeBase.getConf() knowledgeBase.getConf()
); );
mlregressor.fit(trainingData); mlregressor.fit(trainingData);
Expand Down

0 comments on commit 1e76739

Please sign in to comment.