Skip to content

Commit

Permalink
Finalizing the validation mechanism.
Browse files Browse the repository at this point in the history
  • Loading branch information
datumbox committed Dec 18, 2016
1 parent b8da962 commit 14975f4
Show file tree
Hide file tree
Showing 13 changed files with 110 additions and 49 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Expand Up @@ -6,7 +6,7 @@ Version 0.8.0-SNAPSHOT - Build 20161218

- Improved Validation:
- Removed the ValidationMetrics from the Algorithms. Now it is a separate object.
- Removed the kFold validation from Algorithms. Now we offer a splitter mechanism.
- Removed the kFold validation from Algorithms. Now we offer a new validator mechanism.
- A single KnowledgeBase implementation is now used.
- Removed the unnecessary n & d model parameters from all models.

Expand Down
3 changes: 1 addition & 2 deletions TODO.txt
@@ -1,12 +1,11 @@
CODE IMPROVEMENTS
=================

- All ValidationMetrics should hava a serialization number
- Add save() load() methods in the models
- The method kFoldCrossValidation is removed from AbstractValidator. It should be a separate class instead.

- Support of better Transformers (Zscore, decouple boolean transforming from numeric etc).
- Write a ShuffleSplit class similar to KFold.
- Write a ShuffleSplitValidator class similar to KFold.
- Write generic optimizers instead of having optimization methods in the algorithms.
- Support MapDB 3.0 once a stable version is released. Remove the HOTFIX for MapDB bug #664.
- Try to separate persistent storage in a separate module that is inherited by common.
Expand Down
Expand Up @@ -21,7 +21,7 @@
import com.datumbox.framework.common.persistentstorage.interfaces.DatabaseConnector;
import com.datumbox.framework.common.utilities.MapMethods;
import com.datumbox.framework.core.machinelearning.modelselection.metrics.ClassificationMetrics;
import com.datumbox.framework.core.machinelearning.modelselection.splitters.TemporaryKFold;
import com.datumbox.framework.core.machinelearning.modelselection.validators.KFoldValidator;

import java.util.*;

Expand Down Expand Up @@ -113,6 +113,6 @@ public ClassificationMetrics validate(Dataframe testingData) {
public ClassificationMetrics kFoldCrossValidation(Dataframe trainingData, TP trainingParameters, int k) {
logger.info("validate()");

return new TemporaryKFold<>(ClassificationMetrics.class, k).validate(trainingData, dbName, knowledgeBase.getConf(), this.getClass(), trainingParameters);
return new KFoldValidator<>(ClassificationMetrics.class, knowledgeBase.getConf(), k).validate(trainingData, trainingParameters);
}
}
Expand Up @@ -24,7 +24,7 @@
import com.datumbox.framework.common.persistentstorage.interfaces.DatabaseConnector.StorageHint;
import com.datumbox.framework.core.machinelearning.common.interfaces.Cluster;
import com.datumbox.framework.core.machinelearning.modelselection.metrics.ClusteringMetrics;
import com.datumbox.framework.core.machinelearning.modelselection.splitters.TemporaryKFold;
import com.datumbox.framework.core.machinelearning.modelselection.validators.KFoldValidator;

import java.util.*;

Expand Down Expand Up @@ -233,6 +233,6 @@ public ClusteringMetrics validate(Dataframe testingData) {
public ClusteringMetrics kFoldCrossValidation(Dataframe trainingData, TP trainingParameters, int k) {
logger.info("validate()");

return new TemporaryKFold<>(ClusteringMetrics.class, k).validate(trainingData, dbName, knowledgeBase.getConf(), this.getClass(), trainingParameters);
return new KFoldValidator<>(ClusteringMetrics.class, knowledgeBase.getConf(), k).validate(trainingData, trainingParameters);
}
}
Expand Up @@ -18,7 +18,7 @@
import com.datumbox.framework.common.Configuration;
import com.datumbox.framework.common.dataobjects.Dataframe;
import com.datumbox.framework.core.machinelearning.modelselection.metrics.RecommendationMetrics;
import com.datumbox.framework.core.machinelearning.modelselection.splitters.TemporaryKFold;
import com.datumbox.framework.core.machinelearning.modelselection.validators.KFoldValidator;

