## This simple classifier was used to make the initial comparison between scikit-learn and Tribuo

In [1]:
%jars ./jars/tribuo-classification-experiments-4.1.0-jar-with-dependencies.jar
%jars ./jars/tribuo-json-4.1.0-jar-with-dependencies.jar

In [2]:
import java.nio.file.Paths;
import java.nio.file.Files;

In [3]:
import org.tribuo.*;
import org.tribuo.evaluation.TrainTestSplitter;
import org.tribuo.data.csv.CSVLoader;
import org.tribuo.classification.*;
import org.tribuo.classification.evaluation.*;
import org.tribuo.classification.sgd.linear.LogisticRegressionTrainer;
import org.tribuo.classification.sgd.linear.LinearSGDTrainer;
import org.tribuo.classification.sgd.objectives.LogMulticlass;
import org.tribuo.math.optimisers.AdaGrad;
import org.tribuo.Trainer;

In [4]:
import com.fasterxml.jackson.databind.*;
import com.oracle.labs.mlrg.olcut.provenance.ProvenanceUtil;
import com.oracle.labs.mlrg.olcut.config.json.*;

In [5]:
var labelFactory = new LabelFactory();
var csvLoader = new CSVLoader<>(labelFactory);

In [6]:
// This is the classic Iris dataset
var irisHeaders = new String[]{"sepalLength", "sepalWidth", "petalLength", "petalWidth", "species"};
var irisesSource = csvLoader.loadDataSource(Paths.get("data/bezdekIris.data"),"species",irisHeaders);
var irisSplitter = new TrainTestSplitter<>(irisesSource,0.7,1L);

In [7]:
var trainingDataset = new MutableDataset<>(irisSplitter.getTrain());
var testingDataset = new MutableDataset<>(irisSplitter.getTest());
System.out.println(String.format("Training data size = %d, number of features = %d, number of classes = %d",trainingDataset.size(),trainingDataset.getFeatureMap().size(),trainingDataset.getOutputInfo().size()));
System.out.println(String.format("Testing data size = %d, number of features = %d, number of classes = %d",testingDataset.size(),testingDataset.getFeatureMap().size(),testingDataset.getOutputInfo().size()));

Training data size = 105, number of features = 4, number of classes = 3
Testing data size = 45, number of features = 4, number of classes = 3


In [8]:
// Trainer<Label> lrTrainer = new LogisticRegressionTrainer();
// System.out.println(lrTrainer.toString());
Trainer<Label> sgdTrainer = new LinearSGDTrainer(new LogMulticlass(), new AdaGrad(1.0, 0.1), 1000, Trainer.DEFAULT_SEED);
System.out.println(sgdTrainer.toString());

LinearSGDTrainer(objective=LogMulticlass,optimiser=AdaGrad(initialLearningRate=1.0,epsilon=0.1,initialValue=0.0),epochs=1000,minibatchSize=1,seed=12345)


In [9]:
// Model<Label> sgdModel = lrTrainer.train(trainingDataset);
Model<Label> sgdModel = sgdTrainer.train(trainingDataset);

In [10]:
var evaluator = new LabelEvaluator();
var evaluation = evaluator.evaluate(sgdModel,testingDataset);
System.out.println(evaluation.toString());

Class                           n          tp          fn          fp      recall        prec          f1
Iris-versicolor                16          16           0           1       1.000       0.941       0.970
Iris-virginica                 15          14           1           0       0.933       1.000       0.966
Iris-setosa                    14          14           0           0       1.000       1.000       1.000
Total                          45          44           1           1
Accuracy                                                                    0.978
Micro Average                                                               0.978       0.978       0.978
Macro Average                                                               0.978       0.980       0.978
Balanced Error Rate                                                         0.022


In [11]:
System.out.println(evaluation.getConfusionMatrix().toString());

                   Iris-versicolor   Iris-virginica      Iris-setosa
Iris-versicolor                 16                0                0
Iris-virginica                   1               14                0
Iris-setosa                      0                0               14

