Skip to content

Commit

Permalink
half of the API changes are completed. (broken code)
Browse files Browse the repository at this point in the history
  • Loading branch information
datumbox committed Mar 28, 2015
1 parent 3aff0b4 commit 740c1d4
Show file tree
Hide file tree
Showing 18 changed files with 48 additions and 75 deletions.
7 changes: 4 additions & 3 deletions TODO.txt
Expand Up @@ -10,12 +10,13 @@ CODE IMPROVEMENT
================ ================
- use fit/predit/transform methods instead of training, validate etc like Python? - use fit/predit/transform methods instead of training, validate etc like Python?
- DataTransformer: merge transform and normalize into a transform method, remove the trainingMode var and use a fit() instead - DataTransformer: merge transform and normalize into a transform method, remove the trainingMode var and use a fit() instead
- FeatureSelection: change evaluateFeatures() to fit() and clearFeatures() to transform
- BaseMLrecommender: change train to fit()
- BaseMLmodel: change train(trainingData, validationData) to fit(trainingData) and a predict() method. Eliminate the test() and rename it to validate() which takes a dataset and returns its VMs - BaseMLmodel: change train(trainingData, validationData) to fit(trainingData) and a predict() method. Eliminate the test() and rename it to validate() which takes a dataset and returns its VMs
- Add the new methods in the BaseTrainable and the Trainable interface


- eliminate _fit() or estimateModelParameters() and pass the save() in fit()
- Remove initializeTrainingConfiguration(), pass the parameters in train() method - Remove initializeTrainingConfiguration(), pass the parameters in train() method
- newInstance(), constructors etc of inherited from BaseTrainable could perhaps pass to BaseTrainable and reduce the extra code.
- Make this a method DATA_SAFE_CALL_BY_REFERENCE? - Make this a method DATA_SAFE_CALL_BY_REFERENCE?


- Serialization improvements - Serialization improvements
Expand Down
Expand Up @@ -90,10 +90,10 @@ public void train(Dataset trainingData) {
if(selectFeatures) { if(selectFeatures) {
featureSelection = FeatureSelection.newInstance(fsClass, dbName, dbConf); featureSelection = FeatureSelection.newInstance(fsClass, dbName, dbConf);
featureSelection.initializeTrainingConfiguration(trainingParameters.getFeatureSelectionTrainingParameters()); featureSelection.initializeTrainingConfiguration(trainingParameters.getFeatureSelectionTrainingParameters());
featureSelection.evaluateFeatures(trainingData); featureSelection.fit(trainingData);


//remove unnecessary features //remove unnecessary features
featureSelection.clearFeatures(trainingData); featureSelection.transform(trainingData);
} }




Expand Down Expand Up @@ -175,7 +175,7 @@ private BaseMLmodel.ValidationMetrics evaluateData(Dataset data, boolean estimat
} }


//remove unnecessary features //remove unnecessary features
featureSelection.clearFeatures(data); featureSelection.transform(data);
} }




Expand Down
Expand Up @@ -141,10 +141,10 @@ public void train(Map<Object, URI> dataset) {
featureSelection.initializeTrainingConfiguration(trainingParameters.getFeatureSelectionTrainingParameters()); featureSelection.initializeTrainingConfiguration(trainingParameters.getFeatureSelectionTrainingParameters());


//find the most popular features //find the most popular features
featureSelection.evaluateFeatures(trainingDataset); featureSelection.fit(trainingDataset);


//remove unnecessary features //remove unnecessary features
featureSelection.clearFeatures(trainingDataset); featureSelection.transform(trainingDataset);
} }


//initialize mlmodel //initialize mlmodel
Expand Down Expand Up @@ -274,7 +274,7 @@ public BaseMLmodel.ValidationMetrics test(Map<Object, URI> dataset) {
} }


//remove unnecessary features //remove unnecessary features
featureSelection.clearFeatures(testDataset); featureSelection.transform(testDataset);
} }




Expand Down Expand Up @@ -347,7 +347,7 @@ private Dataset getPredictions(List<String> text) {
} }


//remove unnecessary features //remove unnecessary features
featureSelection.clearFeatures(newData); featureSelection.transform(newData);
} }




Expand Down
4 changes: 1 addition & 3 deletions src/main/java/com/datumbox/common/objecttypes/Trainable.java
Expand Up @@ -56,8 +56,6 @@ public interface Trainable<MP extends Learnable, TP extends Parameterizable> {
* @param trainingData * @param trainingData
* @param trainingParameters * @param trainingParameters
*/ */
//public void fit(Dataset trainingData,TP trainingParameters); public void fit(Dataset trainingData, TP trainingParameters);

public void initializeTrainingConfiguration(TP trainingParameters);


} }
Expand Up @@ -78,9 +78,9 @@ public BernoulliNaiveBayes(String dbName, DatabaseConfiguration dbConf) {
} }


