Skip to content

Commit

Permalink
Change implementation of ElasticNet regularization.
Browse files Browse the repository at this point in the history
  • Loading branch information
datumbox committed Mar 15, 2016
1 parent a8b1093 commit 5ad97e0
Show file tree
Hide file tree
Showing 4 changed files with 7 additions and 63 deletions.
Expand Up @@ -340,13 +340,6 @@ protected void _fit(Dataframe trainingData) {
//Drop the temporary Collection
dbc.dropBigMap("tmp_newThitas", tmp_newThitas);
}

//apply the weight correction if we use ElasticNet Regularization
double l1 = kb().getTrainingParameters().getL1();
double l2 = kb().getTrainingParameters().getL2();
if(l1>0.0 && l2>0.0) {
ElasticNetRegularizer.weightCorrection(l2, thitas);
}
}

/** {@inheritDoc} */
Expand Down
Expand Up @@ -252,13 +252,6 @@ protected void _fit(Dataframe trainingData) {
//Drop the temporary Collection
dbc.dropBigMap("tmp_newThitas", tmp_newThitas);
}

//apply the weight correction if we use ElasticNet Regularization
double l1 = kb().getTrainingParameters().getL1();
double l2 = kb().getTrainingParameters().getL2();
if(l1>0.0 && l2>0.0) {
ElasticNetRegularizer.weightCorrection(l2, thitas);
}
}

private void batchGradientDescent(Dataframe trainingData, Map<Object, Double> newThitas, double learningRate) {
Expand Down
Expand Up @@ -38,25 +38,8 @@ public class ElasticNetRegularizer {
* @param <K>
*/
public static <K> void updateWeights(double l1, double l2, double learningRate, Map<K, Double> weights, Map<K, Double> newWeights) {
//Equivalent to Naive elastic net
//L2Regularizer.updateWeights(l2, learningRate, weights, newWeights);
//L1Regularizer.updateWeights(l1, learningRate, weights, newWeights);

for(Map.Entry<K, Double> e : newWeights.entrySet()) {
double w = e.getValue();

//formula 6 on Zou and Hastie paper
double net_w = Math.abs(w) - l1/2.0; //naive elastic net beta
if(net_w>0) {
net_w /= (1.0 + l2);
net_w *= Math.signum(w);
}
else {
net_w = 0.0;
}

newWeights.put(e.getKey(), net_w);
}
L2Regularizer.updateWeights(l2, learningRate, weights, newWeights);
L1Regularizer.updateWeights(l1, learningRate, weights, newWeights);
}

/**
Expand All @@ -69,35 +52,10 @@ public static <K> void updateWeights(double l1, double l2, double learningRate,
* @return
*/
public static <K> double estimatePenalty(double l1, double l2, Map<K, Double> weights) {
//Equivalent to Naive elastic net
//double penalty = 0.0;
//penalty += L2Regularizer.estimatePenalty(l2, weights);
//penalty += L1Regularizer.estimatePenalty(l1, weights);
//return penalty;

double sumAbsWeights = 0.0;
double sumWeightsSquared = 0.0;
for(double w : weights.values()) {
sumAbsWeights += Math.abs(w);
sumWeightsSquared += w*w;
}
return l1*sumAbsWeights +
l2*sumWeightsSquared;//not divided by 2 as in the implementation of L2Regularizer
double penalty = 0.0;
penalty += L2Regularizer.estimatePenalty(l2, weights);
penalty += L1Regularizer.estimatePenalty(l1, weights);
return penalty;
}

/**
* Applies a correction on the weights, turning the naive elastic net weights to elastic net weights.
*
* @param l2
* @param finalWeights
* @param <K>
*/
public static <K> void weightCorrection(double l2, Map<K, Double> finalWeights) {
for(Map.Entry<K, Double> e : finalWeights.entrySet()) {
double w = e.getValue();
if(w != 0.0) {
finalWeights.put(e.getKey(), (1.0+l2)*w); //equation 12 on Zou and Hastie paper
}
}
}
}
Expand Up @@ -134,7 +134,7 @@ public void testKFoldCrossValidation() {
df.denormalize(trainingData);


double expResult = 0.7620091462443493;
double expResult = 0.7748106446239166;
double result = vm.getRSquare();
assertEquals(expResult, result, Constants.DOUBLE_ACCURACY_HIGH);

Expand Down

0 comments on commit 5ad97e0

Please sign in to comment.