Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

LibLinear and LibSVM have unmanaged global RNGs #172

Merged
merged 6 commits into from
Sep 10, 2021
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,11 @@

/**
* A {@link Trainer} which wraps a liblinear-java anomaly detection trainer using a one-class SVM.
*
* <p>
* Note the train method is synchronized on {@code LibLinearTrainer.class} due to a global RNG in liblinear-java.
* This is insufficient to ensure reproducibility if liblinear-java is used directly in the same JVM as Tribuo, but
* avoids locking on classes Tribuo does not control.
* <p>
* See:
* <pre>
* Fan RE, Chang KW, Hsieh CJ, Wang XR, Lin CJ.
Expand Down Expand Up @@ -127,6 +131,9 @@ protected List<Model> trainModels(Parameter curParams, int numFeatures, FeatureN
data.x = features;
data.n = numFeatures;

// Note this isn't sufficient for reproducibility as it doesn't cope with concurrency.
// Concurrency safety is handled by the global lock on LibLinearTrainer.class in LibLinearTrainer.train.
Linear.resetRandom();
return Collections.singletonList(Linear.train(data,curParams));
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import org.tribuo.Example;
import org.tribuo.ImmutableFeatureMap;
import org.tribuo.ImmutableOutputInfo;
import org.tribuo.Trainer;
import org.tribuo.anomaly.Event;
import org.tribuo.anomaly.Event.EventType;
import org.tribuo.common.libsvm.LibSVMModel;
Expand All @@ -38,11 +39,16 @@
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.SplittableRandom;
import java.util.logging.Logger;

/**
* A trainer for anomaly models that uses LibSVM.
* <p>
* Note the train method is synchronized on {@code LibSVMTrainer.class} due to a global RNG in LibSVM.
* This is insufficient to ensure reproducibility if LibSVM is used directly in the same JVM as Tribuo, but
* avoids locking on classes Tribuo does not control.
* <p>
* See:
* <pre>
* Chang CC, Lin CJ.
Expand All @@ -66,11 +72,20 @@ public class LibSVMAnomalyTrainer extends LibSVMTrainer<Event> {
protected LibSVMAnomalyTrainer() {}

/**
* Creates a one-class LibSVM trainer using the supplied parameters.
* @param parameters The training parameters.
* Creates a one-class LibSVM trainer using the supplied parameters and {@link Trainer#DEFAULT_SEED}.
* @param parameters The SVM training parameters.
*/
public LibSVMAnomalyTrainer(SVMParameters<Event> parameters) {
super(parameters);
this(parameters, Trainer.DEFAULT_SEED);
}

/**
* Creates a one-class LibSVM trainer using the supplied parameters and RNG seed.
* @param parameters The SVM parameters.
* @param seed The RNG seed for LibSVM's internal RNG.
*/
public LibSVMAnomalyTrainer(SVMParameters<Event> parameters, long seed) {
super(parameters,seed);
}

