## This demonstrates Tribuo classfication for comparison with scikit-learn classfication

In [1]:
%jars ./jars/tribuo-classification-experiments-4.1.0-jar-with-dependencies.jar
%jars ./jars/tribuo-classification-liblinear-4.1.0-jar-with-dependencies.jar
%jars ./jars/tribuo-json-4.1.0-jar-with-dependencies.jar

In [2]:
import java.nio.file.Paths;
import java.nio.file.Files;
import java.util.logging.Level;
import java.util.logging.Logger;

In [3]:
import org.tribuo.*;
import org.tribuo.evaluation.TrainTestSplitter;
import org.tribuo.data.csv.CSVLoader;
import org.tribuo.math.optimisers.*;
import org.tribuo.classification.*;
import org.tribuo.classification.evaluation.*;
import org.tribuo.classification.sgd.linear.LogisticRegressionTrainer;
import org.tribuo.classification.sgd.linear.LinearSGDTrainer;
import org.tribuo.classification.liblinear.LibLinearClassificationTrainer;
import org.tribuo.classification.sgd.objectives.Hinge;
import org.tribuo.classification.dtree.CARTClassificationTrainer;
import org.tribuo.classification.xgboost.XGBoostClassificationTrainer;
import org.tribuo.Trainer;
import org.tribuo.util.Util;

In [4]:
var labelFactory = new LabelFactory();
var csvLoader = new CSVLoader<>(labelFactory);

In [5]:
var rainHeaders = new String[]{"Month", "MinTemp", "MaxTemp", "Rainfall", "WindGustSpeed", "WindSpeed9am", 
                               "WindSpeed3pm", "Humidity9am", "Humidity3pm", "Pressure9am", "Pressure3pm", 
                               "Temp9am", "Temp3pm", "RainToday", "WindGustDir_E", "WindGustDir_ENE", 
                               "WindGustDir_ESE", "WindGustDir_N", "WindGustDir_NE", "WindGustDir_NNE", 
                               "WindGustDir_NNW", "WindGustDir_NW", "WindGustDir_S", "WindGustDir_SE", 
                               "WindGustDir_SSE", "WindGustDir_SSW", "WindGustDir_SW", "WindGustDir_W",
                               "WindGustDir_WNW", "WindGustDir_WSW", "WindDir9am_E", "WindDir9am_ENE",
                               "WindDir9am_ESE", "WindDir9am_N", "WindDir9am_NE", "WindDir9am_NNE", "WindDir9am_NNW",
                               "WindDir9am_NW", "WindDir9am_S", "WindDir9am_SE", "WindDir9am_SSE", "WindDir9am_SSW", 
                               "WindDir9am_SW", "WindDir9am_W", "WindDir9am_WNW", "WindDir9am_WSW", "WindDir3pm_E",
                               "WindDir3pm_ENE", "WindDir3pm_ESE", "WindDir3pm_N", "WindDir3pm_NE", "WindDir3pm_NNE",
                               "WindDir3pm_NNW", "WindDir3pm_NW", "WindDir3pm_S", "WindDir3pm_SE", "WindDir3pm_SSE", 
                               "WindDir3pm_SSW", "WindDir3pm_SW", "WindDir3pm_W", "WindDir3pm_WNW", "WindDir3pm_WSW",
                               "RainTomorrowN"};
// This dataset is prepared in the notebook: scikit-learn Classifier - Data Cleanup
var weatherSource = csvLoader.loadDataSource(Paths.get("data/cleanedWeatherAUS.csv"),"RainTomorrowN",rainHeaders);
var weatherSplitter = new TrainTestSplitter<>(weatherSource,0.8,1L);

In [6]:
var trainingDataset = new MutableDataset<>(weatherSplitter.getTrain());
var testingDataset = new MutableDataset<>(weatherSplitter.getTest());
System.out.println(String.format("Training data size = %d, number of features = %d, number of classes = %d",trainingDataset.size(),trainingDataset.getFeatureMap().size(),trainingDataset.getOutputInfo().size()));
System.out.println(String.format("Testing data size = %d, number of features = %d, number of classes = %d",testingDataset.size(),testingDataset.getFeatureMap().size(),testingDataset.getOutputInfo().size()));

Training data size = 112629, number of features = 62, number of classes = 2
Testing data size = 28158, number of features = 62, number of classes = 2


In [7]:
// Note: the types including generics were tricky to get working
public Model train(String name, Trainer trainer, Dataset<Label> trainData) { 
    // Train the model
    var startTime = System.currentTimeMillis();
    var model = trainer.train(trainData);
    var endTime = System.currentTimeMillis();
    System.out.println("Training " + name + " took " + Util.formatDuration(startTime,endTime));
    // Evaluate the model on the training data
    // var eval = new LabelEvaluator();
    // var evaluation = eval.evaluate(model,trainData);
    // Don't report training scores
    // System.out.println(evaluation.toString());
    //System.out.println(evaluation.getConfusionMatrix().toString());
    return model;
}

In [8]:
public void evaluate(Model model, Dataset<Label> testData) {
    // Evaluate the model on the test data
    var eval = new LabelEvaluator();
    var evaluation = eval.evaluate(model,testData);
    System.out.println(evaluation.toString());
    //System.out.println(evaluation.getConfusionMatrix().toString());
}

In [9]:
var lrsgd = new LinearSGDTrainer(
    new Hinge(), 
    new AdaGrad(0.1, 0.1), // SGD.getLinearDecaySGD(0.01), 
    5,
    Trainer.DEFAULT_SEED
);

var lr = new LibLinearClassificationTrainer();

var cart = new CARTClassificationTrainer();

var xgb = new XGBoostClassificationTrainer(100);

