Skip to content

Commit

Permalink
Adding support for ElasticNet regularization.
Browse files Browse the repository at this point in the history
  • Loading branch information
datumbox committed Mar 15, 2016
1 parent 8939b68 commit a8b1093
Show file tree
Hide file tree
Showing 9 changed files with 668 additions and 24 deletions.
4 changes: 2 additions & 2 deletions CHANGELOG.md
@@ -1,7 +1,7 @@
CHANGELOG
=========

Version 0.7.0-SNAPSHOT - Build 20160314
Version 0.7.0-SNAPSHOT - Build 20160315
---------------------------------------

- Rename the erase() method to delete() in all interfaces.
Expand Down Expand Up @@ -85,7 +85,7 @@ Version 0.7.0-SNAPSHOT - Build 20160314
- Removed the ClasspathSuite from the tests.
- Created the Extractable which is inherited by AbstractTextExtractor.
- Convert project into multimodule project.
- Added support for L1 and L2 regularization.
- Added support for L1, L2 and ElasticNet regularization.
- Rename library to datumbox-framework-lib.
- Change the structure and config files of persistent storage.

Expand Down
2 changes: 1 addition & 1 deletion README.md
Expand Up @@ -15,7 +15,7 @@ The code is licensed under the [Apache License, Version 2.0](https://github.com/
Version
-------

The latest version is 0.7.0-SNAPSHOT (Build 20160314).
The latest version is 0.7.0-SNAPSHOT (Build 20160315).

The [master branch](https://github.com/datumbox/datumbox-framework/tree/master) is the latest stable version of the framework. The [devel branch](https://github.com/datumbox/datumbox-framework/tree/devel) is the development branch. All the previous stable versions are marked with [tags](https://github.com/datumbox/datumbox-framework/releases).

Expand Down
Expand Up @@ -143,6 +143,43 @@ public static Dataframe[] carsCategorical(Configuration conf) {

return new Dataframe[] {trainingData, validationData};
}

/**
* Housing numerical Dataframe.
*
* @param conf
* @return
*/
public static Dataframe[] housingNumerical(Configuration conf) {
//Data from https://archive.ics.uci.edu/ml/machine-learning-databases/housing/housing.names
Dataframe trainingData;
try (Reader fileReader = new InputStreamReader(Datasets.class.getClassLoader().getResourceAsStream("datasets/housing.csv"), "UTF-8")) {
LinkedHashMap<String, TypeInference.DataType> headerDataTypes = new LinkedHashMap<>();
headerDataTypes.put("CRIM", TypeInference.DataType.NUMERICAL);
headerDataTypes.put("ZN", TypeInference.DataType.NUMERICAL);
headerDataTypes.put("INDUS", TypeInference.DataType.NUMERICAL);
headerDataTypes.put("CHAS", TypeInference.DataType.BOOLEAN);
headerDataTypes.put("NOX", TypeInference.DataType.NUMERICAL);
headerDataTypes.put("RM", TypeInference.DataType.NUMERICAL);
headerDataTypes.put("AGE", TypeInference.DataType.NUMERICAL);
headerDataTypes.put("DIS", TypeInference.DataType.NUMERICAL);
headerDataTypes.put("RAD", TypeInference.DataType.ORDINAL);
headerDataTypes.put("TAX", TypeInference.DataType.NUMERICAL);
headerDataTypes.put("PTRATIO", TypeInference.DataType.NUMERICAL);
headerDataTypes.put("B", TypeInference.DataType.NUMERICAL);
headerDataTypes.put("LSTAT", TypeInference.DataType.NUMERICAL);
headerDataTypes.put("MEDV", TypeInference.DataType.NUMERICAL);

trainingData = Dataframe.Builder.parseCSVFile(fileReader, "MEDV", headerDataTypes, ',', '"', "\r\n", null, null, conf);
}
catch(IOException ex) {
throw new UncheckedIOException(ex);
}

Dataframe validationData = trainingData.copy();

return new Dataframe[] {trainingData, validationData};
}

/**
* Wines Ordinal Dataframe.
Expand Down
453 changes: 453 additions & 0 deletions datumbox-framework-common/src/test/resources/datasets/housing.csv

Large diffs are not rendered by default.

Expand Up @@ -32,6 +32,7 @@
import com.datumbox.framework.core.machinelearning.common.interfaces.TrainParallelizable;
import com.datumbox.framework.core.machinelearning.common.validators.SoftMaxRegressionValidator;
import com.datumbox.framework.core.statistics.descriptivestatistics.Descriptives;
import com.datumbox.framework.core.utilities.regularization.ElasticNetRegularizer;
import com.datumbox.framework.core.utilities.regularization.L1Regularizer;
import com.datumbox.framework.core.utilities.regularization.L2Regularizer;

Expand Down Expand Up @@ -339,6 +340,13 @@ protected void _fit(Dataframe trainingData) {
//Drop the temporary Collection
dbc.dropBigMap("tmp_newThitas", tmp_newThitas);
}

//apply the weight correction if we use ElasticNet Regularization
double l1 = kb().getTrainingParameters().getL1();
double l2 = kb().getTrainingParameters().getL2();
if(l1>0.0 && l2>0.0) {
ElasticNetRegularizer.weightCorrection(l2, thitas);
}
}

/** {@inheritDoc} */
Expand Down Expand Up @@ -396,9 +404,18 @@ private void batchGradientDescent(Dataframe trainingData, Map<List<Object>, Doub
}
});

