Skip to content

Commit

Permalink
Removed the AbstractWrapper and Modeler inherits directly from Abstra…
Browse files Browse the repository at this point in the history
…ctTrainer. Created a TrainableBundle to keep track of the Trainables of Modeler, AbstractBoostingBagging and StepwiseRegression. Removed automatic save after fit, now save() must be called.
  • Loading branch information
datumbox committed Dec 21, 2016
1 parent f5012db commit eb43202
Show file tree
Hide file tree
Showing 36 changed files with 581 additions and 457 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Expand Up @@ -17,6 +17,9 @@ Version 0.8.0-SNAPSHOT - Build 20161220
- Removed all unnecessary passing of class objects from Stepwise Regression, Wrappers and Ensumble learning classes. - Removed all unnecessary passing of class objects from Stepwise Regression, Wrappers and Ensumble learning classes.
- Removed the ignoringNumericalFeatures variable from AbstractCategoricalFeatureSelector. - Removed the ignoringNumericalFeatures variable from AbstractCategoricalFeatureSelector.
- Moved the static methods of Trainable to MLBuilder. This is the prefered method for loading and initializing models. - Moved the static methods of Trainable to MLBuilder. This is the prefered method for loading and initializing models.
- Removed the AbstractWrapper and Modeler inherits directly from AbstractTrainer.
- Created a TrainableBundle to keep track of the Trainables of Modeler, AbstractBoostingBagging and StepwiseRegression.
- Removed automatic save after fit, now save() must be called.


Version 0.7.1-SNAPSHOT - Build 20161217 Version 0.7.1-SNAPSHOT - Build 20161217
--------------------------------------- ---------------------------------------
Expand Down
2 changes: 1 addition & 1 deletion TODO.txt
@@ -1,8 +1,8 @@
CODE IMPROVEMENTS CODE IMPROVEMENTS
================= =================


- Add save(dbName) method in the models.
- Can we make the two constructors of the Trainers to call a common constructor to eliminate duplicate code? - Can we make the two constructors of the Trainers to call a common constructor to eliminate duplicate code?
- Add save(dbName) method in the models


- Support of better Transformers (Zscore, decouple boolean transforming from numeric etc). - Support of better Transformers (Zscore, decouple boolean transforming from numeric etc).
- Write a ShuffleSplitValidator class similar to KFold. Perhaps we need a single Validator class and separate Splitters. - Write a ShuffleSplitValidator class similar to KFold. Perhaps we need a single Validator class and separate Splitters.
Expand Down
Expand Up @@ -17,14 +17,14 @@


import com.datumbox.framework.common.Configuration; import com.datumbox.framework.common.Configuration;
import com.datumbox.framework.common.dataobjects.Dataframe; import com.datumbox.framework.common.dataobjects.Dataframe;
import com.datumbox.framework.common.interfaces.Trainable;
import com.datumbox.framework.common.persistentstorage.interfaces.DatabaseConnector; import com.datumbox.framework.common.persistentstorage.interfaces.DatabaseConnector;
import com.datumbox.framework.core.machinelearning.MLBuilder; import com.datumbox.framework.core.machinelearning.MLBuilder;
import com.datumbox.framework.core.machinelearning.common.abstracts.AbstractTrainer; import com.datumbox.framework.core.machinelearning.common.abstracts.AbstractTrainer;
import com.datumbox.framework.core.machinelearning.common.abstracts.datatransformers.AbstractTransformer; import com.datumbox.framework.core.machinelearning.common.abstracts.datatransformers.AbstractTransformer;
import com.datumbox.framework.core.machinelearning.common.abstracts.featureselectors.AbstractFeatureSelector; import com.datumbox.framework.core.machinelearning.common.abstracts.featureselectors.AbstractFeatureSelector;
import com.datumbox.framework.core.machinelearning.common.abstracts.modelers.AbstractModeler; import com.datumbox.framework.core.machinelearning.common.abstracts.modelers.AbstractModeler;
import com.datumbox.framework.core.machinelearning.common.abstracts.wrappers.AbstractWrapper; import com.datumbox.framework.core.machinelearning.common.dataobjects.TrainableBundle;
import com.datumbox.framework.core.machinelearning.common.interfaces.Parallelizable;