@Override @Override
public void train(Dataset trainingData, Dataset validationData) { protected void _fit(Dataset trainingData) {
knowledgeBase.getTrainingParameters().setMultiProbabilityWeighted(false); knowledgeBase.getTrainingParameters().setMultiProbabilityWeighted(false);
super.train(trainingData, validationData); super._fit(trainingData);
} }


@Override @Override
Expand Down
Expand Up @@ -58,9 +58,9 @@ public BinarizedNaiveBayes(String dbName, DatabaseConfiguration dbConf) {
} }


@Override @Override
public void train(Dataset trainingData, Dataset validationData) { protected void _fit(Dataset trainingData) {
knowledgeBase.getTrainingParameters().setMultiProbabilityWeighted(false); knowledgeBase.getTrainingParameters().setMultiProbabilityWeighted(false);
super.train(trainingData, validationData); super._fit(trainingData);
} }


} }
Expand Up @@ -138,10 +138,11 @@ public SupportVectorMachine(String dbName, DatabaseConfiguration dbConf) {
} }





@Override @Override
public void train(Dataset trainingData, Dataset validationData) { protected void _fit(Dataset trainingData) {
knowledgeBase.getTrainingParameters().getSvmParameter().probability=1; //probabilities are required from the algorithm knowledgeBase.getTrainingParameters().getSvmParameter().probability=1; //probabilities are required from the algorithm
super.train(trainingData, validationData); super._fit(trainingData);
} }


@Override @Override
Expand Down
Expand Up @@ -18,6 +18,7 @@


