Skip to content

Commit

Permalink
Removing custom patch from AbstractBoostingBagging.
Browse files Browse the repository at this point in the history
  • Loading branch information
datumbox committed Dec 21, 2016
1 parent 4688b8f commit 70a5e5d
Show file tree
Hide file tree
Showing 6 changed files with 40 additions and 37 deletions.
4 changes: 3 additions & 1 deletion CHANGELOG.md
Expand Up @@ -20,7 +20,9 @@ Version 0.8.0-SNAPSHOT - Build 20161221
- Removed the AbstractWrapper and Modeler inherits directly from AbstractTrainer.
- Created a TrainableBundle to keep track of the Trainables of Modeler, AbstractBoostingBagging and StepwiseRegression.
- Removed automatic save after fit, now save() must be called.
- AbstractTrainer no longer stores a local copy of dbName.
- AbstractTrainer no longer stores a local copy of dbName. The save method accepts a dbName.
- The DatabaseConfiguration.getDBnameSeparator() method was removed.
- The closeAndRename() is created in DatabaseConnectors and it's used by KnowledgeBase to saveAs the models.

Version 0.7.1-SNAPSHOT - Build 20161217
---------------------------------------
Expand Down
1 change: 1 addition & 0 deletions TODO.txt
@@ -1,6 +1,7 @@
CODE IMPROVEMENTS
=================

- Can we add all the files of a model in a single folder.
- Can we make the two constructors of the Trainers to call a common constructor to eliminate duplicate code?