/** /**
* Modeler is a convenience class which can be used to train Machine Learning * Modeler is a convenience class which can be used to train Machine Learning
Expand All @@ -33,12 +33,14 @@
* *
* @author Vasilis Vryniotis <bbriniotis@datumbox.com> * @author Vasilis Vryniotis <bbriniotis@datumbox.com>
*/ */
public class Modeler extends AbstractWrapper<Modeler.ModelParameters, Modeler.TrainingParameters> { public class Modeler extends AbstractTrainer<Modeler.ModelParameters, Modeler.TrainingParameters> implements Parallelizable {


private TrainableBundle bundle = new TrainableBundle();

/** /**
* It contains all the Model Parameters which are learned during the training. * It contains all the Model Parameters which are learned during the training.
*/ */
public static class ModelParameters extends AbstractWrapper.AbstractModelParameters { public static class ModelParameters extends AbstractTrainer.AbstractModelParameters {
private static final long serialVersionUID = 1L; private static final long serialVersionUID = 1L;


/** /**
Expand All @@ -48,15 +50,78 @@ public static class ModelParameters extends AbstractWrapper.AbstractModelParamet
protected ModelParameters(DatabaseConnector dbc) { protected ModelParameters(DatabaseConnector dbc) {
super(dbc); super(dbc);
} }

} }


/** /**
* It contains the Training Parameters of the Modeler. * It contains the Training Parameters of the Modeler.
*/ */
public static class TrainingParameters extends AbstractWrapper.AbstractTrainingParameters<AbstractTransformer, AbstractFeatureSelector, AbstractModeler> { public static class TrainingParameters extends AbstractTrainer.AbstractTrainingParameters {
private static final long serialVersionUID = 1L; private static final long serialVersionUID = 1L;


//Parameter Objects
private AbstractTransformer.AbstractTrainingParameters dataTransformerTrainingParameters;

private AbstractFeatureSelector.AbstractTrainingParameters featureSelectorTrainingParameters;

private AbstractModeler.AbstractTrainingParameters modelerTrainingParameters;

/**
* Getter for the Training Parameters of the Data Transformer.
*
* @return
*/
public AbstractTransformer.AbstractTrainingParameters getDataTransformerTrainingParameters() {
return dataTransformerTrainingParameters;
}

/**
* Setter for the Training Parameters of the Data Transformer. Pass null
* for none.
*
* @param dataTransformerTrainingParameters
*/
public void setDataTransformerTrainingParameters(AbstractTransformer.AbstractTrainingParameters dataTransformerTrainingParameters) {
this.dataTransformerTrainingParameters = dataTransformerTrainingParameters;
}

/**
* Getter for the Training Parameters of the Feature Selector.
*
* @return
*/
public AbstractFeatureSelector.AbstractTrainingParameters getFeatureSelectorTrainingParameters() {
return featureSelectorTrainingParameters;
}

/**
* Setter for the Training Parameters of the Feature Selector. Pass null
* for none.
*
* @param featureSelectorTrainingParameters
*/
public void setFeatureSelectorTrainingParameters(AbstractFeatureSelector.AbstractTrainingParameters featureSelectorTrainingParameters) {
this.featureSelectorTrainingParameters = featureSelectorTrainingParameters;
}

/**
* Getter for the Training Parameters of the Machine Learning modeler.
*
* @return
*/
public AbstractModeler.AbstractTrainingParameters getModelerTrainingParameters() {
return modelerTrainingParameters;
}

/**
* Setter for the Training Parameters of the Machine Learning modeler.
*
* @param modelerTrainingParameters
*/
public void setModelerTrainingParameters(AbstractModeler.AbstractTrainingParameters modelerTrainingParameters) {
this.modelerTrainingParameters = modelerTrainingParameters;
}

} }




Expand All @@ -79,6 +144,21 @@ public Modeler(String dbName, Configuration conf) {
super(dbName, conf); super(dbName, conf);
} }



private boolean parallelized = true;

/** {@inheritDoc} */
@Override
public boolean isParallelized() {
return parallelized;
}

/** {@inheritDoc} */
@Override
public void setParallelized(boolean parallelized) {
this.parallelized = parallelized;
}

/** /**
* Generates predictions for the given dataset. * Generates predictions for the given dataset.
* *
Expand All @@ -87,36 +167,24 @@ public Modeler(String dbName, Configuration conf) {
public void predict(Dataframe newData) { public void predict(Dataframe newData) {
logger.info("predict()"); logger.info("predict()");


Modeler.TrainingParameters trainingParameters = knowledgeBase.getTrainingParameters(); //load all trainables on the bundles
Configuration conf = knowledgeBase.getConf(); initBundle();


AbstractTrainer.AbstractTrainingParameters dtParams = trainingParameters.getDataTransformerTrainingParameters(); //set the parallized flag to all algorithms
boolean transformData = dtParams!=null; bundle.setParallelized(isParallelized());
if(transformData) {
if(dataTransformer==null) { //run the pipeline
dataTransformer = MLBuilder.load(dtParams.getTClass(), dbName, conf); AbstractTransformer dataTransformer = (AbstractTransformer) bundle.get("dataTransformer");
} if(dataTransformer != null) {
setParallelized(dataTransformer);
dataTransformer.transform(newData); dataTransformer.transform(newData);
} }

AbstractFeatureSelector featureSelector = (AbstractFeatureSelector) bundle.get("featureSelector");
AbstractTrainer.AbstractTrainingParameters fsParams = trainingParameters.getFeatureSelectorTrainingParameters(); if(featureSelector != null) {
boolean selectFeatures = fsParams!=null;
if(selectFeatures) {
if(featureSelector==null) {
featureSelector = MLBuilder.load(fsParams.getTClass(), dbName, conf);
}
setParallelized(featureSelector);
featureSelector.transform(newData); featureSelector.transform(newData);
} }

AbstractModeler modeler = (AbstractModeler) bundle.get("modeler");
if(modeler==null) {
modeler = MLBuilder.load(trainingParameters.getModelerTrainingParameters().getTClass(), dbName, conf);
}
setParallelized(modeler);
modeler.predict(newData); modeler.predict(newData);

if(dataTransformer != null) {
if(transformData) {
dataTransformer.denormalize(newData); dataTransformer.denormalize(newData);
} }
} }
Expand All @@ -127,29 +195,101 @@ protected void _fit(Dataframe trainingData) {
TrainingParameters trainingParameters = knowledgeBase.getTrainingParameters(); TrainingParameters trainingParameters = knowledgeBase.getTrainingParameters();
Configuration conf = knowledgeBase.getConf(); Configuration conf = knowledgeBase.getConf();


AbstractTrainer.AbstractTrainingParameters dtParams = trainingParameters.getDataTransformerTrainingParameters(); //reset previous entries on the bundle
boolean transformData = dtParams!=null; resetBundle();
if(transformData) {
//initialize the parts of the pipeline
AbstractTransformer.AbstractTrainingParameters dtParams = trainingParameters.getDataTransformerTrainingParameters();
AbstractTransformer dataTransformer = null;
if(dtParams != null) {
dataTransformer = MLBuilder.create(dtParams, dbName, conf); dataTransformer = MLBuilder.create(dtParams, dbName, conf);
setParallelized(dataTransformer); bundle.put("dataTransformer", dataTransformer);
dataTransformer.fit_transform(trainingData);
} }


AbstractTrainer.AbstractTrainingParameters fsParams = trainingParameters.getFeatureSelectorTrainingParameters(); AbstractFeatureSelector.AbstractTrainingParameters fsParams = trainingParameters.getFeatureSelectorTrainingParameters();
boolean selectFeatures = fsParams!=null; AbstractFeatureSelector featureSelector = null;
if(selectFeatures) { if(fsParams != null) {
featureSelector = MLBuilder.create(fsParams, dbName, conf); featureSelector = MLBuilder.create(fsParams, dbName, conf);
setParallelized(featureSelector); bundle.put("featureSelector", featureSelector);
featureSelector.fit_transform(trainingData);
} }


AbstractTrainer.AbstractTrainingParameters mlParams = trainingParameters.getModelerTrainingParameters(); AbstractModeler.AbstractTrainingParameters mlParams = trainingParameters.getModelerTrainingParameters();
modeler = MLBuilder.create(mlParams, dbName, conf); AbstractModeler modeler = MLBuilder.create(mlParams, dbName, conf);
setParallelized(modeler); bundle.put("modeler", modeler);

//set the parallized flag to all algorithms
bundle.setParallelized(isParallelized());

//run the pipeline
if(dataTransformer != null) {
dataTransformer.fit_transform(trainingData);
}
if(featureSelector != null) {
featureSelector.fit_transform(trainingData);
}
modeler.fit(trainingData); modeler.fit(trainingData);

if(dataTransformer != null) {
if(transformData) {
dataTransformer.denormalize(trainingData); dataTransformer.denormalize(trainingData);
} }
} }

/** {@inheritDoc} */
@Override
public void save() {
initBundle();
bundle.save();
super.save();
}

/** {@inheritDoc} */
@Override
public void delete() {
initBundle();
bundle.delete();
super.delete();
}

/** {@inheritDoc} */
@Override
public void close() {
initBundle();
bundle.close();
super.close();
}

private void resetBundle() {
bundle.delete();
}

private void initBundle() {
TrainingParameters trainingParameters = knowledgeBase.getTrainingParameters();
Configuration conf = knowledgeBase.getConf();

if(!bundle.containsKey("dataTransformer")) {
AbstractTransformer.AbstractTrainingParameters dtParams = trainingParameters.getDataTransformerTrainingParameters();

AbstractTransformer dataTransformer = null;
if(dtParams != null) {
dataTransformer = MLBuilder.load(dtParams.getTClass(), dbName, conf);
}
bundle.put("dataTransformer", dataTransformer);
}

if(!bundle.containsKey("featureSelector")) {
AbstractFeatureSelector.AbstractTrainingParameters fsParams = trainingParameters.getFeatureSelectorTrainingParameters();

AbstractFeatureSelector featureSelector = null;
if(fsParams != null) {
featureSelector = MLBuilder.load(fsParams.getTClass(), dbName, conf);
}
bundle.put("featureSelector", featureSelector);
}

if(!bundle.containsKey("modeler")) {
AbstractModeler.AbstractTrainingParameters mlParams = trainingParameters.getModelerTrainingParameters();

bundle.put("modeler", MLBuilder.load(mlParams.getTClass(), dbName, conf));
}
}

} }
Expand Up @@ -256,10 +256,7 @@ private void performClustering(Dataframe dataset, int numberOfClusters) {


instance.fit(dataset); instance.fit(dataset);
instance.predict(dataset); instance.predict(dataset);
//Map<Integer, BaseMLclusterer.Cluster> clusters = instance.getClusters(); instance.close();

instance.delete(); //delete immediately the result
//instance = null;
} }


