Skip to content

Commit

Permalink
Rewriting the logic of validators and validation metrics.
Browse files Browse the repository at this point in the history
  • Loading branch information
datumbox committed Dec 18, 2016
1 parent 5e74399 commit 5c4f6dd
Show file tree
Hide file tree
Showing 42 changed files with 966 additions and 1,260 deletions.
1 change: 1 addition & 0 deletions TODO.txt
@@ -1,6 +1,7 @@
CODE IMPROVEMENTS CODE IMPROVEMENTS
================= =================


- All ValidationMetrics should hava a serialization number
- Validation (Validators, Validation Metrics, KnowledgeBase etc) and metrics need to move out of the model - Validation (Validators, Validation Metrics, KnowledgeBase etc) and metrics need to move out of the model
- Add save() load() methods in the models - Add save() load() methods in the models
- Support of better Transformers (Zscore, decouple boolean transforming from numeric etc). - Support of better Transformers (Zscore, decouple boolean transforming from numeric etc).
Expand Down
Expand Up @@ -24,7 +24,6 @@
import com.datumbox.framework.core.machinelearning.common.abstracts.featureselectors.AbstractFeatureSelector; import com.datumbox.framework.core.machinelearning.common.abstracts.featureselectors.AbstractFeatureSelector;
import com.datumbox.framework.core.machinelearning.common.abstracts.modelers.AbstractModeler; import com.datumbox.framework.core.machinelearning.common.abstracts.modelers.AbstractModeler;
import com.datumbox.framework.core.machinelearning.common.abstracts.wrappers.AbstractWrapper; import com.datumbox.framework.core.machinelearning.common.abstracts.wrappers.AbstractWrapper;
import com.datumbox.framework.core.machinelearning.common.interfaces.ValidationMetrics;