In [10]:
System.out.println(lrsgd.toString());
System.out.println(lr.toString());
System.out.println(cart.toString());
System.out.println(xgb.toString());

LinearSGDTrainer(objective=Hinge(margin=1.0),optimiser=AdaGrad(initialLearningRate=0.1,epsilon=0.1,initialValue=0.0),epochs=5,minibatchSize=1,seed=12345)
LibLinearTrainer(solver=L2R_L2LOSS_SVC_DUAL,cost=1.0,terminationCriterion=0.1,maxIterations=1000,regression-epsilon=0.1)
CARTClassificationTrainer(maxDepth=2147483647,minChildWeight=5.0,minImpurityDecrease=0.0,fractionFeaturesInSplit=1.0,useRandomSplitPoints=false,impurity=GiniIndex,seed=12345)
XGBoostTrainer(numTrees=100,parameters{colsample_bytree=1.0, tree_method=auto, seed=12345, max_depth=6, booster=gbtree, objective=multi:softprob, lambda=1.0, eta=0.3, nthread=4, alpha=0.0, subsample=1.0, gamma=0.0, min_child_weight=1.0, verbosity=0})


In [11]:
// Turn off that SGD logging - it effects performance
var logger = Logger.getLogger(org.tribuo.common.sgd.AbstractSGDTrainer.class.getName());
logger.setLevel(Level.OFF);

In [12]:
var lrsgdModel = train("Linear Regression (SGD)", lrsgd, trainingDataset);

// run 1
// time 2.22s

// run 2
// time 1.56s

// run 3
// time 1.58

Training Linear Regression (SGD) took (00:00:01:578)


In [13]:
evaluate(lrsgdModel,testingDataset);


// run 1
// Class     recall        prec          f1
// No        0.959       0.854       0.903
// Yes       0.426       0.748       0.543

// run 2
// Class     recall        prec          f1
// No        0.959       0.854       0.903
// Yes       0.426       0.748       0.543

// run 3
// Class     recall        prec          f1
// No        0.959       0.854       0.903
// Yes       0.426       0.748       0.543

Class                           n          tp          fn          fp      recall        prec          f1
No                         21,889      20,988         901       3,596       0.959       0.854       0.903
Yes                         6,269       2,673       3,596         901       0.426       0.748       0.543
Total                      28,158      23,661       4,497       4,497
Accuracy                                                                    0.840
Micro Average                                                               0.840       0.840       0.840
Macro Average                                                               0.693       0.801       0.723
Balanced Error Rate                                                         0.307


In [14]:
var lrModel = train("Linear Regression", lr, trainingDataset);

// run 1
// time 8.21 s

// run 2
// time 7.21 s

// run 3
// time 6.51

Training Linear Regression took (00:00:06:506)


In [15]:
evaluate(lrModel,testingDataset);

// run 1
// Class      recall        prec          f1
//  No         0.955       0.858       0.904
// Yes         0.449       0.740       0.559

// run 2
// Class      recall        prec          f1
//  No         0.955       0.858       0.904
// Yes         0.449       0.740       0.559

// run 3
// Class      recall        prec          f1


Class                           n          tp          fn          fp      recall        prec          f1
No                         21,889      20,903         986       3,457       0.955       0.858       0.904
Yes                         6,269       2,812       3,457         986       0.449       0.740       0.559
Total                      28,158      23,715       4,443       4,443
Accuracy                                                                    0.842
Micro Average                                                               0.842       0.842       0.842
Macro Average                                                               0.702       0.799       0.731
Balanced Error Rate                                                         0.298


In [16]:
var cartModel = train("Decision Tree", cart, trainingDataset);
// run 1
// time 5.43 s

// run 2
// time 5.19 s

// run 3
// time 4.00 s

Training Decision Tree took (00:00:03:998)


In [17]:
evaluate(cartModel,testingDataset);

// run 1
// Class      recall        prec          f1
//  No        0.896       0.861       0.878
// Yes        0.495       0.576       0.532

// run 2
// Class      recall        prec          f1
//  No        0.896       0.861       0.878
// Yes        0.495       0.576       0.532

// run 3
// Class      recall        prec          f1
//  No        0.896       0.861       0.878
// Yes        0.495       0.576       0.532

Class                           n          tp          fn          fp      recall        prec          f1
No                         21,889      19,604       2,285       3,168       0.896       0.861       0.878
Yes                         6,269       3,101       3,168       2,285       0.495       0.576       0.532
Total                      28,158      22,705       5,453       5,453
Accuracy                                                                    0.806
Micro Average                                                               0.806       0.806       0.806
Macro Average                                                               0.695       0.718       0.705
Balanced Error Rate                                                         0.305


In [18]:
var xgbModel = train("XGBoost", xgb, trainingDataset);
// run 1
// time 1min 25s 

// run 2
// time 1min 18s

// run 3
// time 1min 21s

Training XGBoost took (00:01:21:134)


In [19]:
evaluate(xgbModel,testingDataset);

// run 1
// Class      recall        prec          f1
//  No        0.948       0.876       0.910
//  Yes       0.531       0.745       0.620

// run 2
// Class      recall        prec          f1
//  No        0.948       0.876       0.910
//  Yes       0.531       0.745       0.620

// run 3
// Class      recall        prec          f1
//  No        0.948       0.876       0.910
//  Yes       0.531       0.745       0.620

Class                           n          tp          fn          fp      recall        prec          f1
No                         21,889      20,748       1,141       2,938       0.948       0.876       0.910
Yes                         6,269       3,331       2,938       1,141       0.531       0.745       0.620
Total                      28,158      24,079       4,079       4,079
Accuracy                                                                    0.855
Micro Average                                                               0.855       0.855       0.855
Macro Average                                                               0.740       0.810       0.765
Balanced Error Rate                                                         0.260
