Skip to content

Commit

Permalink
Removed the AbstractWrapper and Modeler inherits directly from Abstra…
Browse files Browse the repository at this point in the history
…ctTrainer. Created a TrainableBundle to keep track of the Trainables of Modeler, AbstractBoostingBagging and StepwiseRegression. Removed automatic save after fit, now save() must be called.
  • Loading branch information
datumbox committed Dec 21, 2016
1 parent f5012db commit eb43202
Show file tree
Hide file tree
Showing 36 changed files with 581 additions and 457 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Expand Up @@ -17,6 +17,9 @@ Version 0.8.0-SNAPSHOT - Build 20161220
- Removed all unnecessary passing of class objects from Stepwise Regression, Wrappers and Ensumble learning classes.
- Removed the ignoringNumericalFeatures variable from AbstractCategoricalFeatureSelector.
- Moved the static methods of Trainable to MLBuilder. This is the prefered method for loading and initializing models.
- Removed the AbstractWrapper and Modeler inherits directly from AbstractTrainer.
- Created a TrainableBundle to keep track of the Trainables of Modeler, AbstractBoostingBagging and StepwiseRegression.
- Removed automatic save after fit, now save() must be called.

Version 0.7.1-SNAPSHOT - Build 20161217
---------------------------------------
Expand Down
2 changes: 1 addition & 1 deletion TODO.txt
@@ -1,8 +1,8 @@
CODE IMPROVEMENTS
=================

- Add save(dbName) method in the models.
- Can we make the two constructors of the Trainers to call a common constructor to eliminate duplicate code?
- Add save(dbName) method in the models

- Support of better Transformers (Zscore, decouple boolean transforming from numeric etc).
- Write a ShuffleSplitValidator class similar to KFold. Perhaps we need a single Validator class and separate Splitters.
Expand Down
Expand Up @@ -17,14 +17,14 @@

import com.datumbox.framework.common.Configuration;
import com.datumbox.framework.common.dataobjects.Dataframe;
import com.datumbox.framework.common.interfaces.Trainable;
import com.datumbox.framework.common.persistentstorage.interfaces.DatabaseConnector;
import com.datumbox.framework.core.machinelearning.MLBuilder;
import com.datumbox.framework.core.machinelearning.common.abstracts.AbstractTrainer;
import com.datumbox.framework.core.machinelearning.common.abstracts.datatransformers.AbstractTransformer;
import com.datumbox.framework.core.machinelearning.common.abstracts.featureselectors.AbstractFeatureSelector;
import com.datumbox.framework.core.machinelearning.common.abstracts.modelers.AbstractModeler;
import com.datumbox.framework.core.machinelearning.common.abstracts.wrappers.AbstractWrapper;
import com.datumbox.framework.core.machinelearning.common.dataobjects.TrainableBundle;
import com.datumbox.framework.core.machinelearning.common.interfaces.Parallelizable;