/** /**
* Modeler is a convenience class which can be used to train Machine Learning * Modeler is a convenience class which can be used to train Machine Learning
Expand Down
Expand Up @@ -28,15 +28,14 @@
import com.datumbox.framework.core.machinelearning.common.abstracts.featureselectors.AbstractFeatureSelector; import com.datumbox.framework.core.machinelearning.common.abstracts.featureselectors.AbstractFeatureSelector;
import com.datumbox.framework.core.machinelearning.common.abstracts.modelers.AbstractModeler; import com.datumbox.framework.core.machinelearning.common.abstracts.modelers.AbstractModeler;
import com.datumbox.framework.core.machinelearning.common.abstracts.wrappers.AbstractWrapper; import com.datumbox.framework.core.machinelearning.common.abstracts.wrappers.AbstractWrapper;
import com.datumbox.framework.core.machinelearning.common.interfaces.ValidationMetrics;
import com.datumbox.framework.core.utilities.text.extractors.AbstractTextExtractor; import com.datumbox.framework.core.utilities.text.extractors.AbstractTextExtractor;


import java.net.URI; import java.net.URI;
import java.util.HashMap; import java.util.HashMap;
import java.util.Map; import java.util.Map;


/** /**
* TextClassifier is a convenience class which can be used to train Text Classification * TextClassifier is a convenience class which can be used to train Text ClassificationMetrics
* models. It is a wrapper class which automatically takes care of the text parsing, * models. It is a wrapper class which automatically takes care of the text parsing,
tokenization, feature selection and modeler training processes. It takes as input tokenization, feature selection and modeler training processes. It takes as input
either a Dataframe object or multiple text files (one for each category) with either a Dataframe object or multiple text files (one for each category) with
Expand Down
Expand Up @@ -77,7 +77,7 @@ public void testTrainAndValidate() {


/* /*
//TODO: restore this test //TODO: restore this test
ClassifierValidator.ValidationMetrics vm = instance.validate(trainingData); ClassificationMetrics.Metrics vm = instance.validate(trainingData);
double expResult2 = 0.8; double expResult2 = 0.8;
Assert.assertEquals(expResult2, vm.getMacroF1(), Constants.DOUBLE_ACCURACY_HIGH); Assert.assertEquals(expResult2, vm.getMacroF1(), Constants.DOUBLE_ACCURACY_HIGH);
Expand Down
Expand Up @@ -21,13 +21,10 @@
import com.datumbox.framework.core.machinelearning.classification.*; import com.datumbox.framework.core.machinelearning.classification.*;
import com.datumbox.framework.core.machinelearning.common.abstracts.featureselectors.AbstractFeatureSelector; import com.datumbox.framework.core.machinelearning.common.abstracts.featureselectors.AbstractFeatureSelector;
import com.datumbox.framework.core.machinelearning.common.abstracts.modelers.AbstractClassifier; import com.datumbox.framework.core.machinelearning.common.abstracts.modelers.AbstractClassifier;
import com.datumbox.framework.core.machinelearning.common.interfaces.ValidationMetrics;
import com.datumbox.framework.core.machinelearning.featureselection.categorical.ChisquareSelect; import com.datumbox.framework.core.machinelearning.featureselection.categorical.ChisquareSelect;
import com.datumbox.framework.core.machinelearning.featureselection.categorical.MutualInformation; import com.datumbox.framework.core.machinelearning.featureselection.categorical.MutualInformation;
import com.datumbox.framework.core.machinelearning.featureselection.scorebased.TFIDF; import com.datumbox.framework.core.machinelearning.featureselection.scorebased.TFIDF;
import com.datumbox.framework.core.machinelearning.validators.ClassifierValidator;
import com.datumbox.framework.core.utilities.text.extractors.NgramsExtractor; import com.datumbox.framework.core.utilities.text.extractors.NgramsExtractor;
import com.datumbox.framework.tests.Constants;
import com.datumbox.framework.tests.abstracts.AbstractTest; import com.datumbox.framework.tests.abstracts.AbstractTest;
import org.junit.Test; import org.junit.Test;


Expand Down Expand Up @@ -317,7 +314,7 @@ private <ML extends AbstractClassifier, FS extends AbstractFeatureSelector> void


/* /*
//TODO: restore this test //TODO: restore this test
ClassifierValidator.ValidationMetrics vm = instance.validate(dataset); ClassificationMetrics.Metrics vm = instance.validate(dataset);
assertEquals(expectedF1score, vm.getMacroF1(), Constants.DOUBLE_ACCURACY_HIGH); assertEquals(expectedF1score, vm.getMacroF1(), Constants.DOUBLE_ACCURACY_HIGH);
*/ */
instance.close(); instance.close();
Expand Down
Expand Up @@ -30,7 +30,6 @@
import com.datumbox.framework.core.machinelearning.common.abstracts.modelers.AbstractClassifier; import com.datumbox.framework.core.machinelearning.common.abstracts.modelers.AbstractClassifier;
import com.datumbox.framework.core.machinelearning.common.interfaces.PredictParallelizable; import com.datumbox.framework.core.machinelearning.common.interfaces.PredictParallelizable;
import com.datumbox.framework.core.machinelearning.common.interfaces.TrainParallelizable; import com.datumbox.framework.core.machinelearning.common.interfaces.TrainParallelizable;
import com.datumbox.framework.core.machinelearning.validators.ClassifierValidator;
import com.datumbox.framework.core.statistics.descriptivestatistics.Descriptives; import com.datumbox.framework.core.statistics.descriptivestatistics.Descriptives;


import java.util.Arrays; import java.util.Arrays;
Expand Down
Expand Up @@ -29,7 +29,6 @@
import com.datumbox.framework.core.machinelearning.common.abstracts.modelers.AbstractClusterer; import com.datumbox.framework.core.machinelearning.common.abstracts.modelers.AbstractClusterer;
import com.datumbox.framework.core.machinelearning.common.interfaces.PredictParallelizable; import com.datumbox.framework.core.machinelearning.common.interfaces.PredictParallelizable;
import com.datumbox.framework.core.machinelearning.common.interfaces.TrainParallelizable; import com.datumbox.framework.core.machinelearning.common.interfaces.TrainParallelizable;
import com.datumbox.framework.core.machinelearning.validators.ClustererValidator;
import com.datumbox.framework.core.mathematics.distances.Distance; import com.datumbox.framework.core.mathematics.distances.Distance;
import com.datumbox.framework.core.statistics.descriptivestatistics.Descriptives; import com.datumbox.framework.core.statistics.descriptivestatistics.Descriptives;


