Skip to content

Commit

Permalink
Run RegressionTest with different bias values
Browse files Browse the repository at this point in the history
  • Loading branch information
bwaldvogel committed Oct 22, 2020
1 parent b8a6bd3 commit 98fdb45
Show file tree
Hide file tree
Showing 55 changed files with 53,915 additions and 37 deletions.
176 changes: 139 additions & 37 deletions src/test/java/de/bwaldvogel/liblinear/RegressionTest.java
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package de.bwaldvogel.liblinear;

import static de.bwaldvogel.liblinear.SolverType.*;
import static org.assertj.core.api.Assertions.*;

import java.nio.charset.StandardCharsets;
Expand All @@ -11,6 +12,7 @@
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.EnumSet;
import java.util.List;
import java.util.stream.Collectors;
import java.util.stream.Stream;
Expand All @@ -36,94 +38,194 @@ private static Collection<TestParams> data() {
List<TestParams> params = new ArrayList<>();
for (String dataset : new String[] {"splice", "dna.scale"}) {
for (SolverType solverType : SOLVERS) {
params.add(new TestParams(dataset, solverType, getExpectedAccuracy(dataset, solverType)));
for (int bias : new int[] {-1, 1}) {
for (boolean regularizeBias : new boolean[] {false, true}) {
if (!regularizeBias && bias != 1) {
continue;
}
if (!regularizeBias) {
// -R option supported only for solver L2R_LR, L2R_L2LOSS_SVC, L1R_L2LOSS_SVC, L1R_LR, and L2R_L2LOSS_SVR
if (!EnumSet.of(L2R_LR, L2R_L2LOSS_SVC, L1R_L2LOSS_SVC, L1R_LR, L2R_L2LOSS_SVR).contains(solverType)) {
continue;
}
}
params.add(new TestParams(dataset, solverType, bias, regularizeBias, getExpectedAccuracy(dataset, solverType, bias, regularizeBias)));
}
}
}
}
return params;
}

