# Multilabel Classification

In [1]:
%jars ../tribuo/MultiLabel/SGD/target/tribuo-multilabel-sgd-4.2.0-SNAPSHOT-jar-with-dependencies.jar
%jars ../tribuo/Common/NearestNeighbour/target/tribuo-common-nearest-neighbour-4.2.0-SNAPSHOT.jar
%jars ../tribuo/Classification/SGD/target/tribuo-classification-sgd-4.2.0-SNAPSHOT-jar-with-dependencies.jar
%jars ../tribuo/Classification/DecisionTree/target/tribuo-classification-tree-4.2.0-SNAPSHOT-jar-with-dependencies.jar

%jars ../tribuo/Reproducibility/target/tribuo-reproducibility-4.2.0-SNAPSHOT-jar-with-dependencies.jar

In [2]:
import org.tribuo.*;
import org.tribuo.classification.Label;
import org.tribuo.classification.dtree.CARTClassificationTrainer;
import org.tribuo.classification.dtree.impurity.*;
import org.tribuo.datasource.*;
import org.tribuo.math.optimisers.*;
import org.tribuo.multilabel.*;
import org.tribuo.multilabel.baseline.*;
import org.tribuo.multilabel.ensemble.*;
import org.tribuo.multilabel.evaluation.*;
import org.tribuo.multilabel.sgd.linear.*;
import org.tribuo.multilabel.sgd.objectives.*;
import org.tribuo.util.Util;

import com.oracle.labs.mlrg.olcut.config.ConfigurationManager;
import com.oracle.labs.mlrg.olcut.config.ConfigurationData;
import com.oracle.labs.mlrg.olcut.provenance.ProvenanceUtil;

import java.nio.file.Paths;

import org.tribuo.reproducibility.ReproUtil;

In [3]:
FileWriter fw = new FileWriter("./configMultilabelResults.csv");
fw.append("Task, Trainer, Model, Equivalent Evaluation, Model Prov Diff, Dataset Name, Datasource\n");
fw.flush();
fw.close();

public String escapeSpecialCharacters(String data) {
    String escapedData = data.replaceAll("\\R", " ");
    if (data.contains(",") || data.contains("\"") || data.contains("'")) {
        data = data.replace("\"", "\"\"");
        escapedData = "\"" + data + "\"";
    }
    return escapedData;
}

public void addToCSV(String task, String trainer, String model, String equal, String diff, String dataset, String datatype) throws Exception{
    FileWriter fw = new FileWriter("./configMultilabelResults.csv", true);
    fw.append(escapeSpecialCharacters(task) + "," +
              escapeSpecialCharacters(trainer) + "," + 
              escapeSpecialCharacters(model) + "," +
              escapeSpecialCharacters(equal) + "," +
              escapeSpecialCharacters(diff) + "," + dataset + "," + datatype + "\n");
    fw.flush();
    fw.close();
}

In [4]:
var configPath = Paths.get("configs","all-multilabel-config.xml");
var cm = new ConfigurationManager(configPath.toString());
HashMap<String,Trainer> mlTrainers = (HashMap<String,Trainer>) cm.lookupAllMap(Trainer.class);

In [5]:
var factory = new MultiLabelFactory();
var trainSource = new LibSVMDataSource<>(Paths.get("yeast_train.svm"),factory);
var testSource = new LibSVMDataSource<>(Paths.get("yeast_test.svm"),factory,trainSource.isZeroIndexed(),trainSource.getMaxFeatureID());
var train = new MutableDataset<>(trainSource);
var test = new MutableDataset<>(testSource);

In [6]:
for (String trainerKey : mlTrainers.keySet()){

    if(!trainerKey.substring(0, 2).equals("ml")){
        continue;
    }
    System.out.println(trainerKey);
    Model<MultiLabel> model = mlTrainers.get(trainerKey).train(train);
    ReproUtil repro = new ReproUtil(model);
    Model<MultiLabel> newModel = repro.reproduceFromProvenance();
    
    MultiLabelEvaluator eval = new MultiLabelEvaluator();
    var oldEvaluation = eval.evaluate(model,test);
    var newEvaluation = eval.evaluate(newModel,test);
    addToCSV("Multilabel Classification", 
             mlTrainers.get(trainerKey).getClass().toString(),
             model.getProvenance().getClassName(), 
             String.valueOf(oldEvaluation.toString().equals(newEvaluation.toString())), 
             ReproUtil.diffProvenance(model.getProvenance(), newModel.getProvenance()),
             "Yeast", "LibSVM");
}

ml-logistic
ml-cc-ensemble
ml-3-nn
ml-cc
ml-br