L1Regularizer.updateWeights(kb().getTrainingParameters().getL1(), learningRate, thitas, newThitas);
double l1 = kb().getTrainingParameters().getL1();
double l2 = kb().getTrainingParameters().getL2();

L2Regularizer.updateWeights(kb().getTrainingParameters().getL2(), learningRate, thitas, newThitas);
if(l1>0.0 && l2>0.0) {
ElasticNetRegularizer.updateWeights(l1, l2, learningRate, thitas, newThitas);
}
else if(l1>0.0) {
L1Regularizer.updateWeights(l1, learningRate, thitas, newThitas);
}
else if(l2>0.0) {
L2Regularizer.updateWeights(l2, learningRate, thitas, newThitas);
}

}

Expand Down Expand Up @@ -432,9 +449,18 @@ private double calculateError(Dataframe trainingData, Map<List<Object>, Double>

error = -error/kb().getModelParameters().getN();

error += L1Regularizer.estimatePenalty(kb().getTrainingParameters().getL1(), thitas);
double l1 = kb().getTrainingParameters().getL1();
double l2 = kb().getTrainingParameters().getL2();

error += L2Regularizer.estimatePenalty(kb().getTrainingParameters().getL2(), thitas);
if(l1>0.0 && l2>0.0) {
error += ElasticNetRegularizer.estimatePenalty(l1, l2, thitas);
}
else if(l1>0.0) {
error += L1Regularizer.estimatePenalty(l1, thitas);
}
else if(l2>0.0) {
error += L2Regularizer.estimatePenalty(l2, thitas);
}

return error;
}
Expand Down
Expand Up @@ -29,6 +29,7 @@
import com.datumbox.framework.core.machinelearning.common.abstracts.algorithms.AbstractLinearRegression;
import com.datumbox.framework.core.machinelearning.common.interfaces.PredictParallelizable;
import com.datumbox.framework.core.machinelearning.common.interfaces.TrainParallelizable;
import com.datumbox.framework.core.utilities.regularization.ElasticNetRegularizer;
import com.datumbox.framework.core.utilities.regularization.L1Regularizer;
import com.datumbox.framework.core.utilities.regularization.L2Regularizer;

Expand Down Expand Up @@ -232,7 +233,6 @@ protected void _fit(Dataframe trainingData) {
tmp_newThitas.putAll(thitas);

batchGradientDescent(trainingData, tmp_newThitas, learningRate);
//stochasticGradientDescent(trainingData, newThitas, learningRate);

double newError = calculateError(trainingData,tmp_newThitas);

Expand All @@ -252,6 +252,13 @@ protected void _fit(Dataframe trainingData) {
//Drop the temporary Collection
dbc.dropBigMap("tmp_newThitas", tmp_newThitas);
}

//apply the weight correction if we use ElasticNet Regularization
double l1 = kb().getTrainingParameters().getL1();
double l2 = kb().getTrainingParameters().getL2();
if(l1>0.0 && l2>0.0) {
ElasticNetRegularizer.weightCorrection(l2, thitas);
}
}