Expand Down
Expand Up @@ -24,7 +24,6 @@
import com.datumbox.framework.common.utilities.MapMethods; import com.datumbox.framework.common.utilities.MapMethods;
import com.datumbox.framework.core.machinelearning.common.abstracts.AbstractTrainer; import com.datumbox.framework.core.machinelearning.common.abstracts.AbstractTrainer;
import com.datumbox.framework.core.machinelearning.common.abstracts.modelers.AbstractClassifier; import com.datumbox.framework.core.machinelearning.common.abstracts.modelers.AbstractClassifier;
import com.datumbox.framework.core.machinelearning.validators.ClassifierValidator;
import com.datumbox.framework.core.machinelearning.ensemblelearning.FixedCombinationRules; import com.datumbox.framework.core.machinelearning.ensemblelearning.FixedCombinationRules;
import com.datumbox.framework.core.statistics.descriptivestatistics.Descriptives; import com.datumbox.framework.core.statistics.descriptivestatistics.Descriptives;
import com.datumbox.framework.core.statistics.sampling.SimpleRandomSampling; import com.datumbox.framework.core.statistics.sampling.SimpleRandomSampling;
Expand Down
Expand Up @@ -30,7 +30,6 @@
import com.datumbox.framework.core.machinelearning.common.abstracts.modelers.AbstractClassifier; import com.datumbox.framework.core.machinelearning.common.abstracts.modelers.AbstractClassifier;
import com.datumbox.framework.core.machinelearning.common.interfaces.PredictParallelizable; import com.datumbox.framework.core.machinelearning.common.interfaces.PredictParallelizable;
import com.datumbox.framework.core.machinelearning.common.interfaces.TrainParallelizable; import com.datumbox.framework.core.machinelearning.common.interfaces.TrainParallelizable;
import com.datumbox.framework.core.machinelearning.validators.ClassifierValidator;
import com.datumbox.framework.core.statistics.descriptivestatistics.Descriptives; import com.datumbox.framework.core.statistics.descriptivestatistics.Descriptives;


import java.util.*; import java.util.*;
Expand Down
Expand Up @@ -20,9 +20,8 @@
import com.datumbox.framework.common.dataobjects.Dataframe; import com.datumbox.framework.common.dataobjects.Dataframe;
import com.datumbox.framework.common.persistentstorage.interfaces.DatabaseConnector; import com.datumbox.framework.common.persistentstorage.interfaces.DatabaseConnector;
import com.datumbox.framework.common.utilities.MapMethods; import com.datumbox.framework.common.utilities.MapMethods;
import com.datumbox.framework.core.machinelearning.common.abstracts.validators.AbstractValidator; import com.datumbox.framework.core.machinelearning.modelselection.metrics.ClassificationMetrics;
import com.datumbox.framework.core.machinelearning.validators.ClassifierValidator; import com.datumbox.framework.core.machinelearning.modelselection.splitters.TemporaryKFold;
import com.datumbox.framework.core.machinelearning.validators.TemporaryKFold;


import java.util.*; import java.util.*;


