Skip to content

Commit

Permalink
newInstance(), constructors etc of inherited from BaseTrainable pass …
Browse files Browse the repository at this point in the history
…to BaseTrainable and reduce the extra code.
  • Loading branch information
datumbox committed Mar 28, 2015
1 parent 870225f commit 7de9734
Show file tree
Hide file tree
Showing 9 changed files with 32 additions and 136 deletions.
1 change: 0 additions & 1 deletion TODO.txt
Expand Up @@ -9,7 +9,6 @@ NEW ALGORITHMS
CODE IMPROVEMENT CODE IMPROVEMENT
================ ================


- newInstance(), constructors etc of inherited from BaseTrainable could perhaps pass to BaseTrainable and reduce the extra code.
- Make this a method DATA_SAFE_CALL_BY_REFERENCE and IS_BINARIZED? - Make this a method DATA_SAFE_CALL_BY_REFERENCE and IS_BINARIZED?
- should we pass validator class in the constructors of all MLmodels? - should we pass validator class in the constructors of all MLmodels?
- Test a call directly on predict() and modify code to return the correct exception - Test a call directly on predict() and modify code to return the correct exception
Expand Down
Expand Up @@ -19,6 +19,7 @@
import com.datumbox.common.dataobjects.Dataset; import com.datumbox.common.dataobjects.Dataset;
import com.datumbox.common.persistentstorage.interfaces.DatabaseConfiguration; import com.datumbox.common.persistentstorage.interfaces.DatabaseConfiguration;
import com.datumbox.common.persistentstorage.interfaces.DatabaseConnector; import com.datumbox.common.persistentstorage.interfaces.DatabaseConnector;
import com.datumbox.framework.machinelearning.common.bases.BaseTrainable;
import com.datumbox.framework.machinelearning.common.bases.featureselection.FeatureSelection; import com.datumbox.framework.machinelearning.common.bases.featureselection.FeatureSelection;
import com.datumbox.framework.machinelearning.common.bases.mlmodels.BaseMLmodel; import com.datumbox.framework.machinelearning.common.bases.mlmodels.BaseMLmodel;
import com.datumbox.framework.machinelearning.common.bases.wrappers.BaseWrapper; import com.datumbox.framework.machinelearning.common.bases.wrappers.BaseWrapper;
Expand Down Expand Up @@ -75,7 +76,7 @@ public void _fit(Dataset trainingData) {


boolean transformData = (dtClass!=null); boolean transformData = (dtClass!=null);
if(transformData) { if(transformData) {
dataTransformer = DataTransformer.newInstance(dtClass, dbName, dbConf); dataTransformer = DataTransformer.<DataTransformer>newInstance(dtClass, dbName, dbConf);
dataTransformer.fit_transform(trainingData, trainingParameters.getDataTransformerTrainingParameters()); dataTransformer.fit_transform(trainingData, trainingParameters.getDataTransformerTrainingParameters());
} }


Expand All @@ -85,7 +86,7 @@ public void _fit(Dataset trainingData) {


boolean selectFeatures = (fsClass!=null); boolean selectFeatures = (fsClass!=null);
if(selectFeatures) { if(selectFeatures) {
featureSelection = FeatureSelection.newInstance(fsClass, dbName, dbConf); featureSelection = FeatureSelection.<FeatureSelection>newInstance(fsClass, dbName, dbConf);
featureSelection.fit(trainingData, trainingParameters.getFeatureSelectionTrainingParameters()); featureSelection.fit(trainingData, trainingParameters.getFeatureSelectionTrainingParameters());


featureSelection.transform(trainingData); featureSelection.transform(trainingData);
Expand All @@ -95,7 +96,8 @@ public void _fit(Dataset trainingData) {




//initialize mlmodel //initialize mlmodel
mlmodel = BaseMLmodel.newInstance(trainingParameters.getMLmodelClass(), dbName, dbConf); Class mlClass = trainingParameters.getMLmodelClass();
mlmodel = BaseMLmodel.<BaseMLmodel>newInstance(mlClass, dbName, dbConf);
int k = trainingParameters.getkFolds(); int k = trainingParameters.getkFolds();


//call k-fold cross validation and get the average accuracy //call k-fold cross validation and get the average accuracy
Expand Down Expand Up @@ -142,7 +144,7 @@ private BaseMLmodel.ValidationMetrics evaluateData(Dataset data, boolean estimat
boolean transformData = (dtClass!=null); boolean transformData = (dtClass!=null);
if(transformData) { if(transformData) {
if(dataTransformer==null) { if(dataTransformer==null) {
dataTransformer = DataTransformer.newInstance(dtClass, dbName, dbConf); dataTransformer = DataTransformer.<DataTransformer>newInstance(dtClass, dbName, dbConf);
} }
dataTransformer.transform(data); dataTransformer.transform(data);
} }
Expand All @@ -152,7 +154,7 @@ private BaseMLmodel.ValidationMetrics evaluateData(Dataset data, boolean estimat
boolean selectFeatures = (fsClass!=null); boolean selectFeatures = (fsClass!=null);
if(selectFeatures) { if(selectFeatures) {
if(featureSelection==null) { if(featureSelection==null) {
featureSelection = FeatureSelection.newInstance(fsClass, dbName, dbConf); featureSelection = FeatureSelection.<FeatureSelection>newInstance(fsClass, dbName, dbConf);
} }


//remove unnecessary features //remove unnecessary features
Expand All @@ -162,7 +164,8 @@ private BaseMLmodel.ValidationMetrics evaluateData(Dataset data, boolean estimat


//initialize mlmodel //initialize mlmodel
if(mlmodel==null) { if(mlmodel==null) {
mlmodel = BaseMLmodel.newInstance(trainingParameters.getMLmodelClass(), dbName, dbConf); Class mlClass = trainingParameters.getMLmodelClass();
mlmodel = BaseMLmodel.<BaseMLmodel>newInstance(mlClass, dbName, dbConf);
} }


//call predict of the mlmodel for the new dataset //call predict of the mlmodel for the new dataset
Expand Down
17 changes: 9 additions & 8 deletions src/main/java/com/datumbox/applications/nlp/TextClassifier.java
Expand Up @@ -21,7 +21,7 @@
import com.datumbox.common.dataobjects.Record; import com.datumbox.common.dataobjects.Record;
import com.datumbox.common.persistentstorage.interfaces.DatabaseConfiguration; import com.datumbox.common.persistentstorage.interfaces.DatabaseConfiguration;
import com.datumbox.common.persistentstorage.interfaces.DatabaseConnector; import com.datumbox.common.persistentstorage.interfaces.DatabaseConnector;
import com.datumbox.configuration.GeneralConfiguration; import com.datumbox.framework.machinelearning.common.bases.BaseTrainable;
import com.datumbox.framework.machinelearning.common.bases.featureselection.CategoricalFeatureSelection; import com.datumbox.framework.machinelearning.common.bases.featureselection.CategoricalFeatureSelection;
import com.datumbox.framework.machinelearning.common.bases.featureselection.FeatureSelection; import com.datumbox.framework.machinelearning.common.bases.featureselection.FeatureSelection;
import com.datumbox.framework.machinelearning.common.bases.mlmodels.BaseMLmodel; import com.datumbox.framework.machinelearning.common.bases.mlmodels.BaseMLmodel;
Expand Down Expand Up @@ -133,15 +133,15 @@ protected void _fit(Dataset trainingDataset) {


boolean transformData = (dtClass!=null); boolean transformData = (dtClass!=null);
if(transformData) { if(transformData) {
dataTransformer = DataTransformer.newInstance(dtClass, dbName, dbConf); dataTransformer = DataTransformer.<DataTransformer>newInstance(dtClass, dbName, dbConf);
dataTransformer.fit_transform(trainingDataset, trainingParameters.getDataTransformerTrainingParameters()); dataTransformer.fit_transform(trainingDataset, trainingParameters.getDataTransformerTrainingParameters());
} }


Class fsClass = trainingParameters.getFeatureSelectionClass(); Class fsClass = trainingParameters.getFeatureSelectionClass();


boolean selectFeatures = (fsClass!=null); boolean selectFeatures = (fsClass!=null);
if(selectFeatures) { if(selectFeatures) {
featureSelection = FeatureSelection.newInstance(fsClass, dbName, dbConf); featureSelection = FeatureSelection.<FeatureSelection>newInstance(fsClass, dbName, dbConf);
FeatureSelection.TrainingParameters featureSelectionParameters = trainingParameters.getFeatureSelectionTrainingParameters(); FeatureSelection.TrainingParameters featureSelectionParameters = trainingParameters.getFeatureSelectionTrainingParameters();
if(CategoricalFeatureSelection.TrainingParameters.class.isAssignableFrom(featureSelectionParameters.getClass())) { if(CategoricalFeatureSelection.TrainingParameters.class.isAssignableFrom(featureSelectionParameters.getClass())) {
((CategoricalFeatureSelection.TrainingParameters)featureSelectionParameters).setIgnoringNumericalFeatures(false); //this should be turned off in feature selection ((CategoricalFeatureSelection.TrainingParameters)featureSelectionParameters).setIgnoringNumericalFeatures(false); //this should be turned off in feature selection
Expand Down Expand Up @@ -251,7 +251,7 @@ public BaseMLmodel.ValidationMetrics validate(Map<Object, URI> dataset) {
boolean transformData = (dtClass!=null); boolean transformData = (dtClass!=null);
if(transformData) { if(transformData) {
if(dataTransformer==null) { if(dataTransformer==null) {
dataTransformer = DataTransformer.newInstance(dtClass, dbName, dbConf); dataTransformer = DataTransformer.<DataTransformer>newInstance(dtClass, dbName, dbConf);
} }


dataTransformer.transform(testDataset); dataTransformer.transform(testDataset);
Expand All @@ -262,7 +262,7 @@ public BaseMLmodel.ValidationMetrics validate(Map<Object, URI> dataset) {
boolean selectFeatures = (fsClass!=null); boolean selectFeatures = (fsClass!=null);
if(selectFeatures) { if(selectFeatures) {
if(featureSelection==null) { if(featureSelection==null) {
featureSelection = FeatureSelection.newInstance(fsClass, dbName, dbConf); featureSelection = FeatureSelection.<FeatureSelection>newInstance(fsClass, dbName, dbConf);
} }


//remove unnecessary features //remove unnecessary features
Expand Down Expand Up @@ -324,7 +324,7 @@ private Dataset getPredictions(List<String> text) {
boolean transformData = (dtClass!=null); boolean transformData = (dtClass!=null);
if(transformData) { if(transformData) {
if(dataTransformer==null) { if(dataTransformer==null) {
dataTransformer = DataTransformer.newInstance(dtClass, dbName, dbConf); dataTransformer = DataTransformer.<DataTransformer>newInstance(dtClass, dbName, dbConf);
} }
dataTransformer.transform(newData); dataTransformer.transform(newData);
} }
Expand All @@ -334,7 +334,7 @@ private Dataset getPredictions(List<String> text) {
boolean selectFeatures = (fsClass!=null); boolean selectFeatures = (fsClass!=null);
if(selectFeatures) { if(selectFeatures) {
if(featureSelection==null) { if(featureSelection==null) {
featureSelection = FeatureSelection.newInstance(fsClass, dbName, dbConf); featureSelection = FeatureSelection.<FeatureSelection>newInstance(fsClass, dbName, dbConf);
} }


//remove unnecessary features //remove unnecessary features
Expand All @@ -344,7 +344,8 @@ private Dataset getPredictions(List<String> text) {


//initialize mlmodel //initialize mlmodel
if(mlmodel==null) { if(mlmodel==null) {
mlmodel = BaseMLmodel.newInstance(trainingParameters.getMLmodelClass(), dbName, dbConf); Class mlClass = trainingParameters.getMLmodelClass();
mlmodel = BaseMLmodel.<BaseMLmodel>newInstance(mlClass, dbName, dbConf);
} }


//call predict of the mlmodel for the new dataset //call predict of the mlmodel for the new dataset
Expand Down
Expand Up @@ -23,6 +23,7 @@
import com.datumbox.framework.machinelearning.common.bases.dataobjects.BaseModelParameters; import com.datumbox.framework.machinelearning.common.bases.dataobjects.BaseModelParameters;
import com.datumbox.framework.machinelearning.common.bases.dataobjects.BaseTrainingParameters; import com.datumbox.framework.machinelearning.common.bases.dataobjects.BaseTrainingParameters;
import com.datumbox.framework.machinelearning.common.dataobjects.KnowledgeBase; import com.datumbox.framework.machinelearning.common.dataobjects.KnowledgeBase;
import java.lang.reflect.InvocationTargetException;


/** /**
* *
Expand All @@ -37,6 +38,19 @@ public abstract class BaseTrainable<MP extends BaseModelParameters, TP extends B


protected KB knowledgeBase; protected KB knowledgeBase;
protected String dbName; protected String dbName;


public static <BT extends BaseTrainable> BT newInstance(Class<BT> aClass, String dbName, DatabaseConfiguration dbConfig) {
BT algorithm = null;
try {
algorithm = aClass.getConstructor(String.class, DatabaseConfiguration.class).newInstance(dbName, dbConfig);;
}
catch (InstantiationException | IllegalAccessException | IllegalArgumentException | InvocationTargetException | NoSuchMethodException | SecurityException ex) {
throw new RuntimeException(ex);
}

return algorithm;
}


protected BaseTrainable(String dbName, DatabaseConfiguration dbConf) { protected BaseTrainable(String dbName, DatabaseConfiguration dbConf) {
String methodName = this.getClass().getSimpleName(); String methodName = this.getClass().getSimpleName();
Expand Down
Expand Up @@ -24,7 +24,6 @@
import com.datumbox.framework.machinelearning.common.bases.dataobjects.BaseModelParameters; import com.datumbox.framework.machinelearning.common.bases.dataobjects.BaseModelParameters;
import com.datumbox.framework.machinelearning.common.bases.dataobjects.BaseTrainingParameters; import com.datumbox.framework.machinelearning.common.bases.dataobjects.BaseTrainingParameters;
import com.datumbox.framework.machinelearning.common.dataobjects.KnowledgeBase; import com.datumbox.framework.machinelearning.common.dataobjects.KnowledgeBase;
import java.lang.reflect.InvocationTargetException;


/** /**
* *
Expand All @@ -50,28 +49,6 @@ public static abstract class TrainingParameters extends BaseTrainingParameters {
} }




/**
* Generates a new instance of a MLmodel by providing the dbName and
the Class of the algorithm.
*
* @param <D>
* @param dbName
* @param aClass
* @param dbConfig
* @return
*/
public static <D extends DataTransformer> D newInstance(Class<D> aClass, String dbName, DatabaseConfiguration dbConfig) {
D algorithm = null;
try {
algorithm = (D) aClass.getConstructor(String.class, DatabaseConfiguration.class).newInstance(dbName, dbConfig);
}
catch (InstantiationException | IllegalAccessException | IllegalArgumentException | InvocationTargetException | NoSuchMethodException | SecurityException ex) {
throw new RuntimeException(ex);
}

return algorithm;
}



/* /*
IMPORTANT METHODS FOR THE FUNCTIONALITY IMPORTANT METHODS FOR THE FUNCTIONALITY
Expand Down
Expand Up @@ -24,7 +24,6 @@
import com.datumbox.framework.machinelearning.common.bases.dataobjects.BaseModelParameters; import com.datumbox.framework.machinelearning.common.bases.dataobjects.BaseModelParameters;
import com.datumbox.framework.machinelearning.common.bases.dataobjects.BaseTrainingParameters; import com.datumbox.framework.machinelearning.common.bases.dataobjects.BaseTrainingParameters;
import com.datumbox.framework.machinelearning.common.dataobjects.KnowledgeBase; import com.datumbox.framework.machinelearning.common.dataobjects.KnowledgeBase;
import java.lang.reflect.InvocationTargetException;


/** /**
* *
Expand All @@ -48,30 +47,6 @@ public static abstract class TrainingParameters extends BaseTrainingParameters {
//here goes public fields that are used as initial training parameters //here goes public fields that are used as initial training parameters
} }



/**
* Generates a new instance of a MLmodel by providing the dbName and
the Class of the algorithm.
*
* @param <F>
* @param dbName
* @param aClass
* @param dbConfig
* @return
*/
public static <F extends FeatureSelection> F newInstance(Class<F> aClass, String dbName, DatabaseConfiguration dbConfig) {
F algorithm = null;
try {
algorithm = (F) aClass.getConstructor(String.class, DatabaseConfiguration.class).newInstance(dbName, dbConfig);
}
catch (InstantiationException | IllegalAccessException | IllegalArgumentException | InvocationTargetException | NoSuchMethodException | SecurityException ex) {
throw new RuntimeException(ex);
}

return algorithm;
}


/* /*
IMPORTANT METHODS FOR THE FUNCTIONALITY IMPORTANT METHODS FOR THE FUNCTIONALITY
*/ */
Expand Down
Expand Up @@ -26,7 +26,6 @@
import com.datumbox.framework.machinelearning.common.bases.dataobjects.BaseTrainingParameters; import com.datumbox.framework.machinelearning.common.bases.dataobjects.BaseTrainingParameters;
import com.datumbox.framework.machinelearning.common.bases.dataobjects.BaseValidationMetrics; import com.datumbox.framework.machinelearning.common.bases.dataobjects.BaseValidationMetrics;
import com.datumbox.framework.machinelearning.common.dataobjects.MLmodelKnowledgeBase; import com.datumbox.framework.machinelearning.common.dataobjects.MLmodelKnowledgeBase;
import java.lang.reflect.InvocationTargetException;


/** /**
* Abstract Class for a Machine Learning algorithm. * Abstract Class for a Machine Learning algorithm.
Expand Down Expand Up @@ -70,36 +69,6 @@ public static abstract class ValidationMetrics extends BaseValidationMetrics {








/**
* Generates a new instance of a BaseMLmodel by providing the dbName and
the Class of the algorithm.
*
* @param <M>
* @param dbName
* @param aClass
* @param dbConfig
* @return
*/
public static <M extends BaseMLmodel> M newInstance(Class<M> aClass, String dbName, DatabaseConfiguration dbConfig) {
M algorithm = null;
try {
algorithm = (M) aClass.getConstructor(String.class, DatabaseConfiguration.class).newInstance(dbName, dbConfig);
}
catch (InstantiationException | IllegalAccessException | IllegalArgumentException | InvocationTargetException | NoSuchMethodException | SecurityException ex) {
throw new RuntimeException(ex);
}

return algorithm;
}







/* /*
IMPORTANT METHODS FOR THE FUNCTIONALITY IMPORTANT METHODS FOR THE FUNCTIONALITY
*/ */
Expand Down
Expand Up @@ -57,35 +57,6 @@ public static abstract class TrainingParameters extends BaseTrainingParameters {








/**
* Generates a new instance of a BaseMLrecommender by providing the dbName and
the Class of the algorithm.
*
* @param <M>
* @param dbName
* @param aClass
* @return
*/
public static <M extends BaseMLrecommender> M newInstance(Class<M> aClass, String dbName, DatabaseConfiguration dbConfig) {
M algorithm = null;
try {
algorithm = (M) aClass.getConstructor(String.class, DatabaseConfiguration.class).newInstance(dbName, dbConfig);;
}
catch (InstantiationException | IllegalAccessException | IllegalArgumentException | InvocationTargetException | NoSuchMethodException | SecurityException ex) {
throw new RuntimeException(ex);
}

return algorithm;
}







/* /*
IMPORTANT METHODS FOR THE FUNCTIONALITY IMPORTANT METHODS FOR THE FUNCTIONALITY
*/ */
Expand Down
Expand Up @@ -25,7 +25,6 @@
import com.datumbox.framework.machinelearning.common.dataobjects.KnowledgeBase; import com.datumbox.framework.machinelearning.common.dataobjects.KnowledgeBase;
import com.datumbox.framework.machinelearning.common.bases.datatransformation.DataTransformer; import com.datumbox.framework.machinelearning.common.bases.datatransformation.DataTransformer;
import com.datumbox.framework.machinelearning.common.bases.featureselection.FeatureSelection; import com.datumbox.framework.machinelearning.common.bases.featureselection.FeatureSelection;
import java.lang.reflect.InvocationTargetException;


/** /**
* The BaseWrapper is a trainable object that uses composition instead of inheritance * The BaseWrapper is a trainable object that uses composition instead of inheritance
Expand Down Expand Up @@ -128,18 +127,6 @@ public void setMLmodelTrainingParameters(ML.TrainingParameters mlmodelTrainingPa
} }




public static <W extends BaseWrapper> W newInstance(Class<W> aClass, String dbName, DatabaseConfiguration dbConfig) {
W algorithm = null;
try {
algorithm = (W) aClass.getConstructor(String.class, DatabaseConfiguration.class).newInstance(dbName, dbConfig);;
}
catch (InstantiationException | IllegalAccessException | IllegalArgumentException | InvocationTargetException | NoSuchMethodException | SecurityException ex) {
throw new RuntimeException(ex);
}

return algorithm;
}



/* /*
IMPORTANT METHODS FOR THE FUNCTIONALITY IMPORTANT METHODS FOR THE FUNCTIONALITY
Expand Down

0 comments on commit 7de9734

Please sign in to comment.