Skip to content

Commit

Permalink
Updated all the methods on MLmodels to follow a Python-like API.
Browse files Browse the repository at this point in the history
  • Loading branch information
datumbox committed Mar 28, 2015
1 parent 740c1d4 commit 53c4090
Show file tree
Hide file tree
Showing 45 changed files with 324 additions and 565 deletions.
11 changes: 4 additions & 7 deletions TODO.txt
Expand Up @@ -8,16 +8,13 @@ NEW ALGORITHMS

CODE IMPROVEMENT
================
- use fit/predit/transform methods instead of training, validate etc like Python?
- DataTransformer: merge transform and normalize into a transform method, remove the trainingMode var and use a fit() instead


- BaseMLmodel: change train(trainingData, validationData) to fit(trainingData) and a predict() method. Eliminate the test() and rename it to validate() which takes a dataset and returns its VMs

- eliminate _fit() or estimateModelParameters() and pass the save() in fit()
- Remove initializeTrainingConfiguration(), pass the parameters in train() method
- newInstance(), constructors etc of inherited from BaseTrainable could perhaps pass to BaseTrainable and reduce the extra code.
- Make this a method DATA_SAFE_CALL_BY_REFERENCE?
- Make this a method DATA_SAFE_CALL_BY_REFERENCE and IS_BINARIZED?
- should we pass validator class in the constructors of all MLmodels?
- DataTransformer: refactor the individual implementations of _process()


- Serialization improvements
- create the required exceptions
Expand Down
54 changes: 19 additions & 35 deletions src/main/java/com/datumbox/applications/datamodeling/Modeler.java
Expand Up @@ -19,7 +19,6 @@
import com.datumbox.common.dataobjects.Dataset;
import com.datumbox.common.persistentstorage.interfaces.DatabaseConfiguration;
import com.datumbox.common.persistentstorage.interfaces.DatabaseConnector;
import com.datumbox.common.utilities.DeepCopy;
import com.datumbox.framework.machinelearning.common.bases.featureselection.FeatureSelection;
import com.datumbox.framework.machinelearning.common.bases.mlmodels.BaseMLmodel;
import com.datumbox.framework.machinelearning.common.bases.wrappers.BaseWrapper;
Expand All @@ -40,7 +39,6 @@ public ModelParameters(DatabaseConnector dbc) {
}