Expand Down Expand Up @@ -103,17 +102,17 @@ protected Object getSelectedClassFromClassScores(AssociativeArray predictionScor




//TODO: remove this once we create the save/load //TODO: remove this once we create the save/load
public ClassifierValidator.ValidationMetrics validate(Dataframe testingData) { public ClassificationMetrics validate(Dataframe testingData) {
logger.info("validate()"); logger.info("validate()");


predict(testingData); predict(testingData);


return new ClassifierValidator().validate(testingData); return new ClassificationMetrics(testingData);
} }
//TODO: remove this once we create the save/load //TODO: remove this once we create the save/load
public ClassifierValidator.ValidationMetrics kFoldCrossValidation(Dataframe trainingData, TP trainingParameters, int k) { public ClassificationMetrics kFoldCrossValidation(Dataframe trainingData, TP trainingParameters, int k) {
logger.info("kFoldCrossValidation()"); logger.info("validate()");


return new TemporaryKFold<>(new ClassifierValidator()).kFoldCrossValidation(trainingData, k, dbName, knowledgeBase.getConf(), this.getClass(), trainingParameters); return new TemporaryKFold<>(ClassificationMetrics.class).validate(trainingData, k, dbName, knowledgeBase.getConf(), this.getClass(), trainingParameters);
} }
} }
Expand Up @@ -23,8 +23,8 @@
import com.datumbox.framework.common.persistentstorage.interfaces.DatabaseConnector.MapType; import com.datumbox.framework.common.persistentstorage.interfaces.DatabaseConnector.MapType;
import com.datumbox.framework.common.persistentstorage.interfaces.DatabaseConnector.StorageHint; import com.datumbox.framework.common.persistentstorage.interfaces.DatabaseConnector.StorageHint;
import com.datumbox.framework.core.machinelearning.common.interfaces.Cluster; import com.datumbox.framework.core.machinelearning.common.interfaces.Cluster;
import com.datumbox.framework.core.machinelearning.validators.ClustererValidator; import com.datumbox.framework.core.machinelearning.modelselection.metrics.ClusteringMetrics;
import com.datumbox.framework.core.machinelearning.validators.TemporaryKFold; import com.datumbox.framework.core.machinelearning.modelselection.splitters.TemporaryKFold;


import java.util.*; import java.util.*;


Expand Down Expand Up @@ -222,17 +222,17 @@ public Map<Integer, CL> getClusters() {
} }


//TODO: remove this once we create the save/load //TODO: remove this once we create the save/load
public ClustererValidator.ValidationMetrics validate(Dataframe testingData) { public ClusteringMetrics validate(Dataframe testingData) {
logger.info("validate()"); logger.info("validate()");


predict(testingData); predict(testingData);


return new ClustererValidator().validate(testingData); return new ClusteringMetrics(testingData);
} }
//TODO: remove this once we create the save/load //TODO: remove this once we create the save/load
public ClustererValidator.ValidationMetrics kFoldCrossValidation(Dataframe trainingData, TP trainingParameters, int k) { public ClusteringMetrics kFoldCrossValidation(Dataframe trainingData, TP trainingParameters, int k) {
logger.info("kFoldCrossValidation()"); logger.info("validate()");


return new TemporaryKFold<>(new ClustererValidator()).kFoldCrossValidation(trainingData, k, dbName, knowledgeBase.getConf(), this.getClass(), trainingParameters); return new TemporaryKFold<>(ClusteringMetrics.class).validate(trainingData, k, dbName, knowledgeBase.getConf(), this.getClass(), trainingParameters);
} }
} }
Expand Up @@ -17,8 +17,8 @@


import com.datumbox.framework.common.Configuration; import com.datumbox.framework.common.Configuration;
import com.datumbox.framework.common.dataobjects.Dataframe; import com.datumbox.framework.common.dataobjects.Dataframe;
import com.datumbox.framework.core.machinelearning.validators.RMSEValidator; import com.datumbox.framework.core.machinelearning.modelselection.metrics.RecommendationMetrics;
import com.datumbox.framework.core.machinelearning.validators.TemporaryKFold; import com.datumbox.framework.core.machinelearning.modelselection.splitters.TemporaryKFold;