private static Double getExpectedAccuracy(String dataset, SolverType solverType) {
private static Double getExpectedAccuracy(String dataset, SolverType solverType, int bias, boolean regularizeBias) {
if (solverType.isSupportVectorRegression() || solverType.isOneClass()) {
return null;
}
switch (dataset) {
case "splice":
switch (solverType) {
case L2R_LR:
return 0.8423;
switch (bias) {
case -1:
return 0.8423;
case 1:
return 0.84322;
}
case L2R_L2LOSS_SVC_DUAL:
return 0.8386;
switch (bias) {
case -1:
return 0.8386;
case 1:
return 0.85057;
}
case L2R_L2LOSS_SVC:
return 0.8432;
switch (bias) {
case -1:
return 0.8432;
case 1:
return regularizeBias ? 0.85149 : 0.85241;
}
case L2R_L1LOSS_SVC_DUAL:
return 0.8382;
switch (bias) {
case -1:
return 0.8382;
case 1:
return 0.83448;
}
case MCSVM_CS:
return 0.8377;
switch (bias) {
case -1:
return 0.8377;
case 1:
return 0.83862;
}
case L1R_L2LOSS_SVC:
return 0.8478;
switch (bias) {
case -1:
return 0.8478;
case 1:
return regularizeBias ? 0.8478 : 0.84782;
}
case L1R_LR:
return 0.8473;
switch (bias) {
case -1:
return 0.8473;
case 1:
return regularizeBias ? 0.84782 : 0.84598;
}
case L2R_LR_DUAL:
return 0.8423;
switch (bias) {
case -1:
return 0.8423;
case 1:
return 0.8492;
}
}
case "dna.scale":
switch (solverType) {
case L2R_LR:
return 0.9511;
switch (bias) {
case -1:
return 0.9511;
case 1:
return 0.94941;
}
case L2R_L2LOSS_SVC_DUAL:
return 0.9452;
switch (bias) {
case -1:
case 1:
return 0.9452;
}
case L2R_L2LOSS_SVC:
return 0.9469;
switch (bias) {
case -1:
return 0.9469;
case 1:
return 0.94604;
}
case L2R_L1LOSS_SVC_DUAL:
return 0.9477;
switch (bias) {
case -1:
return 0.9477;
case 1:
return 0.94604;
}
case MCSVM_CS:
return 0.9292;
switch (bias) {
case -1:
return 0.9292;
case 1:
return 0.92749;
}
case L1R_L2LOSS_SVC:
return 0.9553;
switch (bias) {
case -1:
return 0.9553;
case 1:
return regularizeBias ? 0.956998 : 0.95363;
}
case L1R_LR:
return 0.9536;
switch (bias) {
case -1:
return 0.9536;
case 1:
return regularizeBias ? 0.95194 : 0.95278;
}
case L2R_LR_DUAL:
return 0.9486;
switch (bias) {
case -1:
return 0.9486;
case 1:
return 0.94941;
}
}
default:
throw new IllegalArgumentException("Unknown expectation: " + dataset + ", " + solverType);
throw new IllegalArgumentException("Unknown expectation: " + dataset + ", " + solverType + ", " + bias);
}
}

private static class TestParams {

private final String dataset;
private final SolverType solverType;
private final int bias;
private final boolean regularizeBias;
private final Double expectedAccuracy;

private TestParams(String dataset, SolverType solverType, Double expectedAccuracy) {
private TestParams(String dataset, SolverType solverType, int bias, boolean regularizeBias, Double expectedAccuracy) {
this.dataset = dataset;
this.solverType = solverType;
this.bias = bias;
this.regularizeBias = regularizeBias;
this.expectedAccuracy = expectedAccuracy;
}

@Override
public String toString() {
return "dataset: " + dataset + ", solver: " + solverType;
return "dataset: " + dataset + ", solver: " + solverType + ", bias: " + bias + (!regularizeBias ? " (not regularized)" : "");
}
}

@ParameterizedTest
@MethodSource("data")
void regressionTest(TestParams params) throws Exception {
log.info("Running regression test for '{}'", params);
runRegressionTest(params.dataset, params.solverType, params.expectedAccuracy);
}

private void runRegressionTest(String dataset, SolverType solverType, Double expectedAccuracy) throws Exception {
Linear.resetRandom();
Path trainingFile = Paths.get("src/test/datasets", dataset, dataset);
Problem problem = Train.readProblem(trainingFile, -1);
Model model = Linear.train(problem, new Parameter(solverType, 1, 0.1));
Path testFile = Paths.get("src/test/datasets", dataset, dataset + ".t");
Problem testProblem = Train.readProblem(testFile, -1);

Path expectedFile = Paths.get("src/test/resources/regression", dataset, "predictions_" + solverType.name());
Path trainingFile = Paths.get("src/test/datasets", params.dataset, params.dataset);
Problem problem = Train.readProblem(trainingFile, params.bias);
Parameter parameter = new Parameter(params.solverType, 1, 0.1);
parameter.setRegularizeBias(params.regularizeBias);
Model model = Linear.train(problem, parameter);
Path testFile = Paths.get("src/test/datasets", params.dataset, params.dataset + ".t");
Problem testProblem = Train.readProblem(testFile, params.bias);

String filename = "predictions_" + params.solverType.name();
if (!params.regularizeBias) {
filename += "_notRegularizedBias";
} else {
filename += "_bias_" + params.bias;
}
Path expectedFile = Paths.get("src/test/resources/regression", params.dataset, filename);
final List<String> expectedPredictions;
if (!Files.exists(expectedFile)) {
expectedPredictions = Collections.emptyList();
Expand All @@ -140,13 +242,13 @@ private void runRegressionTest(String dataset, SolverType solverType, Double exp
Feature[] x = testProblem.x[i];
double[] predictedValues = new double[model.getNrClass()];
final double prediction;
if (solverType.isLogisticRegressionSolver()) {
if (params.solverType.isLogisticRegressionSolver()) {
prediction = Linear.predictProbability(model, x, predictedValues);
} else {
prediction = Linear.predictValues(model, x, predictedValues);
}

if (expectedAccuracy != null) {
if (params.expectedAccuracy != null) {
int expectation = (int)testProblem.y[i];
int actual = (int)prediction;
if (actual == expectation) {
Expand Down Expand Up @@ -182,9 +284,9 @@ private void runRegressionTest(String dataset, SolverType solverType, Double exp
}
}

if (expectedAccuracy != null) {
if (params.expectedAccuracy != null) {
double accuracy = correctPredictions / (double)testProblem.l;
assertThat(accuracy).isEqualTo(expectedAccuracy.doubleValue(), Offset.offset(1e-4));
assertThat(accuracy).isEqualTo(params.expectedAccuracy.doubleValue(), Offset.offset(1e-4));
}
}

Expand All @@ -207,7 +309,7 @@ void testOneClass(@TempDir Path tempDir) throws Exception {
}

Problem problem1 = Train.readProblem(spliceClass1, StandardCharsets.UTF_8, -1);
Parameter param = new Parameter(SolverType.ONECLASS_SVM, 1, 0.01);
Parameter param = new Parameter(ONECLASS_SVM, 1, 0.01);
param.setNu(0.1);
Model model = Linear.train(problem1, param);

Expand Down
Loading

0 comments on commit 98fdb45

Please sign in to comment.