private void batchGradientDescent(Dataframe trainingData, Map<Object, Double> newThitas, double learningRate) {
Expand All @@ -263,7 +270,7 @@ private void batchGradientDescent(Dataframe trainingData, Map<Object, Double> ne
streamExecutor.forEach(StreamMethods.stream(trainingData.stream(), isParallelized()), r -> {
//mind the fact that we use the previous thitas to estimate the new ones! this is because the thitas must be updated simultaniously
double error = TypeInference.toDouble(r.getY()) - hypothesisFunction(r.getX(), thitas);

double errorMultiplier = multiplier*error;

synchronized(newThitas) {
Expand All @@ -278,9 +285,18 @@ private void batchGradientDescent(Dataframe trainingData, Map<Object, Double> ne
}
});

L1Regularizer.updateWeights(kb().getTrainingParameters().getL1(), learningRate, thitas, newThitas);
double l1 = kb().getTrainingParameters().getL1();
double l2 = kb().getTrainingParameters().getL2();

L2Regularizer.updateWeights(kb().getTrainingParameters().getL2(), learningRate, thitas, newThitas);
if(l1>0.0 && l2>0.0) {
ElasticNetRegularizer.updateWeights(l1, l2, learningRate, thitas, newThitas);
}
else if(l1>0.0) {
L1Regularizer.updateWeights(l1, learningRate, thitas, newThitas);
}
else if(l2>0.0) {
L2Regularizer.updateWeights(l2, learningRate, thitas, newThitas);
}
}

private double calculateError(Dataframe trainingData, Map<Object, Double> thitas) {
Expand All @@ -292,9 +308,18 @@ private double calculateError(Dataframe trainingData, Map<Object, Double> thitas
}));
error /= kb().getModelParameters().getN();

error += L1Regularizer.estimatePenalty(kb().getTrainingParameters().getL1(), thitas);
double l1 = kb().getTrainingParameters().getL1();
double l2 = kb().getTrainingParameters().getL2();

error += L2Regularizer.estimatePenalty(kb().getTrainingParameters().getL2(), thitas);
if(l1>0.0 && l2>0.0) {
error += ElasticNetRegularizer.estimatePenalty(l1, l2, thitas);
}
else if(l1>0.0) {
error += L1Regularizer.estimatePenalty(l1, thitas);
}
else if(l2>0.0) {
error += L2Regularizer.estimatePenalty(l2, thitas);
}

return error;
}
Expand Down
@@ -0,0 +1,103 @@
/**
* Copyright (C) 2013-2016 Vasilis Vryniotis <bbriniotis@datumbox.com>
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.datumbox.framework.core.utilities.regularization;

import java.util.Map;

/**
* Utility class for ElasticNet regularization.
*
* https://web.stanford.edu/~hastie/Papers/B67.2%20(2005)%20301-320%20Zou%20&%20Hastie.pdf
* http://web.stanford.edu/~hastie/TALKS/enet_talk.pdf
*
* @author Vasilis Vryniotis <bbriniotis@datumbox.com>
*/
public class ElasticNetRegularizer {

/**
* Updates the weights by applying the ElasticNet regularization.
*
* @param l1
* @param l2
* @param learningRate
* @param weights
* @param newWeights
* @param <K>
*/
public static <K> void updateWeights(double l1, double l2, double learningRate, Map<K, Double> weights, Map<K, Double> newWeights) {
//Equivalent to Naive elastic net
//L2Regularizer.updateWeights(l2, learningRate, weights, newWeights);
//L1Regularizer.updateWeights(l1, learningRate, weights, newWeights);

for(Map.Entry<K, Double> e : newWeights.entrySet()) {
double w = e.getValue();

//formula 6 on Zou and Hastie paper
double net_w = Math.abs(w) - l1/2.0; //naive elastic net beta
if(net_w>0) {
net_w /= (1.0 + l2);
net_w *= Math.signum(w);
}
else {
net_w = 0.0;
}

newWeights.put(e.getKey(), net_w);
}
}

/**
* Estimates the penalty by adding the ElasticNet regularization.
*
* @param l1
* @param l2
* @param weights
* @param <K>
* @return
*/
public static <K> double estimatePenalty(double l1, double l2, Map<K, Double> weights) {
//Equivalent to Naive elastic net
//double penalty = 0.0;
//penalty += L2Regularizer.estimatePenalty(l2, weights);
//penalty += L1Regularizer.estimatePenalty(l1, weights);
//return penalty;

double sumAbsWeights = 0.0;
double sumWeightsSquared = 0.0;
for(double w : weights.values()) {
sumAbsWeights += Math.abs(w);
sumWeightsSquared += w*w;
}
return l1*sumAbsWeights +
l2*sumWeightsSquared;//not divided by 2 as in the implementation of L2Regularizer
}

/**
* Applies a correction on the weights, turning the naive elastic net weights to elastic net weights.
*
* @param l2
* @param finalWeights
* @param <K>
*/
public static <K> void weightCorrection(double l2, Map<K, Double> finalWeights) {
for(Map.Entry<K, Double> e : finalWeights.entrySet()) {
double w = e.getValue();
if(w != 0.0) {
finalWeights.put(e.getKey(), (1.0+l2)*w); //equation 12 on Zou and Hastie paper
}
}
}
}
Expand Up @@ -64,7 +64,6 @@ public void testValidate() {

SoftMaxRegression.TrainingParameters param = new SoftMaxRegression.TrainingParameters();
param.setTotalIterations(2000);
param.setL1(0.001);
param.setL2(0.001);

instance.fit(trainingData, param);
Expand Down Expand Up @@ -125,8 +124,8 @@ public void testKFoldCrossValidation() {

SoftMaxRegression.TrainingParameters param = new SoftMaxRegression.TrainingParameters();
param.setTotalIterations(30);
param.setL1(0.00000001);
param.setL2(0.001);
param.setL1(0.0001);
param.setL2(0.0001);

SoftMaxRegression.ValidationMetrics vm = instance.kFoldCrossValidation(trainingData, param, k);

Expand Down
Expand Up @@ -60,6 +60,7 @@ public void testValidate() {

NLMS.TrainingParameters param = new NLMS.TrainingParameters();
param.setTotalIterations(1600);
param.setL1(0.00000001);


instance.fit(trainingData, param);
Expand Down Expand Up @@ -101,39 +102,39 @@ public void testKFoldCrossValidation() {

int k = 5;

Dataframe[] data = Datasets.regressionMixed(conf);
Dataframe[] data = Datasets.housingNumerical(conf);
Dataframe trainingData = data[0];
data[1].delete();

String dbName = this.getClass().getSimpleName();
DummyXYMinMaxNormalizer df = new DummyXYMinMaxNormalizer(dbName, conf);
df.fit_transform(trainingData, new DummyXYMinMaxNormalizer.TrainingParameters());




PCA featureSelector = new PCA(dbName, conf);
PCA.TrainingParameters featureSelectorParameters = new PCA.TrainingParameters();
featureSelectorParameters.setMaxDimensions(trainingData.xColumnSize()-1);
featureSelectorParameters.setWhitened(false);
featureSelectorParameters.setVariancePercentageThreshold(0.99999995);
featureSelector.fit_transform(trainingData, featureSelectorParameters);
featureSelector.delete();


NLMS instance = new NLMS(dbName, conf);

NLMS.TrainingParameters param = new NLMS.TrainingParameters();
param.setTotalIterations(500);
param.setL1(0.0001);
param.setL2(0.00001);
param.setL1(0.001);
param.setL2(0.001);

NLMS.ValidationMetrics vm = instance.kFoldCrossValidation(trainingData, param, k);

df.denormalize(trainingData);


double expResult = 0.9996038547117426;
double expResult = 0.7620091462443493;
double result = vm.getRSquare();
assertEquals(expResult, result, Constants.DOUBLE_ACCURACY_HIGH);

Expand Down

0 comments on commit a8b1093

Please sign in to comment.