/** /**
* Abstract Class for recommender algorithms. * Abstract Class for recommender algorithms.
Expand All @@ -41,17 +41,17 @@ protected AbstractRecommender(String dbName, Configuration conf, Class<MP> mpCla
} }


//TODO: remove this once we create the save/load //TODO: remove this once we create the save/load
public RMSEValidator.ValidationMetrics validate(Dataframe testingData) { public RecommendationMetrics validate(Dataframe testingData) {
logger.info("validate()"); logger.info("validate()");


predict(testingData); predict(testingData);


return new RMSEValidator().validate(testingData); return new RecommendationMetrics(testingData);
} }
//TODO: remove this once we create the save/load //TODO: remove this once we create the save/load
public RMSEValidator.ValidationMetrics kFoldCrossValidation(Dataframe trainingData, TP trainingParameters, int k) { public RecommendationMetrics kFoldCrossValidation(Dataframe trainingData, TP trainingParameters, int k) {
logger.info("kFoldCrossValidation()"); logger.info("validate()");


return new TemporaryKFold<>(new RMSEValidator()).kFoldCrossValidation(trainingData, k, dbName, knowledgeBase.getConf(), this.getClass(), trainingParameters); return new TemporaryKFold<>(RecommendationMetrics.class).validate(trainingData, k, dbName, knowledgeBase.getConf(), this.getClass(), trainingParameters);
} }
} }
Expand Up @@ -17,8 +17,8 @@


import com.datumbox.framework.common.Configuration; import com.datumbox.framework.common.Configuration;
import com.datumbox.framework.common.dataobjects.Dataframe; import com.datumbox.framework.common.dataobjects.Dataframe;
import com.datumbox.framework.core.machinelearning.validators.LinearRegressionValidator; import com.datumbox.framework.core.machinelearning.modelselection.metrics.LinearRegressionMetrics;
import com.datumbox.framework.core.machinelearning.validators.TemporaryKFold; import com.datumbox.framework.core.machinelearning.modelselection.splitters.TemporaryKFold;


/** /**
* Base Class for all the Regression algorithms. * Base Class for all the Regression algorithms.
Expand All @@ -41,19 +41,19 @@ protected AbstractRegressor(String dbName, Configuration conf, Class<MP> mpClass
} }


//TODO: remove this once we create the save/load //TODO: remove this once we create the save/load
public LinearRegressionValidator.ValidationMetrics validate(Dataframe testingData) { public LinearRegressionMetrics validate(Dataframe testingData) {
logger.info("validate()"); logger.info("validate()");


knowledgeBase.load(); knowledgeBase.load();


predict(testingData); predict(testingData);


return new LinearRegressionValidator().validate(testingData); return new LinearRegressionMetrics(testingData);
} }
//TODO: remove this once we create the save/load //TODO: remove this once we create the save/load
public LinearRegressionValidator.ValidationMetrics kFoldCrossValidation(Dataframe trainingData, TP trainingParameters, int k) { public LinearRegressionMetrics kFoldCrossValidation(Dataframe trainingData, TP trainingParameters, int k) {
logger.info("kFoldCrossValidation()"); logger.info("validate()");


return new TemporaryKFold<>(new LinearRegressionValidator()).kFoldCrossValidation(trainingData, k, dbName, knowledgeBase.getConf(), this.getClass(), trainingParameters); return new TemporaryKFold<>(LinearRegressionMetrics.class).validate(trainingData, k, dbName, knowledgeBase.getConf(), this.getClass(), trainingParameters);
} }
} }
Expand Up @@ -17,50 +17,33 @@


import com.datumbox.framework.common.dataobjects.Dataframe; import com.datumbox.framework.common.dataobjects.Dataframe;
import com.datumbox.framework.core.machinelearning.common.interfaces.ValidationMetrics; import com.datumbox.framework.core.machinelearning.common.interfaces.ValidationMetrics;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;


import java.util.List; import java.util.List;


/** /**
* The AbstractValidator class is an abstract class responsible for the K-fold Cross * The AbstractMetrics class stores and estimates information about the performance of the algorithm.
Validation and for the estimation of the average validation metrics. Given that
* different models use different validation metrics, each model family implements
* its own validator.
* *
* @author Vasilis Vryniotis <bbriniotis@datumbox.com> * @author Vasilis Vryniotis <bbriniotis@datumbox.com>
* @param <VM>
*/ */
public abstract class AbstractValidator<VM extends AbstractValidator.AbstractValidationMetrics> { public abstract class AbstractMetrics implements ValidationMetrics {

/**
* The Logger of all Validators.
* We want this to be non-static in order to print the names of the inherited classes.
*/
protected final Logger logger = LoggerFactory.getLogger(getClass());

/**
* The AbstractValidationMetrics class stores information about the performance of the algorithm.
*/
public static abstract class AbstractValidationMetrics implements ValidationMetrics {

}


/** /**
* Estimates the validation metrics on the predicted data. * Estimates the validation metrics on the predicted data.
* *
* @param predictedData * @param predictedData
* @return
*/ */
public abstract VM validate(Dataframe predictedData); protected AbstractMetrics(Dataframe predictedData) {

}


/** /**
* Calculates the average validation metrics by combining the results of the * Calculates the average validation metrics by combining the results of the
* provided list. * provided list.
* *
* @param validationMetricsList * @param validationMetricsList
* @return
*/ */
public abstract VM average(List<VM> validationMetricsList); protected AbstractMetrics(List<? extends AbstractMetrics> validationMetricsList) {

}


} }
Expand Up @@ -15,29 +15,66 @@
*/ */
package com.datumbox.framework.core.machinelearning.common.interfaces; package com.datumbox.framework.core.machinelearning.common.interfaces;