/**
* Abstract Class for recommender algorithms.
Expand Down Expand Up @@ -52,6 +52,6 @@ public RecommendationMetrics validate(Dataframe testingData) {
public RecommendationMetrics kFoldCrossValidation(Dataframe trainingData, TP trainingParameters, int k) {
logger.info("validate()");

return new TemporaryKFold<>(RecommendationMetrics.class, k).validate(trainingData, dbName, knowledgeBase.getConf(), this.getClass(), trainingParameters);
return new KFoldValidator<>(RecommendationMetrics.class, knowledgeBase.getConf(), k).validate(trainingData, trainingParameters);
}
}
Expand Up @@ -18,7 +18,7 @@
import com.datumbox.framework.common.Configuration;
import com.datumbox.framework.common.dataobjects.Dataframe;
import com.datumbox.framework.core.machinelearning.modelselection.metrics.LinearRegressionMetrics;
import com.datumbox.framework.core.machinelearning.modelselection.splitters.TemporaryKFold;
import com.datumbox.framework.core.machinelearning.modelselection.validators.KFoldValidator;

/**
* Base Class for all the Regression algorithms.
Expand Down Expand Up @@ -54,6 +54,6 @@ public LinearRegressionMetrics validate(Dataframe testingData) {
public LinearRegressionMetrics kFoldCrossValidation(Dataframe trainingData, TP trainingParameters, int k) {
logger.info("validate()");

return new TemporaryKFold<>(LinearRegressionMetrics.class, k).validate(trainingData, dbName, knowledgeBase.getConf(), this.getClass(), trainingParameters);
return new KFoldValidator<>(LinearRegressionMetrics.class, knowledgeBase.getConf(), k).validate(trainingData, trainingParameters);
}
}
Expand Up @@ -13,7 +13,7 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.datumbox.framework.core.machinelearning.common.abstracts.validators;
package com.datumbox.framework.core.machinelearning.common.abstracts.modelselection.metrics;

import com.datumbox.framework.common.dataobjects.Dataframe;
import com.datumbox.framework.core.machinelearning.common.interfaces.ValidationMetrics;
Expand Down
@@ -0,0 +1,61 @@
/**
* Copyright (C) 2013-2016 Vasilis Vryniotis <bbriniotis@datumbox.com>
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.datumbox.framework.core.machinelearning.common.abstracts.modelselection.validators;

import com.datumbox.framework.common.Configuration;
import com.datumbox.framework.common.dataobjects.Dataframe;
import com.datumbox.framework.core.machinelearning.common.interfaces.TrainingParameters;
import com.datumbox.framework.core.machinelearning.common.interfaces.ValidationMetrics;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/**
* Base class for all Validators.
*
* @param <VM>
*/
public abstract class AbstractValidator<VM extends ValidationMetrics> {

/**
* The Logger of all Validators.
* We want this to be non-static in order to print the names of the inherited classes.
*/
protected final Logger logger = LoggerFactory.getLogger(getClass());

protected final Class<VM> vmClass;
protected final Configuration conf;

/**
* Default constructor.
*
* @param vmClass
* @param conf
*/
public AbstractValidator(Class<VM> vmClass, Configuration conf) {
this.vmClass = vmClass;
this.conf = conf;
}

/**
* Performs a split on the data, trains models and performs validation.
*
* @param dataset
* @param trainingParameters
* @return
*/
public abstract VM validate(Dataframe dataset, TrainingParameters trainingParameters);

}
Expand Up @@ -17,7 +17,7 @@

import com.datumbox.framework.common.dataobjects.Dataframe;
import com.datumbox.framework.common.dataobjects.Record;
import com.datumbox.framework.core.machinelearning.common.abstracts.validators.AbstractMetrics;
import com.datumbox.framework.core.machinelearning.common.abstracts.modelselection.metrics.AbstractMetrics;

