From 5c4f6dd4ddd312e5d8aa98158538de03e091e2b5 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Sun, 18 Dec 2016 20:03:33 +0000 Subject: [PATCH] Rewriting the logic of validators and validation metrics. --- TODO.txt | 1 + .../applications/datamodeling/Modeler.java | 1 - .../applications/nlp/TextClassifier.java | 3 +- .../datamodeling/ModelerTest.java | 2 +- .../applications/nlp/TextClassifierTest.java | 5 +- .../classification/MaximumEntropy.java | 1 - .../clustering/HierarchicalAgglomerative.java | 1 - .../algorithms/AbstractBoostingBagging.java | 1 - .../algorithms/AbstractNaiveBayes.java | 1 - .../modelers/AbstractClassifier.java | 15 +- .../abstracts/modelers/AbstractClusterer.java | 14 +- .../modelers/AbstractRecommender.java | 14 +- .../abstracts/modelers/AbstractRegressor.java | 14 +- ...actValidator.java => AbstractMetrics.java} | 33 +- .../common/interfaces/ValidationMetrics.java | 55 ++- .../metrics/ClassificationMetrics.java | 291 +++++++++++++ .../metrics/ClusteringMetrics.java | 171 ++++++++ .../metrics/LinearRegressionMetrics.java | 253 +++++++++++ .../metrics/RecommendationMetrics.java | 80 ++++ .../splitters}/TemporaryKFold.java | 29 +- .../validators/ClassifierValidator.java | 383 ----------------- .../validators/ClustererValidator.java | 215 ---------- .../validators/LinearRegressionValidator.java | 392 ------------------ .../validators/RMSEValidator.java | 99 ----- .../text/extractors/NgramsExtractor.java | 2 +- .../BernoulliNaiveBayesTest.java | 8 +- .../BinarizedNaiveBayesTest.java | 8 +- .../classification/MaximumEntropyTest.java | 8 +- .../MultinomialNaiveBayesTest.java | 8 +- .../classification/OrdinalRegressionTest.java | 8 +- .../classification/SoftMaxRegressionTest.java | 8 +- .../SupportVectorMachineTest.java | 8 +- .../clustering/GaussianDPMMTest.java | 14 +- .../HierarchicalAgglomerativeTest.java | 14 +- .../clustering/KmeansTest.java | 13 +- .../clustering/MultinomialDPMMTest.java | 13 +- .../ensemblelearning/AdaboostTest.java | 8 +- .../BootstrapAggregatingTest.java | 8 +- .../CollaborativeFilteringTest.java | 4 +- .../MatrixLinearRegressionTest.java | 8 +- .../machinelearning/regression/NLMSTest.java | 8 +- .../LatentDirichletAllocationTest.java | 4 +- 42 files changed, 966 insertions(+), 1260 deletions(-) rename datumbox-framework-core/src/main/java/com/datumbox/framework/core/machinelearning/common/abstracts/validators/{AbstractValidator.java => AbstractMetrics.java} (56%) create mode 100755 datumbox-framework-core/src/main/java/com/datumbox/framework/core/machinelearning/modelselection/metrics/ClassificationMetrics.java create mode 100755 datumbox-framework-core/src/main/java/com/datumbox/framework/core/machinelearning/modelselection/metrics/ClusteringMetrics.java create mode 100755 datumbox-framework-core/src/main/java/com/datumbox/framework/core/machinelearning/modelselection/metrics/LinearRegressionMetrics.java create mode 100644 datumbox-framework-core/src/main/java/com/datumbox/framework/core/machinelearning/modelselection/metrics/RecommendationMetrics.java rename datumbox-framework-core/src/main/java/com/datumbox/framework/core/machinelearning/{validators => modelselection/splitters}/TemporaryKFold.java (81%) delete mode 100755 datumbox-framework-core/src/main/java/com/datumbox/framework/core/machinelearning/validators/ClassifierValidator.java delete mode 100755 datumbox-framework-core/src/main/java/com/datumbox/framework/core/machinelearning/validators/ClustererValidator.java delete mode 100755 datumbox-framework-core/src/main/java/com/datumbox/framework/core/machinelearning/validators/LinearRegressionValidator.java delete mode 100644 datumbox-framework-core/src/main/java/com/datumbox/framework/core/machinelearning/validators/RMSEValidator.java diff --git a/TODO.txt b/TODO.txt index 3362e481..6606d129 100755 --- a/TODO.txt +++ b/TODO.txt @@ -1,6 +1,7 @@ CODE IMPROVEMENTS ================= +- All ValidationMetrics should hava a serialization number - Validation (Validators, Validation Metrics, KnowledgeBase etc) and metrics need to move out of the model - Add save() load() methods in the models - Support of better Transformers (Zscore, decouple boolean transforming from numeric etc). diff --git a/datumbox-framework-applications/src/main/java/com/datumbox/framework/applications/datamodeling/Modeler.java b/datumbox-framework-applications/src/main/java/com/datumbox/framework/applications/datamodeling/Modeler.java index 6e202da9..311ed772 100755 --- a/datumbox-framework-applications/src/main/java/com/datumbox/framework/applications/datamodeling/Modeler.java +++ b/datumbox-framework-applications/src/main/java/com/datumbox/framework/applications/datamodeling/Modeler.java @@ -24,7 +24,6 @@ import com.datumbox.framework.core.machinelearning.common.abstracts.featureselectors.AbstractFeatureSelector; import com.datumbox.framework.core.machinelearning.common.abstracts.modelers.AbstractModeler; import com.datumbox.framework.core.machinelearning.common.abstracts.wrappers.AbstractWrapper; -import com.datumbox.framework.core.machinelearning.common.interfaces.ValidationMetrics; /** * Modeler is a convenience class which can be used to train Machine Learning diff --git a/datumbox-framework-applications/src/main/java/com/datumbox/framework/applications/nlp/TextClassifier.java b/datumbox-framework-applications/src/main/java/com/datumbox/framework/applications/nlp/TextClassifier.java index af4ce1cd..ea157505 100755 --- a/datumbox-framework-applications/src/main/java/com/datumbox/framework/applications/nlp/TextClassifier.java +++ b/datumbox-framework-applications/src/main/java/com/datumbox/framework/applications/nlp/TextClassifier.java @@ -28,7 +28,6 @@ import com.datumbox.framework.core.machinelearning.common.abstracts.featureselectors.AbstractFeatureSelector; import com.datumbox.framework.core.machinelearning.common.abstracts.modelers.AbstractModeler; import com.datumbox.framework.core.machinelearning.common.abstracts.wrappers.AbstractWrapper; -import com.datumbox.framework.core.machinelearning.common.interfaces.ValidationMetrics; import com.datumbox.framework.core.utilities.text.extractors.AbstractTextExtractor; import java.net.URI; @@ -36,7 +35,7 @@ import java.util.Map; /** - * TextClassifier is a convenience class which can be used to train Text Classification + * TextClassifier is a convenience class which can be used to train Text ClassificationMetrics * models. It is a wrapper class which automatically takes care of the text parsing, tokenization, feature selection and modeler training processes. It takes as input either a Dataframe object or multiple text files (one for each category) with diff --git a/datumbox-framework-applications/src/test/java/com/datumbox/framework/applications/datamodeling/ModelerTest.java b/datumbox-framework-applications/src/test/java/com/datumbox/framework/applications/datamodeling/ModelerTest.java index 746c4dd3..8ef669f3 100755 --- a/datumbox-framework-applications/src/test/java/com/datumbox/framework/applications/datamodeling/ModelerTest.java +++ b/datumbox-framework-applications/src/test/java/com/datumbox/framework/applications/datamodeling/ModelerTest.java @@ -77,7 +77,7 @@ public void testTrainAndValidate() { /* //TODO: restore this test - ClassifierValidator.ValidationMetrics vm = instance.validate(trainingData); + ClassificationMetrics.Metrics vm = instance.validate(trainingData); double expResult2 = 0.8; Assert.assertEquals(expResult2, vm.getMacroF1(), Constants.DOUBLE_ACCURACY_HIGH); diff --git a/datumbox-framework-applications/src/test/java/com/datumbox/framework/applications/nlp/TextClassifierTest.java b/datumbox-framework-applications/src/test/java/com/datumbox/framework/applications/nlp/TextClassifierTest.java index ab9fe1c3..8d585405 100755 --- a/datumbox-framework-applications/src/test/java/com/datumbox/framework/applications/nlp/TextClassifierTest.java +++ b/datumbox-framework-applications/src/test/java/com/datumbox/framework/applications/nlp/TextClassifierTest.java @@ -21,13 +21,10 @@ import com.datumbox.framework.core.machinelearning.classification.*; import com.datumbox.framework.core.machinelearning.common.abstracts.featureselectors.AbstractFeatureSelector; import com.datumbox.framework.core.machinelearning.common.abstracts.modelers.AbstractClassifier; -import com.datumbox.framework.core.machinelearning.common.interfaces.ValidationMetrics; import com.datumbox.framework.core.machinelearning.featureselection.categorical.ChisquareSelect; import com.datumbox.framework.core.machinelearning.featureselection.categorical.MutualInformation; import com.datumbox.framework.core.machinelearning.featureselection.scorebased.TFIDF; -import com.datumbox.framework.core.machinelearning.validators.ClassifierValidator; import com.datumbox.framework.core.utilities.text.extractors.NgramsExtractor; -import com.datumbox.framework.tests.Constants; import com.datumbox.framework.tests.abstracts.AbstractTest; import org.junit.Test; @@ -317,7 +314,7 @@ private void /* //TODO: restore this test - ClassifierValidator.ValidationMetrics vm = instance.validate(dataset); + ClassificationMetrics.Metrics vm = instance.validate(dataset); assertEquals(expectedF1score, vm.getMacroF1(), Constants.DOUBLE_ACCURACY_HIGH); */ instance.close(); diff --git a/datumbox-framework-core/src/main/java/com/datumbox/framework/core/machinelearning/classification/MaximumEntropy.java b/datumbox-framework-core/src/main/java/com/datumbox/framework/core/machinelearning/classification/MaximumEntropy.java index 97dbe98a..551fa2d7 100755 --- a/datumbox-framework-core/src/main/java/com/datumbox/framework/core/machinelearning/classification/MaximumEntropy.java +++ b/datumbox-framework-core/src/main/java/com/datumbox/framework/core/machinelearning/classification/MaximumEntropy.java @@ -30,7 +30,6 @@ import com.datumbox.framework.core.machinelearning.common.abstracts.modelers.AbstractClassifier; import com.datumbox.framework.core.machinelearning.common.interfaces.PredictParallelizable; import com.datumbox.framework.core.machinelearning.common.interfaces.TrainParallelizable; -import com.datumbox.framework.core.machinelearning.validators.ClassifierValidator; import com.datumbox.framework.core.statistics.descriptivestatistics.Descriptives; import java.util.Arrays; diff --git a/datumbox-framework-core/src/main/java/com/datumbox/framework/core/machinelearning/clustering/HierarchicalAgglomerative.java b/datumbox-framework-core/src/main/java/com/datumbox/framework/core/machinelearning/clustering/HierarchicalAgglomerative.java index 1fd0eb8b..a5f4c9d9 100755 --- a/datumbox-framework-core/src/main/java/com/datumbox/framework/core/machinelearning/clustering/HierarchicalAgglomerative.java +++ b/datumbox-framework-core/src/main/java/com/datumbox/framework/core/machinelearning/clustering/HierarchicalAgglomerative.java @@ -29,7 +29,6 @@ import com.datumbox.framework.core.machinelearning.common.abstracts.modelers.AbstractClusterer; import com.datumbox.framework.core.machinelearning.common.interfaces.PredictParallelizable; import com.datumbox.framework.core.machinelearning.common.interfaces.TrainParallelizable; -import com.datumbox.framework.core.machinelearning.validators.ClustererValidator; import com.datumbox.framework.core.mathematics.distances.Distance; import com.datumbox.framework.core.statistics.descriptivestatistics.Descriptives; diff --git a/datumbox-framework-core/src/main/java/com/datumbox/framework/core/machinelearning/common/abstracts/algorithms/AbstractBoostingBagging.java b/datumbox-framework-core/src/main/java/com/datumbox/framework/core/machinelearning/common/abstracts/algorithms/AbstractBoostingBagging.java index 3ab391e5..9ef2eb8f 100755 --- a/datumbox-framework-core/src/main/java/com/datumbox/framework/core/machinelearning/common/abstracts/algorithms/AbstractBoostingBagging.java +++ b/datumbox-framework-core/src/main/java/com/datumbox/framework/core/machinelearning/common/abstracts/algorithms/AbstractBoostingBagging.java @@ -24,7 +24,6 @@ import com.datumbox.framework.common.utilities.MapMethods; import com.datumbox.framework.core.machinelearning.common.abstracts.AbstractTrainer; import com.datumbox.framework.core.machinelearning.common.abstracts.modelers.AbstractClassifier; -import com.datumbox.framework.core.machinelearning.validators.ClassifierValidator; import com.datumbox.framework.core.machinelearning.ensemblelearning.FixedCombinationRules; import com.datumbox.framework.core.statistics.descriptivestatistics.Descriptives; import com.datumbox.framework.core.statistics.sampling.SimpleRandomSampling; diff --git a/datumbox-framework-core/src/main/java/com/datumbox/framework/core/machinelearning/common/abstracts/algorithms/AbstractNaiveBayes.java b/datumbox-framework-core/src/main/java/com/datumbox/framework/core/machinelearning/common/abstracts/algorithms/AbstractNaiveBayes.java index 1a5370dd..24e0dd37 100755 --- a/datumbox-framework-core/src/main/java/com/datumbox/framework/core/machinelearning/common/abstracts/algorithms/AbstractNaiveBayes.java +++ b/datumbox-framework-core/src/main/java/com/datumbox/framework/core/machinelearning/common/abstracts/algorithms/AbstractNaiveBayes.java @@ -30,7 +30,6 @@ import com.datumbox.framework.core.machinelearning.common.abstracts.modelers.AbstractClassifier; import com.datumbox.framework.core.machinelearning.common.interfaces.PredictParallelizable; import com.datumbox.framework.core.machinelearning.common.interfaces.TrainParallelizable; -import com.datumbox.framework.core.machinelearning.validators.ClassifierValidator; import com.datumbox.framework.core.statistics.descriptivestatistics.Descriptives; import java.util.*; diff --git a/datumbox-framework-core/src/main/java/com/datumbox/framework/core/machinelearning/common/abstracts/modelers/AbstractClassifier.java b/datumbox-framework-core/src/main/java/com/datumbox/framework/core/machinelearning/common/abstracts/modelers/AbstractClassifier.java index 6b1f855a..1a29da4b 100755 --- a/datumbox-framework-core/src/main/java/com/datumbox/framework/core/machinelearning/common/abstracts/modelers/AbstractClassifier.java +++ b/datumbox-framework-core/src/main/java/com/datumbox/framework/core/machinelearning/common/abstracts/modelers/AbstractClassifier.java @@ -20,9 +20,8 @@ import com.datumbox.framework.common.dataobjects.Dataframe; import com.datumbox.framework.common.persistentstorage.interfaces.DatabaseConnector; import com.datumbox.framework.common.utilities.MapMethods; -import com.datumbox.framework.core.machinelearning.common.abstracts.validators.AbstractValidator; -import com.datumbox.framework.core.machinelearning.validators.ClassifierValidator; -import com.datumbox.framework.core.machinelearning.validators.TemporaryKFold; +import com.datumbox.framework.core.machinelearning.modelselection.metrics.ClassificationMetrics; +import com.datumbox.framework.core.machinelearning.modelselection.splitters.TemporaryKFold; import java.util.*; @@ -103,17 +102,17 @@ protected Object getSelectedClassFromClassScores(AssociativeArray predictionScor //TODO: remove this once we create the save/load - public ClassifierValidator.ValidationMetrics validate(Dataframe testingData) { + public ClassificationMetrics validate(Dataframe testingData) { logger.info("validate()"); predict(testingData); - return new ClassifierValidator().validate(testingData); + return new ClassificationMetrics(testingData); } //TODO: remove this once we create the save/load - public ClassifierValidator.ValidationMetrics kFoldCrossValidation(Dataframe trainingData, TP trainingParameters, int k) { - logger.info("kFoldCrossValidation()"); + public ClassificationMetrics kFoldCrossValidation(Dataframe trainingData, TP trainingParameters, int k) { + logger.info("validate()"); - return new TemporaryKFold<>(new ClassifierValidator()).kFoldCrossValidation(trainingData, k, dbName, knowledgeBase.getConf(), this.getClass(), trainingParameters); + return new TemporaryKFold<>(ClassificationMetrics.class).validate(trainingData, k, dbName, knowledgeBase.getConf(), this.getClass(), trainingParameters); } } \ No newline at end of file diff --git a/datumbox-framework-core/src/main/java/com/datumbox/framework/core/machinelearning/common/abstracts/modelers/AbstractClusterer.java b/datumbox-framework-core/src/main/java/com/datumbox/framework/core/machinelearning/common/abstracts/modelers/AbstractClusterer.java index ddbfbacf..e0811269 100755 --- a/datumbox-framework-core/src/main/java/com/datumbox/framework/core/machinelearning/common/abstracts/modelers/AbstractClusterer.java +++ b/datumbox-framework-core/src/main/java/com/datumbox/framework/core/machinelearning/common/abstracts/modelers/AbstractClusterer.java @@ -23,8 +23,8 @@ import com.datumbox.framework.common.persistentstorage.interfaces.DatabaseConnector.MapType; import com.datumbox.framework.common.persistentstorage.interfaces.DatabaseConnector.StorageHint; import com.datumbox.framework.core.machinelearning.common.interfaces.Cluster; -import com.datumbox.framework.core.machinelearning.validators.ClustererValidator; -import com.datumbox.framework.core.machinelearning.validators.TemporaryKFold; +import com.datumbox.framework.core.machinelearning.modelselection.metrics.ClusteringMetrics; +import com.datumbox.framework.core.machinelearning.modelselection.splitters.TemporaryKFold; import java.util.*; @@ -222,17 +222,17 @@ public Map getClusters() { } //TODO: remove this once we create the save/load - public ClustererValidator.ValidationMetrics validate(Dataframe testingData) { + public ClusteringMetrics validate(Dataframe testingData) { logger.info("validate()"); predict(testingData); - return new ClustererValidator().validate(testingData); + return new ClusteringMetrics(testingData); } //TODO: remove this once we create the save/load - public ClustererValidator.ValidationMetrics kFoldCrossValidation(Dataframe trainingData, TP trainingParameters, int k) { - logger.info("kFoldCrossValidation()"); + public ClusteringMetrics kFoldCrossValidation(Dataframe trainingData, TP trainingParameters, int k) { + logger.info("validate()"); - return new TemporaryKFold<>(new ClustererValidator()).kFoldCrossValidation(trainingData, k, dbName, knowledgeBase.getConf(), this.getClass(), trainingParameters); + return new TemporaryKFold<>(ClusteringMetrics.class).validate(trainingData, k, dbName, knowledgeBase.getConf(), this.getClass(), trainingParameters); } } \ No newline at end of file diff --git a/datumbox-framework-core/src/main/java/com/datumbox/framework/core/machinelearning/common/abstracts/modelers/AbstractRecommender.java b/datumbox-framework-core/src/main/java/com/datumbox/framework/core/machinelearning/common/abstracts/modelers/AbstractRecommender.java index 67f68d77..cba82d5c 100755 --- a/datumbox-framework-core/src/main/java/com/datumbox/framework/core/machinelearning/common/abstracts/modelers/AbstractRecommender.java +++ b/datumbox-framework-core/src/main/java/com/datumbox/framework/core/machinelearning/common/abstracts/modelers/AbstractRecommender.java @@ -17,8 +17,8 @@ import com.datumbox.framework.common.Configuration; import com.datumbox.framework.common.dataobjects.Dataframe; -import com.datumbox.framework.core.machinelearning.validators.RMSEValidator; -import com.datumbox.framework.core.machinelearning.validators.TemporaryKFold; +import com.datumbox.framework.core.machinelearning.modelselection.metrics.RecommendationMetrics; +import com.datumbox.framework.core.machinelearning.modelselection.splitters.TemporaryKFold; /** * Abstract Class for recommender algorithms. @@ -41,17 +41,17 @@ protected AbstractRecommender(String dbName, Configuration conf, Class mpCla } //TODO: remove this once we create the save/load - public RMSEValidator.ValidationMetrics validate(Dataframe testingData) { + public RecommendationMetrics validate(Dataframe testingData) { logger.info("validate()"); predict(testingData); - return new RMSEValidator().validate(testingData); + return new RecommendationMetrics(testingData); } //TODO: remove this once we create the save/load - public RMSEValidator.ValidationMetrics kFoldCrossValidation(Dataframe trainingData, TP trainingParameters, int k) { - logger.info("kFoldCrossValidation()"); + public RecommendationMetrics kFoldCrossValidation(Dataframe trainingData, TP trainingParameters, int k) { + logger.info("validate()"); - return new TemporaryKFold<>(new RMSEValidator()).kFoldCrossValidation(trainingData, k, dbName, knowledgeBase.getConf(), this.getClass(), trainingParameters); + return new TemporaryKFold<>(RecommendationMetrics.class).validate(trainingData, k, dbName, knowledgeBase.getConf(), this.getClass(), trainingParameters); } } diff --git a/datumbox-framework-core/src/main/java/com/datumbox/framework/core/machinelearning/common/abstracts/modelers/AbstractRegressor.java b/datumbox-framework-core/src/main/java/com/datumbox/framework/core/machinelearning/common/abstracts/modelers/AbstractRegressor.java index 93c5bcdb..fcefcf71 100755 --- a/datumbox-framework-core/src/main/java/com/datumbox/framework/core/machinelearning/common/abstracts/modelers/AbstractRegressor.java +++ b/datumbox-framework-core/src/main/java/com/datumbox/framework/core/machinelearning/common/abstracts/modelers/AbstractRegressor.java @@ -17,8 +17,8 @@ import com.datumbox.framework.common.Configuration; import com.datumbox.framework.common.dataobjects.Dataframe; -import com.datumbox.framework.core.machinelearning.validators.LinearRegressionValidator; -import com.datumbox.framework.core.machinelearning.validators.TemporaryKFold; +import com.datumbox.framework.core.machinelearning.modelselection.metrics.LinearRegressionMetrics; +import com.datumbox.framework.core.machinelearning.modelselection.splitters.TemporaryKFold; /** * Base Class for all the Regression algorithms. @@ -41,19 +41,19 @@ protected AbstractRegressor(String dbName, Configuration conf, Class mpClass } //TODO: remove this once we create the save/load - public LinearRegressionValidator.ValidationMetrics validate(Dataframe testingData) { + public LinearRegressionMetrics validate(Dataframe testingData) { logger.info("validate()"); knowledgeBase.load(); predict(testingData); - return new LinearRegressionValidator().validate(testingData); + return new LinearRegressionMetrics(testingData); } //TODO: remove this once we create the save/load - public LinearRegressionValidator.ValidationMetrics kFoldCrossValidation(Dataframe trainingData, TP trainingParameters, int k) { - logger.info("kFoldCrossValidation()"); + public LinearRegressionMetrics kFoldCrossValidation(Dataframe trainingData, TP trainingParameters, int k) { + logger.info("validate()"); - return new TemporaryKFold<>(new LinearRegressionValidator()).kFoldCrossValidation(trainingData, k, dbName, knowledgeBase.getConf(), this.getClass(), trainingParameters); + return new TemporaryKFold<>(LinearRegressionMetrics.class).validate(trainingData, k, dbName, knowledgeBase.getConf(), this.getClass(), trainingParameters); } } \ No newline at end of file diff --git a/datumbox-framework-core/src/main/java/com/datumbox/framework/core/machinelearning/common/abstracts/validators/AbstractValidator.java b/datumbox-framework-core/src/main/java/com/datumbox/framework/core/machinelearning/common/abstracts/validators/AbstractMetrics.java similarity index 56% rename from datumbox-framework-core/src/main/java/com/datumbox/framework/core/machinelearning/common/abstracts/validators/AbstractValidator.java rename to datumbox-framework-core/src/main/java/com/datumbox/framework/core/machinelearning/common/abstracts/validators/AbstractMetrics.java index b3af32c5..1cf08b61 100755 --- a/datumbox-framework-core/src/main/java/com/datumbox/framework/core/machinelearning/common/abstracts/validators/AbstractValidator.java +++ b/datumbox-framework-core/src/main/java/com/datumbox/framework/core/machinelearning/common/abstracts/validators/AbstractMetrics.java @@ -17,50 +17,33 @@ import com.datumbox.framework.common.dataobjects.Dataframe; import com.datumbox.framework.core.machinelearning.common.interfaces.ValidationMetrics; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; import java.util.List; /** - * The AbstractValidator class is an abstract class responsible for the K-fold Cross - Validation and for the estimation of the average validation metrics. Given that - * different models use different validation metrics, each model family implements - * its own validator. + * The AbstractMetrics class stores and estimates information about the performance of the algorithm. * * @author Vasilis Vryniotis - * @param */ -public abstract class AbstractValidator { - - /** - * The Logger of all Validators. - * We want this to be non-static in order to print the names of the inherited classes. - */ - protected final Logger logger = LoggerFactory.getLogger(getClass()); - - /** - * The AbstractValidationMetrics class stores information about the performance of the algorithm. - */ - public static abstract class AbstractValidationMetrics implements ValidationMetrics { - - } +public abstract class AbstractMetrics implements ValidationMetrics { /** * Estimates the validation metrics on the predicted data. * * @param predictedData - * @return */ - public abstract VM validate(Dataframe predictedData); + protected AbstractMetrics(Dataframe predictedData) { + + } /** * Calculates the average validation metrics by combining the results of the * provided list. * * @param validationMetricsList - * @return */ - public abstract VM average(List validationMetricsList); + protected AbstractMetrics(List validationMetricsList) { + + } } diff --git a/datumbox-framework-core/src/main/java/com/datumbox/framework/core/machinelearning/common/interfaces/ValidationMetrics.java b/datumbox-framework-core/src/main/java/com/datumbox/framework/core/machinelearning/common/interfaces/ValidationMetrics.java index 0ce7eeec..2c168f91 100755 --- a/datumbox-framework-core/src/main/java/com/datumbox/framework/core/machinelearning/common/interfaces/ValidationMetrics.java +++ b/datumbox-framework-core/src/main/java/com/datumbox/framework/core/machinelearning/common/interfaces/ValidationMetrics.java @@ -15,11 +15,14 @@ */ package com.datumbox.framework.core.machinelearning.common.interfaces; +import com.datumbox.framework.common.dataobjects.Dataframe; + import java.io.Serializable; import java.lang.reflect.InvocationTargetException; +import java.util.List; /** - * Interface for every ValidationMetrics class in the framework. + * Interface for every Metrics class in the framework. * * @author Vasilis Vryniotis */ @@ -27,17 +30,51 @@ public interface ValidationMetrics extends Serializable { /** - * This method allows us to create a new empty Validation Metrics object - * from an existing object. Casting to the appropriate type is required. - * - * @return + * Creates a new empty Validation Metrics object. + * + * @return + */ + public static VM newInstance(Class vmClass) { + try { + return vmClass.getConstructor().newInstance(); + } + catch (InstantiationException | IllegalAccessException | IllegalArgumentException | InvocationTargetException | NoSuchMethodException | SecurityException ex) { + throw new RuntimeException(ex); + } + } + + /** + * Estimates the Validation Metrics object from predictions. + * + * @param vmClass + * @param predictedData + * @param + * @return */ - default public ValidationMetrics getEmptyObject() { + public static VM newInstance(Class vmClass, Dataframe predictedData) { try { - return this.getClass().getConstructor().newInstance(); - } + return vmClass.getConstructor(Dataframe.class).newInstance(predictedData); + } catch (InstantiationException | IllegalAccessException | IllegalArgumentException | InvocationTargetException | NoSuchMethodException | SecurityException ex) { throw new RuntimeException(ex); } - } + } + + /** + * Estimates the average Validation Metrics object from a list of metrics. + * + * @param vmClass + * @param validationMetricsList + * @param + * @return + */ + public static VM newInstance(Class vmClass, List validationMetricsList) { + try { + return vmClass.getConstructor(List.class).newInstance(validationMetricsList); + } + catch (InstantiationException | IllegalAccessException | IllegalArgumentException | InvocationTargetException | NoSuchMethodException | SecurityException ex) { + throw new RuntimeException(ex); + } + } + } diff --git a/datumbox-framework-core/src/main/java/com/datumbox/framework/core/machinelearning/modelselection/metrics/ClassificationMetrics.java b/datumbox-framework-core/src/main/java/com/datumbox/framework/core/machinelearning/modelselection/metrics/ClassificationMetrics.java new file mode 100755 index 00000000..4dfa7240 --- /dev/null +++ b/datumbox-framework-core/src/main/java/com/datumbox/framework/core/machinelearning/modelselection/metrics/ClassificationMetrics.java @@ -0,0 +1,291 @@ +/** + * Copyright (C) 2013-2016 Vasilis Vryniotis + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.datumbox.framework.core.machinelearning.modelselection.metrics; + +import com.datumbox.framework.common.dataobjects.Dataframe; +import com.datumbox.framework.common.dataobjects.Record; +import com.datumbox.framework.core.machinelearning.common.abstracts.validators.AbstractMetrics; + +import java.util.*; + +/** + * Estimates validation metrics for Classifiers. + * + * @author Vasilis Vryniotis + */ +public class ClassificationMetrics extends AbstractMetrics { + + /** + * Enum that stores the 4 possible Sensitivity Rates. + */ + public enum SensitivityRates { + /** + * True Positive. + */ + TRUE_POSITIVE, + + /** + * True Negative. + */ + TRUE_NEGATIVE, + + /** + * False Positive. + */ + FALSE_POSITIVE, + + /** + * False Negative. + */ + FALSE_NEGATIVE; + } + + //validation metrics + private double accuracy = 0.0; + + private double macroPrecision = 0.0; + private double macroRecall = 0.0; + private double macroF1 = 0.0; + + private Map microPrecision = new HashMap<>(); //this is small. Size equal to 4*class numbers + + private Map microRecall = new HashMap<>(); //this is small. Size equal to 4*class numbers + + private Map microF1 = new HashMap<>(); //this is small. Size equal to 4*class numbers + + private Map, Double> contingencyTable = new HashMap<>(); //this is small. Size equal to 4*class numbers + + /** + * Getter for Accuracy. + * + * @return + */ + public double getAccuracy() { + return accuracy; + } + + /** + * Getter for Macro Precision. + * + * @return + */ + public double getMacroPrecision() { + return macroPrecision; + } + + /** + * Getter for Macro Recall. + * + * @return + */ + public double getMacroRecall() { + return macroRecall; + } + + /** + * Getter for Macro F1. + * + * @return + */ + public double getMacroF1() { + return macroF1; + } + + /** + * Getter for Micro Precision. + * + * @return + */ + public Map getMicroPrecision() { + return microPrecision; + } + + /** + * Getter for Micro Recall. + * + * @return + */ + public Map getMicroRecall() { + return microRecall; + } + + /** + * Getter for Micro F1. + * + * @return + */ + public Map getMicroF1() { + return microF1; + } + + /** + * Getter for Contingency Table. + * + * @return + */ + public Map, Double> getContingencyTable() { + return contingencyTable; + } + + /** + * @param predictedData + * @see AbstractMetrics#AbstractMetrics(Dataframe) + */ + public ClassificationMetrics(Dataframe predictedData) { + super(predictedData); + + //retrieve the classes from the dataset + Set classesSet = new HashSet<>(); + for(Record r : predictedData) { + classesSet.add(r.getY()); + classesSet.add(r.getYPredicted()); + } + + for(Object theClass : classesSet) { + contingencyTable.put(Arrays.asList(theClass, SensitivityRates.TRUE_POSITIVE), 0.0); //true possitive + contingencyTable.put(Arrays.asList(theClass, SensitivityRates.FALSE_POSITIVE), 0.0); //false possitive + contingencyTable.put(Arrays.asList(theClass, SensitivityRates.TRUE_NEGATIVE), 0.0); //true negative + contingencyTable.put(Arrays.asList(theClass, SensitivityRates.FALSE_NEGATIVE), 0.0); //false negative + } + + int n = predictedData.size(); + int c = classesSet.size(); + + int correctCount=0; + for(Record r : predictedData) { + Object yPred = r.getYPredicted(); + if(yPred.equals(r.getY())) { + ++correctCount; + + for(Object cl : classesSet) { + if(cl.equals(yPred)) { + List tpk = Arrays.asList(cl, SensitivityRates.TRUE_POSITIVE); + contingencyTable.put(tpk, contingencyTable.get(tpk) + 1.0); + } + else { + List tpk = Arrays.asList(cl, SensitivityRates.TRUE_NEGATIVE); + contingencyTable.put(tpk, contingencyTable.get(tpk) + 1.0); + } + } + } + else { + for(Object cl : classesSet) { + if(cl.equals(yPred)) { + List tpk = Arrays.asList(cl, SensitivityRates.FALSE_POSITIVE); + contingencyTable.put(tpk, contingencyTable.get(tpk) + 1.0); + } + else if(cl.equals(r.getY())) { + List tpk = Arrays.asList(cl, SensitivityRates.FALSE_NEGATIVE); + contingencyTable.put(tpk, contingencyTable.get(tpk) + 1.0); + } + else { + List tpk = Arrays.asList(cl, SensitivityRates.TRUE_NEGATIVE); + contingencyTable.put(tpk, contingencyTable.get(tpk) + 1.0); + } + } + } + } + + accuracy = correctCount/(double)n; + + //Average Precision, Recall and F1: http://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.104.8244&rep=rep1&type=pdf + int activeClasses = c; + for(Object theClass : classesSet) { + double tp = contingencyTable.get(Arrays.asList(theClass, SensitivityRates.TRUE_POSITIVE)); + double fp = contingencyTable.get(Arrays.asList(theClass, SensitivityRates.FALSE_POSITIVE)); + double fn = contingencyTable.get(Arrays.asList(theClass, SensitivityRates.FALSE_NEGATIVE)); + + + double classPrecision=0.0; + double classRecall=0.0; + double classF1=0.0; + if(tp>0.0) { + classPrecision = tp/(tp+fp); + classRecall = tp/(tp+fn); + classF1 = 2.0*classPrecision*classRecall/(classPrecision+classRecall); + } + else if(tp==0.0 && fp==0.0 && fn==0.0) { + //if this category did not appear in the dataset reduce the number of classes + --activeClasses; + } + + + microPrecision.put(theClass, classPrecision); + microRecall.put(theClass, classRecall); + microF1.put(theClass, classF1); + + macroPrecision += classPrecision; + macroRecall += classRecall; + macroF1 += classF1; + } + + macroPrecision /= activeClasses; + macroRecall /= activeClasses; + macroF1 /= activeClasses; + } + + /** + * @param validationMetricsList + * @see AbstractMetrics#AbstractMetrics(List) + */ + public ClassificationMetrics(List validationMetricsList) { + super(validationMetricsList); + + int k = validationMetricsList.size(); //number of samples + + for(ClassificationMetrics vmSample : validationMetricsList) { + + //fetch the classes from the keys of one of the micro metrics. This way if a class is not included in a fold, we don't get null exceptions + Set classesSet = vmSample.getMicroPrecision().keySet(); + + for(Object theClass : classesSet) { + + Map, Double> ctEntryMap = vmSample.getContingencyTable(); + + //get the values of all SensitivityRates and average them + for(SensitivityRates sr : SensitivityRates.values()) { + List tpk = Arrays.asList(theClass, sr); + + Double previousValue = contingencyTable.get(tpk); + if(previousValue==null) { + previousValue=0.0; + } + + contingencyTable.put(tpk, previousValue + ctEntryMap.get(tpk)/k); + } + + //update micro metrics of class + Double previousPrecision = microPrecision.getOrDefault(theClass, 0.0); + microPrecision.put(theClass, previousPrecision + vmSample.getMicroPrecision().get(theClass)/k); + + + Double previousRecall = microRecall.getOrDefault(theClass, 0.0); + microRecall.put(theClass, previousRecall + vmSample.getMicroRecall().get(theClass)/k); + + + Double previousF1 = microF1.getOrDefault(theClass, 0.0); + microF1.put(theClass, previousF1 + vmSample.getMicroF1().get(theClass)/k); + + } + + //update macro metrics + accuracy += vmSample.getAccuracy()/k; + macroPrecision += vmSample.getMacroPrecision()/k; + macroRecall += vmSample.getMacroRecall()/k; + macroF1 += vmSample.getMacroF1()/k; + } + } +} diff --git a/datumbox-framework-core/src/main/java/com/datumbox/framework/core/machinelearning/modelselection/metrics/ClusteringMetrics.java b/datumbox-framework-core/src/main/java/com/datumbox/framework/core/machinelearning/modelselection/metrics/ClusteringMetrics.java new file mode 100755 index 00000000..2f1a104e --- /dev/null +++ b/datumbox-framework-core/src/main/java/com/datumbox/framework/core/machinelearning/modelselection/metrics/ClusteringMetrics.java @@ -0,0 +1,171 @@ +/** + * Copyright (C) 2013-2016 Vasilis Vryniotis + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.datumbox.framework.core.machinelearning.modelselection.metrics; + +import com.datumbox.framework.common.dataobjects.Dataframe; +import com.datumbox.framework.common.dataobjects.Record; +import com.datumbox.framework.core.machinelearning.common.abstracts.validators.AbstractMetrics; + +import java.util.*; + +/** + * Estimates validation metrics for Clustering models. + * + * References: + * http://nlp.stanford.edu/IR-book/html/htmledition/evaluation-of-clustering-1.html + * http://thesis.neminis.org/wp-content/plugins/downloads-manager/upload/masterThesis-VR.pdf + * + * @author Vasilis Vryniotis + */ +public class ClusteringMetrics extends AbstractMetrics { + + private double purity = 0.0; + private double NMI = 0.0; //Normalized Mutual Information: I(Omega,Gama) calculation + + /** + * Getter for Purity. + * + * @return + */ + public Double getPurity() { + return purity; + } + + /** + * Getter for NMI. + * + * @return + */ + public Double getNMI() { + return NMI; + } + + /** + * @param predictedData + * @see AbstractMetrics#AbstractMetrics(Dataframe) + */ + public ClusteringMetrics(Dataframe predictedData) { + super(predictedData); + + int n = predictedData.size(); + + Set clusterIdSet = new HashSet<>(); + Set goldStandardClassesSet = new HashSet<>(); + for(Record r : predictedData) { + Object y = r.getY(); + if(y != null) { + goldStandardClassesSet.add(y); + } + clusterIdSet.add(r.getYPredicted()); + } + + if(!goldStandardClassesSet.isEmpty()) { + //We don't store the Contingency Table because we can't average it with + //k-cross fold validation. Each clustering produces a different number + //of clusters and thus different enumeration. Thus averaging the results + //is impossible and that is why we don't store it in the validation object. + + //List = [Clusterid,GoldStandardClass] + Map, Double> ctMap = new HashMap<>(); + + //frequency tables + Map countOfW = new HashMap<>(); //this is small equal to number of clusters + Map countOfC = new HashMap<>(); //this is small equal to number of classes + + //initialize the tables with zeros + for(Object clusterId : clusterIdSet) { + countOfW.put(clusterId, 0.0); + for(Object theClass : goldStandardClassesSet) { + ctMap.put(Arrays.asList(clusterId, theClass), 0.0); + + countOfC.put(theClass, 0.0); + } + } + + //count the co-occurrences of ClusterId-GoldStanardClass + for(Record r : predictedData) { + Object clusterId = r.getYPredicted(); //fetch cluster assignment + Object goldStandardClass = r.getY(); //the original class of the objervation + List tpk = Arrays.asList(clusterId, goldStandardClass); + ctMap.put(tpk, ctMap.get(tpk) + 1.0); + + //update cluster and class counts + countOfW.put(clusterId, countOfW.get(clusterId)+1.0); + countOfC.put(goldStandardClass, countOfC.get(goldStandardClass)+1.0); + } + + double logN = Math.log((double)n); + double Iwc=0.0; //http://nlp.stanford.edu/IR-book/html/htmledition/evaluation-of-clustering-1.html + for(Object clusterId : clusterIdSet) { + double maxCounts=Double.NEGATIVE_INFINITY; + + //loop through the possible classes and find the most popular one + for(Object goldStandardClass : goldStandardClassesSet) { + List tpk = Arrays.asList(clusterId, goldStandardClass); + double Nwc = ctMap.get(tpk); + if(Nwc>maxCounts) { + maxCounts=Nwc; + } + + if(Nwc>0) { + Iwc+= (Nwc/n)*(Math.log(Nwc) -Math.log(countOfC.get(goldStandardClass)) + -Math.log(countOfW.get(clusterId)) + logN); + } + } + purity += maxCounts; + } + //ctMap = null; + + double entropyW=0.0; + for(Double Nw : countOfW.values()) { + entropyW-=(Nw/n)*(Math.log(Nw)-logN); + } + + double entropyC=0.0; + for(Double Nc : countOfW.values()) { + entropyC-=(Nc/n)*(Math.log(Nc)-logN); + } + + purity /= n; + NMI = Iwc/((entropyW+entropyC)/2.0); + } + } + + /** + * @param validationMetricsList + * @see AbstractMetrics#AbstractMetrics(List) + */ + public ClusteringMetrics(List validationMetricsList) { + super(validationMetricsList); + + //estimate average values + int k = 0; + for(ClusteringMetrics vmSample : validationMetricsList) { + if(vmSample.getNMI()==null) { //it is null when we don't have goldStandardClass information + continue; + } + + NMI += vmSample.getNMI(); + purity += vmSample.getPurity(); + k++; + } + + if(k>0) { + NMI /= k; + purity /= k; + } + } +} diff --git a/datumbox-framework-core/src/main/java/com/datumbox/framework/core/machinelearning/modelselection/metrics/LinearRegressionMetrics.java b/datumbox-framework-core/src/main/java/com/datumbox/framework/core/machinelearning/modelselection/metrics/LinearRegressionMetrics.java new file mode 100755 index 00000000..377318ab --- /dev/null +++ b/datumbox-framework-core/src/main/java/com/datumbox/framework/core/machinelearning/modelselection/metrics/LinearRegressionMetrics.java @@ -0,0 +1,253 @@ +/** + * Copyright (C) 2013-2016 Vasilis Vryniotis + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.datumbox.framework.core.machinelearning.modelselection.metrics; + +import com.datumbox.framework.common.dataobjects.Dataframe; +import com.datumbox.framework.common.dataobjects.FlatDataList; +import com.datumbox.framework.common.dataobjects.Record; +import com.datumbox.framework.common.dataobjects.TypeInference; +import com.datumbox.framework.core.machinelearning.common.abstracts.validators.AbstractMetrics; +import com.datumbox.framework.core.statistics.distributions.ContinuousDistributions; +import com.datumbox.framework.core.statistics.nonparametrics.onesample.Lilliefors; +import com.datumbox.framework.core.statistics.parametrics.onesample.DurbinWatson; + +import java.util.List; + +/** + * Validation class for Linear Regression. + * + * @author Vasilis Vryniotis + */ +public class LinearRegressionMetrics extends AbstractMetrics { + + private double RSquare = 0.0; + private double RSquareAdjusted = 0.0; + private double SSE = 0.0; + private double SSR = 0.0; + private double SST = 0.0; + private double dfRegression = 0.0; + private double dfResidual = 0.0; + private double dfTotal = 0.0; + private double F = 0.0; + private double FPValue = 0.0; + private Double StdErrorOfEstimate = 0.0; //this can have null value if dfResidual is 0 + private double DW = 0.0; //Durbin–Watson statistic + private double NormalResiduals = 0.0; //Test on whether the residuals can be considered Normal + + /** + * Getter for the R Square. + * + * @return + */ + public double getRSquare() { + return RSquare; + } + + /** + * Getter for the R Square Adjusted. + * + * @return + */ + public double getRSquareAdjusted() { + return RSquareAdjusted; + } + + /** + * Getter for the Sum of Squared Errors. + * + * @return + */ + public double getSSE() { + return SSE; + } + + /** + * Getter for the Sum of Squared due to Regression. + * + * @return + */ + public double getSSR() { + return SSR; + } + + /** + * Getter for the Sum of Squared Total. + * + * @return + */ + public double getSST() { + return SST; + } + + /** + * Getter for the degrees of freedom of Regression. + * + * @return + */ + public double getDfRegression() { + return dfRegression; + } + + /** + * Getter for the degrees of freedom of Residual. + * + * @return + */ + public double getDfResidual() { + return dfResidual; + } + + /** + * Getter for the degrees of freedom of Total. + * + * @return + */ + public double getDfTotal() { + return dfTotal; + } + + /** + * Getter for F score. + * + * @return + */ + public double getF() { + return F; + } + + /** + * Getter for F p-value. + * + * @return + */ + public double getFPValue() { + return FPValue; + } + + /** + * Getter for Standard Error of Estimate. + * + * @return + */ + public Double getStdErrorOfEstimate() { + return StdErrorOfEstimate; + } + + /** + * Getter of Durbin Watson statistic. + * + * @return + */ + public double getDW() { + return DW; + } + + /** + * Getter for Normal Residuals. + * + * @return + */ + public double getNormalResiduals() { + return NormalResiduals; + } + + /** + * @param predictedData + * @see AbstractMetrics#AbstractMetrics(Dataframe) + */ + public LinearRegressionMetrics(Dataframe predictedData) { + super(predictedData); + + int n = predictedData.size(); + + FlatDataList errorList = new FlatDataList(); + double Ybar = 0.0; + for(Record r : predictedData) { + Ybar += TypeInference.toDouble(r.getY())/n; + errorList.add(TypeInference.toDouble(r.getY())-TypeInference.toDouble(r.getYPredicted())); + } + + DW = DurbinWatson.calculateScore(errorList); + + for(Record r : predictedData) { + SSE += Math.pow(TypeInference.toDouble(r.getY())-TypeInference.toDouble(r.getYPredicted()), 2.0); + } + + boolean normalResiduals = Lilliefors.test(errorList.toFlatDataCollection(), "normalDistribution", 0.05); + NormalResiduals = (normalResiduals)?0.0:1.0; //if the Lilliefors validate rejects the H0 means that the normality hypothesis is rejected thus the residuals are not normal + //errorList = null; + + for(Record r : predictedData) { + SSR += Math.pow(TypeInference.toDouble(r.getY()) - Ybar, 2); + } + + SST = SSR+SSE; + RSquare = SSR/SST; + + int d = predictedData.xColumnSize()+1;//add one for the constant + int p = d - 1; //exclude constant + + RSquareAdjusted = 1.0 - ((n-1.0)/(n-p-1.0))*(1.0-RSquare); + + //degrees of freedom + dfTotal = n-1.0; + dfRegression = d-1.0; + dfResidual = Math.max(n-d, 0.0); + + F = (SSR/dfRegression)/(SSE/dfResidual); + + FPValue = 1.0; + if(n>d) { + FPValue = ContinuousDistributions.fCdf(F, (int)dfRegression, (int)dfResidual); + } + + StdErrorOfEstimate = null; + if(dfResidual > 0.0) { + StdErrorOfEstimate = Math.sqrt(SSE/dfResidual); + } + } + + /** + * @param validationMetricsList + * @see AbstractMetrics#AbstractMetrics(List) + */ + public LinearRegressionMetrics(List validationMetricsList) { + super(validationMetricsList); + + if(!validationMetricsList.isEmpty()) { + int k = validationMetricsList.size(); //number of samples + for (LinearRegressionMetrics vmSample : validationMetricsList) { + RSquare += vmSample.getRSquare() / k; + RSquareAdjusted += vmSample.getRSquareAdjusted() / k; + SSE += vmSample.getSSE() / k; + SSR += vmSample.getSSR() / k; + SST += vmSample.getSST() / k; + dfRegression += vmSample.getDfRegression() / k; + dfResidual += vmSample.getDfResidual() / k; + dfTotal += vmSample.getDfTotal() / k; + F += vmSample.getF() / k; + FPValue += vmSample.getFPValue() / k; + Double stdErrorOfEstimate = vmSample.getStdErrorOfEstimate(); + if (stdErrorOfEstimate == null) { + stdErrorOfEstimate = 0.0; + } + StdErrorOfEstimate += stdErrorOfEstimate / k; + DW += vmSample.getDW() / k; + NormalResiduals += vmSample.getNormalResiduals() / k; //percentage of samples that found the residuals to be normal + } + } + } +} diff --git a/datumbox-framework-core/src/main/java/com/datumbox/framework/core/machinelearning/modelselection/metrics/RecommendationMetrics.java b/datumbox-framework-core/src/main/java/com/datumbox/framework/core/machinelearning/modelselection/metrics/RecommendationMetrics.java new file mode 100644 index 00000000..b36f2478 --- /dev/null +++ b/datumbox-framework-core/src/main/java/com/datumbox/framework/core/machinelearning/modelselection/metrics/RecommendationMetrics.java @@ -0,0 +1,80 @@ +/** + * Copyright (C) 2013-2016 Vasilis Vryniotis + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.datumbox.framework.core.machinelearning.modelselection.metrics; + +import com.datumbox.framework.common.dataobjects.AssociativeArray; +import com.datumbox.framework.common.dataobjects.Dataframe; +import com.datumbox.framework.common.dataobjects.Record; +import com.datumbox.framework.common.dataobjects.TypeInference; +import com.datumbox.framework.core.machinelearning.common.abstracts.validators.AbstractMetrics; + +import java.util.*; + +/** + * Estimates RMSE for Recommendation models. + * + * @author Vasilis Vryniotis + */ +public class RecommendationMetrics extends AbstractMetrics { + + private double RMSE = 0.0; + + /** + * Getter for RMSE. + * + * @return + */ + public double getRMSE() { + return RMSE; + } + + + /** + * @param predictedData + * @see AbstractMetrics#AbstractMetrics(Dataframe) + */ + public RecommendationMetrics(Dataframe predictedData) { + super(predictedData); + + int i = 0; + for(Record r : predictedData) { + AssociativeArray predictions = r.getYPredictedProbabilities(); + for(Map.Entry entry : r.getX().entrySet()) { + Object column = entry.getKey(); + Object value = entry.getValue(); + RMSE += Math.pow(TypeInference.toDouble(value)-TypeInference.toDouble(predictions.get(column)), 2.0); + ++i; + } + } + + RMSE = Math.sqrt(RMSE/i); + } + + /** + * @param validationMetricsList + * @see AbstractMetrics#AbstractMetrics(List) + */ + public RecommendationMetrics(List validationMetricsList) { + super(validationMetricsList); + + if(!validationMetricsList.isEmpty()) { + int k = validationMetricsList.size(); //number of samples + for(RecommendationMetrics vmSample : validationMetricsList) { + RMSE += vmSample.getRMSE()/k; + } + } + } +} diff --git a/datumbox-framework-core/src/main/java/com/datumbox/framework/core/machinelearning/validators/TemporaryKFold.java b/datumbox-framework-core/src/main/java/com/datumbox/framework/core/machinelearning/modelselection/splitters/TemporaryKFold.java similarity index 81% rename from datumbox-framework-core/src/main/java/com/datumbox/framework/core/machinelearning/validators/TemporaryKFold.java rename to datumbox-framework-core/src/main/java/com/datumbox/framework/core/machinelearning/modelselection/splitters/TemporaryKFold.java index 4d71b4a6..504ae169 100644 --- a/datumbox-framework-core/src/main/java/com/datumbox/framework/core/machinelearning/validators/TemporaryKFold.java +++ b/datumbox-framework-core/src/main/java/com/datumbox/framework/core/machinelearning/modelselection/splitters/TemporaryKFold.java @@ -13,7 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package com.datumbox.framework.core.machinelearning.validators; +package com.datumbox.framework.core.machinelearning.modelselection.splitters; import com.datumbox.framework.common.Configuration; import com.datumbox.framework.common.dataobjects.Dataframe; @@ -22,7 +22,7 @@ import com.datumbox.framework.common.utilities.PHPMethods; import com.datumbox.framework.core.machinelearning.common.abstracts.AbstractTrainer; import com.datumbox.framework.core.machinelearning.common.abstracts.modelers.AbstractModeler; -import com.datumbox.framework.core.machinelearning.common.abstracts.validators.AbstractValidator; +import com.datumbox.framework.core.machinelearning.common.interfaces.ValidationMetrics; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -30,7 +30,7 @@ import java.util.LinkedList; import java.util.List; -public class TemporaryKFold { +public class TemporaryKFold { //TODO: remove this temporary class and create a permanent solution @@ -42,10 +42,15 @@ public class TemporaryKFold validator; + private final Class vmClass; - public TemporaryKFold(AbstractValidator validator) { - this.validator = validator; + /** + * The constructor of the Splitter. + * + * @param vmClass + */ + public TemporaryKFold(Class vmClass) { + this.vmClass = vmClass; } /** @@ -60,7 +65,7 @@ public TemporaryKFold(AbstractValidator validator) { * @param trainingParameters * @return */ - public VM kFoldCrossValidation(Dataframe dataset, int k, String dbName, Configuration conf, Class aClass, AbstractTrainer.AbstractTrainingParameters trainingParameters) { + public VM validate(Dataframe dataset, int k, String dbName, Configuration conf, Class aClass, AbstractTrainer.AbstractTrainingParameters trainingParameters) { int n = dataset.size(); if(k<=0 || n<=k) { throw new IllegalArgumentException("Invalid number of folds."); @@ -114,11 +119,11 @@ public VM kFoldCrossValidation(Dataframe dataset, int k, String dbName, Configur //initialize modeler - AbstractModeler modeler = Trainable.newInstance((Class)aClass, foldDBname+(fold+1), conf); + AbstractModeler modeler = Trainable.newInstance((Class)aClass, foldDBname+(fold+1), conf); Dataframe trainingData = dataset.getSubset(foldTrainingIds); - modeler.fit(trainingData, (AbstractTrainer.AbstractTrainingParameters) trainingParameters); + modeler.fit(trainingData, trainingParameters); trainingData.delete(); //trainingData = null; @@ -128,9 +133,7 @@ public VM kFoldCrossValidation(Dataframe dataset, int k, String dbName, Configur //fetch validation metrics modeler.predict(validationData); - - - VM entrySample = validator.validate(validationData); + VM entrySample = ValidationMetrics.newInstance(vmClass, validationData); validationData.delete(); //validationData = null; @@ -142,7 +145,7 @@ public VM kFoldCrossValidation(Dataframe dataset, int k, String dbName, Configur validationMetricsList.add(entrySample); } - VM avgValidationMetrics = validator.average(validationMetricsList); + VM avgValidationMetrics = ValidationMetrics.newInstance(vmClass, validationMetricsList); return avgValidationMetrics; } diff --git a/datumbox-framework-core/src/main/java/com/datumbox/framework/core/machinelearning/validators/ClassifierValidator.java b/datumbox-framework-core/src/main/java/com/datumbox/framework/core/machinelearning/validators/ClassifierValidator.java deleted file mode 100755 index 4915ca01..00000000 --- a/datumbox-framework-core/src/main/java/com/datumbox/framework/core/machinelearning/validators/ClassifierValidator.java +++ /dev/null @@ -1,383 +0,0 @@ -/** - * Copyright (C) 2013-2016 Vasilis Vryniotis - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package com.datumbox.framework.core.machinelearning.validators; - -import com.datumbox.framework.common.dataobjects.Dataframe; -import com.datumbox.framework.common.dataobjects.Record; -import com.datumbox.framework.core.machinelearning.common.abstracts.validators.AbstractValidator; - -import java.util.*; - -/** - * Estimates validation metrics for Classifiers. - * - * @author Vasilis Vryniotis - */ -public class ClassifierValidator extends AbstractValidator { - - /** - * Enum that stores the 4 possible Sensitivity Rates. - */ - public enum SensitivityRates { - /** - * True Positive. - */ - TRUE_POSITIVE, - - /** - * True Negative. - */ - TRUE_NEGATIVE, - - /** - * False Positive. - */ - FALSE_POSITIVE, - - /** - * False Negative. - */ - FALSE_NEGATIVE; - } - - /** {@inheritDoc} */ - public static class ValidationMetrics extends AbstractValidator.AbstractValidationMetrics { - - //validation metrics - private double accuracy = 0.0; - - private double macroPrecision = 0.0; - private double macroRecall = 0.0; - private double macroF1 = 0.0; - - private Map microPrecision = new HashMap<>(); //this is small. Size equal to 4*class numbers - - private Map microRecall = new HashMap<>(); //this is small. Size equal to 4*class numbers - - private Map microF1 = new HashMap<>(); //this is small. Size equal to 4*class numbers - - private Map, Double> ContingencyTable = new HashMap<>(); //this is small. Size equal to 4*class numbers - - /** - * Getter for Accuracy. - * - * @return - */ - public double getAccuracy() { - return accuracy; - } - - /** - * Setter for Accuracy. - * - * @param accuracy - */ - public void setAccuracy(double accuracy) { - this.accuracy = accuracy; - } - - /** - * Getter for Macro Precision. - * - * @return - */ - public double getMacroPrecision() { - return macroPrecision; - } - - /** - * Setter for Macro Precision. - * - * @param macroPrecision - */ - public void setMacroPrecision(double macroPrecision) { - this.macroPrecision = macroPrecision; - } - - /** - * Getter for Macro Recall. - * - * @return - */ - public double getMacroRecall() { - return macroRecall; - } - - /** - * Setter for Macro Recall. - * - * @param macroRecall - */ - public void setMacroRecall(double macroRecall) { - this.macroRecall = macroRecall; - } - - /** - * Getter for Macro F1. - * - * @return - */ - public double getMacroF1() { - return macroF1; - } - - /** - * Setter for Macro F1. - * - * @param macroF1 - */ - public void setMacroF1(double macroF1) { - this.macroF1 = macroF1; - } - - /** - * Getter for Micro Precision. - * - * @return - */ - public Map getMicroPrecision() { - return microPrecision; - } - - /** - * Setter for Micro Precision. - * - * @param microPrecision - */ - public void setMicroPrecision(Map microPrecision) { - this.microPrecision = microPrecision; - } - - /** - * Getter for Micro Recall. - * - * @return - */ - public Map getMicroRecall() { - return microRecall; - } - - /** - * Setter for Micro Recall. - * - * @param microRecall - */ - public void setMicroRecall(Map microRecall) { - this.microRecall = microRecall; - } - - /** - * Getter for Micro F1. - * - * @return - */ - public Map getMicroF1() { - return microF1; - } - - /** - * Setter for Micro F1. - * - * @param microF1 - */ - public void setMicroF1(Map microF1) { - this.microF1 = microF1; - } - - /** - * Getter for Contingency Table. - * - * @return - */ - public Map, Double> getContingencyTable() { - return ContingencyTable; - } - - /** - * Setter for Contingency Table. - * - * @param ContingencyTable - */ - public void setContingencyTable(Map, Double> ContingencyTable) { - this.ContingencyTable = ContingencyTable; - } - } - - /** {@inheritDoc} */ - @Override - public ValidationMetrics validate(Dataframe predictedData) { - //retrieve the classes from the dataset - Set classesSet = new HashSet<>(); - for(Record r : predictedData) { - classesSet.add(r.getY()); - classesSet.add(r.getYPredicted()); - } - - //create new validation metrics object - ValidationMetrics validationMetrics = new ValidationMetrics(); - - Map, Double> ctMap = validationMetrics.getContingencyTable(); - for(Object theClass : classesSet) { - ctMap.put(Arrays.asList(theClass, SensitivityRates.TRUE_POSITIVE), 0.0); //true possitive - ctMap.put(Arrays.asList(theClass, SensitivityRates.FALSE_POSITIVE), 0.0); //false possitive - ctMap.put(Arrays.asList(theClass, SensitivityRates.TRUE_NEGATIVE), 0.0); //true negative - ctMap.put(Arrays.asList(theClass, SensitivityRates.FALSE_NEGATIVE), 0.0); //false negative - } - - int n = predictedData.size(); - int c = classesSet.size(); - - int correctCount=0; - for(Record r : predictedData) { - if(r.getYPredicted().equals(r.getY())) { - ++correctCount; - - for(Object cl : classesSet) { - if(cl.equals(r.getYPredicted())) { - List tpk = Arrays.asList(cl, SensitivityRates.TRUE_POSITIVE); - ctMap.put(tpk, ctMap.get(tpk) + 1.0); - } - else { - List tpk = Arrays.asList(cl, SensitivityRates.TRUE_NEGATIVE); - ctMap.put(tpk, ctMap.get(tpk) + 1.0); - } - } - } - else { - for(Object cl : classesSet) { - if(cl.equals(r.getYPredicted())) { - List tpk = Arrays.asList(cl, SensitivityRates.FALSE_POSITIVE); - ctMap.put(tpk, ctMap.get(tpk) + 1.0); - } - else if(cl.equals(r.getY())) { - List tpk = Arrays.asList(cl, SensitivityRates.FALSE_NEGATIVE); - ctMap.put(tpk, ctMap.get(tpk) + 1.0); - } - else { - List tpk = Arrays.asList(cl, SensitivityRates.TRUE_NEGATIVE); - ctMap.put(tpk, ctMap.get(tpk) + 1.0); - } - } - } - } - - validationMetrics.setAccuracy(correctCount/(double)n); - - //Average Precision, Recall and F1: http://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.104.8244&rep=rep1&type=pdf - int activeClasses = c; - for(Object theClass : classesSet) { - - - double tp = ctMap.get(Arrays.asList(theClass, SensitivityRates.TRUE_POSITIVE)); - double fp = ctMap.get(Arrays.asList(theClass, SensitivityRates.FALSE_POSITIVE)); - double fn = ctMap.get(Arrays.asList(theClass, SensitivityRates.FALSE_NEGATIVE)); - - - double classPrecision=0.0; - double classRecall=0.0; - double classF1=0.0; - if(tp>0.0) { - classPrecision = tp/(tp+fp); - classRecall = tp/(tp+fn); - classF1 = 2.0*classPrecision*classRecall/(classPrecision+classRecall); - } - else if(tp==0.0 && fp==0.0 && fn==0.0) { - //if this category did not appear in the dataset reduce the number of classes - --activeClasses; - } - - - validationMetrics.getMicroPrecision().put(theClass, classPrecision); - validationMetrics.getMicroRecall().put(theClass, classRecall); - validationMetrics.getMicroF1().put(theClass, classF1); - - validationMetrics.setMacroPrecision(validationMetrics.getMacroPrecision() + classPrecision); - validationMetrics.setMacroRecall(validationMetrics.getMacroRecall() + classRecall); - validationMetrics.setMacroF1(validationMetrics.getMacroF1() + classF1); - } - - validationMetrics.setMacroPrecision(validationMetrics.getMacroPrecision()/activeClasses); - validationMetrics.setMacroRecall(validationMetrics.getMacroRecall()/activeClasses); - validationMetrics.setMacroF1(validationMetrics.getMacroF1()/activeClasses); - - return validationMetrics; - } - - /** {@inheritDoc} */ - @Override - public ValidationMetrics average(List validationMetricsList) { - if(validationMetricsList.isEmpty()) { - return null; - } - - int k = validationMetricsList.size(); //number of samples - - ValidationMetrics avgValidationMetrics = new ValidationMetrics(); - for(ValidationMetrics vmSample : validationMetricsList) { - - //fetch the classes from the keys of one of the micro metrics. This way if a class is not included in a fold, we don't get null exceptions - Set classesSet = vmSample.getMicroPrecision().keySet(); - - for(Object theClass : classesSet) { - - Map, Double> ctEntryMap = vmSample.getContingencyTable(); - - //get the values of all SensitivityRates and average them - for(SensitivityRates sr : SensitivityRates.values()) { - List tpk = Arrays.asList(theClass, sr); - - Double previousValue = avgValidationMetrics.getContingencyTable().get(tpk); - if(previousValue==null) { - previousValue=0.0; - } - - avgValidationMetrics.getContingencyTable().put(tpk, previousValue + ctEntryMap.get(tpk)/k); - } - - //update micro metrics of class - Double previousPrecision = avgValidationMetrics.getMicroPrecision().get(theClass); - if(previousPrecision==null) { - previousPrecision=0.0; - } - avgValidationMetrics.getMicroPrecision().put(theClass, previousPrecision + vmSample.getMicroPrecision().get(theClass)/k); - - - Double previousRecall = avgValidationMetrics.getMicroRecall().get(theClass); - if(previousRecall==null) { - previousRecall=0.0; - } - avgValidationMetrics.getMicroRecall().put(theClass, previousRecall + vmSample.getMicroRecall().get(theClass)/k); - - - Double previousF1 = avgValidationMetrics.getMicroF1().get(theClass); - if(previousF1==null) { - previousF1=0.0; - } - avgValidationMetrics.getMicroF1().put(theClass, previousF1 + vmSample.getMicroF1().get(theClass)/k); - - } - - //update macro metrics - avgValidationMetrics.setAccuracy(avgValidationMetrics.getAccuracy() + vmSample.getAccuracy()/k); - avgValidationMetrics.setMacroPrecision(avgValidationMetrics.getMacroPrecision() + vmSample.getMacroPrecision()/k); - avgValidationMetrics.setMacroRecall(avgValidationMetrics.getMacroRecall() + vmSample.getMacroRecall()/k); - avgValidationMetrics.setMacroF1(avgValidationMetrics.getMacroF1() + vmSample.getMacroF1()/k); - } - - - return avgValidationMetrics; - } -} diff --git a/datumbox-framework-core/src/main/java/com/datumbox/framework/core/machinelearning/validators/ClustererValidator.java b/datumbox-framework-core/src/main/java/com/datumbox/framework/core/machinelearning/validators/ClustererValidator.java deleted file mode 100755 index 36e4bdc8..00000000 --- a/datumbox-framework-core/src/main/java/com/datumbox/framework/core/machinelearning/validators/ClustererValidator.java +++ /dev/null @@ -1,215 +0,0 @@ -/** - * Copyright (C) 2013-2016 Vasilis Vryniotis - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package com.datumbox.framework.core.machinelearning.validators; - -import com.datumbox.framework.common.dataobjects.Dataframe; -import com.datumbox.framework.common.dataobjects.Record; -import com.datumbox.framework.core.machinelearning.common.abstracts.validators.AbstractValidator; - -import java.util.*; - -/** - * Estimates validation metrics for Clustering models. - * - * @author Vasilis Vryniotis - */ -public class ClustererValidator extends AbstractValidator { - - /** - * - * {@inheritDoc} - * - * References: - * http://nlp.stanford.edu/IR-book/html/htmledition/evaluation-of-clustering-1.html - * http://thesis.neminis.org/wp-content/plugins/downloads-manager/upload/masterThesis-VR.pdf - */ - public static class ValidationMetrics extends AbstractValidator.AbstractValidationMetrics { - - private Double purity = null; - private Double NMI = null; //Normalized Mutual Information: I(Omega,Gama) calculation - - /** - * Getter for Purity. - * - * @return - */ - public Double getPurity() { - return purity; - } - - /** - * Setter for Purity. - * - * @param purity - */ - public void setPurity(Double purity) { - this.purity = purity; - } - - /** - * Getter for NMI. - * - * @return - */ - public Double getNMI() { - return NMI; - } - - /** - * Setter for NMI. - * - * @param NMI - */ - public void setNMI(Double NMI) { - this.NMI = NMI; - } - - } - - - /** {@inheritDoc} */ - @Override - public ValidationMetrics validate(Dataframe predictedData) { - int n = predictedData.size(); - - Set clusterIdSet = new HashSet<>(); - Set goldStandardClassesSet = new HashSet<>(); - for(Record r : predictedData) { - Object y = r.getY(); - if(y != null) { - goldStandardClassesSet.add(y); - } - clusterIdSet.add(r.getYPredicted()); - } - - //create new validation metrics object - ValidationMetrics validationMetrics = new ValidationMetrics(); - - if(goldStandardClassesSet.isEmpty()) { - return validationMetrics; - } - - //We don't store the Contingency Table because we can't average it with - //k-cross fold validation. Each clustering produces a different number - //of clusters and thus different enumeration. Thus averaging the results - //is impossible and that is why we don't store it in the validation object. - - //List = [Clusterid,GoldStandardClass] - Map, Double> ctMap = new HashMap<>(); - - //frequency tables - Map countOfW = new HashMap<>(); //this is small equal to number of clusters - Map countOfC = new HashMap<>(); //this is small equal to number of classes - - //initialize the tables with zeros - for(Object clusterId : clusterIdSet) { - countOfW.put(clusterId, 0.0); - for(Object theClass : goldStandardClassesSet) { - ctMap.put(Arrays.asList(clusterId, theClass), 0.0); - - countOfC.put(theClass, 0.0); - } - } - - //count the co-occurrences of ClusterId-GoldStanardClass - for(Record r : predictedData) { - Object clusterId = r.getYPredicted(); //fetch cluster assignment - Object goldStandardClass = r.getY(); //the original class of the objervation - List tpk = Arrays.asList(clusterId, goldStandardClass); - ctMap.put(tpk, ctMap.get(tpk) + 1.0); - - //update cluster and class counts - countOfW.put(clusterId, countOfW.get(clusterId)+1.0); - countOfC.put(goldStandardClass, countOfC.get(goldStandardClass)+1.0); - } - - double logN = Math.log((double)n); - double purity=0.0; - double Iwc=0.0; //http://nlp.stanford.edu/IR-book/html/htmledition/evaluation-of-clustering-1.html - for(Object clusterId : clusterIdSet) { - double maxCounts=Double.NEGATIVE_INFINITY; - - //loop through the possible classes and find the most popular one - for(Object goldStandardClass : goldStandardClassesSet) { - List tpk = Arrays.asList(clusterId, goldStandardClass); - double Nwc = ctMap.get(tpk); - if(Nwc>maxCounts) { - maxCounts=Nwc; - } - - if(Nwc>0) { - Iwc+= (Nwc/n)*(Math.log(Nwc) -Math.log(countOfC.get(goldStandardClass)) - -Math.log(countOfW.get(clusterId)) + logN); - } - } - purity += maxCounts; - } - //ctMap = null; - purity/=n; - - validationMetrics.setPurity(purity); - - double entropyW=0.0; - for(Double Nw : countOfW.values()) { - entropyW-=(Nw/n)*(Math.log(Nw)-logN); - } - - double entropyC=0.0; - for(Double Nc : countOfW.values()) { - entropyC-=(Nc/n)*(Math.log(Nc)-logN); - } - - validationMetrics.setNMI(Iwc/((entropyW+entropyC)/2.0)); - - return validationMetrics; - } - - /** {@inheritDoc} */ - @Override - public ValidationMetrics average(List validationMetricsList) { - if(validationMetricsList.isEmpty()) { - return null; - } - - int k = validationMetricsList.size(); //number of samples - - - //create a new empty ValidationMetrics Object - ValidationMetrics avgValidationMetrics = new ValidationMetrics(); - - //estimate average values - for(ValidationMetrics vmSample : validationMetricsList) { - if(vmSample.getNMI()==null) { //it is null when we don't have goldStandardClass information - continue; // - } - - //update metrics - Double NMI = avgValidationMetrics.getNMI(); - if(NMI==null) { - NMI = 0.0; - } - avgValidationMetrics.setNMI(NMI+ vmSample.getNMI()/k); - Double purity = avgValidationMetrics.getPurity(); - if(purity==null) { - purity = 0.0; - } - avgValidationMetrics.setPurity(purity+ vmSample.getPurity()/k); - } - - - return avgValidationMetrics; - } -} diff --git a/datumbox-framework-core/src/main/java/com/datumbox/framework/core/machinelearning/validators/LinearRegressionValidator.java b/datumbox-framework-core/src/main/java/com/datumbox/framework/core/machinelearning/validators/LinearRegressionValidator.java deleted file mode 100755 index baae1726..00000000 --- a/datumbox-framework-core/src/main/java/com/datumbox/framework/core/machinelearning/validators/LinearRegressionValidator.java +++ /dev/null @@ -1,392 +0,0 @@ -/** - * Copyright (C) 2013-2016 Vasilis Vryniotis - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package com.datumbox.framework.core.machinelearning.validators; - -import com.datumbox.framework.common.dataobjects.Dataframe; -import com.datumbox.framework.common.dataobjects.FlatDataList; -import com.datumbox.framework.common.dataobjects.Record; -import com.datumbox.framework.common.dataobjects.TypeInference; -import com.datumbox.framework.core.machinelearning.common.abstracts.validators.AbstractValidator; -import com.datumbox.framework.core.statistics.distributions.ContinuousDistributions; -import com.datumbox.framework.core.statistics.nonparametrics.onesample.Lilliefors; -import com.datumbox.framework.core.statistics.parametrics.onesample.DurbinWatson; - -import java.util.List; - -/** - * Validation class for Linear Regression. - * - * @author Vasilis Vryniotis - */ -public class LinearRegressionValidator extends AbstractValidator { - - /** {@inheritDoc} */ - public static class ValidationMetrics extends AbstractValidator.AbstractValidationMetrics { - private double RSquare = 0.0; - private double RSquareAdjusted = 0.0; - private double SSE = 0.0; - private double SSR = 0.0; - private double SST = 0.0; - private double dfRegression = 0.0; - private double dfResidual = 0.0; - private double dfTotal = 0.0; - private double F = 0.0; - private double FPValue = 0.0; - private Double StdErrorOfEstimate = 0.0; //this can have null value if dfResidual is 0 - private double DW = 0.0; //Durbin–Watson statistic - private double NormalResiduals = 0.0; //Test on whether the residuals can be considered Normal - - /** - * Getter for the R Square. - * - * @return - */ - public double getRSquare() { - return RSquare; - } - - /** - * Setter for the R Square. - * - * @param RSquare - */ - public void setRSquare(double RSquare) { - this.RSquare = RSquare; - } - - /** - * Getter for the R Square Adjusted. - * - * @return - */ - public double getRSquareAdjusted() { - return RSquareAdjusted; - } - - /** - * Setter for the R Square Adjusted. - * - * @param RSquareAdjusted - */ - public void setRSquareAdjusted(double RSquareAdjusted) { - this.RSquareAdjusted = RSquareAdjusted; - } - - /** - * Getter for the Sum of Squared Errors. - * - * @return - */ - public double getSSE() { - return SSE; - } - - /** - * Setter for the Sum of Squared Errors. - * - * @param SSE - */ - public void setSSE(double SSE) { - this.SSE = SSE; - } - - /** - * Getter for the Sum of Squared due to Regression. - * - * @return - */ - public double getSSR() { - return SSR; - } - - /** - * Setter for the Sum of Squared due to Regression. - * - * @param SSR - */ - public void setSSR(double SSR) { - this.SSR = SSR; - } - - /** - * Getter for the Sum of Squared Total. - * - * @return - */ - public double getSST() { - return SST; - } - - /** - * Setter for the Sum of Squared Total. - * - * @param SST - */ - public void setSST(double SST) { - this.SST = SST; - } - - /** - * Getter for the degrees of freedom of Regression. - * - * @return - */ - public double getDfRegression() { - return dfRegression; - } - - /** - * Setter for the degrees of freedom of Regression. - * - * @param dfRegression - */ - public void setDfRegression(double dfRegression) { - this.dfRegression = dfRegression; - } - - /** - * Getter for the degrees of freedom of Residual. - * - * @return - */ - public double getDfResidual() { - return dfResidual; - } - - /** - * Setter for the degrees of freedom of Residual. - * - * @param dfResidual - */ - public void setDfResidual(double dfResidual) { - this.dfResidual = dfResidual; - } - - /** - * Getter for the degrees of freedom of Total. - * - * @return - */ - public double getDfTotal() { - return dfTotal; - } - - /** - * Setter for the degrees of freedom of Total. - * - * @param dfTotal - */ - public void setDfTotal(double dfTotal) { - this.dfTotal = dfTotal; - } - - /** - * Getter for F score. - * - * @return - */ - public double getF() { - return F; - } - - /** - * Setter for F score. - * - * @param F - */ - public void setF(double F) { - this.F = F; - } - - /** - * Getter for F p-value. - * - * @return - */ - public double getFPValue() { - return FPValue; - } - - /** - * Setter for F p-value. - * - * @param FPValue - */ - public void setFPValue(double FPValue) { - this.FPValue = FPValue; - } - - /** - * Getter for Standard Error of Estimate. - * - * @return - */ - public Double getStdErrorOfEstimate() { - return StdErrorOfEstimate; - } - - /** - * Setter for Standard Error of Estimate. - * - * @param StdErrorOfEstimate - */ - public void setStdErrorOfEstimate(Double StdErrorOfEstimate) { - this.StdErrorOfEstimate = StdErrorOfEstimate; - } - - /** - * Getter of Durbin Watson statistic. - * - * @return - */ - public double getDW() { - return DW; - } - - /** - * Setter of Durbin Watson statistic. - * - * @param DW - */ - public void setDW(double DW) { - this.DW = DW; - } - - /** - * Getter for Normal Residuals. - * - * @return - */ - public double getNormalResiduals() { - return NormalResiduals; - } - - /** - * Setter for Normal Residuals. - * - * @param NormalResiduals - */ - public void setNormalResiduals(double NormalResiduals) { - this.NormalResiduals = NormalResiduals; - } - - } - - /** {@inheritDoc} */ - @Override - public ValidationMetrics validate(Dataframe predictedData) { - //create new validation metrics object - ValidationMetrics validationMetrics = new ValidationMetrics(); - - int n = predictedData.size(); - - FlatDataList errorList = new FlatDataList(); - double Ybar = 0.0; - for(Record r : predictedData) { - Ybar += TypeInference.toDouble(r.getY())/n; - errorList.add(TypeInference.toDouble(r.getY())-TypeInference.toDouble(r.getYPredicted())); - } - - validationMetrics.setDW(DurbinWatson.calculateScore(errorList)); //autocorrelation metric (around 2 no autocorrelation) - - double SSE = 0.0; - for(Record r : predictedData) { - SSE += Math.pow(TypeInference.toDouble(r.getY())-TypeInference.toDouble(r.getYPredicted()), 2.0); - } - validationMetrics.setSSE(SSE); - - boolean normalResiduals = Lilliefors.test(errorList.toFlatDataCollection(), "normalDistribution", 0.05); - validationMetrics.setNormalResiduals( (normalResiduals)?0.0:1.0 ); //if the Lilliefors validate rejects the H0 means that the normality hypothesis is rejected thus the residuals are not normal - //errorList = null; - - double SSR = 0.0; - for(Record r : predictedData) { - SSR += Math.pow(TypeInference.toDouble(r.getY()) - Ybar, 2); - } - validationMetrics.setSSR(SSR); - - double SST = SSR+SSE; - validationMetrics.setSST(SST); - - double RSquare = SSR/SST; - validationMetrics.setRSquare(RSquare); - - int d = predictedData.xColumnSize()+1;//add one for the constant - int p = d - 1; //exclude constant - - double RSquareAdjusted = 1.0 - ((n-1.0)/(n-p-1.0))*(1.0-RSquare); - validationMetrics.setRSquareAdjusted(RSquareAdjusted); - - //degrees of freedom - double dfTotal = n-1.0; - validationMetrics.setDfTotal(dfTotal); - double dfRegression = d-1.0; - validationMetrics.setDfRegression(dfRegression); - double dfResidual = Math.max(n-d, 0.0); - validationMetrics.setDfResidual(dfResidual); - - double F = (SSR/dfRegression)/(SSE/dfResidual); - validationMetrics.setF(F); - - double FPValue = 1.0; - if(n>d) { - FPValue = ContinuousDistributions.fCdf(F, (int)dfRegression, (int)dfResidual); - } - validationMetrics.setFPValue(FPValue); - - Double StdErrorOfEstimate = null; - if(dfResidual>0) { - StdErrorOfEstimate = Math.sqrt(SSE/dfResidual); - } - validationMetrics.setStdErrorOfEstimate(StdErrorOfEstimate); - - return validationMetrics; - } - - /** {@inheritDoc} */ - @Override - public ValidationMetrics average(List validationMetricsList) { - - if(validationMetricsList.isEmpty()) { - return null; - } - - ValidationMetrics avgValidationMetrics = new ValidationMetrics(); - - int k = validationMetricsList.size(); //number of samples - for(ValidationMetrics vmSample : validationMetricsList) { - avgValidationMetrics.setRSquare(avgValidationMetrics.getRSquare() + vmSample.getRSquare()/k); - avgValidationMetrics.setRSquareAdjusted(avgValidationMetrics.getRSquareAdjusted() + vmSample.getRSquareAdjusted()/k); - avgValidationMetrics.setSSE(avgValidationMetrics.getSSE() + vmSample.getSSE()/k); - avgValidationMetrics.setSSR(avgValidationMetrics.getSSR() + vmSample.getSSR()/k); - avgValidationMetrics.setSST(avgValidationMetrics.getSST() + vmSample.getSST()/k); - avgValidationMetrics.setDfRegression(avgValidationMetrics.getDfRegression() + vmSample.getDfRegression()/k); - avgValidationMetrics.setDfResidual(avgValidationMetrics.getDfResidual() + vmSample.getDfResidual()/k); - avgValidationMetrics.setDfTotal(avgValidationMetrics.getDfTotal() + vmSample.getDfTotal()/k); - avgValidationMetrics.setF(avgValidationMetrics.getF() + vmSample.getF()/k); - avgValidationMetrics.setFPValue(avgValidationMetrics.getFPValue() + vmSample.getFPValue()/k); - Double stdErrorOfEstimate = vmSample.getStdErrorOfEstimate(); - if(stdErrorOfEstimate==null) { - stdErrorOfEstimate=0.0; - } - avgValidationMetrics.setStdErrorOfEstimate(avgValidationMetrics.getStdErrorOfEstimate() + stdErrorOfEstimate/k); - avgValidationMetrics.setDW(avgValidationMetrics.getDW() + vmSample.getDW()/k); - avgValidationMetrics.setNormalResiduals(avgValidationMetrics.getNormalResiduals() + vmSample.getNormalResiduals()/k); //percentage of samples that found the residuals to be normal - } - - return avgValidationMetrics; - } -} diff --git a/datumbox-framework-core/src/main/java/com/datumbox/framework/core/machinelearning/validators/RMSEValidator.java b/datumbox-framework-core/src/main/java/com/datumbox/framework/core/machinelearning/validators/RMSEValidator.java deleted file mode 100644 index 197329b3..00000000 --- a/datumbox-framework-core/src/main/java/com/datumbox/framework/core/machinelearning/validators/RMSEValidator.java +++ /dev/null @@ -1,99 +0,0 @@ -/** - * Copyright (C) 2013-2016 Vasilis Vryniotis - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package com.datumbox.framework.core.machinelearning.validators; - -import com.datumbox.framework.common.dataobjects.AssociativeArray; -import com.datumbox.framework.common.dataobjects.Dataframe; -import com.datumbox.framework.common.dataobjects.Record; -import com.datumbox.framework.common.dataobjects.TypeInference; -import com.datumbox.framework.core.machinelearning.common.abstracts.validators.AbstractValidator; -import com.datumbox.framework.core.machinelearning.recommendersystem.CollaborativeFiltering; - -import java.util.*; - -/** - * Estimates RMSE for ML models. - * - * @author Vasilis Vryniotis - */ -public class RMSEValidator extends AbstractValidator { - - /** {@inheritDoc} */ - public static class ValidationMetrics extends AbstractValidator.AbstractValidationMetrics { - - //validation metrics - private double RMSE = 0.0; - - /** - * Getter for RMSE. - * - * @return - */ - public double getRMSE() { - return RMSE; - } - - /** - * Setter for RMSE. - * - * @param RMSE - */ - public void setRMSE(double RMSE) { - this.RMSE = RMSE; - } - - } - - /** {@inheritDoc} */ - @Override - public ValidationMetrics validate(Dataframe predictedData) { - ValidationMetrics validationMetrics = new ValidationMetrics(); - - double RMSE = 0.0; - int i = 0; - for(Record r : predictedData) { - AssociativeArray predictions = r.getYPredictedProbabilities(); - for(Map.Entry entry : r.getX().entrySet()) { - Object column = entry.getKey(); - Object value = entry.getValue(); - RMSE += Math.pow(TypeInference.toDouble(value)-TypeInference.toDouble(predictions.get(column)), 2.0); - ++i; - } - } - - RMSE = Math.sqrt(RMSE/i); - validationMetrics.setRMSE(RMSE); - - return validationMetrics; - } - - /** {@inheritDoc} */ - @Override - public ValidationMetrics average(List validationMetricsList) { - if(validationMetricsList.isEmpty()) { - return null; - } - - ValidationMetrics avgValidationMetrics = new ValidationMetrics(); - - int k = validationMetricsList.size(); //number of samples - for(ValidationMetrics vmSample : validationMetricsList) { - avgValidationMetrics.setRMSE(avgValidationMetrics.getRMSE() + vmSample.getRMSE()/k); - } - - return avgValidationMetrics; - } -} diff --git a/datumbox-framework-core/src/main/java/com/datumbox/framework/core/utilities/text/extractors/NgramsExtractor.java b/datumbox-framework-core/src/main/java/com/datumbox/framework/core/utilities/text/extractors/NgramsExtractor.java index 4ad8cfa5..070f52f4 100755 --- a/datumbox-framework-core/src/main/java/com/datumbox/framework/core/utilities/text/extractors/NgramsExtractor.java +++ b/datumbox-framework-core/src/main/java/com/datumbox/framework/core/utilities/text/extractors/NgramsExtractor.java @@ -20,7 +20,7 @@ /** * The NgramsExtractor class can be used to tokenize a string, extract its keyword * combinations and estimate their occurrence scores in the original string. This - * extractor is ideal for the feature extraction phase of Text Classification. + * extractor is ideal for the feature extraction phase of Text ClassificationMetrics. * * @author Vasilis Vryniotis */ diff --git a/datumbox-framework-core/src/test/java/com/datumbox/framework/core/machinelearning/classification/BernoulliNaiveBayesTest.java b/datumbox-framework-core/src/test/java/com/datumbox/framework/core/machinelearning/classification/BernoulliNaiveBayesTest.java index a7297e3f..90406afb 100755 --- a/datumbox-framework-core/src/test/java/com/datumbox/framework/core/machinelearning/classification/BernoulliNaiveBayesTest.java +++ b/datumbox-framework-core/src/test/java/com/datumbox/framework/core/machinelearning/classification/BernoulliNaiveBayesTest.java @@ -18,7 +18,7 @@ import com.datumbox.framework.common.Configuration; import com.datumbox.framework.common.dataobjects.Dataframe; import com.datumbox.framework.common.dataobjects.Record; -import com.datumbox.framework.core.machinelearning.validators.ClassifierValidator; +import com.datumbox.framework.core.machinelearning.modelselection.metrics.ClassificationMetrics; import com.datumbox.framework.tests.Constants; import com.datumbox.framework.tests.Datasets; import com.datumbox.framework.tests.abstracts.AbstractTest; @@ -83,11 +83,11 @@ public void testValidate() { } /** - * Test of kFoldCrossValidation method, of class BernoulliNaiveBayes. + * Test of validate method, of class BernoulliNaiveBayes. */ @Test public void testKFoldCrossValidation() { - logger.info("kFoldCrossValidation"); + logger.info("validate"); Configuration conf = Configuration.getConfiguration(); @@ -103,7 +103,7 @@ public void testKFoldCrossValidation() { BernoulliNaiveBayes.TrainingParameters param = new BernoulliNaiveBayes.TrainingParameters(); - ClassifierValidator.ValidationMetrics vm = instance.kFoldCrossValidation(trainingData, param, k); + ClassificationMetrics vm = instance.kFoldCrossValidation(trainingData, param, k); double expResult = 0.6631318681318682; double result = vm.getMacroF1(); diff --git a/datumbox-framework-core/src/test/java/com/datumbox/framework/core/machinelearning/classification/BinarizedNaiveBayesTest.java b/datumbox-framework-core/src/test/java/com/datumbox/framework/core/machinelearning/classification/BinarizedNaiveBayesTest.java index f152f7b6..66855ac9 100755 --- a/datumbox-framework-core/src/test/java/com/datumbox/framework/core/machinelearning/classification/BinarizedNaiveBayesTest.java +++ b/datumbox-framework-core/src/test/java/com/datumbox/framework/core/machinelearning/classification/BinarizedNaiveBayesTest.java @@ -18,7 +18,7 @@ import com.datumbox.framework.common.Configuration; import com.datumbox.framework.common.dataobjects.Dataframe; import com.datumbox.framework.common.dataobjects.Record; -import com.datumbox.framework.core.machinelearning.validators.ClassifierValidator; +import com.datumbox.framework.core.machinelearning.modelselection.metrics.ClassificationMetrics; import com.datumbox.framework.tests.Constants; import com.datumbox.framework.tests.Datasets; import com.datumbox.framework.tests.abstracts.AbstractTest; @@ -84,11 +84,11 @@ public void testValidate() { /** - * Test of kFoldCrossValidation method, of class BinarizedNaiveBayes. + * Test of validate method, of class BinarizedNaiveBayes. */ @Test public void testKFoldCrossValidation() { - logger.info("kFoldCrossValidation"); + logger.info("validate"); Configuration conf = Configuration.getConfiguration(); @@ -104,7 +104,7 @@ public void testKFoldCrossValidation() { BinarizedNaiveBayes.TrainingParameters param = new BinarizedNaiveBayes.TrainingParameters(); - ClassifierValidator.ValidationMetrics vm = instance.kFoldCrossValidation(trainingData, param, k); + ClassificationMetrics vm = instance.kFoldCrossValidation(trainingData, param, k); double expResult = 0.6631318681318682; double result = vm.getMacroF1(); diff --git a/datumbox-framework-core/src/test/java/com/datumbox/framework/core/machinelearning/classification/MaximumEntropyTest.java b/datumbox-framework-core/src/test/java/com/datumbox/framework/core/machinelearning/classification/MaximumEntropyTest.java index a271757d..7d6c2117 100755 --- a/datumbox-framework-core/src/test/java/com/datumbox/framework/core/machinelearning/classification/MaximumEntropyTest.java +++ b/datumbox-framework-core/src/test/java/com/datumbox/framework/core/machinelearning/classification/MaximumEntropyTest.java @@ -18,7 +18,7 @@ import com.datumbox.framework.common.Configuration; import com.datumbox.framework.common.dataobjects.Dataframe; import com.datumbox.framework.common.dataobjects.Record; -import com.datumbox.framework.core.machinelearning.validators.ClassifierValidator; +import com.datumbox.framework.core.machinelearning.modelselection.metrics.ClassificationMetrics; import com.datumbox.framework.tests.Constants; import com.datumbox.framework.tests.Datasets; import com.datumbox.framework.tests.abstracts.AbstractTest; @@ -84,11 +84,11 @@ public void testValidate() { /** - * Test of kFoldCrossValidation method, of class MaximumEntropy. + * Test of validate method, of class MaximumEntropy. */ @Test public void testKFoldCrossValidation() { - logger.info("kFoldCrossValidation"); + logger.info("validate"); Configuration conf = Configuration.getConfiguration(); @@ -105,7 +105,7 @@ public void testKFoldCrossValidation() { MaximumEntropy.TrainingParameters param = new MaximumEntropy.TrainingParameters(); param.setTotalIterations(10); - ClassifierValidator.ValidationMetrics vm = instance.kFoldCrossValidation(trainingData, param, k); + ClassificationMetrics vm = instance.kFoldCrossValidation(trainingData, param, k); double expResult = 0.6051098901098901; double result = vm.getMacroF1(); diff --git a/datumbox-framework-core/src/test/java/com/datumbox/framework/core/machinelearning/classification/MultinomialNaiveBayesTest.java b/datumbox-framework-core/src/test/java/com/datumbox/framework/core/machinelearning/classification/MultinomialNaiveBayesTest.java index 31d8daba..de59b94c 100755 --- a/datumbox-framework-core/src/test/java/com/datumbox/framework/core/machinelearning/classification/MultinomialNaiveBayesTest.java +++ b/datumbox-framework-core/src/test/java/com/datumbox/framework/core/machinelearning/classification/MultinomialNaiveBayesTest.java @@ -19,7 +19,7 @@ import com.datumbox.framework.common.dataobjects.Dataframe; import com.datumbox.framework.common.dataobjects.Record; import com.datumbox.framework.core.machinelearning.datatransformation.DummyXYMinMaxNormalizer; -import com.datumbox.framework.core.machinelearning.validators.ClassifierValidator; +import com.datumbox.framework.core.machinelearning.modelselection.metrics.ClassificationMetrics; import com.datumbox.framework.tests.Constants; import com.datumbox.framework.tests.Datasets; import com.datumbox.framework.tests.abstracts.AbstractTest; @@ -100,11 +100,11 @@ public void testValidate() { /** - * Test of kFoldCrossValidation method, of class MultinomialNaiveBayes. + * Test of validate method, of class MultinomialNaiveBayes. */ @Test public void testKFoldCrossValidation() { - logger.info("kFoldCrossValidation"); + logger.info("validate"); Configuration conf = Configuration.getConfiguration(); @@ -121,7 +121,7 @@ public void testKFoldCrossValidation() { MultinomialNaiveBayes.TrainingParameters param = new MultinomialNaiveBayes.TrainingParameters(); param.setMultiProbabilityWeighted(true); - ClassifierValidator.ValidationMetrics vm = instance.kFoldCrossValidation(trainingData, param, k); + ClassificationMetrics vm = instance.kFoldCrossValidation(trainingData, param, k); double expResult = 0.6631318681318682; double result = vm.getMacroF1(); diff --git a/datumbox-framework-core/src/test/java/com/datumbox/framework/core/machinelearning/classification/OrdinalRegressionTest.java b/datumbox-framework-core/src/test/java/com/datumbox/framework/core/machinelearning/classification/OrdinalRegressionTest.java index d086cbe8..ee3cbd80 100755 --- a/datumbox-framework-core/src/test/java/com/datumbox/framework/core/machinelearning/classification/OrdinalRegressionTest.java +++ b/datumbox-framework-core/src/test/java/com/datumbox/framework/core/machinelearning/classification/OrdinalRegressionTest.java @@ -19,7 +19,7 @@ import com.datumbox.framework.common.dataobjects.Dataframe; import com.datumbox.framework.common.dataobjects.Record; import com.datumbox.framework.core.machinelearning.datatransformation.DummyXMinMaxNormalizer; -import com.datumbox.framework.core.machinelearning.validators.ClassifierValidator; +import com.datumbox.framework.core.machinelearning.modelselection.metrics.ClassificationMetrics; import com.datumbox.framework.tests.Constants; import com.datumbox.framework.tests.Datasets; import com.datumbox.framework.tests.abstracts.AbstractTest; @@ -100,11 +100,11 @@ public void testValidate() { /** - * Test of kFoldCrossValidation method, of class OrdinalRegression. + * Test of validate method, of class OrdinalRegression. */ @Test public void testKFoldCrossValidation() { - logger.info("kFoldCrossValidation"); + logger.info("validate"); Configuration conf = Configuration.getConfiguration(); @@ -126,7 +126,7 @@ public void testKFoldCrossValidation() { param.setTotalIterations(100); param.setL2(0.001); - ClassifierValidator.ValidationMetrics vm = instance.kFoldCrossValidation(trainingData, param, k); + ClassificationMetrics vm = instance.kFoldCrossValidation(trainingData, param, k); df.denormalize(trainingData); diff --git a/datumbox-framework-core/src/test/java/com/datumbox/framework/core/machinelearning/classification/SoftMaxRegressionTest.java b/datumbox-framework-core/src/test/java/com/datumbox/framework/core/machinelearning/classification/SoftMaxRegressionTest.java index 0d20346f..85478caf 100755 --- a/datumbox-framework-core/src/test/java/com/datumbox/framework/core/machinelearning/classification/SoftMaxRegressionTest.java +++ b/datumbox-framework-core/src/test/java/com/datumbox/framework/core/machinelearning/classification/SoftMaxRegressionTest.java @@ -20,7 +20,7 @@ import com.datumbox.framework.common.dataobjects.Record; import com.datumbox.framework.core.machinelearning.datatransformation.DummyXYMinMaxNormalizer; import com.datumbox.framework.core.machinelearning.datatransformation.XMinMaxNormalizer; -import com.datumbox.framework.core.machinelearning.validators.ClassifierValidator; +import com.datumbox.framework.core.machinelearning.modelselection.metrics.ClassificationMetrics; import com.datumbox.framework.tests.Constants; import com.datumbox.framework.tests.Datasets; import com.datumbox.framework.tests.abstracts.AbstractTest; @@ -102,11 +102,11 @@ public void testValidate() { /** - * Test of kFoldCrossValidation method, of class SoftMaxRegression. + * Test of validate method, of class SoftMaxRegression. */ @Test public void testKFoldCrossValidation() { - logger.info("kFoldCrossValidation"); + logger.info("validate"); Configuration conf = Configuration.getConfiguration(); @@ -128,7 +128,7 @@ public void testKFoldCrossValidation() { param.setL1(0.0001); param.setL2(0.0001); - ClassifierValidator.ValidationMetrics vm = instance.kFoldCrossValidation(trainingData, param, k); + ClassificationMetrics vm = instance.kFoldCrossValidation(trainingData, param, k); df.denormalize(trainingData); diff --git a/datumbox-framework-core/src/test/java/com/datumbox/framework/core/machinelearning/classification/SupportVectorMachineTest.java b/datumbox-framework-core/src/test/java/com/datumbox/framework/core/machinelearning/classification/SupportVectorMachineTest.java index f8f14d25..e4f25ae9 100755 --- a/datumbox-framework-core/src/test/java/com/datumbox/framework/core/machinelearning/classification/SupportVectorMachineTest.java +++ b/datumbox-framework-core/src/test/java/com/datumbox/framework/core/machinelearning/classification/SupportVectorMachineTest.java @@ -19,7 +19,7 @@ import com.datumbox.framework.common.dataobjects.Dataframe; import com.datumbox.framework.common.dataobjects.Record; import com.datumbox.framework.core.machinelearning.datatransformation.DummyXYMinMaxNormalizer; -import com.datumbox.framework.core.machinelearning.validators.ClassifierValidator; +import com.datumbox.framework.core.machinelearning.modelselection.metrics.ClassificationMetrics; import com.datumbox.framework.tests.Constants; import com.datumbox.framework.tests.Datasets; import com.datumbox.framework.tests.abstracts.AbstractTest; @@ -100,11 +100,11 @@ public void testValidate() { /** - * Test of kFoldCrossValidation method, of class SupportVectorMachine. + * Test of validate method, of class SupportVectorMachine. */ @Test public void testKFoldCrossValidation() { - logger.info("kFoldCrossValidation"); + logger.info("validate"); Configuration conf = Configuration.getConfiguration(); int k = 5; @@ -120,7 +120,7 @@ public void testKFoldCrossValidation() { SupportVectorMachine.TrainingParameters param = new SupportVectorMachine.TrainingParameters(); param.getSvmParameter().kernel_type = svm_parameter.LINEAR; - ClassifierValidator.ValidationMetrics vm = instance.kFoldCrossValidation(trainingData, param, k); + ClassificationMetrics vm = instance.kFoldCrossValidation(trainingData, param, k); double expResult = 0.6473992673992675; double result = vm.getMacroF1(); diff --git a/datumbox-framework-core/src/test/java/com/datumbox/framework/core/machinelearning/clustering/GaussianDPMMTest.java b/datumbox-framework-core/src/test/java/com/datumbox/framework/core/machinelearning/clustering/GaussianDPMMTest.java index 369f878d..ca1fc504 100755 --- a/datumbox-framework-core/src/test/java/com/datumbox/framework/core/machinelearning/clustering/GaussianDPMMTest.java +++ b/datumbox-framework-core/src/test/java/com/datumbox/framework/core/machinelearning/clustering/GaussianDPMMTest.java @@ -17,16 +17,12 @@ import com.datumbox.framework.common.Configuration; import com.datumbox.framework.common.dataobjects.Dataframe; -import com.datumbox.framework.common.dataobjects.Record; -import com.datumbox.framework.core.machinelearning.validators.ClustererValidator; +import com.datumbox.framework.core.machinelearning.modelselection.metrics.ClusteringMetrics; import com.datumbox.framework.tests.Constants; import com.datumbox.framework.tests.Datasets; import com.datumbox.framework.tests.abstracts.AbstractTest; import org.junit.Test; -import java.util.HashMap; -import java.util.Map; - import static org.junit.Assert.assertEquals; /** @@ -69,7 +65,7 @@ public void testValidate() { //instance = null; instance = new GaussianDPMM(dbName, conf); - ClustererValidator.ValidationMetrics vm = instance.validate(validationData); + ClusteringMetrics vm = instance.validate(validationData); double expResult = 1.0; double result = vm.getPurity(); @@ -83,11 +79,11 @@ public void testValidate() { /** - * Test of kFoldCrossValidation method, of class GaussianDPMM. + * Test of validate method, of class GaussianDPMM. */ @Test public void testKFoldCrossValidation() { - logger.info("kFoldCrossValidation"); + logger.info("validate"); Configuration conf = Configuration.getConfiguration(); @@ -110,7 +106,7 @@ public void testKFoldCrossValidation() { param.setMu0(new double[]{0.0, 0.0}); param.setPsi0(new double[][]{{1.0,0.0},{0.0,1.0}}); - ClustererValidator.ValidationMetrics vm = instance.kFoldCrossValidation(trainingData, param, k); + ClusteringMetrics vm = instance.kFoldCrossValidation(trainingData, param, k); double expResult = 1.0; diff --git a/datumbox-framework-core/src/test/java/com/datumbox/framework/core/machinelearning/clustering/HierarchicalAgglomerativeTest.java b/datumbox-framework-core/src/test/java/com/datumbox/framework/core/machinelearning/clustering/HierarchicalAgglomerativeTest.java index 0ffe05ab..49eb57b3 100755 --- a/datumbox-framework-core/src/test/java/com/datumbox/framework/core/machinelearning/clustering/HierarchicalAgglomerativeTest.java +++ b/datumbox-framework-core/src/test/java/com/datumbox/framework/core/machinelearning/clustering/HierarchicalAgglomerativeTest.java @@ -17,17 +17,13 @@ import com.datumbox.framework.common.Configuration; import com.datumbox.framework.common.dataobjects.Dataframe; -import com.datumbox.framework.common.dataobjects.Record; import com.datumbox.framework.core.machinelearning.datatransformation.DummyXYMinMaxNormalizer; -import com.datumbox.framework.core.machinelearning.validators.ClustererValidator; +import com.datumbox.framework.core.machinelearning.modelselection.metrics.ClusteringMetrics; import com.datumbox.framework.tests.Constants; import com.datumbox.framework.tests.Datasets; import com.datumbox.framework.tests.abstracts.AbstractTest; import org.junit.Test; -import java.util.HashMap; -import java.util.Map; - import static org.junit.Assert.assertEquals; /** @@ -78,7 +74,7 @@ public void testValidate() { df = new DummyXYMinMaxNormalizer(dbName, conf); instance = new HierarchicalAgglomerative(dbName, conf); - ClustererValidator.ValidationMetrics vm = instance.validate(validationData); + ClusteringMetrics vm = instance.validate(validationData); df.denormalize(trainingData); df.denormalize(validationData); @@ -96,11 +92,11 @@ public void testValidate() { /** - * Test of kFoldCrossValidation method, of class HierarchicalAgglomerative. + * Test of validate method, of class HierarchicalAgglomerative. */ @Test public void testKFoldCrossValidation() { - logger.info("kFoldCrossValidation"); + logger.info("validate"); Configuration conf = Configuration.getConfiguration(); @@ -126,7 +122,7 @@ public void testKFoldCrossValidation() { param.setMinClustersThreshold(2); param.setMaxDistanceThreshold(Double.MAX_VALUE); - ClustererValidator.ValidationMetrics vm = instance.kFoldCrossValidation(trainingData, param, k); + ClusteringMetrics vm = instance.kFoldCrossValidation(trainingData, param, k); df.denormalize(trainingData); diff --git a/datumbox-framework-core/src/test/java/com/datumbox/framework/core/machinelearning/clustering/KmeansTest.java b/datumbox-framework-core/src/test/java/com/datumbox/framework/core/machinelearning/clustering/KmeansTest.java index 106042b7..a85acc4b 100755 --- a/datumbox-framework-core/src/test/java/com/datumbox/framework/core/machinelearning/clustering/KmeansTest.java +++ b/datumbox-framework-core/src/test/java/com/datumbox/framework/core/machinelearning/clustering/KmeansTest.java @@ -18,15 +18,12 @@ import com.datumbox.framework.common.Configuration; import com.datumbox.framework.common.dataobjects.Dataframe; import com.datumbox.framework.core.machinelearning.datatransformation.DummyXYMinMaxNormalizer; -import com.datumbox.framework.core.machinelearning.validators.ClustererValidator; +import com.datumbox.framework.core.machinelearning.modelselection.metrics.ClusteringMetrics; import com.datumbox.framework.tests.Constants; import com.datumbox.framework.tests.Datasets; import com.datumbox.framework.tests.abstracts.AbstractTest; import org.junit.Test; -import java.util.HashMap; -import java.util.Map; - import static org.junit.Assert.assertEquals; /** @@ -81,7 +78,7 @@ public void testValidate() { df = new DummyXYMinMaxNormalizer(dbName, conf); instance = new Kmeans(dbName, conf); - ClustererValidator.ValidationMetrics vm = instance.validate(validationData); + ClusteringMetrics vm = instance.validate(validationData); df.denormalize(trainingData); df.denormalize(validationData); @@ -99,11 +96,11 @@ public void testValidate() { /** - * Test of kFoldCrossValidation method, of class Kmeans. + * Test of validate method, of class Kmeans. */ @Test public void testKFoldCrossValidation() { - logger.info("kFoldCrossValidation"); + logger.info("validate"); Configuration conf = Configuration.getConfiguration(); @@ -133,7 +130,7 @@ public void testKFoldCrossValidation() { param.setCategoricalGamaMultiplier(1.0); param.setSubsetFurthestFirstcValue(2.0); - ClustererValidator.ValidationMetrics vm = instance.kFoldCrossValidation(trainingData, param, k); + ClusteringMetrics vm = instance.kFoldCrossValidation(trainingData, param, k); df.denormalize(trainingData); diff --git a/datumbox-framework-core/src/test/java/com/datumbox/framework/core/machinelearning/clustering/MultinomialDPMMTest.java b/datumbox-framework-core/src/test/java/com/datumbox/framework/core/machinelearning/clustering/MultinomialDPMMTest.java index 11b586b3..a799b4cc 100755 --- a/datumbox-framework-core/src/test/java/com/datumbox/framework/core/machinelearning/clustering/MultinomialDPMMTest.java +++ b/datumbox-framework-core/src/test/java/com/datumbox/framework/core/machinelearning/clustering/MultinomialDPMMTest.java @@ -17,16 +17,13 @@ import com.datumbox.framework.common.Configuration; import com.datumbox.framework.common.dataobjects.Dataframe; -import com.datumbox.framework.core.machinelearning.validators.ClustererValidator; +import com.datumbox.framework.core.machinelearning.modelselection.metrics.ClusteringMetrics; import com.datumbox.framework.tests.Constants; import com.datumbox.framework.tests.Datasets; import com.datumbox.framework.tests.abstracts.AbstractTest; import org.junit.Assert; import org.junit.Test; -import java.util.HashMap; -import java.util.Map; - import static org.junit.Assert.assertEquals; /** @@ -66,7 +63,7 @@ public void testValidate() { //instance = null; instance = new MultinomialDPMM(dbName, conf); - ClustererValidator.ValidationMetrics vm = instance.validate(validationData); + ClusteringMetrics vm = instance.validate(validationData); double expResult = 1.0; double result = vm.getPurity(); @@ -80,11 +77,11 @@ public void testValidate() { /** - * Test of kFoldCrossValidation method, of class MultinomialDPMM. + * Test of validate method, of class MultinomialDPMM. */ @Test public void testKFoldCrossValidation() { - logger.info("kFoldCrossValidation"); + logger.info("validate"); Configuration conf = Configuration.getConfiguration(); @@ -104,7 +101,7 @@ public void testKFoldCrossValidation() { param.setInitializationMethod(MultinomialDPMM.TrainingParameters.Initialization.ONE_CLUSTER_PER_RECORD); param.setAlphaWords(1); - ClustererValidator.ValidationMetrics vm = instance.kFoldCrossValidation(trainingData, param, k); + ClusteringMetrics vm = instance.kFoldCrossValidation(trainingData, param, k); double expResult = 1.0; diff --git a/datumbox-framework-core/src/test/java/com/datumbox/framework/core/machinelearning/ensemblelearning/AdaboostTest.java b/datumbox-framework-core/src/test/java/com/datumbox/framework/core/machinelearning/ensemblelearning/AdaboostTest.java index 4a054f80..53fb9a9e 100755 --- a/datumbox-framework-core/src/test/java/com/datumbox/framework/core/machinelearning/ensemblelearning/AdaboostTest.java +++ b/datumbox-framework-core/src/test/java/com/datumbox/framework/core/machinelearning/ensemblelearning/AdaboostTest.java @@ -20,7 +20,7 @@ import com.datumbox.framework.common.dataobjects.Record; import com.datumbox.framework.core.machinelearning.classification.MultinomialNaiveBayes; import com.datumbox.framework.core.machinelearning.datatransformation.DummyXYMinMaxNormalizer; -import com.datumbox.framework.core.machinelearning.validators.ClassifierValidator; +import com.datumbox.framework.core.machinelearning.modelselection.metrics.ClassificationMetrics; import com.datumbox.framework.tests.Constants; import com.datumbox.framework.tests.Datasets; import com.datumbox.framework.tests.abstracts.AbstractTest; @@ -113,11 +113,11 @@ public void testValidate() { /** - * Test of kFoldCrossValidation method, of class Adaboost. + * Test of validate method, of class Adaboost. */ @Test public void testKFoldCrossValidation() { - logger.info("kFoldCrossValidation"); + logger.info("validate"); Configuration conf = Configuration.getConfiguration(); @@ -142,7 +142,7 @@ public void testKFoldCrossValidation() { param.setWeakClassifierTrainingParameters(trainingParameters); - ClassifierValidator.ValidationMetrics vm = instance.kFoldCrossValidation(trainingData, param, k); + ClassificationMetrics vm = instance.kFoldCrossValidation(trainingData, param, k); double expResult = 0.6923992673992675; double result = vm.getMacroF1(); diff --git a/datumbox-framework-core/src/test/java/com/datumbox/framework/core/machinelearning/ensemblelearning/BootstrapAggregatingTest.java b/datumbox-framework-core/src/test/java/com/datumbox/framework/core/machinelearning/ensemblelearning/BootstrapAggregatingTest.java index 2bce8c61..437ffc94 100755 --- a/datumbox-framework-core/src/test/java/com/datumbox/framework/core/machinelearning/ensemblelearning/BootstrapAggregatingTest.java +++ b/datumbox-framework-core/src/test/java/com/datumbox/framework/core/machinelearning/ensemblelearning/BootstrapAggregatingTest.java @@ -20,7 +20,7 @@ import com.datumbox.framework.common.dataobjects.Record; import com.datumbox.framework.core.machinelearning.classification.MultinomialNaiveBayes; import com.datumbox.framework.core.machinelearning.datatransformation.DummyXYMinMaxNormalizer; -import com.datumbox.framework.core.machinelearning.validators.ClassifierValidator; +import com.datumbox.framework.core.machinelearning.modelselection.metrics.ClassificationMetrics; import com.datumbox.framework.tests.Constants; import com.datumbox.framework.tests.Datasets; import com.datumbox.framework.tests.abstracts.AbstractTest; @@ -109,11 +109,11 @@ public void testValidate() { /** - * Test of kFoldCrossValidation method, of class BootstrapAggregating. + * Test of validate method, of class BootstrapAggregating. */ @Test public void testKFoldCrossValidation() { - logger.info("kFoldCrossValidation"); + logger.info("validate"); Configuration conf = Configuration.getConfiguration(); @@ -138,7 +138,7 @@ public void testKFoldCrossValidation() { param.setWeakClassifierTrainingParameters(trainingParameters); - ClassifierValidator.ValidationMetrics vm = instance.kFoldCrossValidation(trainingData, param, k); + ClassificationMetrics vm = instance.kFoldCrossValidation(trainingData, param, k); double expResult = 0.6609432234432234; double result = vm.getMacroF1(); diff --git a/datumbox-framework-core/src/test/java/com/datumbox/framework/core/machinelearning/recommendersystem/CollaborativeFilteringTest.java b/datumbox-framework-core/src/test/java/com/datumbox/framework/core/machinelearning/recommendersystem/CollaborativeFilteringTest.java index 4919a397..877b5a88 100755 --- a/datumbox-framework-core/src/test/java/com/datumbox/framework/core/machinelearning/recommendersystem/CollaborativeFilteringTest.java +++ b/datumbox-framework-core/src/test/java/com/datumbox/framework/core/machinelearning/recommendersystem/CollaborativeFilteringTest.java @@ -19,7 +19,7 @@ import com.datumbox.framework.common.dataobjects.AssociativeArray; import com.datumbox.framework.common.dataobjects.Dataframe; import com.datumbox.framework.common.dataobjects.TypeInference; -import com.datumbox.framework.core.machinelearning.validators.RMSEValidator; +import com.datumbox.framework.core.machinelearning.modelselection.metrics.RecommendationMetrics; import com.datumbox.framework.tests.Constants; import com.datumbox.framework.tests.Datasets; import com.datumbox.framework.tests.abstracts.AbstractTest; @@ -64,7 +64,7 @@ public void testValidate() { //instance = null; instance = new CollaborativeFiltering(dbName, conf); - RMSEValidator.ValidationMetrics vm = instance.validate(validationData); + RecommendationMetrics vm = instance.validate(validationData); Map expResult = new HashMap<>(); expResult.put("pitta", 4.686394033077408); diff --git a/datumbox-framework-core/src/test/java/com/datumbox/framework/core/machinelearning/regression/MatrixLinearRegressionTest.java b/datumbox-framework-core/src/test/java/com/datumbox/framework/core/machinelearning/regression/MatrixLinearRegressionTest.java index aa5c33dc..03e0a0ce 100755 --- a/datumbox-framework-core/src/test/java/com/datumbox/framework/core/machinelearning/regression/MatrixLinearRegressionTest.java +++ b/datumbox-framework-core/src/test/java/com/datumbox/framework/core/machinelearning/regression/MatrixLinearRegressionTest.java @@ -21,7 +21,7 @@ import com.datumbox.framework.common.dataobjects.TypeInference; import com.datumbox.framework.core.machinelearning.datatransformation.DummyXYMinMaxNormalizer; import com.datumbox.framework.core.machinelearning.datatransformation.XYMinMaxNormalizer; -import com.datumbox.framework.core.machinelearning.validators.LinearRegressionValidator; +import com.datumbox.framework.core.machinelearning.modelselection.metrics.LinearRegressionMetrics; import com.datumbox.framework.tests.Constants; import com.datumbox.framework.tests.Datasets; import com.datumbox.framework.tests.abstracts.AbstractTest; @@ -92,11 +92,11 @@ public void testValidate() { /** - * Test of kFoldCrossValidation method, of class MatrixLinearRegression. + * Test of validate method, of class MatrixLinearRegression. */ @Test public void testKFoldCrossValidation() { - logger.info("kFoldCrossValidation"); + logger.info("validate"); Configuration conf = Configuration.getConfiguration(); @@ -116,7 +116,7 @@ public void testKFoldCrossValidation() { MatrixLinearRegression.TrainingParameters param = new MatrixLinearRegression.TrainingParameters(); - LinearRegressionValidator.ValidationMetrics vm = instance.kFoldCrossValidation(trainingData, param, k); + LinearRegressionMetrics vm = instance.kFoldCrossValidation(trainingData, param, k); df.denormalize(trainingData); diff --git a/datumbox-framework-core/src/test/java/com/datumbox/framework/core/machinelearning/regression/NLMSTest.java b/datumbox-framework-core/src/test/java/com/datumbox/framework/core/machinelearning/regression/NLMSTest.java index 7e31b9d3..8a62588d 100755 --- a/datumbox-framework-core/src/test/java/com/datumbox/framework/core/machinelearning/regression/NLMSTest.java +++ b/datumbox-framework-core/src/test/java/com/datumbox/framework/core/machinelearning/regression/NLMSTest.java @@ -21,7 +21,7 @@ import com.datumbox.framework.common.dataobjects.TypeInference; import com.datumbox.framework.core.machinelearning.datatransformation.DummyXYMinMaxNormalizer; import com.datumbox.framework.core.machinelearning.featureselection.continuous.PCA; -import com.datumbox.framework.core.machinelearning.validators.LinearRegressionValidator; +import com.datumbox.framework.core.machinelearning.modelselection.metrics.LinearRegressionMetrics; import com.datumbox.framework.tests.Constants; import com.datumbox.framework.tests.Datasets; import com.datumbox.framework.tests.abstracts.AbstractTest; @@ -93,11 +93,11 @@ public void testValidate() { /** - * Test of kFoldCrossValidation method, of class NLMS. + * Test of validate method, of class NLMS. */ @Test public void testKFoldCrossValidation() { - logger.info("kFoldCrossValidation"); + logger.info("validate"); Configuration conf = Configuration.getConfiguration(); @@ -130,7 +130,7 @@ public void testKFoldCrossValidation() { param.setL1(0.001); param.setL2(0.001); - LinearRegressionValidator.ValidationMetrics vm = instance.kFoldCrossValidation(trainingData, param, k); + LinearRegressionMetrics vm = instance.kFoldCrossValidation(trainingData, param, k); df.denormalize(trainingData); diff --git a/datumbox-framework-core/src/test/java/com/datumbox/framework/core/machinelearning/topicmodeling/LatentDirichletAllocationTest.java b/datumbox-framework-core/src/test/java/com/datumbox/framework/core/machinelearning/topicmodeling/LatentDirichletAllocationTest.java index 8cec5970..78172c92 100755 --- a/datumbox-framework-core/src/test/java/com/datumbox/framework/core/machinelearning/topicmodeling/LatentDirichletAllocationTest.java +++ b/datumbox-framework-core/src/test/java/com/datumbox/framework/core/machinelearning/topicmodeling/LatentDirichletAllocationTest.java @@ -19,7 +19,7 @@ import com.datumbox.framework.common.dataobjects.Dataframe; import com.datumbox.framework.common.dataobjects.Record; import com.datumbox.framework.core.machinelearning.classification.SoftMaxRegression; -import com.datumbox.framework.core.machinelearning.validators.ClassifierValidator; +import com.datumbox.framework.core.machinelearning.modelselection.metrics.ClassificationMetrics; import com.datumbox.framework.core.utilities.text.extractors.UniqueWordSequenceExtractor; import com.datumbox.framework.tests.Constants; import com.datumbox.framework.tests.abstracts.AbstractTest; @@ -91,7 +91,7 @@ public void testValidate() { tp.setLearningRate(1.0); tp.setTotalIterations(50); - ClassifierValidator.ValidationMetrics vm = smr.kFoldCrossValidation(reducedTrainingData, tp, 1); + ClassificationMetrics vm = smr.kFoldCrossValidation(reducedTrainingData, tp, 1); double expResult = 0.6843125117743629; double result = vm.getMacroF1();