From 14975f4cedb1fe1a51723b8594a4b77c1e7acce1 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Sun, 18 Dec 2016 21:25:44 +0000 Subject: [PATCH] Finalizing the validation mechanism. --- CHANGELOG.md | 2 +- TODO.txt | 3 +- .../modelers/AbstractClassifier.java | 4 +- .../abstracts/modelers/AbstractClusterer.java | 4 +- .../modelers/AbstractRecommender.java | 4 +- .../abstracts/modelers/AbstractRegressor.java | 4 +- .../metrics}/AbstractMetrics.java | 2 +- .../validators/AbstractValidator.java | 61 ++++++++++++++++++ .../metrics/ClassificationMetrics.java | 3 +- .../metrics/ClusteringMetrics.java | 3 +- .../metrics/LinearRegressionMetrics.java | 3 +- .../metrics/RecommendationMetrics.java | 3 +- .../KFoldValidator.java} | 63 +++++++++---------- 13 files changed, 110 insertions(+), 49 deletions(-) rename datumbox-framework-core/src/main/java/com/datumbox/framework/core/machinelearning/common/abstracts/{validators => modelselection/metrics}/AbstractMetrics.java (98%) create mode 100644 datumbox-framework-core/src/main/java/com/datumbox/framework/core/machinelearning/common/abstracts/modelselection/validators/AbstractValidator.java rename datumbox-framework-core/src/main/java/com/datumbox/framework/core/machinelearning/modelselection/{splitters/TemporaryKFold.java => validators/KFoldValidator.java} (71%) diff --git a/CHANGELOG.md b/CHANGELOG.md index ee172ca9..dd8f8321 100755 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,7 +6,7 @@ Version 0.8.0-SNAPSHOT - Build 20161218 - Improved Validation: - Removed the ValidationMetrics from the Algorithms. Now it is a separate object. - - Removed the kFold validation from Algorithms. Now we offer a splitter mechanism. + - Removed the kFold validation from Algorithms. Now we offer a new validator mechanism. - A single KnowledgeBase implementation is now used. - Removed the unnecessary n & d model parameters from all models. diff --git a/TODO.txt b/TODO.txt index 2df48b41..4ed0a6db 100755 --- a/TODO.txt +++ b/TODO.txt @@ -1,12 +1,11 @@ CODE IMPROVEMENTS ================= -- All ValidationMetrics should hava a serialization number - Add save() load() methods in the models - The method kFoldCrossValidation is removed from AbstractValidator. It should be a separate class instead. - Support of better Transformers (Zscore, decouple boolean transforming from numeric etc). -- Write a ShuffleSplit class similar to KFold. +- Write a ShuffleSplitValidator class similar to KFold. - Write generic optimizers instead of having optimization methods in the algorithms. - Support MapDB 3.0 once a stable version is released. Remove the HOTFIX for MapDB bug #664. - Try to separate persistent storage in a separate module that is inherited by common. diff --git a/datumbox-framework-core/src/main/java/com/datumbox/framework/core/machinelearning/common/abstracts/modelers/AbstractClassifier.java b/datumbox-framework-core/src/main/java/com/datumbox/framework/core/machinelearning/common/abstracts/modelers/AbstractClassifier.java index 849f7c06..ce150203 100755 --- a/datumbox-framework-core/src/main/java/com/datumbox/framework/core/machinelearning/common/abstracts/modelers/AbstractClassifier.java +++ b/datumbox-framework-core/src/main/java/com/datumbox/framework/core/machinelearning/common/abstracts/modelers/AbstractClassifier.java @@ -21,7 +21,7 @@ import com.datumbox.framework.common.persistentstorage.interfaces.DatabaseConnector; import com.datumbox.framework.common.utilities.MapMethods; import com.datumbox.framework.core.machinelearning.modelselection.metrics.ClassificationMetrics; -import com.datumbox.framework.core.machinelearning.modelselection.splitters.TemporaryKFold; +import com.datumbox.framework.core.machinelearning.modelselection.validators.KFoldValidator; import java.util.*; @@ -113,6 +113,6 @@ public ClassificationMetrics validate(Dataframe testingData) { public ClassificationMetrics kFoldCrossValidation(Dataframe trainingData, TP trainingParameters, int k) { logger.info("validate()"); - return new TemporaryKFold<>(ClassificationMetrics.class, k).validate(trainingData, dbName, knowledgeBase.getConf(), this.getClass(), trainingParameters); + return new KFoldValidator<>(ClassificationMetrics.class, knowledgeBase.getConf(), k).validate(trainingData, trainingParameters); } } \ No newline at end of file diff --git a/datumbox-framework-core/src/main/java/com/datumbox/framework/core/machinelearning/common/abstracts/modelers/AbstractClusterer.java b/datumbox-framework-core/src/main/java/com/datumbox/framework/core/machinelearning/common/abstracts/modelers/AbstractClusterer.java index e63ab165..7b4e526e 100755 --- a/datumbox-framework-core/src/main/java/com/datumbox/framework/core/machinelearning/common/abstracts/modelers/AbstractClusterer.java +++ b/datumbox-framework-core/src/main/java/com/datumbox/framework/core/machinelearning/common/abstracts/modelers/AbstractClusterer.java @@ -24,7 +24,7 @@ import com.datumbox.framework.common.persistentstorage.interfaces.DatabaseConnector.StorageHint; import com.datumbox.framework.core.machinelearning.common.interfaces.Cluster; import com.datumbox.framework.core.machinelearning.modelselection.metrics.ClusteringMetrics; -import com.datumbox.framework.core.machinelearning.modelselection.splitters.TemporaryKFold; +import com.datumbox.framework.core.machinelearning.modelselection.validators.KFoldValidator; import java.util.*; @@ -233,6 +233,6 @@ public ClusteringMetrics validate(Dataframe testingData) { public ClusteringMetrics kFoldCrossValidation(Dataframe trainingData, TP trainingParameters, int k) { logger.info("validate()"); - return new TemporaryKFold<>(ClusteringMetrics.class, k).validate(trainingData, dbName, knowledgeBase.getConf(), this.getClass(), trainingParameters); + return new KFoldValidator<>(ClusteringMetrics.class, knowledgeBase.getConf(), k).validate(trainingData, trainingParameters); } } \ No newline at end of file diff --git a/datumbox-framework-core/src/main/java/com/datumbox/framework/core/machinelearning/common/abstracts/modelers/AbstractRecommender.java b/datumbox-framework-core/src/main/java/com/datumbox/framework/core/machinelearning/common/abstracts/modelers/AbstractRecommender.java index 98b2bcae..c2a22e54 100755 --- a/datumbox-framework-core/src/main/java/com/datumbox/framework/core/machinelearning/common/abstracts/modelers/AbstractRecommender.java +++ b/datumbox-framework-core/src/main/java/com/datumbox/framework/core/machinelearning/common/abstracts/modelers/AbstractRecommender.java @@ -18,7 +18,7 @@ import com.datumbox.framework.common.Configuration; import com.datumbox.framework.common.dataobjects.Dataframe; import com.datumbox.framework.core.machinelearning.modelselection.metrics.RecommendationMetrics; -import com.datumbox.framework.core.machinelearning.modelselection.splitters.TemporaryKFold; +import com.datumbox.framework.core.machinelearning.modelselection.validators.KFoldValidator; /** * Abstract Class for recommender algorithms. @@ -52,6 +52,6 @@ public RecommendationMetrics validate(Dataframe testingData) { public RecommendationMetrics kFoldCrossValidation(Dataframe trainingData, TP trainingParameters, int k) { logger.info("validate()"); - return new TemporaryKFold<>(RecommendationMetrics.class, k).validate(trainingData, dbName, knowledgeBase.getConf(), this.getClass(), trainingParameters); + return new KFoldValidator<>(RecommendationMetrics.class, knowledgeBase.getConf(), k).validate(trainingData, trainingParameters); } } diff --git a/datumbox-framework-core/src/main/java/com/datumbox/framework/core/machinelearning/common/abstracts/modelers/AbstractRegressor.java b/datumbox-framework-core/src/main/java/com/datumbox/framework/core/machinelearning/common/abstracts/modelers/AbstractRegressor.java index 15ae2bea..5c7a1cfb 100755 --- a/datumbox-framework-core/src/main/java/com/datumbox/framework/core/machinelearning/common/abstracts/modelers/AbstractRegressor.java +++ b/datumbox-framework-core/src/main/java/com/datumbox/framework/core/machinelearning/common/abstracts/modelers/AbstractRegressor.java @@ -18,7 +18,7 @@ import com.datumbox.framework.common.Configuration; import com.datumbox.framework.common.dataobjects.Dataframe; import com.datumbox.framework.core.machinelearning.modelselection.metrics.LinearRegressionMetrics; -import com.datumbox.framework.core.machinelearning.modelselection.splitters.TemporaryKFold; +import com.datumbox.framework.core.machinelearning.modelselection.validators.KFoldValidator; /** * Base Class for all the Regression algorithms. @@ -54,6 +54,6 @@ public LinearRegressionMetrics validate(Dataframe testingData) { public LinearRegressionMetrics kFoldCrossValidation(Dataframe trainingData, TP trainingParameters, int k) { logger.info("validate()"); - return new TemporaryKFold<>(LinearRegressionMetrics.class, k).validate(trainingData, dbName, knowledgeBase.getConf(), this.getClass(), trainingParameters); + return new KFoldValidator<>(LinearRegressionMetrics.class, knowledgeBase.getConf(), k).validate(trainingData, trainingParameters); } } \ No newline at end of file diff --git a/datumbox-framework-core/src/main/java/com/datumbox/framework/core/machinelearning/common/abstracts/validators/AbstractMetrics.java b/datumbox-framework-core/src/main/java/com/datumbox/framework/core/machinelearning/common/abstracts/modelselection/metrics/AbstractMetrics.java similarity index 98% rename from datumbox-framework-core/src/main/java/com/datumbox/framework/core/machinelearning/common/abstracts/validators/AbstractMetrics.java rename to datumbox-framework-core/src/main/java/com/datumbox/framework/core/machinelearning/common/abstracts/modelselection/metrics/AbstractMetrics.java index 1cf08b61..36cdb81d 100755 --- a/datumbox-framework-core/src/main/java/com/datumbox/framework/core/machinelearning/common/abstracts/validators/AbstractMetrics.java +++ b/datumbox-framework-core/src/main/java/com/datumbox/framework/core/machinelearning/common/abstracts/modelselection/metrics/AbstractMetrics.java @@ -13,7 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package com.datumbox.framework.core.machinelearning.common.abstracts.validators; +package com.datumbox.framework.core.machinelearning.common.abstracts.modelselection.metrics; import com.datumbox.framework.common.dataobjects.Dataframe; import com.datumbox.framework.core.machinelearning.common.interfaces.ValidationMetrics; diff --git a/datumbox-framework-core/src/main/java/com/datumbox/framework/core/machinelearning/common/abstracts/modelselection/validators/AbstractValidator.java b/datumbox-framework-core/src/main/java/com/datumbox/framework/core/machinelearning/common/abstracts/modelselection/validators/AbstractValidator.java new file mode 100644 index 00000000..e1c03876 --- /dev/null +++ b/datumbox-framework-core/src/main/java/com/datumbox/framework/core/machinelearning/common/abstracts/modelselection/validators/AbstractValidator.java @@ -0,0 +1,61 @@ +/** + * Copyright (C) 2013-2016 Vasilis Vryniotis + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.datumbox.framework.core.machinelearning.common.abstracts.modelselection.validators; + +import com.datumbox.framework.common.Configuration; +import com.datumbox.framework.common.dataobjects.Dataframe; +import com.datumbox.framework.core.machinelearning.common.interfaces.TrainingParameters; +import com.datumbox.framework.core.machinelearning.common.interfaces.ValidationMetrics; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * Base class for all Validators. + * + * @param + */ +public abstract class AbstractValidator { + + /** + * The Logger of all Validators. + * We want this to be non-static in order to print the names of the inherited classes. + */ + protected final Logger logger = LoggerFactory.getLogger(getClass()); + + protected final Class vmClass; + protected final Configuration conf; + + /** + * Default constructor. + * + * @param vmClass + * @param conf + */ + public AbstractValidator(Class vmClass, Configuration conf) { + this.vmClass = vmClass; + this.conf = conf; + } + + /** + * Performs a split on the data, trains models and performs validation. + * + * @param dataset + * @param trainingParameters + * @return + */ + public abstract VM validate(Dataframe dataset, TrainingParameters trainingParameters); + +} diff --git a/datumbox-framework-core/src/main/java/com/datumbox/framework/core/machinelearning/modelselection/metrics/ClassificationMetrics.java b/datumbox-framework-core/src/main/java/com/datumbox/framework/core/machinelearning/modelselection/metrics/ClassificationMetrics.java index 4dfa7240..38259566 100755 --- a/datumbox-framework-core/src/main/java/com/datumbox/framework/core/machinelearning/modelselection/metrics/ClassificationMetrics.java +++ b/datumbox-framework-core/src/main/java/com/datumbox/framework/core/machinelearning/modelselection/metrics/ClassificationMetrics.java @@ -17,7 +17,7 @@ import com.datumbox.framework.common.dataobjects.Dataframe; import com.datumbox.framework.common.dataobjects.Record; -import com.datumbox.framework.core.machinelearning.common.abstracts.validators.AbstractMetrics; +import com.datumbox.framework.core.machinelearning.common.abstracts.modelselection.metrics.AbstractMetrics; import java.util.*; @@ -27,6 +27,7 @@ * @author Vasilis Vryniotis */ public class ClassificationMetrics extends AbstractMetrics { + private static final long serialVersionUID = 1L; /** * Enum that stores the 4 possible Sensitivity Rates. diff --git a/datumbox-framework-core/src/main/java/com/datumbox/framework/core/machinelearning/modelselection/metrics/ClusteringMetrics.java b/datumbox-framework-core/src/main/java/com/datumbox/framework/core/machinelearning/modelselection/metrics/ClusteringMetrics.java index 2f1a104e..94d8ac29 100755 --- a/datumbox-framework-core/src/main/java/com/datumbox/framework/core/machinelearning/modelselection/metrics/ClusteringMetrics.java +++ b/datumbox-framework-core/src/main/java/com/datumbox/framework/core/machinelearning/modelselection/metrics/ClusteringMetrics.java @@ -17,7 +17,7 @@ import com.datumbox.framework.common.dataobjects.Dataframe; import com.datumbox.framework.common.dataobjects.Record; -import com.datumbox.framework.core.machinelearning.common.abstracts.validators.AbstractMetrics; +import com.datumbox.framework.core.machinelearning.common.abstracts.modelselection.metrics.AbstractMetrics; import java.util.*; @@ -31,6 +31,7 @@ * @author Vasilis Vryniotis */ public class ClusteringMetrics extends AbstractMetrics { + private static final long serialVersionUID = 1L; private double purity = 0.0; private double NMI = 0.0; //Normalized Mutual Information: I(Omega,Gama) calculation diff --git a/datumbox-framework-core/src/main/java/com/datumbox/framework/core/machinelearning/modelselection/metrics/LinearRegressionMetrics.java b/datumbox-framework-core/src/main/java/com/datumbox/framework/core/machinelearning/modelselection/metrics/LinearRegressionMetrics.java index 377318ab..7567d0f7 100755 --- a/datumbox-framework-core/src/main/java/com/datumbox/framework/core/machinelearning/modelselection/metrics/LinearRegressionMetrics.java +++ b/datumbox-framework-core/src/main/java/com/datumbox/framework/core/machinelearning/modelselection/metrics/LinearRegressionMetrics.java @@ -19,7 +19,7 @@ import com.datumbox.framework.common.dataobjects.FlatDataList; import com.datumbox.framework.common.dataobjects.Record; import com.datumbox.framework.common.dataobjects.TypeInference; -import com.datumbox.framework.core.machinelearning.common.abstracts.validators.AbstractMetrics; +import com.datumbox.framework.core.machinelearning.common.abstracts.modelselection.metrics.AbstractMetrics; import com.datumbox.framework.core.statistics.distributions.ContinuousDistributions; import com.datumbox.framework.core.statistics.nonparametrics.onesample.Lilliefors; import com.datumbox.framework.core.statistics.parametrics.onesample.DurbinWatson; @@ -32,6 +32,7 @@ * @author Vasilis Vryniotis */ public class LinearRegressionMetrics extends AbstractMetrics { + private static final long serialVersionUID = 1L; private double RSquare = 0.0; private double RSquareAdjusted = 0.0; diff --git a/datumbox-framework-core/src/main/java/com/datumbox/framework/core/machinelearning/modelselection/metrics/RecommendationMetrics.java b/datumbox-framework-core/src/main/java/com/datumbox/framework/core/machinelearning/modelselection/metrics/RecommendationMetrics.java index b36f2478..7fe2800f 100644 --- a/datumbox-framework-core/src/main/java/com/datumbox/framework/core/machinelearning/modelselection/metrics/RecommendationMetrics.java +++ b/datumbox-framework-core/src/main/java/com/datumbox/framework/core/machinelearning/modelselection/metrics/RecommendationMetrics.java @@ -19,7 +19,7 @@ import com.datumbox.framework.common.dataobjects.Dataframe; import com.datumbox.framework.common.dataobjects.Record; import com.datumbox.framework.common.dataobjects.TypeInference; -import com.datumbox.framework.core.machinelearning.common.abstracts.validators.AbstractMetrics; +import com.datumbox.framework.core.machinelearning.common.abstracts.modelselection.metrics.AbstractMetrics; import java.util.*; @@ -29,6 +29,7 @@ * @author Vasilis Vryniotis */ public class RecommendationMetrics extends AbstractMetrics { + private static final long serialVersionUID = 1L; private double RMSE = 0.0; diff --git a/datumbox-framework-core/src/main/java/com/datumbox/framework/core/machinelearning/modelselection/splitters/TemporaryKFold.java b/datumbox-framework-core/src/main/java/com/datumbox/framework/core/machinelearning/modelselection/validators/KFoldValidator.java similarity index 71% rename from datumbox-framework-core/src/main/java/com/datumbox/framework/core/machinelearning/modelselection/splitters/TemporaryKFold.java rename to datumbox-framework-core/src/main/java/com/datumbox/framework/core/machinelearning/modelselection/validators/KFoldValidator.java index cafbd484..e34ca646 100644 --- a/datumbox-framework-core/src/main/java/com/datumbox/framework/core/machinelearning/modelselection/splitters/TemporaryKFold.java +++ b/datumbox-framework-core/src/main/java/com/datumbox/framework/core/machinelearning/modelselection/validators/KFoldValidator.java @@ -13,45 +13,43 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package com.datumbox.framework.core.machinelearning.modelselection.splitters; +package com.datumbox.framework.core.machinelearning.modelselection.validators; import com.datumbox.framework.common.Configuration; import com.datumbox.framework.common.dataobjects.Dataframe; import com.datumbox.framework.common.dataobjects.FlatDataList; import com.datumbox.framework.common.interfaces.Trainable; import com.datumbox.framework.common.utilities.PHPMethods; +import com.datumbox.framework.core.machinelearning.common.abstracts.AbstractTrainer; import com.datumbox.framework.core.machinelearning.common.abstracts.modelers.AbstractModeler; +import com.datumbox.framework.core.machinelearning.common.abstracts.modelselection.validators.AbstractValidator; +import com.datumbox.framework.core.machinelearning.common.interfaces.TrainingParameters; import com.datumbox.framework.core.machinelearning.common.interfaces.ValidationMetrics; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; import java.util.ArrayList; import java.util.LinkedList; import java.util.List; -public class TemporaryKFold { - - //TODO: remove this temporary class and create a permanent solution, using Splitters. +/** + * Performs K-fold cross validation. + * + * @param + */ +public class KFoldValidator extends AbstractValidator { /** - * The Logger of all Validators. - * We want this to be non-static in order to print the names of the inherited classes. + * Stores the number of folds. */ - protected final Logger logger = LoggerFactory.getLogger(getClass()); - - private static final String DB_INDICATOR="Kfold"; - - private final Class vmClass; private final int k; /** - * The constructor of the Splitter. + * The constructor of the K-Fold cross validator. * * @param vmClass * @param k */ - public TemporaryKFold(Class vmClass, int k) { - this.vmClass = vmClass; + public KFoldValidator(Class vmClass, Configuration conf, int k) { + super(vmClass, conf); this.k = k; } @@ -60,13 +58,11 @@ public TemporaryKFold(Class vmClass, int k) { * of folds and returns the average metrics across all folds. * * @param dataset - * @param dbName - * @param conf - * @param aClass * @param trainingParameters * @return */ - public VM validate(Dataframe dataset, String dbName, Configuration conf, Class aClass, AbstractModeler.AbstractTrainingParameters trainingParameters) { + @Override + public VM validate(Dataframe dataset, TrainingParameters trainingParameters) { int n = dataset.size(); if(k<=0 || n<=k) { throw new IllegalArgumentException("Invalid number of folds."); @@ -84,11 +80,21 @@ public VM validate(Dataframe dataset, String dbName, Configuration conf, Class aClass = null; + try { + //By convertion the training parameters are one level below the algorithm class. This allows us to retrieve the algorithm class from the training parameters. + String className = trainingParameters.getClass().getCanonicalName(); + aClass = (Class) Class.forName(className.substring(0, className.lastIndexOf('.'))); + } + catch (ClassNotFoundException e) { + throw new IllegalArgumentException(e); + } + + //initialize modeler + AbstractModeler modeler = Trainable.newInstance(aClass, "kfold_"+System.nanoTime(), conf); List validationMetricsList = new LinkedList<>(); for(int fold=0;fold