From eb4320230d0dd6392c028affa51f85e44cf977a3 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Wed, 21 Dec 2016 18:41:39 +0000 Subject: [PATCH] Removed the AbstractWrapper and Modeler inherits directly from AbstractTrainer. Created a TrainableBundle to keep track of the Trainables of Modeler, AbstractBoostingBagging and StepwiseRegression. Removed automatic save after fit, now save() must be called. --- CHANGELOG.md | 3 + TODO.txt | 2 +- .../applications/datamodeling/Modeler.java | 232 ++++++++++++++---- .../framework/applications/nlp/CETR.java | 5 +- .../applications/nlp/TextClassifier.java | 6 - .../datamodeling/ModelerTest.java | 5 +- .../applications/nlp/TextClassifierTest.java | 25 +- .../common/abstracts/AbstractTrainer.java | 3 - .../algorithms/AbstractBoostingBagging.java | 110 ++++++--- .../abstracts/wrappers/AbstractWrapper.java | 202 --------------- .../common/dataobjects/TrainableBundle.java | 114 +++++++++ .../validators/KFoldValidator.java | 2 +- .../regression/StepwiseRegression.java | 96 +++++--- .../BernoulliNaiveBayesTest.java | 1 + .../BinarizedNaiveBayesTest.java | 1 + .../classification/MaximumEntropyTest.java | 1 + .../MultinomialNaiveBayesTest.java | 15 +- .../classification/OrdinalRegressionTest.java | 19 +- .../classification/SoftMaxRegressionTest.java | 20 +- .../SupportVectorMachineTest.java | 15 +- .../clustering/GaussianDPMMTest.java | 7 +- .../HierarchicalAgglomerativeTest.java | 16 +- .../clustering/KmeansTest.java | 17 +- .../clustering/MultinomialDPMMTest.java | 1 + .../ensemblelearning/AdaboostTest.java | 18 +- .../BayesianEnsembleMethodTest.java | 18 +- .../BootstrapAggregatingTest.java | 16 +- .../categorical/ChisquareSelectTest.java | 2 + .../categorical/MutualInformationTest.java | 2 + .../featureselection/continuous/PCATest.java | 6 +- .../scorebased/TFIDFTest.java | 6 +- .../CollaborativeFilteringTest.java | 1 + .../MatrixLinearRegressionTest.java | 20 +- .../machinelearning/regression/NLMSTest.java | 20 +- .../regression/StepwiseRegressionTest.java | 10 +- .../LatentDirichletAllocationTest.java | 1 + 36 files changed, 581 insertions(+), 457 deletions(-) delete mode 100755 datumbox-framework-core/src/main/java/com/datumbox/framework/core/machinelearning/common/abstracts/wrappers/AbstractWrapper.java create mode 100755 datumbox-framework-core/src/main/java/com/datumbox/framework/core/machinelearning/common/dataobjects/TrainableBundle.java diff --git a/CHANGELOG.md b/CHANGELOG.md index 40339a41..78986cc3 100755 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -17,6 +17,9 @@ Version 0.8.0-SNAPSHOT - Build 20161220 - Removed all unnecessary passing of class objects from Stepwise Regression, Wrappers and Ensumble learning classes. - Removed the ignoringNumericalFeatures variable from AbstractCategoricalFeatureSelector. - Moved the static methods of Trainable to MLBuilder. This is the prefered method for loading and initializing models. +- Removed the AbstractWrapper and Modeler inherits directly from AbstractTrainer. +- Created a TrainableBundle to keep track of the Trainables of Modeler, AbstractBoostingBagging and StepwiseRegression. +- Removed automatic save after fit, now save() must be called. Version 0.7.1-SNAPSHOT - Build 20161217 --------------------------------------- diff --git a/TODO.txt b/TODO.txt index 7844bce1..5f6f8c11 100755 --- a/TODO.txt +++ b/TODO.txt @@ -1,8 +1,8 @@ CODE IMPROVEMENTS ================= +- Add save(dbName) method in the models. - Can we make the two constructors of the Trainers to call a common constructor to eliminate duplicate code? -- Add save(dbName) method in the models - Support of better Transformers (Zscore, decouple boolean transforming from numeric etc). - Write a ShuffleSplitValidator class similar to KFold. Perhaps we need a single Validator class and separate Splitters. diff --git a/datumbox-framework-applications/src/main/java/com/datumbox/framework/applications/datamodeling/Modeler.java b/datumbox-framework-applications/src/main/java/com/datumbox/framework/applications/datamodeling/Modeler.java index a631af9e..d79d41d7 100755 --- a/datumbox-framework-applications/src/main/java/com/datumbox/framework/applications/datamodeling/Modeler.java +++ b/datumbox-framework-applications/src/main/java/com/datumbox/framework/applications/datamodeling/Modeler.java @@ -17,14 +17,14 @@ import com.datumbox.framework.common.Configuration; import com.datumbox.framework.common.dataobjects.Dataframe; -import com.datumbox.framework.common.interfaces.Trainable; import com.datumbox.framework.common.persistentstorage.interfaces.DatabaseConnector; import com.datumbox.framework.core.machinelearning.MLBuilder; import com.datumbox.framework.core.machinelearning.common.abstracts.AbstractTrainer; import com.datumbox.framework.core.machinelearning.common.abstracts.datatransformers.AbstractTransformer; import com.datumbox.framework.core.machinelearning.common.abstracts.featureselectors.AbstractFeatureSelector; import com.datumbox.framework.core.machinelearning.common.abstracts.modelers.AbstractModeler; -import com.datumbox.framework.core.machinelearning.common.abstracts.wrappers.AbstractWrapper; +import com.datumbox.framework.core.machinelearning.common.dataobjects.TrainableBundle; +import com.datumbox.framework.core.machinelearning.common.interfaces.Parallelizable; /** * Modeler is a convenience class which can be used to train Machine Learning @@ -33,12 +33,14 @@ * * @author Vasilis Vryniotis */ -public class Modeler extends AbstractWrapper { - +public class Modeler extends AbstractTrainer implements Parallelizable { + + private TrainableBundle bundle = new TrainableBundle(); + /** * It contains all the Model Parameters which are learned during the training. */ - public static class ModelParameters extends AbstractWrapper.AbstractModelParameters { + public static class ModelParameters extends AbstractTrainer.AbstractModelParameters { private static final long serialVersionUID = 1L; /** @@ -48,15 +50,78 @@ public static class ModelParameters extends AbstractWrapper.AbstractModelParamet protected ModelParameters(DatabaseConnector dbc) { super(dbc); } - + } /** * It contains the Training Parameters of the Modeler. */ - public static class TrainingParameters extends AbstractWrapper.AbstractTrainingParameters { + public static class TrainingParameters extends AbstractTrainer.AbstractTrainingParameters { private static final long serialVersionUID = 1L; + //Parameter Objects + private AbstractTransformer.AbstractTrainingParameters dataTransformerTrainingParameters; + + private AbstractFeatureSelector.AbstractTrainingParameters featureSelectorTrainingParameters; + + private AbstractModeler.AbstractTrainingParameters modelerTrainingParameters; + + /** + * Getter for the Training Parameters of the Data Transformer. + * + * @return + */ + public AbstractTransformer.AbstractTrainingParameters getDataTransformerTrainingParameters() { + return dataTransformerTrainingParameters; + } + + /** + * Setter for the Training Parameters of the Data Transformer. Pass null + * for none. + * + * @param dataTransformerTrainingParameters + */ + public void setDataTransformerTrainingParameters(AbstractTransformer.AbstractTrainingParameters dataTransformerTrainingParameters) { + this.dataTransformerTrainingParameters = dataTransformerTrainingParameters; + } + + /** + * Getter for the Training Parameters of the Feature Selector. + * + * @return + */ + public AbstractFeatureSelector.AbstractTrainingParameters getFeatureSelectorTrainingParameters() { + return featureSelectorTrainingParameters; + } + + /** + * Setter for the Training Parameters of the Feature Selector. Pass null + * for none. + * + * @param featureSelectorTrainingParameters + */ + public void setFeatureSelectorTrainingParameters(AbstractFeatureSelector.AbstractTrainingParameters featureSelectorTrainingParameters) { + this.featureSelectorTrainingParameters = featureSelectorTrainingParameters; + } + + /** + * Getter for the Training Parameters of the Machine Learning modeler. + * + * @return + */ + public AbstractModeler.AbstractTrainingParameters getModelerTrainingParameters() { + return modelerTrainingParameters; + } + + /** + * Setter for the Training Parameters of the Machine Learning modeler. + * + * @param modelerTrainingParameters + */ + public void setModelerTrainingParameters(AbstractModeler.AbstractTrainingParameters modelerTrainingParameters) { + this.modelerTrainingParameters = modelerTrainingParameters; + } + } @@ -79,6 +144,21 @@ public Modeler(String dbName, Configuration conf) { super(dbName, conf); } + + private boolean parallelized = true; + + /** {@inheritDoc} */ + @Override + public boolean isParallelized() { + return parallelized; + } + + /** {@inheritDoc} */ + @Override + public void setParallelized(boolean parallelized) { + this.parallelized = parallelized; + } + /** * Generates predictions for the given dataset. * @@ -87,36 +167,24 @@ public Modeler(String dbName, Configuration conf) { public void predict(Dataframe newData) { logger.info("predict()"); - Modeler.TrainingParameters trainingParameters = knowledgeBase.getTrainingParameters(); - Configuration conf = knowledgeBase.getConf(); + //load all trainables on the bundles + initBundle(); - AbstractTrainer.AbstractTrainingParameters dtParams = trainingParameters.getDataTransformerTrainingParameters(); - boolean transformData = dtParams!=null; - if(transformData) { - if(dataTransformer==null) { - dataTransformer = MLBuilder.load(dtParams.getTClass(), dbName, conf); - } - setParallelized(dataTransformer); + //set the parallized flag to all algorithms + bundle.setParallelized(isParallelized()); + + //run the pipeline + AbstractTransformer dataTransformer = (AbstractTransformer) bundle.get("dataTransformer"); + if(dataTransformer != null) { dataTransformer.transform(newData); } - - AbstractTrainer.AbstractTrainingParameters fsParams = trainingParameters.getFeatureSelectorTrainingParameters(); - boolean selectFeatures = fsParams!=null; - if(selectFeatures) { - if(featureSelector==null) { - featureSelector = MLBuilder.load(fsParams.getTClass(), dbName, conf); - } - setParallelized(featureSelector); + AbstractFeatureSelector featureSelector = (AbstractFeatureSelector) bundle.get("featureSelector"); + if(featureSelector != null) { featureSelector.transform(newData); } - - if(modeler==null) { - modeler = MLBuilder.load(trainingParameters.getModelerTrainingParameters().getTClass(), dbName, conf); - } - setParallelized(modeler); + AbstractModeler modeler = (AbstractModeler) bundle.get("modeler"); modeler.predict(newData); - - if(transformData) { + if(dataTransformer != null) { dataTransformer.denormalize(newData); } } @@ -127,29 +195,101 @@ protected void _fit(Dataframe trainingData) { TrainingParameters trainingParameters = knowledgeBase.getTrainingParameters(); Configuration conf = knowledgeBase.getConf(); - AbstractTrainer.AbstractTrainingParameters dtParams = trainingParameters.getDataTransformerTrainingParameters(); - boolean transformData = dtParams!=null; - if(transformData) { + //reset previous entries on the bundle + resetBundle(); + + //initialize the parts of the pipeline + AbstractTransformer.AbstractTrainingParameters dtParams = trainingParameters.getDataTransformerTrainingParameters(); + AbstractTransformer dataTransformer = null; + if(dtParams != null) { dataTransformer = MLBuilder.create(dtParams, dbName, conf); - setParallelized(dataTransformer); - dataTransformer.fit_transform(trainingData); + bundle.put("dataTransformer", dataTransformer); } - AbstractTrainer.AbstractTrainingParameters fsParams = trainingParameters.getFeatureSelectorTrainingParameters(); - boolean selectFeatures = fsParams!=null; - if(selectFeatures) { + AbstractFeatureSelector.AbstractTrainingParameters fsParams = trainingParameters.getFeatureSelectorTrainingParameters(); + AbstractFeatureSelector featureSelector = null; + if(fsParams != null) { featureSelector = MLBuilder.create(fsParams, dbName, conf); - setParallelized(featureSelector); - featureSelector.fit_transform(trainingData); + bundle.put("featureSelector", featureSelector); } - AbstractTrainer.AbstractTrainingParameters mlParams = trainingParameters.getModelerTrainingParameters(); - modeler = MLBuilder.create(mlParams, dbName, conf); - setParallelized(modeler); + AbstractModeler.AbstractTrainingParameters mlParams = trainingParameters.getModelerTrainingParameters(); + AbstractModeler modeler = MLBuilder.create(mlParams, dbName, conf); + bundle.put("modeler", modeler); + + //set the parallized flag to all algorithms + bundle.setParallelized(isParallelized()); + + //run the pipeline + if(dataTransformer != null) { + dataTransformer.fit_transform(trainingData); + } + if(featureSelector != null) { + featureSelector.fit_transform(trainingData); + } modeler.fit(trainingData); - - if(transformData) { + if(dataTransformer != null) { dataTransformer.denormalize(trainingData); } } + + /** {@inheritDoc} */ + @Override + public void save() { + initBundle(); + bundle.save(); + super.save(); + } + + /** {@inheritDoc} */ + @Override + public void delete() { + initBundle(); + bundle.delete(); + super.delete(); + } + + /** {@inheritDoc} */ + @Override + public void close() { + initBundle(); + bundle.close(); + super.close(); + } + + private void resetBundle() { + bundle.delete(); + } + + private void initBundle() { + TrainingParameters trainingParameters = knowledgeBase.getTrainingParameters(); + Configuration conf = knowledgeBase.getConf(); + + if(!bundle.containsKey("dataTransformer")) { + AbstractTransformer.AbstractTrainingParameters dtParams = trainingParameters.getDataTransformerTrainingParameters(); + + AbstractTransformer dataTransformer = null; + if(dtParams != null) { + dataTransformer = MLBuilder.load(dtParams.getTClass(), dbName, conf); + } + bundle.put("dataTransformer", dataTransformer); + } + + if(!bundle.containsKey("featureSelector")) { + AbstractFeatureSelector.AbstractTrainingParameters fsParams = trainingParameters.getFeatureSelectorTrainingParameters(); + + AbstractFeatureSelector featureSelector = null; + if(fsParams != null) { + featureSelector = MLBuilder.load(fsParams.getTClass(), dbName, conf); + } + bundle.put("featureSelector", featureSelector); + } + + if(!bundle.containsKey("modeler")) { + AbstractModeler.AbstractTrainingParameters mlParams = trainingParameters.getModelerTrainingParameters(); + + bundle.put("modeler", MLBuilder.load(mlParams.getTClass(), dbName, conf)); + } + } + } diff --git a/datumbox-framework-applications/src/main/java/com/datumbox/framework/applications/nlp/CETR.java b/datumbox-framework-applications/src/main/java/com/datumbox/framework/applications/nlp/CETR.java index 9bfd3032..e0b52943 100755 --- a/datumbox-framework-applications/src/main/java/com/datumbox/framework/applications/nlp/CETR.java +++ b/datumbox-framework-applications/src/main/java/com/datumbox/framework/applications/nlp/CETR.java @@ -256,10 +256,7 @@ private void performClustering(Dataframe dataset, int numberOfClusters) { instance.fit(dataset); instance.predict(dataset); - //Map clusters = instance.getClusters(); - - instance.delete(); //delete immediately the result - //instance = null; + instance.close(); } private List calculateTTRlist(List rows) { diff --git a/datumbox-framework-applications/src/main/java/com/datumbox/framework/applications/nlp/TextClassifier.java b/datumbox-framework-applications/src/main/java/com/datumbox/framework/applications/nlp/TextClassifier.java index 73bf1c48..51fde4cd 100755 --- a/datumbox-framework-applications/src/main/java/com/datumbox/framework/applications/nlp/TextClassifier.java +++ b/datumbox-framework-applications/src/main/java/com/datumbox/framework/applications/nlp/TextClassifier.java @@ -20,15 +20,9 @@ import com.datumbox.framework.common.dataobjects.AssociativeArray; import com.datumbox.framework.common.dataobjects.Dataframe; import com.datumbox.framework.common.dataobjects.Record; -import com.datumbox.framework.common.interfaces.Trainable; import com.datumbox.framework.common.persistentstorage.interfaces.DatabaseConnector; import com.datumbox.framework.common.utilities.StringCleaner; import com.datumbox.framework.core.machinelearning.common.abstracts.AbstractTrainer; -import com.datumbox.framework.core.machinelearning.common.abstracts.datatransformers.AbstractTransformer; -import com.datumbox.framework.core.machinelearning.common.abstracts.featureselectors.AbstractCategoricalFeatureSelector; -import com.datumbox.framework.core.machinelearning.common.abstracts.featureselectors.AbstractFeatureSelector; -import com.datumbox.framework.core.machinelearning.common.abstracts.modelers.AbstractModeler; -import com.datumbox.framework.core.machinelearning.common.abstracts.wrappers.AbstractWrapper; import com.datumbox.framework.core.machinelearning.modelselection.metrics.ClassificationMetrics; import com.datumbox.framework.core.utilities.text.extractors.AbstractTextExtractor; diff --git a/datumbox-framework-applications/src/test/java/com/datumbox/framework/applications/datamodeling/ModelerTest.java b/datumbox-framework-applications/src/test/java/com/datumbox/framework/applications/datamodeling/ModelerTest.java index d0d13813..9bdd3725 100755 --- a/datumbox-framework-applications/src/test/java/com/datumbox/framework/applications/datamodeling/ModelerTest.java +++ b/datumbox-framework-applications/src/test/java/com/datumbox/framework/applications/datamodeling/ModelerTest.java @@ -74,6 +74,7 @@ public void testTrainAndValidate() { Modeler instance = MLBuilder.create(trainingParameters, dbName, conf); instance.fit(trainingData); + instance.save(); instance.close(); @@ -86,6 +87,7 @@ public void testTrainAndValidate() { double expResult2 = 0.8; assertEquals(expResult2, vm.getMacroF1(), Constants.DOUBLE_ACCURACY_HIGH); + trainingData.delete(); instance.close(); //instance = null; @@ -107,8 +109,7 @@ public void testTrainAndValidate() { assertEquals(expResult, result); instance.delete(); - - trainingData.delete(); + validationData.delete(); } diff --git a/datumbox-framework-applications/src/test/java/com/datumbox/framework/applications/nlp/TextClassifierTest.java b/datumbox-framework-applications/src/test/java/com/datumbox/framework/applications/nlp/TextClassifierTest.java index a83e2c52..ffc3888b 100755 --- a/datumbox-framework-applications/src/test/java/com/datumbox/framework/applications/nlp/TextClassifierTest.java +++ b/datumbox-framework-applications/src/test/java/com/datumbox/framework/applications/nlp/TextClassifierTest.java @@ -63,9 +63,7 @@ public void testTrainAndValidateBernoulliNaiveBayes() { fsParams.setRareFeatureThreshold(3); trainAndValidate( - BernoulliNaiveBayes.class, mlParams, - ChisquareSelect.class, fsParams, 0.8393075950598075 ); @@ -86,9 +84,7 @@ public void testTrainAndValidateBinarizedNaiveBayes() { fsParams.setRareFeatureThreshold(3); trainAndValidate( - BinarizedNaiveBayes.class, mlParams, - ChisquareSelect.class, fsParams, 0.8413587159387832 ); @@ -109,9 +105,7 @@ public void testTrainAndValidateMaximumEntropy() { fsParams.setRareFeatureThreshold(3); trainAndValidate( - MaximumEntropy.class, mlParams, - ChisquareSelect.class, fsParams, 0.9411031042128604 ); @@ -132,9 +126,7 @@ public void testTrainAndValidateMultinomialNaiveBayes() { fsParams.setRareFeatureThreshold(3); trainAndValidate( - MultinomialNaiveBayes.class, mlParams, - ChisquareSelect.class, fsParams, 0.8685865263692268 ); @@ -155,9 +147,7 @@ public void testTrainAndValidateOrdinalRegression() { fsParams.setRareFeatureThreshold(3); trainAndValidate( - OrdinalRegression.class, mlParams, - ChisquareSelect.class, fsParams, 0.8290058479532163 ); @@ -178,9 +168,7 @@ public void testTrainAndValidateSoftMaxRegression() { fsParams.setRareFeatureThreshold(3); trainAndValidate( - SoftMaxRegression.class, mlParams, - ChisquareSelect.class, fsParams, 0.7663106693454584 ); @@ -201,9 +189,7 @@ public void testTrainAndValidateSupportVectorMachine() { fsParams.setRareFeatureThreshold(3); trainAndValidate( - SupportVectorMachine.class, mlParams, - ChisquareSelect.class, fsParams, 0.9803846153846154 ); @@ -223,9 +209,7 @@ public void testTrainAndValidateMutualInformation() { fsParams.setRareFeatureThreshold(3); trainAndValidate( - MultinomialNaiveBayes.class, mlParams, - MutualInformation.class, fsParams, 0.8954671493044679 ); @@ -244,9 +228,7 @@ public void testTrainAndValidateTFIDF() { fsParams.setMaxFeatures(1000); trainAndValidate( - MultinomialNaiveBayes.class, mlParams, - TFIDF.class, fsParams, 0.80461962936161 ); @@ -257,15 +239,11 @@ public void testTrainAndValidateTFIDF() { * * @param * @param - * @param modelerClass * @param modelerTrainingParameters - * @param featureSelectorClass * @param featureSelectorTrainingParameters */ private void trainAndValidate( - Class modelerClass, ML.AbstractTrainingParameters modelerTrainingParameters, - Class featureSelectorClass, FS.AbstractTrainingParameters featureSelectorTrainingParameters, double expectedF1score) { Configuration conf = Configuration.getConfiguration(); @@ -302,6 +280,7 @@ private void TextClassifier instance = MLBuilder.create(trainingParameters, dbName, conf); instance.fit(dataset); + instance.save(); ClassificationMetrics vm = instance.validate(dataset); @@ -313,7 +292,7 @@ private void instance = MLBuilder.load(TextClassifier.class, dbName, conf); - Dataframe validationData = null; + Dataframe validationData; try { validationData = instance.predict(this.getClass().getClassLoader().getResource("datasets/sentimentAnalysis.unlabelled.txt").toURI()); } diff --git a/datumbox-framework-core/src/main/java/com/datumbox/framework/core/machinelearning/common/abstracts/AbstractTrainer.java b/datumbox-framework-core/src/main/java/com/datumbox/framework/core/machinelearning/common/abstracts/AbstractTrainer.java index 435c7551..80a890dc 100755 --- a/datumbox-framework-core/src/main/java/com/datumbox/framework/core/machinelearning/common/abstracts/AbstractTrainer.java +++ b/datumbox-framework-core/src/main/java/com/datumbox/framework/core/machinelearning/common/abstracts/AbstractTrainer.java @@ -161,9 +161,6 @@ public void fit(Dataframe trainingData) { knowledgeBase.clear(); _fit(trainingData); - - logger.info("Saving model"); - knowledgeBase.save(); //TODO: remove this } /** {@inheritDoc} */ diff --git a/datumbox-framework-core/src/main/java/com/datumbox/framework/core/machinelearning/common/abstracts/algorithms/AbstractBoostingBagging.java b/datumbox-framework-core/src/main/java/com/datumbox/framework/core/machinelearning/common/abstracts/algorithms/AbstractBoostingBagging.java index bdaf148c..b538cfc9 100755 --- a/datumbox-framework-core/src/main/java/com/datumbox/framework/core/machinelearning/common/abstracts/algorithms/AbstractBoostingBagging.java +++ b/datumbox-framework-core/src/main/java/com/datumbox/framework/core/machinelearning/common/abstracts/algorithms/AbstractBoostingBagging.java @@ -17,7 +17,6 @@ import com.datumbox.framework.common.Configuration; import com.datumbox.framework.common.dataobjects.*; -import com.datumbox.framework.common.interfaces.Trainable; import com.datumbox.framework.common.persistentstorage.interfaces.DatabaseConnector; import com.datumbox.framework.common.persistentstorage.interfaces.DatabaseConnector.MapType; import com.datumbox.framework.common.persistentstorage.interfaces.DatabaseConnector.StorageHint; @@ -25,6 +24,7 @@ import com.datumbox.framework.core.machinelearning.MLBuilder; import com.datumbox.framework.core.machinelearning.common.abstracts.AbstractTrainer; import com.datumbox.framework.core.machinelearning.common.abstracts.modelers.AbstractClassifier; +import com.datumbox.framework.core.machinelearning.common.dataobjects.TrainableBundle; import com.datumbox.framework.core.machinelearning.ensemblelearning.FixedCombinationRules; import com.datumbox.framework.core.statistics.descriptivestatistics.Descriptives; import com.datumbox.framework.core.statistics.sampling.SimpleRandomSampling; @@ -43,6 +43,8 @@ */ public abstract class AbstractBoostingBagging extends AbstractClassifier { + private TrainableBundle bundle = new TrainableBundle(); + private static final String DB_INDICATOR = "Cmp"; private static final int MAX_NUM_OF_RETRIES = 2; @@ -148,9 +150,10 @@ protected AbstractBoostingBagging(String dbName, Configuration conf) { /** {@inheritDoc} */ @Override protected void _predict(Dataframe newData) { - Class weakClassifierClass = knowledgeBase.getTrainingParameters().getWeakClassifierTrainingParameters().getTClass(); + //load all trainables on the bundles + initBundle(); + List weakClassifierWeights = knowledgeBase.getModelParameters().getWeakClassifierWeights(); - String prefix = dbName+knowledgeBase.getConf().getDbConfig().getDBnameSeparator()+DB_INDICATOR; //create a temporary map for the observed probabilities in training set DatabaseConnector dbc = knowledgeBase.getDbc(); @@ -165,9 +168,9 @@ protected void _predict(Dataframe newData) { AssociativeArray classifierWeightsArray = new AssociativeArray(); int totalWeakClassifiers = weakClassifierWeights.size(); for(int t=0;t classesSet = modelParameters.getClasses(); - - //first we need to find all the classes for(Record r : trainingData) { Object theClass=r.getY(); @@ -231,25 +236,31 @@ protected void _fit(Dataframe trainingData) { //training the weak classifiers int t=0; int retryCounter = 0; + String prefix = dbName+conf.getDbConfig().getDBnameSeparator()+DB_INDICATOR; while(t weakClassifierClass = trainingParameters.getWeakClassifierTrainingParameters().getTClass(); - //the number of weak classifiers is the minimum between the classifiers that were defined in training parameters AND the number of the weak classifiers that were kept +1 for the one that was abandoned due to high error - int totalWeakClassifiers = Math.min(modelParameters.getWeakClassifierWeights().size()+1, trainingParameters.getMaxWeakClassifiers()); + int totalWeakClassifiers = Math.min(modelParameters.getWeakClassifierWeights().size(), trainingParameters.getMaxWeakClassifiers()); for(int t=0;t - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package com.datumbox.framework.core.machinelearning.common.abstracts.wrappers; - -import com.datumbox.framework.common.Configuration; -import com.datumbox.framework.common.interfaces.Trainable; -import com.datumbox.framework.core.machinelearning.common.abstracts.AbstractTrainer; -import com.datumbox.framework.core.machinelearning.common.abstracts.datatransformers.AbstractTransformer; -import com.datumbox.framework.core.machinelearning.common.abstracts.featureselectors.AbstractFeatureSelector; -import com.datumbox.framework.core.machinelearning.common.abstracts.modelers.AbstractModeler; -import com.datumbox.framework.core.machinelearning.common.interfaces.Parallelizable; - -/** - * The AbstractWrapper is a trainable object that uses composition instead of inheritance - to extend the functionality of a AbstractModeler. It includes various internal objects - * such as Data Transformers, Feature Selectors and Machine Learning models which - * are combined in the training and prediction process. - * - * @author Vasilis Vryniotis - * @param - * @param - */ -public abstract class AbstractWrapper extends AbstractTrainer implements Parallelizable { - - /** - * The AbstractTransformer instance of the wrapper. - */ - protected AbstractTransformer dataTransformer = null; - - /** - * The AbstractFeatureSelector instance of the wrapper. - */ - protected AbstractFeatureSelector featureSelector = null; - - /** - * The Machine Learning model instance of the wrapper. - */ - protected AbstractModeler modeler = null; - - /** - * The AbstractTrainingParameters class stores the parameters that can be changed - before training the algorithm. - * - * @param
- * @param - * @param - */ - public static abstract class AbstractTrainingParameters
extends AbstractTrainer.AbstractTrainingParameters { - - //Parameter Objects - private DT.AbstractTrainingParameters dataTransformerTrainingParameters; - - private FS.AbstractTrainingParameters featureSelectorTrainingParameters; - - private ML.AbstractTrainingParameters modelerTrainingParameters; - - /** - * Getter for the Training Parameters of the Data Transformer. - * - * @return - */ - public DT.AbstractTrainingParameters getDataTransformerTrainingParameters() { - return dataTransformerTrainingParameters; - } - - /** - * Setter for the Training Parameters of the Data Transformer. Pass null - * for none. - * - * @param dataTransformerTrainingParameters - */ - public void setDataTransformerTrainingParameters(DT.AbstractTrainingParameters dataTransformerTrainingParameters) { - this.dataTransformerTrainingParameters = dataTransformerTrainingParameters; - } - - /** - * Getter for the Training Parameters of the Feature Selector. - * - * @return - */ - public FS.AbstractTrainingParameters getFeatureSelectorTrainingParameters() { - return featureSelectorTrainingParameters; - } - - /** - * Setter for the Training Parameters of the Feature Selector. Pass null - * for none. - * - * @param featureSelectorTrainingParameters - */ - public void setFeatureSelectorTrainingParameters(FS.AbstractTrainingParameters featureSelectorTrainingParameters) { - this.featureSelectorTrainingParameters = featureSelectorTrainingParameters; - } - - /** - * Getter for the Training Parameters of the Machine Learning modeler. - * - * @return - */ - public ML.AbstractTrainingParameters getModelerTrainingParameters() { - return modelerTrainingParameters; - } - - /** - * Setter for the Training Parameters of the Machine Learning modeler. - * - * @param modelerTrainingParameters - */ - public void setModelerTrainingParameters(ML.AbstractTrainingParameters modelerTrainingParameters) { - this.modelerTrainingParameters = modelerTrainingParameters; - } - - } - - /** - * @param dbName - * @param conf - * @param trainingParameters - * @see AbstractTrainer#AbstractTrainer(String, Configuration, AbstractTrainer.AbstractTrainingParameters) - */ - protected AbstractWrapper(String dbName, Configuration conf, TP trainingParameters) { - super(dbName, conf, trainingParameters); - } - - /** - * @param dbName - * @param conf - * @see AbstractTrainer#AbstractTrainer(String, Configuration) - */ - protected AbstractWrapper(String dbName, Configuration conf) { - super(dbName, conf); - } - - private boolean parallelized = true; - - /** {@inheritDoc} */ - @Override - public boolean isParallelized() { - return parallelized; - } - - /** {@inheritDoc} */ - @Override - public void setParallelized(boolean parallelized) { - this.parallelized = parallelized; - } - - /** {@inheritDoc} */ - @Override - public void delete() { - if(dataTransformer!=null) { - dataTransformer.delete(); - } - if(featureSelector!=null) { - featureSelector.delete(); - } - if(modeler!=null) { - modeler.delete(); - } - knowledgeBase.delete(); - } - - /** {@inheritDoc} */ - @Override - public void close() { - if(dataTransformer!=null) { - dataTransformer.close(); - } - if(featureSelector!=null) { - featureSelector.close(); - } - if(modeler!=null) { - modeler.close(); - } - knowledgeBase.close(); - } - - /** - * Updates the parallelized flag of the component if it supports it. This is - * done just before the train and predict methods. - * - * @param algorithm - */ - protected void setParallelized(Trainable algorithm) { - if (algorithm != null && algorithm instanceof Parallelizable) { - ((Parallelizable)algorithm).setParallelized(isParallelized()); - } - } -} diff --git a/datumbox-framework-core/src/main/java/com/datumbox/framework/core/machinelearning/common/dataobjects/TrainableBundle.java b/datumbox-framework-core/src/main/java/com/datumbox/framework/core/machinelearning/common/dataobjects/TrainableBundle.java new file mode 100755 index 00000000..f97c099f --- /dev/null +++ b/datumbox-framework-core/src/main/java/com/datumbox/framework/core/machinelearning/common/dataobjects/TrainableBundle.java @@ -0,0 +1,114 @@ +/** + * Copyright (C) 2013-2016 Vasilis Vryniotis + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.datumbox.framework.core.machinelearning.common.dataobjects; + +import com.datumbox.framework.common.interfaces.Trainable; +import com.datumbox.framework.core.machinelearning.common.interfaces.Parallelizable; + +import java.util.HashMap; +import java.util.Map; + +/** + * This object stores a bundle of Trainables and it is used by algorithms that have other Trainables internally. + * + * @author Vasilis Vryniotis + */ +public class TrainableBundle implements AutoCloseable { + + /** + * Keeps a reference of all the wrapped algorithms. + */ + private final Map bundle = new HashMap<>(); + + /** + * Returns whether the bundle contains the specified key. + * + * @param key + * @return + */ + public boolean containsKey(String key) { + return bundle.containsKey(key); + } + + /** + * Returns the trainable with the specific key or null if the key is missing. + * + * @param key + * @return + */ + public Trainable get(String key) { + return bundle.get(key); + } + + /** + * Puts the trainable in the bundle using a specific key and returns the previous entry or null. + * + * @param key + * @param value + * @return + */ + public Trainable put(String key, Trainable value) { + return bundle.put(key, value); + } + + /** + * Updates the parallelized flag of all wrapped algorithms. + * + * @param parallelized + */ + public void setParallelized(boolean parallelized) { + for(Trainable t : bundle.values()) { + if (t !=null && t instanceof Parallelizable) { + ((Parallelizable)t).setParallelized(parallelized); + } + } + } + + /** {@inheritDoc} */ + public void save() { + for(Trainable t : bundle.values()) { + if(t != null) { + t.save(); + } + } + } + + /** {@inheritDoc} */ + public void delete() { + for(Trainable t : bundle.values()) { + if(t != null) { + t.delete(); + } + } + bundle.clear(); + } + + /** {@inheritDoc} */ + @Override + public void close() { + for(Trainable t : bundle.values()) { + if(t != null) { + try { + t.close(); + } + catch (Exception ex) { + throw new RuntimeException(ex); + } + } + } + bundle.clear(); + } +} diff --git a/datumbox-framework-core/src/main/java/com/datumbox/framework/core/machinelearning/modelselection/validators/KFoldValidator.java b/datumbox-framework-core/src/main/java/com/datumbox/framework/core/machinelearning/modelselection/validators/KFoldValidator.java index bf63d69b..6d824351 100644 --- a/datumbox-framework-core/src/main/java/com/datumbox/framework/core/machinelearning/modelselection/validators/KFoldValidator.java +++ b/datumbox-framework-core/src/main/java/com/datumbox/framework/core/machinelearning/modelselection/validators/KFoldValidator.java @@ -133,7 +133,7 @@ public VM validate(Dataframe dataset, TrainingParameters trainingParameters) { //add the validationMetrics in the list validationMetricsList.add(entrySample); } - modeler.delete(); + modeler.close(); VM avgValidationMetrics = ValidationMetrics.newInstance(vmClass, validationMetricsList); diff --git a/datumbox-framework-core/src/main/java/com/datumbox/framework/core/machinelearning/regression/StepwiseRegression.java b/datumbox-framework-core/src/main/java/com/datumbox/framework/core/machinelearning/regression/StepwiseRegression.java index 44e0a0a6..3b1baecc 100755 --- a/datumbox-framework-core/src/main/java/com/datumbox/framework/core/machinelearning/regression/StepwiseRegression.java +++ b/datumbox-framework-core/src/main/java/com/datumbox/framework/core/machinelearning/regression/StepwiseRegression.java @@ -17,12 +17,12 @@ import com.datumbox.framework.common.Configuration; import com.datumbox.framework.common.dataobjects.Dataframe; -import com.datumbox.framework.common.interfaces.Trainable; import com.datumbox.framework.common.persistentstorage.interfaces.DatabaseConnector; import com.datumbox.framework.common.utilities.MapMethods; import com.datumbox.framework.core.machinelearning.MLBuilder; import com.datumbox.framework.core.machinelearning.common.abstracts.AbstractTrainer; import com.datumbox.framework.core.machinelearning.common.abstracts.modelers.AbstractRegressor; +import com.datumbox.framework.core.machinelearning.common.dataobjects.TrainableBundle; import com.datumbox.framework.core.machinelearning.common.interfaces.StepwiseCompatible; import java.util.HashSet; @@ -39,6 +39,8 @@ */ public class StepwiseRegression extends AbstractRegressor { + private TrainableBundle bundle = new TrainableBundle(); + /** {@inheritDoc} */ public static class ModelParameters extends AbstractRegressor.AbstractModelParameters { private static final long serialVersionUID = 1L; @@ -150,31 +152,27 @@ protected StepwiseRegression(String dbName, Configuration conf) { super(dbName, conf); } - /** {@inheritDoc} */ - @Override - public void delete() { - loadRegressor().delete(); - super.delete(); - } - - /** {@inheritDoc} */ - @Override - public void close() { - loadRegressor().close(); - super.close(); - } - /** {@inheritDoc} */ @Override protected void _predict(Dataframe newData) { - loadRegressor().predict(newData); + //load all trainables on the bundles + initBundle(); + + //run the pipeline + AbstractRegressor mlregressor = (AbstractRegressor) bundle.get("mlregressor"); + mlregressor.predict(newData); } /** {@inheritDoc} */ @Override protected void _fit(Dataframe trainingData) { TrainingParameters trainingParameters = knowledgeBase.getTrainingParameters(); - + Configuration conf = knowledgeBase.getConf(); + + //reset previous entries on the bundle + resetBundle(); + + //perform stepwise Integer maxIterations = trainingParameters.getMaxIterations(); if(maxIterations==null) { maxIterations = Integer.MAX_VALUE; @@ -214,35 +212,67 @@ protected void _fit(Dataframe trainingData) { } //once we have the dataset has been cleared from the unnecessary columns train the model once again - AbstractRegressor mlregressor = createRegressor(); - + AbstractRegressor mlregressor = MLBuilder.create( + knowledgeBase.getTrainingParameters().getRegressionTrainingParameters(), + dbName, + conf + ); mlregressor.fit(copiedTrainingData); + bundle.put("mlregressor", mlregressor); + copiedTrainingData.delete(); } - private AbstractRegressor createRegressor() { - return MLBuilder.create( - knowledgeBase.getTrainingParameters().getRegressionTrainingParameters(), - dbName, - knowledgeBase.getConf() - ); + /** {@inheritDoc} */ + @Override + public void save() { + initBundle(); + bundle.save(); + super.save(); } - private AbstractRegressor loadRegressor() { - return MLBuilder.load( - knowledgeBase.getTrainingParameters().getRegressionTrainingParameters().getTClass(), - dbName, - knowledgeBase.getConf() - ); + /** {@inheritDoc} */ + @Override + public void delete() { + initBundle(); + bundle.delete(); + super.delete(); + } + + /** {@inheritDoc} */ + @Override + public void close() { + initBundle(); + bundle.close(); + super.close(); + } + + private void resetBundle() { + bundle.delete(); + } + + private void initBundle() { + TrainingParameters trainingParameters = knowledgeBase.getTrainingParameters(); + Configuration conf = knowledgeBase.getConf(); + + if(!bundle.containsKey("mlregressor")) { + AbstractTrainingParameters mlParams = trainingParameters.getRegressionTrainingParameters(); + + bundle.put("mlregressor", MLBuilder.load(mlParams.getTClass(), dbName, conf)); + } } private Map runRegression(Dataframe trainingData) { - AbstractRegressor mlregressor = createRegressor(); + AbstractRegressor mlregressor = MLBuilder.create( + knowledgeBase.getTrainingParameters().getRegressionTrainingParameters(), + dbName, + knowledgeBase.getConf() + ); mlregressor.fit(trainingData); Map pvalues = ((StepwiseCompatible)mlregressor).getFeaturePvalues(); - mlregressor.delete(); + mlregressor.close(); return pvalues; } diff --git a/datumbox-framework-core/src/test/java/com/datumbox/framework/core/machinelearning/classification/BernoulliNaiveBayesTest.java b/datumbox-framework-core/src/test/java/com/datumbox/framework/core/machinelearning/classification/BernoulliNaiveBayesTest.java index 109f8c5c..83b1aa08 100755 --- a/datumbox-framework-core/src/test/java/com/datumbox/framework/core/machinelearning/classification/BernoulliNaiveBayesTest.java +++ b/datumbox-framework-core/src/test/java/com/datumbox/framework/core/machinelearning/classification/BernoulliNaiveBayesTest.java @@ -58,6 +58,7 @@ public void testPredict() { BernoulliNaiveBayes instance = MLBuilder.create(new BernoulliNaiveBayes.TrainingParameters(), dbName, conf); instance.fit(trainingData); + instance.save(); instance.close(); //instance = null; diff --git a/datumbox-framework-core/src/test/java/com/datumbox/framework/core/machinelearning/classification/BinarizedNaiveBayesTest.java b/datumbox-framework-core/src/test/java/com/datumbox/framework/core/machinelearning/classification/BinarizedNaiveBayesTest.java index 843cb3db..ab480f25 100755 --- a/datumbox-framework-core/src/test/java/com/datumbox/framework/core/machinelearning/classification/BinarizedNaiveBayesTest.java +++ b/datumbox-framework-core/src/test/java/com/datumbox/framework/core/machinelearning/classification/BinarizedNaiveBayesTest.java @@ -58,6 +58,7 @@ public void testPredict() { BinarizedNaiveBayes instance = MLBuilder.create(new BinarizedNaiveBayes.TrainingParameters(), dbName, conf); instance.fit(trainingData); + instance.save(); instance.close(); //instance = null; diff --git a/datumbox-framework-core/src/test/java/com/datumbox/framework/core/machinelearning/classification/MaximumEntropyTest.java b/datumbox-framework-core/src/test/java/com/datumbox/framework/core/machinelearning/classification/MaximumEntropyTest.java index aeef37eb..de4d4f4b 100755 --- a/datumbox-framework-core/src/test/java/com/datumbox/framework/core/machinelearning/classification/MaximumEntropyTest.java +++ b/datumbox-framework-core/src/test/java/com/datumbox/framework/core/machinelearning/classification/MaximumEntropyTest.java @@ -62,6 +62,7 @@ public void testPredict() { MaximumEntropy instance = MLBuilder.create(param, dbName, conf); instance.fit(trainingData); + instance.save(); instance.close(); //instance = null; diff --git a/datumbox-framework-core/src/test/java/com/datumbox/framework/core/machinelearning/classification/MultinomialNaiveBayesTest.java b/datumbox-framework-core/src/test/java/com/datumbox/framework/core/machinelearning/classification/MultinomialNaiveBayesTest.java index 750868e7..66f7c1fa 100755 --- a/datumbox-framework-core/src/test/java/com/datumbox/framework/core/machinelearning/classification/MultinomialNaiveBayesTest.java +++ b/datumbox-framework-core/src/test/java/com/datumbox/framework/core/machinelearning/classification/MultinomialNaiveBayesTest.java @@ -59,7 +59,7 @@ public void testPredict() { DummyXYMinMaxNormalizer df = MLBuilder.create(new DummyXYMinMaxNormalizer.TrainingParameters(), dbName, conf); df.fit_transform(trainingData); - df.transform(validationData); + df.save(); MultinomialNaiveBayes.TrainingParameters param = new MultinomialNaiveBayes.TrainingParameters(); param.setMultiProbabilityWeighted(true); @@ -67,6 +67,10 @@ public void testPredict() { MultinomialNaiveBayes instance = MLBuilder.create(param, dbName, conf); instance.fit(trainingData); + instance.save(); + + df.denormalize(trainingData); + trainingData.delete(); instance.close(); df.close(); @@ -75,10 +79,10 @@ public void testPredict() { df = MLBuilder.load(DummyXYMinMaxNormalizer.class, dbName, conf); instance = MLBuilder.load(MultinomialNaiveBayes.class, dbName, conf); - + + df.transform(validationData); instance.predict(validationData); - - df.denormalize(trainingData); + df.denormalize(validationData); @@ -95,8 +99,7 @@ public void testPredict() { df.delete(); instance.delete(); - - trainingData.delete(); + validationData.delete(); } diff --git a/datumbox-framework-core/src/test/java/com/datumbox/framework/core/machinelearning/classification/OrdinalRegressionTest.java b/datumbox-framework-core/src/test/java/com/datumbox/framework/core/machinelearning/classification/OrdinalRegressionTest.java index 492683af..70b8e6b0 100755 --- a/datumbox-framework-core/src/test/java/com/datumbox/framework/core/machinelearning/classification/OrdinalRegressionTest.java +++ b/datumbox-framework-core/src/test/java/com/datumbox/framework/core/machinelearning/classification/OrdinalRegressionTest.java @@ -59,7 +59,7 @@ public void testPredict() { DummyXMinMaxNormalizer df = MLBuilder.create(new DummyXMinMaxNormalizer.TrainingParameters(), dbName, conf); df.fit_transform(trainingData); - df.transform(validationData); + df.save(); OrdinalRegression.TrainingParameters param = new OrdinalRegression.TrainingParameters(); @@ -69,7 +69,11 @@ public void testPredict() { OrdinalRegression instance = MLBuilder.create(param, dbName, conf); instance.fit(trainingData); - + instance.save(); + + df.denormalize(trainingData); + trainingData.delete(); + instance.close(); df.close(); //instance = null; @@ -77,11 +81,11 @@ public void testPredict() { df = MLBuilder.load(DummyXMinMaxNormalizer.class, dbName, conf); instance = MLBuilder.load(OrdinalRegression.class, dbName, conf); - + + df.transform(validationData); instance.predict(validationData); - - df.denormalize(trainingData); + df.denormalize(validationData); Map expResult = new HashMap<>(); @@ -96,8 +100,7 @@ public void testPredict() { df.delete(); instance.delete(); - - trainingData.delete(); + validationData.delete(); } @@ -138,7 +141,7 @@ public void testKFoldCrossValidation() { double result = vm.getMacroF1(); assertEquals(expResult, result, Constants.DOUBLE_ACCURACY_HIGH); - df.delete(); + df.close(); trainingData.delete(); } diff --git a/datumbox-framework-core/src/test/java/com/datumbox/framework/core/machinelearning/classification/SoftMaxRegressionTest.java b/datumbox-framework-core/src/test/java/com/datumbox/framework/core/machinelearning/classification/SoftMaxRegressionTest.java index f5ddc71b..095fe851 100755 --- a/datumbox-framework-core/src/test/java/com/datumbox/framework/core/machinelearning/classification/SoftMaxRegressionTest.java +++ b/datumbox-framework-core/src/test/java/com/datumbox/framework/core/machinelearning/classification/SoftMaxRegressionTest.java @@ -60,7 +60,7 @@ public void testPredict() { DummyXYMinMaxNormalizer df = MLBuilder.create(new DummyXYMinMaxNormalizer.TrainingParameters(), dbName, conf); df.fit_transform(trainingData); - df.transform(validationData); + df.save(); SoftMaxRegression.TrainingParameters param = new SoftMaxRegression.TrainingParameters(); @@ -70,7 +70,12 @@ public void testPredict() { SoftMaxRegression instance = MLBuilder.create(param, dbName, conf); instance.fit(trainingData); - + instance.save(); + + df.denormalize(trainingData); + trainingData.delete(); + + instance.close(); df.close(); //instance = null; @@ -78,10 +83,10 @@ public void testPredict() { df = MLBuilder.load(DummyXYMinMaxNormalizer.class, dbName, conf); instance = MLBuilder.load(SoftMaxRegression.class, dbName, conf); - + + df.transform(validationData); instance.predict(validationData); - - df.denormalize(trainingData); + df.denormalize(validationData); @@ -97,8 +102,7 @@ public void testPredict() { df.delete(); instance.delete(); - - trainingData.delete(); + validationData.delete(); } @@ -135,7 +139,7 @@ public void testKFoldCrossValidation() { double expResult = 0.7557492507492508; double result = vm.getMacroF1(); assertEquals(expResult, result, Constants.DOUBLE_ACCURACY_HIGH); - df.delete(); + df.close(); trainingData.delete(); } diff --git a/datumbox-framework-core/src/test/java/com/datumbox/framework/core/machinelearning/classification/SupportVectorMachineTest.java b/datumbox-framework-core/src/test/java/com/datumbox/framework/core/machinelearning/classification/SupportVectorMachineTest.java index 0f24375d..4d5c4206 100755 --- a/datumbox-framework-core/src/test/java/com/datumbox/framework/core/machinelearning/classification/SupportVectorMachineTest.java +++ b/datumbox-framework-core/src/test/java/com/datumbox/framework/core/machinelearning/classification/SupportVectorMachineTest.java @@ -60,7 +60,7 @@ public void testPredict() { String dbName = this.getClass().getSimpleName(); DummyXYMinMaxNormalizer df = MLBuilder.create(new DummyXYMinMaxNormalizer.TrainingParameters(), dbName, conf); df.fit_transform(trainingData); - df.transform(validationData); + df.save(); SupportVectorMachine.TrainingParameters param = new SupportVectorMachine.TrainingParameters(); param.getSvmParameter().kernel_type = svm_parameter.RBF; @@ -68,6 +68,10 @@ public void testPredict() { SupportVectorMachine instance = MLBuilder.create(param, dbName, conf); instance.fit(trainingData); + instance.save(); + + df.denormalize(trainingData); + trainingData.delete(); instance.close(); df.close(); @@ -76,11 +80,11 @@ public void testPredict() { df = MLBuilder.load(DummyXYMinMaxNormalizer.class, dbName, conf); instance = MLBuilder.load(SupportVectorMachine.class, dbName, conf); - + + df.transform(validationData); instance.predict(validationData); - - df.denormalize(trainingData); + df.denormalize(validationData); @@ -96,8 +100,7 @@ public void testPredict() { df.delete(); instance.delete(); - - trainingData.delete(); + validationData.delete(); } diff --git a/datumbox-framework-core/src/test/java/com/datumbox/framework/core/machinelearning/clustering/GaussianDPMMTest.java b/datumbox-framework-core/src/test/java/com/datumbox/framework/core/machinelearning/clustering/GaussianDPMMTest.java index a98417b6..eebe45fe 100755 --- a/datumbox-framework-core/src/test/java/com/datumbox/framework/core/machinelearning/clustering/GaussianDPMMTest.java +++ b/datumbox-framework-core/src/test/java/com/datumbox/framework/core/machinelearning/clustering/GaussianDPMMTest.java @@ -62,7 +62,9 @@ public void testPredict() { GaussianDPMM instance = MLBuilder.create(param, dbName, conf); instance.fit(trainingData); - + instance.save(); + + trainingData.delete(); instance.close(); //instance = null; instance = MLBuilder.load(GaussianDPMM.class, dbName, conf); @@ -75,8 +77,7 @@ public void testPredict() { assertEquals(expResult, result, Constants.DOUBLE_ACCURACY_HIGH); instance.delete(); - - trainingData.delete(); + validationData.delete(); } diff --git a/datumbox-framework-core/src/test/java/com/datumbox/framework/core/machinelearning/clustering/HierarchicalAgglomerativeTest.java b/datumbox-framework-core/src/test/java/com/datumbox/framework/core/machinelearning/clustering/HierarchicalAgglomerativeTest.java index 7a57e25b..8eb2e516 100755 --- a/datumbox-framework-core/src/test/java/com/datumbox/framework/core/machinelearning/clustering/HierarchicalAgglomerativeTest.java +++ b/datumbox-framework-core/src/test/java/com/datumbox/framework/core/machinelearning/clustering/HierarchicalAgglomerativeTest.java @@ -55,7 +55,7 @@ public void testPredict() { DummyXYMinMaxNormalizer df = MLBuilder.create(new DummyXYMinMaxNormalizer.TrainingParameters(), dbName, conf); df.fit_transform(trainingData); - df.transform(validationData); + df.save(); HierarchicalAgglomerative.TrainingParameters param = new HierarchicalAgglomerative.TrainingParameters(); @@ -66,7 +66,10 @@ public void testPredict() { HierarchicalAgglomerative instance = MLBuilder.create(param, dbName, conf); instance.fit(trainingData); - + instance.save(); + + df.denormalize(trainingData); + trainingData.delete(); instance.close(); df.close(); @@ -76,10 +79,10 @@ public void testPredict() { df = MLBuilder.load(DummyXYMinMaxNormalizer.class, dbName, conf); instance = MLBuilder.load(HierarchicalAgglomerative.class, dbName, conf); + df.transform(validationData); instance.predict(validationData); ClusteringMetrics vm = new ClusteringMetrics(validationData); - - df.denormalize(trainingData); + df.denormalize(validationData); double expResult = 1.0; @@ -88,8 +91,7 @@ public void testPredict() { df.delete(); instance.delete(); - - trainingData.delete(); + validationData.delete(); } @@ -133,7 +135,7 @@ public void testKFoldCrossValidation() { double result = vm.getPurity(); assertEquals(expResult, result, Constants.DOUBLE_ACCURACY_HIGH); - df.delete(); + df.close(); trainingData.delete(); } diff --git a/datumbox-framework-core/src/test/java/com/datumbox/framework/core/machinelearning/clustering/KmeansTest.java b/datumbox-framework-core/src/test/java/com/datumbox/framework/core/machinelearning/clustering/KmeansTest.java index d84cb18a..04f734d3 100755 --- a/datumbox-framework-core/src/test/java/com/datumbox/framework/core/machinelearning/clustering/KmeansTest.java +++ b/datumbox-framework-core/src/test/java/com/datumbox/framework/core/machinelearning/clustering/KmeansTest.java @@ -54,9 +54,7 @@ public void testPredict() { String dbName = this.getClass().getSimpleName(); DummyXYMinMaxNormalizer df = MLBuilder.create(new DummyXYMinMaxNormalizer.TrainingParameters(), dbName, conf); df.fit_transform(trainingData); - - df.transform(validationData); - + df.save(); Kmeans.TrainingParameters param = new Kmeans.TrainingParameters(); @@ -70,7 +68,10 @@ public void testPredict() { Kmeans instance = MLBuilder.create(param, dbName, conf); instance.fit(trainingData); - + instance.save(); + + df.denormalize(trainingData); + trainingData.delete(); instance.close(); df.close(); @@ -80,10 +81,11 @@ public void testPredict() { df = MLBuilder.load(DummyXYMinMaxNormalizer.class, dbName, conf); instance = MLBuilder.load(Kmeans.class, dbName, conf); + + df.transform(validationData); instance.predict(validationData); ClusteringMetrics vm = new ClusteringMetrics(validationData); - df.denormalize(trainingData); df.denormalize(validationData); double expResult = 1.0; @@ -92,8 +94,7 @@ public void testPredict() { df.delete(); instance.delete(); - - trainingData.delete(); + validationData.delete(); } @@ -141,7 +142,7 @@ public void testKFoldCrossValidation() { double result = vm.getPurity(); assertEquals(expResult, result, Constants.DOUBLE_ACCURACY_HIGH); - df.delete(); + df.close(); trainingData.delete(); } diff --git a/datumbox-framework-core/src/test/java/com/datumbox/framework/core/machinelearning/clustering/MultinomialDPMMTest.java b/datumbox-framework-core/src/test/java/com/datumbox/framework/core/machinelearning/clustering/MultinomialDPMMTest.java index 57b720c9..c5bb4546 100755 --- a/datumbox-framework-core/src/test/java/com/datumbox/framework/core/machinelearning/clustering/MultinomialDPMMTest.java +++ b/datumbox-framework-core/src/test/java/com/datumbox/framework/core/machinelearning/clustering/MultinomialDPMMTest.java @@ -60,6 +60,7 @@ public void testPredict() { MultinomialDPMM instance = MLBuilder.create(param, dbName, conf); instance.fit(trainingData); + instance.save(); instance.close(); //instance = null; diff --git a/datumbox-framework-core/src/test/java/com/datumbox/framework/core/machinelearning/ensemblelearning/AdaboostTest.java b/datumbox-framework-core/src/test/java/com/datumbox/framework/core/machinelearning/ensemblelearning/AdaboostTest.java index edc995a3..da08c72e 100755 --- a/datumbox-framework-core/src/test/java/com/datumbox/framework/core/machinelearning/ensemblelearning/AdaboostTest.java +++ b/datumbox-framework-core/src/test/java/com/datumbox/framework/core/machinelearning/ensemblelearning/AdaboostTest.java @@ -59,9 +59,7 @@ public void testPredict() { String dbName = this.getClass().getSimpleName(); DummyXYMinMaxNormalizer df = MLBuilder.create(new DummyXYMinMaxNormalizer.TrainingParameters(), dbName, conf); df.fit_transform(trainingData); - - df.transform(validationData); - + df.save(); Adaboost.TrainingParameters param = new Adaboost.TrainingParameters(); @@ -77,7 +75,10 @@ public void testPredict() { Adaboost instance = MLBuilder.create(param, dbName, conf); instance.fit(trainingData); - + instance.save(); + + df.denormalize(trainingData); + trainingData.delete(); instance.close(); df.close(); @@ -86,13 +87,13 @@ public void testPredict() { df = MLBuilder.load(DummyXYMinMaxNormalizer.class, dbName, conf); instance = MLBuilder.load(Adaboost.class, dbName, conf); - + + df.transform(validationData); instance.predict(validationData); - - df.denormalize(trainingData); + df.denormalize(validationData); Map expResult = new HashMap<>(); @@ -107,8 +108,7 @@ public void testPredict() { df.delete(); instance.delete(); - - trainingData.delete(); + validationData.delete(); } diff --git a/datumbox-framework-core/src/test/java/com/datumbox/framework/core/machinelearning/ensemblelearning/BayesianEnsembleMethodTest.java b/datumbox-framework-core/src/test/java/com/datumbox/framework/core/machinelearning/ensemblelearning/BayesianEnsembleMethodTest.java index 6ee191ce..3cd40a35 100755 --- a/datumbox-framework-core/src/test/java/com/datumbox/framework/core/machinelearning/ensemblelearning/BayesianEnsembleMethodTest.java +++ b/datumbox-framework-core/src/test/java/com/datumbox/framework/core/machinelearning/ensemblelearning/BayesianEnsembleMethodTest.java @@ -54,15 +54,17 @@ public void testPredict() { String dbName = this.getClass().getSimpleName(); DummyXYMinMaxNormalizer df = MLBuilder.create(new DummyXYMinMaxNormalizer.TrainingParameters(), dbName, conf); df.fit_transform(trainingData); - - df.transform(validationData); + df.save(); BayesianEnsembleMethod instance = MLBuilder.create(new BayesianEnsembleMethod.TrainingParameters(), dbName, conf); instance.fit(trainingData); - - + instance.save(); + + df.denormalize(trainingData); + + trainingData.delete(); instance.close(); df.close(); //instance = null; @@ -70,10 +72,11 @@ public void testPredict() { df = MLBuilder.load(DummyXYMinMaxNormalizer.class, dbName, conf); instance = MLBuilder.load(BayesianEnsembleMethod.class, dbName, conf); + + df.transform(validationData); instance.predict(validationData); - - df.denormalize(trainingData); + df.denormalize(validationData); Map expResult = new HashMap<>(); @@ -88,8 +91,7 @@ public void testPredict() { df.delete(); instance.delete(); - - trainingData.delete(); + validationData.delete(); } diff --git a/datumbox-framework-core/src/test/java/com/datumbox/framework/core/machinelearning/ensemblelearning/BootstrapAggregatingTest.java b/datumbox-framework-core/src/test/java/com/datumbox/framework/core/machinelearning/ensemblelearning/BootstrapAggregatingTest.java index 345a121c..c59da5b9 100755 --- a/datumbox-framework-core/src/test/java/com/datumbox/framework/core/machinelearning/ensemblelearning/BootstrapAggregatingTest.java +++ b/datumbox-framework-core/src/test/java/com/datumbox/framework/core/machinelearning/ensemblelearning/BootstrapAggregatingTest.java @@ -59,7 +59,7 @@ public void testPredict() { String dbName = this.getClass().getSimpleName(); DummyXYMinMaxNormalizer df = MLBuilder.create(new DummyXYMinMaxNormalizer.TrainingParameters(), dbName, conf); df.fit_transform(trainingData); - df.transform(validationData); + df.save(); BootstrapAggregating.TrainingParameters param = new BootstrapAggregating.TrainingParameters(); @@ -76,7 +76,10 @@ public void testPredict() { BootstrapAggregating instance = MLBuilder.create(param, dbName, conf); instance.fit(trainingData); - + instance.save(); + + df.denormalize(trainingData); + trainingData.delete(); instance.close(); df.close(); @@ -85,10 +88,10 @@ public void testPredict() { df = MLBuilder.load(DummyXYMinMaxNormalizer.class, dbName, conf); instance = MLBuilder.load(BootstrapAggregating.class, dbName, conf); - + + df.transform(validationData); instance.predict(validationData); - - df.denormalize(trainingData); + df.denormalize(validationData); Map expResult = new HashMap<>(); @@ -103,8 +106,7 @@ public void testPredict() { df.delete(); instance.delete(); - - trainingData.delete(); + validationData.delete(); } diff --git a/datumbox-framework-core/src/test/java/com/datumbox/framework/core/machinelearning/featureselection/categorical/ChisquareSelectTest.java b/datumbox-framework-core/src/test/java/com/datumbox/framework/core/machinelearning/featureselection/categorical/ChisquareSelectTest.java index 0e167e70..297375c3 100755 --- a/datumbox-framework-core/src/test/java/com/datumbox/framework/core/machinelearning/featureselection/categorical/ChisquareSelectTest.java +++ b/datumbox-framework-core/src/test/java/com/datumbox/framework/core/machinelearning/featureselection/categorical/ChisquareSelectTest.java @@ -58,6 +58,8 @@ public void testSelectFeatures() { instance.fit_transform(trainingData); + instance.save(); + instance.close(); //instance = null; diff --git a/datumbox-framework-core/src/test/java/com/datumbox/framework/core/machinelearning/featureselection/categorical/MutualInformationTest.java b/datumbox-framework-core/src/test/java/com/datumbox/framework/core/machinelearning/featureselection/categorical/MutualInformationTest.java index df5e9846..82c6274f 100755 --- a/datumbox-framework-core/src/test/java/com/datumbox/framework/core/machinelearning/featureselection/categorical/MutualInformationTest.java +++ b/datumbox-framework-core/src/test/java/com/datumbox/framework/core/machinelearning/featureselection/categorical/MutualInformationTest.java @@ -57,6 +57,8 @@ public void testSelectFeatures() { instance.fit_transform(trainingData); + instance.save(); + instance.close(); //instance = null; diff --git a/datumbox-framework-core/src/test/java/com/datumbox/framework/core/machinelearning/featureselection/continuous/PCATest.java b/datumbox-framework-core/src/test/java/com/datumbox/framework/core/machinelearning/featureselection/continuous/PCATest.java index 2850b8fe..19a79850 100755 --- a/datumbox-framework-core/src/test/java/com/datumbox/framework/core/machinelearning/featureselection/continuous/PCATest.java +++ b/datumbox-framework-core/src/test/java/com/datumbox/framework/core/machinelearning/featureselection/continuous/PCATest.java @@ -59,6 +59,9 @@ public void testSelectFeatures() { PCA instance = MLBuilder.create(param, dbName, conf); instance.fit_transform(originalData); + instance.save(); + + originalData.delete(); instance.close(); //instance = null; @@ -85,8 +88,7 @@ public void testSelectFeatures() { } instance.delete(); - - originalData.delete(); + validationdata.delete(); expResult.delete(); } diff --git a/datumbox-framework-core/src/test/java/com/datumbox/framework/core/machinelearning/featureselection/scorebased/TFIDFTest.java b/datumbox-framework-core/src/test/java/com/datumbox/framework/core/machinelearning/featureselection/scorebased/TFIDFTest.java index b83576b8..831b5b44 100755 --- a/datumbox-framework-core/src/test/java/com/datumbox/framework/core/machinelearning/featureselection/scorebased/TFIDFTest.java +++ b/datumbox-framework-core/src/test/java/com/datumbox/framework/core/machinelearning/featureselection/scorebased/TFIDFTest.java @@ -57,6 +57,9 @@ public void testSelectFeatures() { TFIDF instance = MLBuilder.create(param, dbName, conf); instance.fit_transform(trainingData); + instance.save(); + + trainingData.delete(); instance.close(); //instance = null; @@ -69,8 +72,7 @@ public void testSelectFeatures() { Set result = validationData.getXDataTypes().keySet(); assertEquals(expResult, result); instance.delete(); - - trainingData.delete(); + validationData.delete(); } diff --git a/datumbox-framework-core/src/test/java/com/datumbox/framework/core/machinelearning/recommendersystem/CollaborativeFilteringTest.java b/datumbox-framework-core/src/test/java/com/datumbox/framework/core/machinelearning/recommendersystem/CollaborativeFilteringTest.java index 0450cf95..7578c4a0 100755 --- a/datumbox-framework-core/src/test/java/com/datumbox/framework/core/machinelearning/recommendersystem/CollaborativeFilteringTest.java +++ b/datumbox-framework-core/src/test/java/com/datumbox/framework/core/machinelearning/recommendersystem/CollaborativeFilteringTest.java @@ -60,6 +60,7 @@ public void testPredict() { CollaborativeFiltering instance = MLBuilder.create(param, dbName, conf); instance.fit(trainingData); + instance.save(); instance.close(); //instance = null; diff --git a/datumbox-framework-core/src/test/java/com/datumbox/framework/core/machinelearning/regression/MatrixLinearRegressionTest.java b/datumbox-framework-core/src/test/java/com/datumbox/framework/core/machinelearning/regression/MatrixLinearRegressionTest.java index 52b6f010..644d5804 100755 --- a/datumbox-framework-core/src/test/java/com/datumbox/framework/core/machinelearning/regression/MatrixLinearRegressionTest.java +++ b/datumbox-framework-core/src/test/java/com/datumbox/framework/core/machinelearning/regression/MatrixLinearRegressionTest.java @@ -55,13 +55,15 @@ public void testPredict() { String dbName = this.getClass().getSimpleName(); XYMinMaxNormalizer df = MLBuilder.create(new XYMinMaxNormalizer.TrainingParameters(), dbName, conf); df.fit_transform(trainingData); - - df.transform(validationData); + df.save(); MatrixLinearRegression instance = MLBuilder.create(new MatrixLinearRegression.TrainingParameters(), dbName, conf); instance.fit(trainingData); - + instance.save(); + + df.denormalize(trainingData); + trainingData.delete(); instance.close(); df.close(); @@ -70,12 +72,13 @@ public void testPredict() { df = MLBuilder.load(XYMinMaxNormalizer.class, dbName, conf); instance = MLBuilder.load(MatrixLinearRegression.class, dbName, conf); - + + + df.transform(validationData); instance.predict(validationData); - - df.denormalize(trainingData); + df.denormalize(validationData); @@ -85,8 +88,7 @@ public void testPredict() { df.delete(); instance.delete(); - - trainingData.delete(); + validationData.delete(); } @@ -122,7 +124,7 @@ public void testKFoldCrossValidation() { double result = vm.getRSquare(); assertEquals(expResult, result, Constants.DOUBLE_ACCURACY_HIGH); - df.delete(); + df.close(); trainingData.delete(); } diff --git a/datumbox-framework-core/src/test/java/com/datumbox/framework/core/machinelearning/regression/NLMSTest.java b/datumbox-framework-core/src/test/java/com/datumbox/framework/core/machinelearning/regression/NLMSTest.java index 975f4908..91cf60ba 100755 --- a/datumbox-framework-core/src/test/java/com/datumbox/framework/core/machinelearning/regression/NLMSTest.java +++ b/datumbox-framework-core/src/test/java/com/datumbox/framework/core/machinelearning/regression/NLMSTest.java @@ -55,8 +55,7 @@ public void testPredict() { String dbName = this.getClass().getSimpleName(); DummyXYMinMaxNormalizer df = MLBuilder.create(new DummyXYMinMaxNormalizer.TrainingParameters(), dbName, conf); df.fit_transform(trainingData); - - df.transform(validationData); + df.save(); NLMS.TrainingParameters param = new NLMS.TrainingParameters(); @@ -67,7 +66,10 @@ public void testPredict() { NLMS instance = MLBuilder.create(param, dbName, conf); instance.fit(trainingData); - + instance.save(); + + df.denormalize(trainingData); + trainingData.delete(); instance.close(); df.close(); @@ -76,10 +78,11 @@ public void testPredict() { df = MLBuilder.load(DummyXYMinMaxNormalizer.class, dbName, conf); instance = MLBuilder.load(NLMS.class, dbName, conf); + + df.transform(validationData); instance.predict(validationData); - - df.denormalize(trainingData); + df.denormalize(validationData); for(Record r : validationData) { @@ -88,8 +91,7 @@ public void testPredict() { df.delete(); instance.delete(); - - trainingData.delete(); + validationData.delete(); } @@ -123,7 +125,7 @@ public void testKFoldCrossValidation() { PCA featureSelector = MLBuilder.create(featureSelectorParameters, dbName, conf); featureSelector.fit_transform(trainingData); - featureSelector.delete(); + featureSelector.close(); @@ -141,7 +143,7 @@ public void testKFoldCrossValidation() { double result = vm.getRSquare(); assertEquals(expResult, result, Constants.DOUBLE_ACCURACY_HIGH); - df.delete(); + df.close(); trainingData.delete(); } diff --git a/datumbox-framework-core/src/test/java/com/datumbox/framework/core/machinelearning/regression/StepwiseRegressionTest.java b/datumbox-framework-core/src/test/java/com/datumbox/framework/core/machinelearning/regression/StepwiseRegressionTest.java index d53f1879..4d8e07a8 100755 --- a/datumbox-framework-core/src/test/java/com/datumbox/framework/core/machinelearning/regression/StepwiseRegressionTest.java +++ b/datumbox-framework-core/src/test/java/com/datumbox/framework/core/machinelearning/regression/StepwiseRegressionTest.java @@ -52,6 +52,7 @@ public void testPredict() { DummyXYMinMaxNormalizer df = MLBuilder.create(new DummyXYMinMaxNormalizer.TrainingParameters(), dbName, conf); df.fit_transform(trainingData); + df.save(); StepwiseRegression.TrainingParameters param = new StepwiseRegression.TrainingParameters(); param.setAout(0.05); @@ -62,8 +63,10 @@ public void testPredict() { StepwiseRegression instance = MLBuilder.create(param, dbName, conf); instance.fit(trainingData); + instance.save(); df.denormalize(trainingData); + trainingData.delete(); instance.close(); @@ -72,9 +75,9 @@ public void testPredict() { //df = null; df = MLBuilder.load(DummyXYMinMaxNormalizer.class, dbName, conf); - df.transform(validationData); - instance = MLBuilder.load(StepwiseRegression.class, dbName, conf); + + df.transform(validationData); instance.predict(validationData); df.denormalize(validationData); @@ -85,8 +88,7 @@ public void testPredict() { df.delete(); instance.delete(); - - trainingData.delete(); + validationData.delete(); } diff --git a/datumbox-framework-core/src/test/java/com/datumbox/framework/core/machinelearning/topicmodeling/LatentDirichletAllocationTest.java b/datumbox-framework-core/src/test/java/com/datumbox/framework/core/machinelearning/topicmodeling/LatentDirichletAllocationTest.java index ebec9c06..b1a5a269 100755 --- a/datumbox-framework-core/src/test/java/com/datumbox/framework/core/machinelearning/topicmodeling/LatentDirichletAllocationTest.java +++ b/datumbox-framework-core/src/test/java/com/datumbox/framework/core/machinelearning/topicmodeling/LatentDirichletAllocationTest.java @@ -79,6 +79,7 @@ public void testPredict() { LatentDirichletAllocation lda = MLBuilder.create(trainingParameters, dbName, conf); lda.fit(trainingData); + lda.save(); lda.close(); lda = MLBuilder.load(LatentDirichletAllocation.class, dbName, conf);