import com.datumbox.framework.common.dataobjects.Dataframe;

import java.io.Serializable; import java.io.Serializable;
import java.lang.reflect.InvocationTargetException; import java.lang.reflect.InvocationTargetException;
import java.util.List;


/** /**
* Interface for every ValidationMetrics class in the framework. * Interface for every Metrics class in the framework.
* *
* @author Vasilis Vryniotis <bbriniotis@datumbox.com> * @author Vasilis Vryniotis <bbriniotis@datumbox.com>
*/ */


public interface ValidationMetrics extends Serializable { public interface ValidationMetrics extends Serializable {


/** /**
* This method allows us to create a new empty Validation Metrics object * Creates a new empty Validation Metrics object.
* from an existing object. Casting to the appropriate type is required. *
* * @return
* @return */
public static <VM extends ValidationMetrics> VM newInstance(Class<VM> vmClass) {
try {
return vmClass.getConstructor().newInstance();
}
catch (InstantiationException | IllegalAccessException | IllegalArgumentException | InvocationTargetException | NoSuchMethodException | SecurityException ex) {
throw new RuntimeException(ex);
}
}

/**
* Estimates the Validation Metrics object from predictions.
*
* @param vmClass
* @param predictedData
* @param <VM>
* @return
*/ */
default public ValidationMetrics getEmptyObject() { public static <VM extends ValidationMetrics> VM newInstance(Class<VM> vmClass, Dataframe predictedData) {
try { try {
return this.getClass().getConstructor().newInstance(); return vmClass.getConstructor(Dataframe.class).newInstance(predictedData);
} }
catch (InstantiationException | IllegalAccessException | IllegalArgumentException | InvocationTargetException | NoSuchMethodException | SecurityException ex) { catch (InstantiationException | IllegalAccessException | IllegalArgumentException | InvocationTargetException | NoSuchMethodException | SecurityException ex) {
throw new RuntimeException(ex); throw new RuntimeException(ex);
} }
} }

/**
* Estimates the average Validation Metrics object from a list of metrics.
*
* @param vmClass
* @param validationMetricsList
* @param <VM>
* @return
*/
public static <VM extends ValidationMetrics> VM newInstance(Class<VM> vmClass, List<VM> validationMetricsList) {
try {
return vmClass.getConstructor(List.class).newInstance(validationMetricsList);
}
catch (InstantiationException | IllegalAccessException | IllegalArgumentException | InvocationTargetException | NoSuchMethodException | SecurityException ex) {
throw new RuntimeException(ex);
}
}

} }

0 comments on commit 5c4f6dd

Please sign in to comment.