/**
Expand Down Expand Up @@ -100,7 +115,7 @@ protected LibSVMModel<Event> createModel(ModelProvenance provenance, ImmutableFe
}

@Override
protected List<svm_model> trainModels(svm_parameter curParams, int numFeatures, svm_node[][] features, double[][] outputs) {
protected List<svm_model> trainModels(svm_parameter curParams, int numFeatures, svm_node[][] features, double[][] outputs, SplittableRandom localRNG) {
svm_problem problem = new svm_problem();
problem.l = outputs[0].length;
problem.x = features;
Expand All @@ -112,6 +127,9 @@ protected List<svm_model> trainModels(svm_parameter curParams, int numFeatures,
if(checkString != null) {
throw new IllegalArgumentException("Error checking SVM parameters: " + checkString);
}
// This is safe because we synchronize on LibSVMTrainer.class in the train method to
// ensure there is no concurrent use of the rng.
svm.rand.setSeed(localRNG.nextLong());
return Collections.singletonList(svm.svm_train(problem, curParams));
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ public SVMAnomalyType(SVMMode type) {

@Override
public boolean isClassification() {
return true;
return false;
}

@Override
Expand All @@ -74,7 +74,7 @@ public boolean isRegression() {

@Override
public boolean isAnomaly() {
return false;
return true;
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,11 @@

/**
* A {@link Trainer} which wraps a liblinear-java classifier trainer.
*
* <p>
* Note the train method is synchronized on {@code LibLinearTrainer.class} due to a global RNG in liblinear-java.
* This is insufficient to ensure reproducibility if liblinear-java is used directly in the same JVM as Tribuo, but
* avoids locking on classes Tribuo does not control.
* <p>
* See:
* <pre>
* Fan RE, Chang KW, Hsieh CJ, Wang XR, Lin CJ.
Expand Down Expand Up @@ -114,6 +118,9 @@ protected List<Model> trainModels(Parameter curParams, int numFeatures, FeatureN
data.n = numFeatures;
data.bias = 1.0;

// Note this isn't sufficient for reproducibility as it doesn't cope with concurrency.
// Concurrency safety is handled by the global lock on LibLinearTrainer.class in LibLinearTrainer.train.
Linear.resetRandom();
return Collections.singletonList(Linear.train(data,curParams));
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,6 @@ public void testMulticlass() throws IOException, ClassNotFoundException {
for (Example<Label> example : examples) {
model.getExcuse(example);
}
//System.out.println("*** PASSED: " + prefix);
}

private void checkModelType(LinearType modelType) throws IOException, ClassNotFoundException {
Expand All @@ -129,7 +128,6 @@ private void checkModelType(LinearType modelType) throws IOException, ClassNotFo
for (Example<Label> example : examples) {
model.getExcuse(example);
}
//System.out.println("*** PASSED: " + prefix);
}


Expand Down Expand Up @@ -184,6 +182,19 @@ public Model<Label> testLibLinear(Pair<Dataset<Label>,Dataset<Label>> p) {
return m;
}

@Test
public void testReproducible() {
// Note this test will need to change if LibLinearTrainer grows a per Problem RNG.
Pair<Dataset<Label>,Dataset<Label>> p = LabelledDataGenerator.denseTrainTest();
Model<Label> m = t.train(p.getA());
Map<String, List<Pair<String,Double>>> mFeatures = m.getTopFeatures(-1);

Model<Label> mTwo = t.train(p.getA());
Map<String, List<Pair<String,Double>>> mTwoFeatures = mTwo.getTopFeatures(-1);

assertEquals(mFeatures,mTwoFeatures);
}

@Test
public void testDenseData() {
Pair<Dataset<Label>,Dataset<Label>> p = LabelledDataGenerator.denseTrainTest();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,10 @@ public class LibSVMClassificationModel extends LibSVMModel<Label> {
}
}

/**
* Returns the number of support vectors.
* @return The number of support vectors.
*/
public int getNumberOfSupportVectors() {
return models.get(0).SV.length;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import org.tribuo.Example;
import org.tribuo.ImmutableFeatureMap;
import org.tribuo.ImmutableOutputInfo;
import org.tribuo.Trainer;
import org.tribuo.classification.Label;
import org.tribuo.classification.WeightedLabels;
import org.tribuo.common.libsvm.LibSVMModel;
Expand All @@ -39,11 +40,16 @@
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.SplittableRandom;
import java.util.logging.Logger;

/**
* A trainer for classification models that uses LibSVM.
* <p>
* Note the train method is synchronized on {@code LibSVMTrainer.class} due to a global RNG in LibSVM.
* This is insufficient to ensure reproducibility if LibSVM is used directly in the same JVM as Tribuo, but
* avoids locking on classes Tribuo does not control.
* <p>
* See:
* <pre>
* Chang CC, Lin CJ.
Expand All @@ -69,10 +75,27 @@ public class LibSVMClassificationTrainer extends LibSVMTrainer<Label> implements
@Config(description="Use Label specific weights.")
private Map<String,Float> labelWeights = Collections.emptyMap();

/**
* For OLCUT.
*/
protected LibSVMClassificationTrainer() {}

/**
* Constructs a classification LibSVM trainer using the specified parameters
* and {@link Trainer#DEFAULT_SEED}.
* @param parameters The SVM parameters.
*/
public LibSVMClassificationTrainer(SVMParameters<Label> parameters) {
super(parameters);
this(parameters, Trainer.DEFAULT_SEED);
}

/**
* Constructs a classification LibSVM trainer using the specified parameters and seed.
* @param parameters The SVM parameters.
* @param seed The RNG seed for LibSVM's internal RNG.
*/
public LibSVMClassificationTrainer(SVMParameters<Label> parameters, long seed) {
super(parameters,seed);
}

/**
Expand All @@ -92,7 +115,7 @@ protected LibSVMModel<Label> createModel(ModelProvenance provenance, ImmutableFe
}

@Override
protected List<svm_model> trainModels(svm_parameter curParams, int numFeatures, svm_node[][] features, double[][] outputs) {
protected List<svm_model> trainModels(svm_parameter curParams, int numFeatures, svm_node[][] features, double[][] outputs, SplittableRandom localRNG) {
svm_problem problem = new svm_problem();
problem.l = outputs[0].length;
problem.x = features;
Expand All @@ -104,6 +127,9 @@ protected List<svm_model> trainModels(svm_parameter curParams, int numFeatures,
if(checkString != null) {
throw new IllegalArgumentException("Error checking SVM parameters: " + checkString);
}
// This is safe because we synchronize on LibSVMTrainer.class in the train method to
// ensure there is no concurrent use of the rng.
svm.rand.setSeed(localRNG.nextLong());
return Collections.singletonList(svm.svm_train(problem, curParams));
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
package org.tribuo.classification.libsvm;

import com.oracle.labs.mlrg.olcut.util.Pair;
import libsvm.svm_model;
import org.tribuo.CategoricalIDInfo;
import org.tribuo.CategoricalInfo;
import org.tribuo.Dataset;
Expand Down Expand Up @@ -56,12 +57,14 @@
import java.io.ObjectInputStream;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.Arrays;
import java.util.List;
import java.util.Locale;
import java.util.Map;

import static org.junit.jupiter.api.Assertions.assertArrayEquals;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.jupiter.api.Assertions.assertTrue;
Expand Down Expand Up @@ -186,6 +189,34 @@ public Model<Label> testLibSVM(Pair<Dataset<Label>,Dataset<Label>> p) {
return m;
}

@Test
public void testReproducibility() {
Pair<Dataset<Label>,Dataset<Label>> p = LabelledDataGenerator.denseTrainTest();
long seed = 42L;
SVMParameters<Label> params = new SVMParameters<>(new SVMClassificationType(SVMMode.NU_SVC),KernelType.RBF);
params.setProbability();
LibSVMTrainer<Label> first = new LibSVMClassificationTrainer(params,seed);
LibSVMModel<Label> firstModel = first.train(p.getA());

LibSVMTrainer<Label> second = new LibSVMClassificationTrainer(params,seed);
LibSVMModel<Label> secondModel = second.train(p.getA());

LibSVMModel<Label> thirdModel = second.train(p.getA());

svm_model m = firstModel.getInnerModels().get(0);
svm_model mTwo = secondModel.getInnerModels().get(0);
svm_model mThree = thirdModel.getInnerModels().get(0);

// One and two use the same RNG seed and should be identical
assertArrayEquals(m.sv_coef,mTwo.sv_coef);
assertArrayEquals(m.probA,mTwo.probA);
assertArrayEquals(m.probB,mTwo.probB);

// The RNG state of three has diverged and should produce a different model.
assertFalse(Arrays.equals(mTwo.probA,mThree.probA));
assertFalse(Arrays.equals(mTwo.probB,mThree.probB));
}

@Test
public void testDenseData() {
Pair<Dataset<Label>,Dataset<Label>> p = LabelledDataGenerator.denseTrainTest();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,10 @@
/**
* A {@link Trainer} which wraps a liblinear-java trainer.
* <p>
* Note the train method is synchronized on {@code LibLinearTrainer.class} due to a global RNG in liblinear-java.
* This is insufficient to ensure reproducibility if liblinear-java is used directly in the same JVM as Tribuo, but
* avoids locking on classes Tribuo does not control.
* <p>
* See:
* <pre>
* Fan RE, Chang KW, Hsieh CJ, Wang XR, Lin CJ.
Expand Down Expand Up @@ -139,7 +143,10 @@ public LibLinearModel<T> train(Dataset<T> examples, Map<String, Provenance> runP

Pair<FeatureNode[][],double[][]> data = extractData(examples,outputIDInfo,featureIDMap);

List<de.bwaldvogel.liblinear.Model> models = trainModels(curParams,featureIDMap.size()+1,data.getA(),data.getB());
List<de.bwaldvogel.liblinear.Model> models;
synchronized (LibLinearTrainer.class) {
models = trainModels(curParams, featureIDMap.size() + 1, data.getA(), data.getB());
}

return createModel(provenance,featureIDMap,outputIDInfo,models);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ protected LibSVMModel(String name, ModelProvenance description, ImmutableFeature
/**
* Returns an unmodifiable copy of the underlying list of libsvm models.
* <p>
* Deprecated to unify the names across LibLinear, LibSVM and XGBoost.
* @deprecated Deprecated to unify the names across LibLinear, LibSVM and XGBoost.
* @return The underlying model list.
*/
@Deprecated
Expand Down
Loading