public static class TrainingParameters extends BaseWrapper.TrainingParameters<DataTransformer, FeatureSelection, BaseMLmodel> {

//primitives/wrappers
private Integer kFolds = 5;

Expand All @@ -64,7 +62,8 @@ public Modeler(String dbName, DatabaseConfiguration dbConf) {



public void train(Dataset trainingData) {
@Override
public void _fit(Dataset trainingData) {

//get the training parameters
Modeler.TrainingParameters trainingParameters = knowledgeBase.getTrainingParameters();
Expand All @@ -77,9 +76,7 @@ public void train(Dataset trainingData) {
boolean transformData = (dtClass!=null);
if(transformData) {
dataTransformer = DataTransformer.newInstance(dtClass, dbName, dbConf);
dataTransformer.initializeTrainingConfiguration(knowledgeBase.getTrainingParameters().getDataTransformerTrainingParameters());
dataTransformer.transform(trainingData, true);
dataTransformer.normalize(trainingData);
dataTransformer.fit_transform(trainingData, trainingParameters.getDataTransformerTrainingParameters());
}


Expand All @@ -89,40 +86,28 @@ public void train(Dataset trainingData) {
boolean selectFeatures = (fsClass!=null);
if(selectFeatures) {
featureSelection = FeatureSelection.newInstance(fsClass, dbName, dbConf);
featureSelection.initializeTrainingConfiguration(trainingParameters.getFeatureSelectionTrainingParameters());
featureSelection.fit(trainingData);

//remove unnecessary features
featureSelection.transform(trainingData);
featureSelection.fit(trainingData, trainingParameters.getFeatureSelectionTrainingParameters());

featureSelection.transform(trainingData);
}




//initialize mlmodel
mlmodel = BaseMLmodel.newInstance(trainingParameters.getMLmodelClass(), dbName, dbConf);
mlmodel.initializeTrainingConfiguration(trainingParameters.getMLmodelTrainingParameters());

int k = trainingParameters.getkFolds();
if(k>1) {
//call k-fold cross validation and get the average accuracy
BaseMLmodel.ValidationMetrics averageValidationMetrics = (BaseMLmodel.ValidationMetrics) mlmodel.kFoldCrossValidation(trainingData, k);

//call k-fold cross validation and get the average accuracy
BaseMLmodel.ValidationMetrics averageValidationMetrics = (BaseMLmodel.ValidationMetrics) mlmodel.kFoldCrossValidation(trainingData, trainingParameters.getMLmodelTrainingParameters(), k);

//train the mlmodel on the whole dataset and pass as ValidationDataset the empty set
mlmodel.train(trainingData, new Dataset());
//train the mlmodel on the whole dataset
mlmodel.fit(trainingData, trainingParameters.getMLmodelTrainingParameters());

//set its ValidationMetrics to the average VP from k-fold cross validation
mlmodel.setValidationMetrics(averageValidationMetrics);
}
else { //k==1
Dataset validationDataset = trainingData;

boolean algorithmModifiesDataset = mlmodel.modifiesData();
if(algorithmModifiesDataset) {
validationDataset = DeepCopy.<Dataset>cloneObject(validationDataset);
}
mlmodel.train(trainingData, validationDataset);
}
//set its ValidationMetrics to the average VP from k-fold cross validation
mlmodel.setValidationMetrics(averageValidationMetrics);


if(transformData) {
dataTransformer.denormalize(trainingData); //optional denormalization
}
Expand All @@ -135,7 +120,7 @@ public void predict(Dataset newData) {
evaluateData(newData, false);
}

public BaseMLmodel.ValidationMetrics test(Dataset testData) {
public BaseMLmodel.ValidationMetrics validate(Dataset testData) {
return evaluateData(testData, true);
}

Expand All @@ -162,8 +147,7 @@ private BaseMLmodel.ValidationMetrics evaluateData(Dataset data, boolean estimat
if(dataTransformer==null) {
dataTransformer = DataTransformer.newInstance(dtClass, dbName, dbConf);
}
dataTransformer.transform(data, false);
dataTransformer.normalize(data);
dataTransformer.transform(data);
}

Class fsClass = trainingParameters.getFeatureSelectionClass();
Expand All @@ -188,8 +172,8 @@ private BaseMLmodel.ValidationMetrics evaluateData(Dataset data, boolean estimat

BaseMLmodel.ValidationMetrics vm = null;
if(estimateValidationMetrics) {
//run test which calculates validation metrics. It is used by test() method
vm = mlmodel.test(data);
//run validate which calculates validation metrics. It is used by validate() method
vm = mlmodel.validate(data);
}
else {
//run predict which does not calculate validation metrics. It is used in from predict() method
Expand Down
5 changes: 2 additions & 3 deletions src/main/java/com/datumbox/applications/nlp/CETR.java
Expand Up @@ -188,9 +188,8 @@ private void performClustering(Dataset dataset, DatabaseConfiguration dbConf, in
param.setCategoricalGamaMultiplier(1.0);
//param.setSubsetFurthestFirstcValue(2.0);

instance.initializeTrainingConfiguration(param);
instance.train(dataset, dataset);

instance.fit(dataset, param);
instance.predict(dataset);
//Map<Integer, BaseMLclusterer.Cluster> clusters = instance.getClusters();

instance.erase(); //erase immediately the result
Expand Down
70 changes: 31 additions & 39 deletions src/main/java/com/datumbox/applications/nlp/TextClassifier.java
Expand Up @@ -21,7 +21,7 @@
import com.datumbox.common.dataobjects.Record;
import com.datumbox.common.persistentstorage.interfaces.DatabaseConfiguration;
import com.datumbox.common.persistentstorage.interfaces.DatabaseConnector;
import com.datumbox.common.utilities.DeepCopy;
import com.datumbox.configuration.GeneralConfiguration;
import com.datumbox.framework.machinelearning.common.bases.featureselection.CategoricalFeatureSelection;
import com.datumbox.framework.machinelearning.common.bases.featureselection.FeatureSelection;
import com.datumbox.framework.machinelearning.common.bases.mlmodels.BaseMLmodel;
Expand Down Expand Up @@ -103,30 +103,36 @@ public void setTextExtractorTrainingParameters(TextExtractor.Parameters textExtr
public TextClassifier(String dbName, DatabaseConfiguration dbConf) {
super(dbName, dbConf, TextClassifier.ModelParameters.class, TextClassifier.TrainingParameters.class);
}


@Deprecated
@Override
public void fit(Dataset trainingData, TrainingParameters trainingParameters) {
super.fit(trainingData, trainingParameters);
}

public void train(Map<Object, URI> dataset) {
public void fit(Map<Object, URI> dataset, TrainingParameters trainingParameters) {

//get the training parameters
TextClassifier.TrainingParameters trainingParameters = knowledgeBase.getTrainingParameters();
DatabaseConfiguration dbConf = knowledgeBase.getDbConf();
initializeTrainingConfiguration(trainingParameters);

TextExtractor textExtractor = TextExtractor.newInstance(trainingParameters.getTextExtractorClass());
textExtractor.setParameters(trainingParameters.getTextExtractorTrainingParameters());

//build trainingDataset
Dataset trainingDataset = DatasetBuilder.parseFromTextFiles(dataset, textExtractor);
//Dataset trainingDataset = DatasetBuilder.parseFromTextLists(DatasetBuilder.stringListsFromTextFiles(dataset), textExtractor);

_fit(trainingDataset);
}

@Override
protected void _fit(Dataset trainingDataset) {
TextClassifier.TrainingParameters trainingParameters = knowledgeBase.getTrainingParameters();
DatabaseConfiguration dbConf = knowledgeBase.getDbConf();
Class dtClass = trainingParameters.getDataTransformerClass();

boolean transformData = (dtClass!=null);
if(transformData) {
dataTransformer = DataTransformer.newInstance(dtClass, dbName, dbConf);
dataTransformer.initializeTrainingConfiguration(trainingParameters.getDataTransformerTrainingParameters());
dataTransformer.transform(trainingDataset, true);
dataTransformer.normalize(trainingDataset);
dataTransformer.fit_transform(trainingDataset, trainingParameters.getDataTransformerTrainingParameters());
}

Class fsClass = trainingParameters.getFeatureSelectionClass();
Expand All @@ -138,39 +144,28 @@ public void train(Map<Object, URI> dataset) {
if(CategoricalFeatureSelection.TrainingParameters.class.isAssignableFrom(featureSelectionParameters.getClass())) {
((CategoricalFeatureSelection.TrainingParameters)featureSelectionParameters).setIgnoringNumericalFeatures(false); //this should be turned off in feature selection
}
featureSelection.initializeTrainingConfiguration(trainingParameters.getFeatureSelectionTrainingParameters());

//find the most popular features
featureSelection.fit(trainingDataset);
featureSelection.fit(trainingDataset, trainingParameters.getFeatureSelectionTrainingParameters());

//remove unnecessary features
featureSelection.transform(trainingDataset);
}

//initialize mlmodel
mlmodel = BaseMLmodel.newInstance(trainingParameters.getMLmodelClass(), dbName, dbConf);
mlmodel.initializeTrainingConfiguration(trainingParameters.getMLmodelTrainingParameters());


int k = trainingParameters.getkFolds();
if(k>1) {
//call k-fold cross validation and get the average accuracy
BaseMLmodel.ValidationMetrics averageValidationMetrics = (BaseMLmodel.ValidationMetrics) mlmodel.kFoldCrossValidation(trainingDataset, k);
//call k-fold cross validation and get the average accuracy
BaseMLmodel.ValidationMetrics averageValidationMetrics = (BaseMLmodel.ValidationMetrics) mlmodel.kFoldCrossValidation(trainingDataset, trainingParameters.getMLmodelTrainingParameters(), k);

//train the mlmodel on the whole dataset and pass as ValidationDataset the empty set
mlmodel.train(trainingDataset, new Dataset());
//train the mlmodel on the whole dataset
mlmodel.fit(trainingDataset, trainingParameters.getMLmodelTrainingParameters());

//set its ValidationMetrics to the average VP from k-fold cross validation
mlmodel.setValidationMetrics(averageValidationMetrics);
}
else { //k==1
Dataset validationDataset = trainingDataset;

boolean algorithmModifiesDataset = mlmodel.modifiesData();
if(algorithmModifiesDataset) {
validationDataset = DeepCopy.<Dataset>cloneObject(validationDataset);
}
mlmodel.train(trainingDataset, validationDataset);
}
//set its ValidationMetrics to the average VP from k-fold cross validation
mlmodel.setValidationMetrics(averageValidationMetrics);


if(transformData) {
dataTransformer.denormalize(trainingDataset); //optional denormalization
Expand Down Expand Up @@ -238,7 +233,7 @@ public List<AssociativeArray> predictProbabilities(List<String> text) {
return predictedClassProbabilities;
}

public BaseMLmodel.ValidationMetrics test(Map<Object, URI> dataset) {
public BaseMLmodel.ValidationMetrics validate(Map<Object, URI> dataset) {

//ensure db loaded
knowledgeBase.load();
Expand All @@ -251,7 +246,6 @@ public BaseMLmodel.ValidationMetrics test(Map<Object, URI> dataset) {

//build the testDataset
Dataset testDataset = DatasetBuilder.parseFromTextFiles(dataset, textExtractor);
//Dataset testDataset = DatasetBuilder.parseFromTextLists(DatasetBuilder.stringListsFromTextFiles(dataset), textExtractor);

Class dtClass = trainingParameters.getDataTransformerClass();

Expand All @@ -261,8 +255,7 @@ public BaseMLmodel.ValidationMetrics test(Map<Object, URI> dataset) {
dataTransformer = DataTransformer.newInstance(dtClass, dbName, dbConf);
}

dataTransformer.transform(testDataset, false);
dataTransformer.normalize(testDataset);
dataTransformer.transform(testDataset);
}

Class fsClass = trainingParameters.getFeatureSelectionClass();
Expand All @@ -284,7 +277,7 @@ public BaseMLmodel.ValidationMetrics test(Map<Object, URI> dataset) {
}

//call predict of the mlmodel for the new dataset
BaseMLmodel.ValidationMetrics vm = mlmodel.test(testDataset);
BaseMLmodel.ValidationMetrics vm = mlmodel.validate(testDataset);

if(transformData) {
dataTransformer.denormalize(testDataset); //optional denormization
Expand All @@ -295,7 +288,7 @@ public BaseMLmodel.ValidationMetrics test(Map<Object, URI> dataset) {

public BaseMLmodel.ValidationMetrics getValidationMetrics() {
if(mlmodel==null) {
test(new HashMap<>()); //this forces the loading of the algorithm
validate(new HashMap<>()); //this forces the loading of the algorithm
}
BaseMLmodel.ValidationMetrics vm = mlmodel.getValidationMetrics();

Expand Down Expand Up @@ -334,8 +327,7 @@ private Dataset getPredictions(List<String> text) {
if(dataTransformer==null) {
dataTransformer = DataTransformer.newInstance(dtClass, dbName, dbConf);
}
dataTransformer.transform(newData, false);
dataTransformer.normalize(newData);
dataTransformer.transform(newData);
}

Class fsClass = trainingParameters.getFeatureSelectionClass();
Expand Down
Expand Up @@ -60,10 +60,14 @@ public void fit(Dataset trainingData, TP trainingParameters) {
System.out.println("fit()");
}

initializeTrainingConfiguration(trainingParameters);
_fit(trainingData);
}

protected void initializeTrainingConfiguration(TP trainingParameters) {
//reset knowledge base
knowledgeBase.reinitialize();
knowledgeBase.setTrainingParameters(trainingParameters);
_fit(trainingData);
}

protected abstract void _fit(Dataset trainingData);
Expand Down
Expand Up @@ -213,20 +213,26 @@ protected void estimateModelParameters(Dataset trainingData) {
}

BaseMLclassifier mlclassifier = BaseMLmodel.newInstance(weakClassifierClass, dbName+knowledgeBase.getDbConf().getDBnameSeparator()+DB_INDICATOR+String.valueOf(t), knowledgeBase.getDbConf());
mlclassifier.initializeTrainingConfiguration(weakClassifierTrainingParameters);
boolean copyData = mlclassifier.modifiesData();

Dataset validationDataset = trainingData;
if(mlclassifier.modifiesData()) {

if(copyData) {
sampledTrainingDataset = DeepCopy.<Dataset>cloneObject(sampledTrainingDataset);
}
mlclassifier.fit(sampledTrainingDataset, weakClassifierTrainingParameters);
sampledTrainingDataset = null;


Dataset validationDataset = trainingData;
if(copyData) {
validationDataset = DeepCopy.<Dataset>cloneObject(validationDataset);
}
mlclassifier.train(sampledTrainingDataset, validationDataset);
mlclassifier.predict(validationDataset);
mlclassifier = null;


boolean stop = updateObservationAndClassifierWeights(validationDataset, observationWeights);

mlclassifier = null;
sampledTrainingDataset = null;
validationDataset = null;

if(stop==true) {
Expand Down
Expand Up @@ -225,7 +225,7 @@ protected VM validateModel(Dataset validationData) {
} catch (IllegalArgumentException | SecurityException | NoSuchMethodException | IllegalAccessException | InvocationTargetException ex) {
throw new RuntimeException(ex);
}
validationMetrics.setNormalResiduals( (normalResiduals)?0.0:1.0 ); //if the Lilliefors test rejects the H0 means that the normality hypothesis is rejected thus the residuals are not normal
validationMetrics.setNormalResiduals( (normalResiduals)?0.0:1.0 ); //if the Lilliefors validate rejects the H0 means that the normality hypothesis is rejected thus the residuals are not normal
errorList = null;

double SSR = 0.0;
Expand Down

0 comments on commit 53c4090

Please sign in to comment.