import com.datumbox.common.dataobjects.Dataset; import com.datumbox.common.dataobjects.Dataset;
import com.datumbox.common.objecttypes.Trainable; import com.datumbox.common.objecttypes.Trainable;
import com.datumbox.configuration.GeneralConfiguration;
import com.datumbox.framework.machinelearning.common.bases.dataobjects.BaseModelParameters; import com.datumbox.framework.machinelearning.common.bases.dataobjects.BaseModelParameters;
import com.datumbox.framework.machinelearning.common.bases.dataobjects.BaseTrainingParameters; import com.datumbox.framework.machinelearning.common.bases.dataobjects.BaseTrainingParameters;
import com.datumbox.framework.machinelearning.common.dataobjects.KnowledgeBase; import com.datumbox.framework.machinelearning.common.dataobjects.KnowledgeBase;
Expand Down Expand Up @@ -50,24 +51,21 @@ public void erase() {
public MP getModelParameters() { public MP getModelParameters() {
return knowledgeBase.getModelParameters(); return knowledgeBase.getModelParameters();


}

@Override
public void initializeTrainingConfiguration(TP trainingParameters) {
//reset knowledge base
knowledgeBase.reinitialize();
knowledgeBase.setTrainingParameters(trainingParameters);
} }


/*
@Override @Override
public void fit(Dataset trainingData, TP trainingParameters) { public void fit(Dataset trainingData, TP trainingParameters) {
if(GeneralConfiguration.DEBUG) {
System.out.println("fit()");
}

//reset knowledge base //reset knowledge base
knowledgeBase.reinitialize(); knowledgeBase.reinitialize();
knowledgeBase.setTrainingParameters(trainingParameters); knowledgeBase.setTrainingParameters(trainingParameters);
_fit(trainingData); _fit(trainingData);
} }


public abstract void _fit(Dataset trainingData); protected abstract void _fit(Dataset trainingData);
*/
} }
Expand Up @@ -39,15 +39,12 @@ public abstract class BaseMinMaxNormalizer extends DataTransformer<BaseMinMaxNor
public static class ModelParameters extends DataTransformer.ModelParameters { public static class ModelParameters extends DataTransformer.ModelParameters {


@BigMap @BigMap

protected Map<Object, Double> minColumnValues; protected Map<Object, Double> minColumnValues;


@BigMap @BigMap

protected Map<Object, Double> maxColumnValues; protected Map<Object, Double> maxColumnValues;


@BigMap @BigMap

protected Map<Object, Object> referenceLevels; protected Map<Object, Object> referenceLevels;


public ModelParameters(DatabaseConnector dbc) { public ModelParameters(DatabaseConnector dbc) {
Expand Down
Expand Up @@ -56,6 +56,7 @@ public static abstract class TrainingParameters extends BaseTrainingParameters {
* @param <F> * @param <F>
* @param dbName * @param dbName
* @param aClass * @param aClass
* @param dbConfig
* @return * @return
*/ */
public static <F extends FeatureSelection> F newInstance(Class<F> aClass, String dbName, DatabaseConfiguration dbConfig) { public static <F extends FeatureSelection> F newInstance(Class<F> aClass, String dbName, DatabaseConfiguration dbConfig) {
Expand Down Expand Up @@ -84,10 +85,11 @@ protected FeatureSelection(String dbName, DatabaseConfiguration dbConf, Class<MP
} }




public void evaluateFeatures(Dataset trainingData) { @Override
protected void _fit(Dataset trainingData) {


if(GeneralConfiguration.DEBUG) { if(GeneralConfiguration.DEBUG) {
System.out.println("evaluateFeatures()"); System.out.println("fit()");
} }


estimateModelParameters(trainingData); estimateModelParameters(trainingData);
Expand All @@ -99,9 +101,9 @@ public void evaluateFeatures(Dataset trainingData) {


} }


public void clearFeatures(Dataset newData) { public void transform(Dataset newData) {
if(GeneralConfiguration.DEBUG) { if(GeneralConfiguration.DEBUG) {
System.out.println("clearFeatures()"); System.out.println("transform()");
} }


knowledgeBase.load(); knowledgeBase.load();
Expand Down
Expand Up @@ -137,37 +137,17 @@ public VM kFoldCrossValidation(Dataset trainingData, int k) {
* Trains a model with the trainingData and validates it with the validationData. * Trains a model with the trainingData and validates it with the validationData.
* *
* @param trainingData * @param trainingData
* @param validationData
*/ */
@SuppressWarnings("unchecked") @SuppressWarnings("unchecked")
public void train(Dataset trainingData, Dataset validationData) { @Override

protected void _fit(Dataset trainingData) {
if(GeneralConfiguration.DEBUG) {
System.out.println("train()");
}


if(GeneralConfiguration.DEBUG) { if(GeneralConfiguration.DEBUG) {
System.out.println("estimateModelParameters()"); System.out.println("estimateModelParameters()");
} }


//train the model to get the parameters //train the model to get the parameters
estimateModelParameters(trainingData); estimateModelParameters(trainingData);



if(validationData != null && !validationData.isEmpty()) {

if(GeneralConfiguration.DEBUG) {
System.out.println("validateModel()");
}

//validate the model with the validation data and update the validationMetrics
VM validationMetrics = validateModel(validationData);
knowledgeBase.setValidationMetrics(validationMetrics);

}


if(GeneralConfiguration.DEBUG) { if(GeneralConfiguration.DEBUG) {
System.out.println("Saving model"); System.out.println("Saving model");
} }
Expand Down
Expand Up @@ -105,12 +105,8 @@ protected BaseMLrecommender(String dbName, DatabaseConfiguration dbConf, Class<M
* @param trainingData * @param trainingData
*/ */
@SuppressWarnings("unchecked") @SuppressWarnings("unchecked")
public void train(Dataset trainingData) { @Override

public void _fit(Dataset trainingData) {
if(GeneralConfiguration.DEBUG) {
System.out.println("train()");
}



if(GeneralConfiguration.DEBUG) { if(GeneralConfiguration.DEBUG) {
System.out.println("estimateModelParameters()"); System.out.println("estimateModelParameters()");
Expand Down
Expand Up @@ -95,13 +95,13 @@ public void testSelectFeatures() {
ChisquareSelect instance = new ChisquareSelect(dbName, TestConfiguration.getDBConfig()); ChisquareSelect instance = new ChisquareSelect(dbName, TestConfiguration.getDBConfig());
instance.initializeTrainingConfiguration(param); instance.initializeTrainingConfiguration(param);


instance.evaluateFeatures(trainingData); instance.fit(trainingData);
instance = null; instance = null;




instance = new ChisquareSelect(dbName, TestConfiguration.getDBConfig()); instance = new ChisquareSelect(dbName, TestConfiguration.getDBConfig());


instance.clearFeatures(trainingData); instance.transform(trainingData);


Set<Object> expResult = new HashSet<>(Arrays.asList("high_paid", "has_boat", "has_luxury_car", "has_butler", "has_pool")); Set<Object> expResult = new HashSet<>(Arrays.asList("high_paid", "has_boat", "has_luxury_car", "has_butler", "has_pool"));
Set<Object> result = trainingData.getColumns().keySet(); Set<Object> result = trainingData.getColumns().keySet();
Expand Down
Expand Up @@ -53,13 +53,13 @@ public void testSelectFeatures() {
MutualInformation instance = new MutualInformation(dbName, TestConfiguration.getDBConfig()); MutualInformation instance = new MutualInformation(dbName, TestConfiguration.getDBConfig());
instance.initializeTrainingConfiguration(param); instance.initializeTrainingConfiguration(param);


instance.evaluateFeatures(trainingData); instance.fit(trainingData);
instance = null; instance = null;




instance = new MutualInformation(dbName, TestConfiguration.getDBConfig()); instance = new MutualInformation(dbName, TestConfiguration.getDBConfig());


instance.clearFeatures(trainingData); instance.transform(trainingData);


Set<Object> expResult = new HashSet<>(Arrays.asList("high_paid", "has_boat", "has_luxury_car", "has_butler", "has_pool")); Set<Object> expResult = new HashSet<>(Arrays.asList("high_paid", "has_boat", "has_luxury_car", "has_butler", "has_pool"));
Set<Object> result = trainingData.getColumns().keySet(); Set<Object> result = trainingData.getColumns().keySet();
Expand Down
Expand Up @@ -57,7 +57,7 @@ public void testCalculateParameters() {
param.setMaxDimensions(null); param.setMaxDimensions(null);
instance.initializeTrainingConfiguration(param); instance.initializeTrainingConfiguration(param);


instance.evaluateFeatures(originaldata); instance.fit(originaldata);
instance=null; instance=null;


Dataset newdata = originaldata; Dataset newdata = originaldata;
Expand All @@ -72,7 +72,7 @@ public void testCalculateParameters() {
expResult.add(Record.<Double>newDataVector(new Double[]{-14.1401, -6.4677, 1.4920}, null)); expResult.add(Record.<Double>newDataVector(new Double[]{-14.1401, -6.4677, 1.4920}, null));
expResult.add(Record.<Double>newDataVector(new Double[]{-23.8837, 3.7408, -2.3614}, null)); expResult.add(Record.<Double>newDataVector(new Double[]{-23.8837, 3.7408, -2.3614}, null));


instance.clearFeatures(newdata); instance.transform(newdata);


assertEquals(newdata.size(), expResult.size()); assertEquals(newdata.size(), expResult.size());


Expand Down
Expand Up @@ -80,13 +80,13 @@ public void testSelectFeatures() {
TFIDF instance = new TFIDF(dbName, TestConfiguration.getDBConfig()); TFIDF instance = new TFIDF(dbName, TestConfiguration.getDBConfig());
instance.initializeTrainingConfiguration(param); instance.initializeTrainingConfiguration(param);


instance.evaluateFeatures(trainingData); instance.fit(trainingData);
instance = null; instance = null;




instance = new TFIDF(dbName, TestConfiguration.getDBConfig()); instance = new TFIDF(dbName, TestConfiguration.getDBConfig());


instance.clearFeatures(trainingData); instance.transform(trainingData);


Set<Object> expResult = new HashSet<>(Arrays.asList("important1", "important2", "important3")); Set<Object> expResult = new HashSet<>(Arrays.asList("important1", "important2", "important3"));
Set<Object> result = trainingData.getColumns().keySet(); Set<Object> result = trainingData.getColumns().keySet();
Expand Down
Expand Up @@ -159,7 +159,7 @@ public void testPredict() {
CollaborativeFiltering.TrainingParameters param = new CollaborativeFiltering.TrainingParameters(); CollaborativeFiltering.TrainingParameters param = new CollaborativeFiltering.TrainingParameters();
param.setSimilarityMethod(CollaborativeFiltering.TrainingParameters.SimilarityMeasure.PEARSONS_CORRELATION); param.setSimilarityMethod(CollaborativeFiltering.TrainingParameters.SimilarityMeasure.PEARSONS_CORRELATION);
instance.initializeTrainingConfiguration(param); instance.initializeTrainingConfiguration(param);
instance.train(trainingData); instance._fit(trainingData);




instance = null; instance = null;
Expand Down
Expand Up @@ -182,14 +182,14 @@ public void testKFoldCrossValidation() {
featureSelectionParameters.setWhitened(false); featureSelectionParameters.setWhitened(false);
featureSelectionParameters.setVarianceThreshold(0.99999995); featureSelectionParameters.setVarianceThreshold(0.99999995);
featureSelection.initializeTrainingConfiguration(featureSelectionParameters); featureSelection.initializeTrainingConfiguration(featureSelectionParameters);
featureSelection.evaluateFeatures(trainingData); featureSelection.fit(trainingData);
/* /*
featureSelection=null; featureSelection=null;
featureSelection = new PCA(dbName, TestConfiguration.getDBConfig()); featureSelection = new PCA(dbName, TestConfiguration.getDBConfig());
*/ */
featureSelection.clearFeatures(trainingData); featureSelection.transform(trainingData);
featureSelection.erase(); featureSelection.erase();




Expand Down

0 comments on commit 740c1d4

Please sign in to comment.