import java.util.*;

Expand All @@ -27,6 +27,7 @@
* @author Vasilis Vryniotis <bbriniotis@datumbox.com>
*/
public class ClassificationMetrics extends AbstractMetrics {
private static final long serialVersionUID = 1L;

/**
* Enum that stores the 4 possible Sensitivity Rates.
Expand Down
Expand Up @@ -17,7 +17,7 @@

import com.datumbox.framework.common.dataobjects.Dataframe;
import com.datumbox.framework.common.dataobjects.Record;
import com.datumbox.framework.core.machinelearning.common.abstracts.validators.AbstractMetrics;
import com.datumbox.framework.core.machinelearning.common.abstracts.modelselection.metrics.AbstractMetrics;

import java.util.*;

Expand All @@ -31,6 +31,7 @@
* @author Vasilis Vryniotis <bbriniotis@datumbox.com>
*/
public class ClusteringMetrics extends AbstractMetrics {
private static final long serialVersionUID = 1L;

private double purity = 0.0;
private double NMI = 0.0; //Normalized Mutual Information: I(Omega,Gama) calculation
Expand Down
Expand Up @@ -19,7 +19,7 @@
import com.datumbox.framework.common.dataobjects.FlatDataList;
import com.datumbox.framework.common.dataobjects.Record;
import com.datumbox.framework.common.dataobjects.TypeInference;
import com.datumbox.framework.core.machinelearning.common.abstracts.validators.AbstractMetrics;
import com.datumbox.framework.core.machinelearning.common.abstracts.modelselection.metrics.AbstractMetrics;
import com.datumbox.framework.core.statistics.distributions.ContinuousDistributions;
import com.datumbox.framework.core.statistics.nonparametrics.onesample.Lilliefors;
import com.datumbox.framework.core.statistics.parametrics.onesample.DurbinWatson;
Expand All @@ -32,6 +32,7 @@
* @author Vasilis Vryniotis <bbriniotis@datumbox.com>
*/
public class LinearRegressionMetrics extends AbstractMetrics {
private static final long serialVersionUID = 1L;

private double RSquare = 0.0;
private double RSquareAdjusted = 0.0;
Expand Down
Expand Up @@ -19,7 +19,7 @@
import com.datumbox.framework.common.dataobjects.Dataframe;
import com.datumbox.framework.common.dataobjects.Record;
import com.datumbox.framework.common.dataobjects.TypeInference;
import com.datumbox.framework.core.machinelearning.common.abstracts.validators.AbstractMetrics;
import com.datumbox.framework.core.machinelearning.common.abstracts.modelselection.metrics.AbstractMetrics;

import java.util.*;

Expand All @@ -29,6 +29,7 @@
* @author Vasilis Vryniotis <bbriniotis@datumbox.com>
*/
public class RecommendationMetrics extends AbstractMetrics {
private static final long serialVersionUID = 1L;

private double RMSE = 0.0;

Expand Down
Expand Up @@ -13,45 +13,43 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.datumbox.framework.core.machinelearning.modelselection.splitters;
package com.datumbox.framework.core.machinelearning.modelselection.validators;

import com.datumbox.framework.common.Configuration;
import com.datumbox.framework.common.dataobjects.Dataframe;
import com.datumbox.framework.common.dataobjects.FlatDataList;
import com.datumbox.framework.common.interfaces.Trainable;
import com.datumbox.framework.common.utilities.PHPMethods;
import com.datumbox.framework.core.machinelearning.common.abstracts.AbstractTrainer;
import com.datumbox.framework.core.machinelearning.common.abstracts.modelers.AbstractModeler;
import com.datumbox.framework.core.machinelearning.common.abstracts.modelselection.validators.AbstractValidator;
import com.datumbox.framework.core.machinelearning.common.interfaces.TrainingParameters;
import com.datumbox.framework.core.machinelearning.common.interfaces.ValidationMetrics;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.util.ArrayList;
import java.util.LinkedList;
import java.util.List;

public class TemporaryKFold<VM extends ValidationMetrics> {

//TODO: remove this temporary class and create a permanent solution, using Splitters.
/**
* Performs K-fold cross validation.
*
* @param <VM>
*/
public class KFoldValidator<VM extends ValidationMetrics> extends AbstractValidator<VM> {

/**
* The Logger of all Validators.
* We want this to be non-static in order to print the names of the inherited classes.
* Stores the number of folds.
*/
protected final Logger logger = LoggerFactory.getLogger(getClass());

private static final String DB_INDICATOR="Kfold";

private final Class<VM> vmClass;
private final int k;

/**
* The constructor of the Splitter.
* The constructor of the K-Fold cross validator.
*
* @param vmClass
* @param k
*/
public TemporaryKFold(Class<VM> vmClass, int k) {
this.vmClass = vmClass;
public KFoldValidator(Class<VM> vmClass, Configuration conf, int k) {
super(vmClass, conf);
this.k = k;
}

Expand All @@ -60,13 +58,11 @@ public TemporaryKFold(Class<VM> vmClass, int k) {
* of folds and returns the average metrics across all folds.
*
* @param dataset
* @param dbName
* @param conf
* @param aClass
* @param trainingParameters
* @return
*/
public VM validate(Dataframe dataset, String dbName, Configuration conf, Class<? extends AbstractModeler> aClass, AbstractModeler.AbstractTrainingParameters trainingParameters) {
@Override
public VM validate(Dataframe dataset, TrainingParameters trainingParameters) {
int n = dataset.size();
if(k<=0 || n<=k) {
throw new IllegalArgumentException("Invalid number of folds.");
Expand All @@ -84,11 +80,21 @@ public VM validate(Dataframe dataset, String dbName, Configuration conf, Class<?
}
PHPMethods.shuffle(ids);

String foldDBname=dbName+conf.getDbConfig().getDBnameSeparator()+DB_INDICATOR;
Class<? extends AbstractModeler> aClass = null;
try {
//By convertion the training parameters are one level below the algorithm class. This allows us to retrieve the algorithm class from the training parameters.
String className = trainingParameters.getClass().getCanonicalName();
aClass = (Class<? extends AbstractModeler>) Class.forName(className.substring(0, className.lastIndexOf('.')));
}
catch (ClassNotFoundException e) {
throw new IllegalArgumentException(e);
}

//initialize modeler
AbstractModeler modeler = Trainable.newInstance(aClass, "kfold_"+System.nanoTime(), conf);

List<VM> validationMetricsList = new LinkedList<>();
for(int fold=0;fold<k;++fold) {

logger.info("Kfold {}", fold);

//as fold window we consider the part of the ids that are used for validation
Expand Down Expand Up @@ -119,14 +125,9 @@ public VM validate(Dataframe dataset, String dbName, Configuration conf, Class<?
}


//initialize modeler
AbstractModeler modeler = Trainable.newInstance(aClass, foldDBname+(fold+1), conf);


Dataframe trainingData = dataset.getSubset(foldTrainingIds);
modeler.fit(trainingData, trainingParameters);
modeler.fit(trainingData, (AbstractTrainer.AbstractTrainingParameters) trainingParameters);
trainingData.delete();
//trainingData = null;


Dataframe validationData = dataset.getSubset(foldValidationIds);
Expand All @@ -136,15 +137,11 @@ public VM validate(Dataframe dataset, String dbName, Configuration conf, Class<?

VM entrySample = ValidationMetrics.newInstance(vmClass, validationData);
validationData.delete();
//validationData = null;

//delete algorithm
modeler.delete();
//modeler = null;

//add the validationMetrics in the list
validationMetricsList.add(entrySample);
}
modeler.delete();

VM avgValidationMetrics = ValidationMetrics.newInstance(vmClass, validationMetricsList);

Expand Down

0 comments on commit 14975f4

Please sign in to comment.