private List<Double> calculateTTRlist(List<String> rows) { private List<Double> calculateTTRlist(List<String> rows) {
Expand Down
Expand Up @@ -20,15 +20,9 @@
import com.datumbox.framework.common.dataobjects.AssociativeArray; import com.datumbox.framework.common.dataobjects.AssociativeArray;
import com.datumbox.framework.common.dataobjects.Dataframe; import com.datumbox.framework.common.dataobjects.Dataframe;
import com.datumbox.framework.common.dataobjects.Record; import com.datumbox.framework.common.dataobjects.Record;
import com.datumbox.framework.common.interfaces.Trainable;
import com.datumbox.framework.common.persistentstorage.interfaces.DatabaseConnector; import com.datumbox.framework.common.persistentstorage.interfaces.DatabaseConnector;
import com.datumbox.framework.common.utilities.StringCleaner; import com.datumbox.framework.common.utilities.StringCleaner;
import com.datumbox.framework.core.machinelearning.common.abstracts.AbstractTrainer; import com.datumbox.framework.core.machinelearning.common.abstracts.AbstractTrainer;
import com.datumbox.framework.core.machinelearning.common.abstracts.datatransformers.AbstractTransformer;
import com.datumbox.framework.core.machinelearning.common.abstracts.featureselectors.AbstractCategoricalFeatureSelector;
import com.datumbox.framework.core.machinelearning.common.abstracts.featureselectors.AbstractFeatureSelector;
import com.datumbox.framework.core.machinelearning.common.abstracts.modelers.AbstractModeler;
import com.datumbox.framework.core.machinelearning.common.abstracts.wrappers.AbstractWrapper;
import com.datumbox.framework.core.machinelearning.modelselection.metrics.ClassificationMetrics; import com.datumbox.framework.core.machinelearning.modelselection.metrics.ClassificationMetrics;
import com.datumbox.framework.core.utilities.text.extractors.AbstractTextExtractor; import com.datumbox.framework.core.utilities.text.extractors.AbstractTextExtractor;


Expand Down
Expand Up @@ -74,6 +74,7 @@ public void testTrainAndValidate() {


Modeler instance = MLBuilder.create(trainingParameters, dbName, conf); Modeler instance = MLBuilder.create(trainingParameters, dbName, conf);
instance.fit(trainingData); instance.fit(trainingData);
instance.save();


instance.close(); instance.close();


Expand All @@ -86,6 +87,7 @@ public void testTrainAndValidate() {
double expResult2 = 0.8; double expResult2 = 0.8;
assertEquals(expResult2, vm.getMacroF1(), Constants.DOUBLE_ACCURACY_HIGH); assertEquals(expResult2, vm.getMacroF1(), Constants.DOUBLE_ACCURACY_HIGH);


trainingData.delete();
instance.close(); instance.close();
//instance = null; //instance = null;


Expand All @@ -107,8 +109,7 @@ public void testTrainAndValidate() {
assertEquals(expResult, result); assertEquals(expResult, result);


instance.delete(); instance.delete();


trainingData.delete();
validationData.delete(); validationData.delete();
} }


Expand Down

0 comments on commit eb43202

Please sign in to comment.