/**
* Modeler is a convenience class which can be used to train Machine Learning
Expand All @@ -33,12 +33,14 @@
*
* @author Vasilis Vryniotis <bbriniotis@datumbox.com>
*/
public class Modeler extends AbstractWrapper<Modeler.ModelParameters, Modeler.TrainingParameters> {

public class Modeler extends AbstractTrainer<Modeler.ModelParameters, Modeler.TrainingParameters> implements Parallelizable {

private TrainableBundle bundle = new TrainableBundle();

/**
* It contains all the Model Parameters which are learned during the training.
*/
public static class ModelParameters extends AbstractWrapper.AbstractModelParameters {
public static class ModelParameters extends AbstractTrainer.AbstractModelParameters {
private static final long serialVersionUID = 1L;

/**
Expand All @@ -48,15 +50,78 @@ public static class ModelParameters extends AbstractWrapper.AbstractModelParamet
protected ModelParameters(DatabaseConnector dbc) {
super(dbc);
}

}

/**
* It contains the Training Parameters of the Modeler.
*/
public static class TrainingParameters extends AbstractWrapper.AbstractTrainingParameters<AbstractTransformer, AbstractFeatureSelector, AbstractModeler> {
public static class TrainingParameters extends AbstractTrainer.AbstractTrainingParameters {
private static final long serialVersionUID = 1L;

//Parameter Objects
private AbstractTransformer.AbstractTrainingParameters dataTransformerTrainingParameters;

private AbstractFeatureSelector.AbstractTrainingParameters featureSelectorTrainingParameters;

private AbstractModeler.AbstractTrainingParameters modelerTrainingParameters;

/**
* Getter for the Training Parameters of the Data Transformer.
*
* @return
*/
public AbstractTransformer.AbstractTrainingParameters getDataTransformerTrainingParameters() {
return dataTransformerTrainingParameters;
}

/**
* Setter for the Training Parameters of the Data Transformer. Pass null
* for none.
*
* @param dataTransformerTrainingParameters
*/
public void setDataTransformerTrainingParameters(AbstractTransformer.AbstractTrainingParameters dataTransformerTrainingParameters) {
this.dataTransformerTrainingParameters = dataTransformerTrainingParameters;
}

/**
* Getter for the Training Parameters of the Feature Selector.
*
* @return
*/
public AbstractFeatureSelector.AbstractTrainingParameters getFeatureSelectorTrainingParameters() {
return featureSelectorTrainingParameters;
}

/**
* Setter for the Training Parameters of the Feature Selector. Pass null
* for none.
*
* @param featureSelectorTrainingParameters
*/
public void setFeatureSelectorTrainingParameters(AbstractFeatureSelector.AbstractTrainingParameters featureSelectorTrainingParameters) {
this.featureSelectorTrainingParameters = featureSelectorTrainingParameters;
}

/**
* Getter for the Training Parameters of the Machine Learning modeler.
*
* @return
*/
public AbstractModeler.AbstractTrainingParameters getModelerTrainingParameters() {
return modelerTrainingParameters;
}

/**
* Setter for the Training Parameters of the Machine Learning modeler.
*
* @param modelerTrainingParameters
*/
public void setModelerTrainingParameters(AbstractModeler.AbstractTrainingParameters modelerTrainingParameters) {
this.modelerTrainingParameters = modelerTrainingParameters;
}

}


Expand All @@ -79,6 +144,21 @@ public Modeler(String dbName, Configuration conf) {
super(dbName, conf);
}


private boolean parallelized = true;

/** {@inheritDoc} */
@Override
public boolean isParallelized() {
return parallelized;
}

/** {@inheritDoc} */
@Override
public void setParallelized(boolean parallelized) {
this.parallelized = parallelized;
}

/**
* Generates predictions for the given dataset.
*
Expand All @@ -87,36 +167,24 @@ public Modeler(String dbName, Configuration conf) {
public void predict(Dataframe newData) {
logger.info("predict()");

Modeler.TrainingParameters trainingParameters = knowledgeBase.getTrainingParameters();
Configuration conf = knowledgeBase.getConf();
//load all trainables on the bundles
initBundle();

AbstractTrainer.AbstractTrainingParameters dtParams = trainingParameters.getDataTransformerTrainingParameters();
boolean transformData = dtParams!=null;
if(transformData) {
if(dataTransformer==null) {
dataTransformer = MLBuilder.load(dtParams.getTClass(), dbName, conf);
}
setParallelized(dataTransformer);
//set the parallized flag to all algorithms
bundle.setParallelized(isParallelized());

//run the pipeline
AbstractTransformer dataTransformer = (AbstractTransformer) bundle.get("dataTransformer");
if(dataTransformer != null) {
dataTransformer.transform(newData);
}

AbstractTrainer.AbstractTrainingParameters fsParams = trainingParameters.getFeatureSelectorTrainingParameters();
boolean selectFeatures = fsParams!=null;
if(selectFeatures) {
if(featureSelector==null) {
featureSelector = MLBuilder.load(fsParams.getTClass(), dbName, conf);
}
setParallelized(featureSelector);
AbstractFeatureSelector featureSelector = (AbstractFeatureSelector) bundle.get("featureSelector");
if(featureSelector != null) {
featureSelector.transform(newData);
}

if(modeler==null) {
modeler = MLBuilder.load(trainingParameters.getModelerTrainingParameters().getTClass(), dbName, conf);
}
setParallelized(modeler);
AbstractModeler modeler = (AbstractModeler) bundle.get("modeler");
modeler.predict(newData);

if(transformData) {
if(dataTransformer != null) {
dataTransformer.denormalize(newData);
}
}
Expand All @@ -127,29 +195,101 @@ protected void _fit(Dataframe trainingData) {
TrainingParameters trainingParameters = knowledgeBase.getTrainingParameters();
Configuration conf = knowledgeBase.getConf();

AbstractTrainer.AbstractTrainingParameters dtParams = trainingParameters.getDataTransformerTrainingParameters();
boolean transformData = dtParams!=null;
if(transformData) {
//reset previous entries on the bundle
resetBundle();

//initialize the parts of the pipeline
AbstractTransformer.AbstractTrainingParameters dtParams = trainingParameters.getDataTransformerTrainingParameters();
AbstractTransformer dataTransformer = null;
if(dtParams != null) {
dataTransformer = MLBuilder.create(dtParams, dbName, conf);
setParallelized(dataTransformer);
dataTransformer.fit_transform(trainingData);
bundle.put("dataTransformer", dataTransformer);
}

AbstractTrainer.AbstractTrainingParameters fsParams = trainingParameters.getFeatureSelectorTrainingParameters();
boolean selectFeatures = fsParams!=null;
if(selectFeatures) {
AbstractFeatureSelector.AbstractTrainingParameters fsParams = trainingParameters.getFeatureSelectorTrainingParameters();
AbstractFeatureSelector featureSelector = null;
if(fsParams != null) {
featureSelector = MLBuilder.create(fsParams, dbName, conf);
setParallelized(featureSelector);
featureSelector.fit_transform(trainingData);
bundle.put("featureSelector", featureSelector);
}

AbstractTrainer.AbstractTrainingParameters mlParams = trainingParameters.getModelerTrainingParameters();
modeler = MLBuilder.create(mlParams, dbName, conf);
setParallelized(modeler);
AbstractModeler.AbstractTrainingParameters mlParams = trainingParameters.getModelerTrainingParameters();
AbstractModeler modeler = MLBuilder.create(mlParams, dbName, conf);
bundle.put("modeler", modeler);

//set the parallized flag to all algorithms
bundle.setParallelized(isParallelized());

//run the pipeline
if(dataTransformer != null) {
dataTransformer.fit_transform(trainingData);
}
if(featureSelector != null) {
featureSelector.fit_transform(trainingData);
}
modeler.fit(trainingData);

if(transformData) {
if(dataTransformer != null) {
dataTransformer.denormalize(trainingData);
}
}

/** {@inheritDoc} */
@Override
public void save() {
initBundle();
bundle.save();
super.save();
}

/** {@inheritDoc} */
@Override
public void delete() {
initBundle();
bundle.delete();
super.delete();
}

/** {@inheritDoc} */
@Override
public void close() {
initBundle();
bundle.close();
super.close();
}

private void resetBundle() {
bundle.delete();
}

private void initBundle() {
TrainingParameters trainingParameters = knowledgeBase.getTrainingParameters();
Configuration conf = knowledgeBase.getConf();

if(!bundle.containsKey("dataTransformer")) {
AbstractTransformer.AbstractTrainingParameters dtParams = trainingParameters.getDataTransformerTrainingParameters();

AbstractTransformer dataTransformer = null;
if(dtParams != null) {
dataTransformer = MLBuilder.load(dtParams.getTClass(), dbName, conf);
}
bundle.put("dataTransformer", dataTransformer);
}

if(!bundle.containsKey("featureSelector")) {
AbstractFeatureSelector.AbstractTrainingParameters fsParams = trainingParameters.getFeatureSelectorTrainingParameters();

AbstractFeatureSelector featureSelector = null;
if(fsParams != null) {
featureSelector = MLBuilder.load(fsParams.getTClass(), dbName, conf);
}
bundle.put("featureSelector", featureSelector);
}

if(!bundle.containsKey("modeler")) {
AbstractModeler.AbstractTrainingParameters mlParams = trainingParameters.getModelerTrainingParameters();

bundle.put("modeler", MLBuilder.load(mlParams.getTClass(), dbName, conf));
}
}

}
Expand Up @@ -256,10 +256,7 @@ private void performClustering(Dataframe dataset, int numberOfClusters) {

instance.fit(dataset);
instance.predict(dataset);
//Map<Integer, BaseMLclusterer.Cluster> clusters = instance.getClusters();

instance.delete(); //delete immediately the result
//instance = null;
instance.close();
}

private List<Double> calculateTTRlist(List<String> rows) {
Expand Down
Expand Up @@ -20,15 +20,9 @@
import com.datumbox.framework.common.dataobjects.AssociativeArray;
import com.datumbox.framework.common.dataobjects.Dataframe;
import com.datumbox.framework.common.dataobjects.Record;
import com.datumbox.framework.common.interfaces.Trainable;
import com.datumbox.framework.common.persistentstorage.interfaces.DatabaseConnector;
import com.datumbox.framework.common.utilities.StringCleaner;
import com.datumbox.framework.core.machinelearning.common.abstracts.AbstractTrainer;
import com.datumbox.framework.core.machinelearning.common.abstracts.datatransformers.AbstractTransformer;
import com.datumbox.framework.core.machinelearning.common.abstracts.featureselectors.AbstractCategoricalFeatureSelector;
import com.datumbox.framework.core.machinelearning.common.abstracts.featureselectors.AbstractFeatureSelector;
import com.datumbox.framework.core.machinelearning.common.abstracts.modelers.AbstractModeler;
import com.datumbox.framework.core.machinelearning.common.abstracts.wrappers.AbstractWrapper;
import com.datumbox.framework.core.machinelearning.modelselection.metrics.ClassificationMetrics;
import com.datumbox.framework.core.utilities.text.extractors.AbstractTextExtractor;

Expand Down
Expand Up @@ -74,6 +74,7 @@ public void testTrainAndValidate() {

Modeler instance = MLBuilder.create(trainingParameters, dbName, conf);
instance.fit(trainingData);
instance.save();

instance.close();

Expand All @@ -86,6 +87,7 @@ public void testTrainAndValidate() {
double expResult2 = 0.8;
assertEquals(expResult2, vm.getMacroF1(), Constants.DOUBLE_ACCURACY_HIGH);

trainingData.delete();
instance.close();
//instance = null;

Expand All @@ -107,8 +109,7 @@ public void testTrainAndValidate() {
assertEquals(expResult, result);

instance.delete();

trainingData.delete();

validationData.delete();
}

Expand Down

0 comments on commit eb43202

Please sign in to comment.