- Support of better Transformers (Zscore, decouple boolean transforming from numeric etc).
Expand Down
Expand Up @@ -35,6 +35,10 @@
*/
public class Modeler extends AbstractTrainer<Modeler.ModelParameters, Modeler.TrainingParameters> implements Parallelizable {

private static final String DT_KEY = "dt";
private static final String FS_KEY = "fs";
private static final String ML_KEY = "ml";

private TrainableBundle bundle = new TrainableBundle();

/**
Expand Down Expand Up @@ -173,15 +177,15 @@ public void predict(Dataframe newData) {
bundle.setParallelized(isParallelized());

//run the pipeline
AbstractTransformer dataTransformer = (AbstractTransformer) bundle.get("dataTransformer");
AbstractTransformer dataTransformer = (AbstractTransformer) bundle.get(DT_KEY);
if(dataTransformer != null) {
dataTransformer.transform(newData);
}
AbstractFeatureSelector featureSelector = (AbstractFeatureSelector) bundle.get("featureSelector");
AbstractFeatureSelector featureSelector = (AbstractFeatureSelector) bundle.get(FS_KEY);
if(featureSelector != null) {
featureSelector.transform(newData);
}
AbstractModeler modeler = (AbstractModeler) bundle.get("modeler");
AbstractModeler modeler = (AbstractModeler) bundle.get(ML_KEY);
modeler.predict(newData);
if(dataTransformer != null) {
dataTransformer.denormalize(newData);
Expand All @@ -193,7 +197,6 @@ public void predict(Dataframe newData) {
protected void _fit(Dataframe trainingData) {
TrainingParameters trainingParameters = knowledgeBase.getTrainingParameters();
Configuration conf = knowledgeBase.getConf();
String dbName = knowledgeBase.getDbc().getDatabaseName();

//reset previous entries on the bundle
resetBundle();
Expand All @@ -203,19 +206,19 @@ protected void _fit(Dataframe trainingData) {
AbstractTransformer dataTransformer = null;
if(dtParams != null) {
dataTransformer = MLBuilder.create(dtParams, conf);
bundle.put("dataTransformer", dataTransformer);
bundle.put(DT_KEY, dataTransformer);
}

AbstractFeatureSelector.AbstractTrainingParameters fsParams = trainingParameters.getFeatureSelectorTrainingParameters();
AbstractFeatureSelector featureSelector = null;
if(fsParams != null) {
featureSelector = MLBuilder.create(fsParams, conf);
bundle.put("featureSelector", featureSelector);
bundle.put(FS_KEY, featureSelector);
}

AbstractModeler.AbstractTrainingParameters mlParams = trainingParameters.getModelerTrainingParameters();
AbstractModeler modeler = MLBuilder.create(mlParams, conf);
bundle.put("modeler", modeler);
bundle.put(ML_KEY, modeler);

//set the parallized flag to all algorithms
bundle.setParallelized(isParallelized());
Expand Down Expand Up @@ -267,30 +270,30 @@ private void initBundle() {
Configuration conf = knowledgeBase.getConf();
String dbName = knowledgeBase.getDbc().getDatabaseName();

if(!bundle.containsKey("dataTransformer")) {
if(!bundle.containsKey(DT_KEY)) {
AbstractTransformer.AbstractTrainingParameters dtParams = trainingParameters.getDataTransformerTrainingParameters();

AbstractTransformer dataTransformer = null;
if(dtParams != null) {
dataTransformer = MLBuilder.load(dtParams.getTClass(), dbName, conf);
dataTransformer = MLBuilder.load(dtParams.getTClass(), dbName + "_" + DT_KEY, conf);
}
bundle.put("dataTransformer", dataTransformer);
bundle.put(DT_KEY, dataTransformer);
}

if(!bundle.containsKey("featureSelector")) {
if(!bundle.containsKey(FS_KEY)) {
AbstractFeatureSelector.AbstractTrainingParameters fsParams = trainingParameters.getFeatureSelectorTrainingParameters();

AbstractFeatureSelector featureSelector = null;
if(fsParams != null) {
featureSelector = MLBuilder.load(fsParams.getTClass(), dbName, conf);
featureSelector = MLBuilder.load(fsParams.getTClass(), dbName + "_" + FS_KEY, conf);
}
bundle.put("featureSelector", featureSelector);
bundle.put(FS_KEY, featureSelector);
}

if(!bundle.containsKey("modeler")) {
if(!bundle.containsKey(ML_KEY)) {
AbstractModeler.AbstractTrainingParameters mlParams = trainingParameters.getModelerTrainingParameters();

bundle.put("modeler", MLBuilder.load(mlParams.getTClass(), dbName, conf));
bundle.put(ML_KEY, MLBuilder.load(mlParams.getTClass(), dbName + "_" + ML_KEY, conf));
}
}

Expand Down
Expand Up @@ -167,20 +167,20 @@ protected void _predict(Dataframe newData) {
//using the weak classifiers
AssociativeArray classifierWeightsArray = new AssociativeArray();
int totalWeakClassifiers = weakClassifierWeights.size();
for(int t=0;t<totalWeakClassifiers;++t) {
for(int i=0;i<totalWeakClassifiers;++i) {

AbstractClassifier mlclassifier = (AbstractClassifier) bundle.get(String.valueOf(t));
AbstractClassifier mlclassifier = (AbstractClassifier) bundle.get(DB_INDICATOR + i);
mlclassifier.predict(newData);

classifierWeightsArray.put(t, weakClassifierWeights.get(t));
classifierWeightsArray.put(i, weakClassifierWeights.get(i));

for(Map.Entry<Integer, Record> e : newData.entries()) {
Integer rId = e.getKey();
Record r = e.getValue();
AssociativeArray classProbabilities = r.getYPredictedProbabilities();

DataTable2D rDecisions = tmp_recordDecisions.get(rId);
rDecisions.put(t, classProbabilities);
rDecisions.put(i, classProbabilities);

tmp_recordDecisions.put(rId, rDecisions); //WARNING: Do not remove this! We must put it back to the Map to persist it on Disk-backed maps
}
Expand Down Expand Up @@ -258,7 +258,7 @@ protected void _fit(Dataframe trainingData) {
mlclassifier.close();
}
else {
bundle.put(String.valueOf(i), mlclassifier);
bundle.put(DB_INDICATOR + i, mlclassifier);
}

if(status==Status.STOP) {
Expand Down Expand Up @@ -319,13 +319,7 @@ protected enum Status {
public void save(String dbName) {
initBundle();
String knowledgeBaseName = createKnowledgeBaseName(dbName);
//bundle.save(knowledgeBaseName);
for(String i : bundle.keySet()) { //TODO: remove this custom case if possible. try adding the key in the name.
Trainable t = bundle.get(i);
if(t != null) {
t.save(knowledgeBaseName + "_" + DB_INDICATOR + i);
}
}
bundle.save(knowledgeBaseName);
super.save(dbName);
}

Expand Down Expand Up @@ -359,8 +353,9 @@ private void initBundle() {
Class<AbstractClassifier> weakClassifierClass = trainingParameters.getWeakClassifierTrainingParameters().getTClass();
int totalWeakClassifiers = Math.min(modelParameters.getWeakClassifierWeights().size(), trainingParameters.getMaxWeakClassifiers());
for(int i=0;i<totalWeakClassifiers;i++) {
if (!bundle.containsKey(String.valueOf(i))) {
bundle.put(String.valueOf(i), MLBuilder.load(weakClassifierClass, dbc.getDatabaseName() + "_" + DB_INDICATOR + i, conf));
String key = DB_INDICATOR + i;
if (!bundle.containsKey(key)) {
bundle.put(key, MLBuilder.load(weakClassifierClass, dbc.getDatabaseName() + "_" + key, conf));
}
}
}
Expand Down
Expand Up @@ -89,9 +89,10 @@ public void setParallelized(boolean parallelized) {

/** {@inheritDoc} */
public void save(String dbName) {
for(Trainable t : bundle.values()) {
for(Map.Entry<String, Trainable> e : bundle.entrySet()) {
Trainable t = e.getValue();
if(t != null) {
t.save(dbName);
t.save(dbName + "_" + e.getKey());
}
}
}
Expand Down
Expand Up @@ -39,6 +39,8 @@
*/
public class StepwiseRegression extends AbstractRegressor<StepwiseRegression.ModelParameters, StepwiseRegression.TrainingParameters> {

private static final String REG_KEY = "reg";

private TrainableBundle bundle = new TrainableBundle();

/** {@inheritDoc} */
Expand Down Expand Up @@ -158,7 +160,7 @@ protected void _predict(Dataframe newData) {
initBundle();

//run the pipeline
AbstractRegressor mlregressor = (AbstractRegressor) bundle.get("mlregressor");
AbstractRegressor mlregressor = (AbstractRegressor) bundle.get(REG_KEY);
mlregressor.predict(newData);
}

Expand All @@ -167,7 +169,6 @@ protected void _predict(Dataframe newData) {
protected void _fit(Dataframe trainingData) {
TrainingParameters trainingParameters = knowledgeBase.getTrainingParameters();
Configuration conf = knowledgeBase.getConf();
String dbName = knowledgeBase.getDbc().getDatabaseName();

//reset previous entries on the bundle
resetBundle();
Expand Down Expand Up @@ -217,7 +218,7 @@ protected void _fit(Dataframe trainingData) {
conf
);
mlregressor.fit(copiedTrainingData);
bundle.put("mlregressor", mlregressor);
bundle.put(REG_KEY, mlregressor);

copiedTrainingData.delete();
}
Expand Down Expand Up @@ -256,10 +257,10 @@ private void initBundle() {
Configuration conf = knowledgeBase.getConf();
String dbName = knowledgeBase.getDbc().getDatabaseName();

if(!bundle.containsKey("mlregressor")) {
if(!bundle.containsKey(REG_KEY)) {
AbstractTrainingParameters mlParams = trainingParameters.getRegressionTrainingParameters();

bundle.put("mlregressor", MLBuilder.load(mlParams.getTClass(), dbName, conf));
bundle.put(REG_KEY, MLBuilder.load(mlParams.getTClass(), dbName + "_" + REG_KEY, conf));
}
}

Expand Down

0 comments on commit 70a5e5d

Please sign in to comment.