## This generates Tribuo Hdbscan predictions from dataset: 5000 records, 5 centers, 4 features

In [1]:
%jars ../../../jars/junit-jupiter-api-5.7.0.jar
%jars ../../../jars/opentest4j-1.2.0.jar
%jars ../../../jars/junit-platform-commons-1.7.1.jar
%jars ../../../jars/tribuo-clustering-hdbscan-4.3.0-jar-with-dependencies.jar

In [2]:
import java.nio.file.Paths;
import java.util.*;
import java.util.stream.Collectors;
import static org.junit.jupiter.api.Assertions.*;

In [3]:
import org.tribuo.Dataset;
import org.tribuo.Feature;
import org.tribuo.MutableDataset;
import org.tribuo.Prediction;
import org.tribuo.clustering.ClusterID;
import org.tribuo.clustering.ClusteringFactory;
import org.tribuo.clustering.hdbscan.HdbscanModel;
import org.tribuo.clustering.hdbscan.HdbscanTrainer;
import org.tribuo.data.columnar.FieldProcessor;
import org.tribuo.data.columnar.ResponseProcessor;
import org.tribuo.data.columnar.RowProcessor;
import org.tribuo.data.columnar.processors.field.DoubleFieldProcessor;
import org.tribuo.data.columnar.processors.response.EmptyResponseProcessor;
import org.tribuo.data.csv.CSVDataSource;
import org.tribuo.data.csv.CSVLoader;
import org.tribuo.math.distance.DistanceType;
import org.tribuo.math.neighbour.NeighboursQueryFactoryType;
import org.tribuo.util.Util;

In [4]:
ClusteringFactory clusteringFactory = new ClusteringFactory();
ResponseProcessor<ClusterID> emptyResponseProcessor = new EmptyResponseProcessor<>(clusteringFactory);
Map<String, FieldProcessor> regexMappingProcessors = new HashMap<>();
regexMappingProcessors.put("Feature1", new DoubleFieldProcessor("Feature1"));
regexMappingProcessors.put("Feature2", new DoubleFieldProcessor("Feature2"));
regexMappingProcessors.put("Feature3", new DoubleFieldProcessor("Feature3"));
regexMappingProcessors.put("Feature4", new DoubleFieldProcessor("Feature4"));

RowProcessor<ClusterID> rowProcessor = new RowProcessor<>(emptyResponseProcessor, regexMappingProcessors);
CSVDataSource<ClusterID> csvDataSource = new CSVDataSource<>(Paths.get("../../../data/big-gaussians-5centers-train.csv"), rowProcessor, false);
Dataset<ClusterID> dataset = new MutableDataset<>(csvDataSource);

CSVDataSource<ClusterID> csvTestSource = new CSVDataSource<>(Paths.get("../../../data/big-gaussians-5centers-predict.csv"),rowProcessor,false);
Dataset<ClusterID> predictSet = new MutableDataset<>(csvTestSource);


In [5]:
System.out.println(String.format("Data size = %d, number of features = %d",dataset.size(),dataset.getFeatureMap().size()));
System.out.println(String.format("Predict Data size = %d, number of features = %d",predictSet.size(),predictSet.getFeatureMap().size()));


Data size = 4000, number of features = 4
Predict Data size = 1000, number of features = 4


In [6]:
var trainer = new HdbscanTrainer(8, DistanceType.L2.getDistance(), 8, 2, NeighboursQueryFactoryType.KD_TREE);
var model = trainer.train(dataset);

In [7]:
System.out.println(model.getClusterLabels().size());

4000


In [8]:
// System.out.println(model.getClusterLabels());

In [9]:
// System.out.println(model.getOutlierScores());

In [10]:
List<Prediction<ClusterID>> predictions = model.predict(predictSet);
int i = 0;
List<Integer> actualLabelPredictions = new ArrayList<>();
for (Prediction<ClusterID> pred : predictions) {
    actualLabelPredictions.add(pred.getOutput().getID());
    i++;
}

In [11]:
System.out.println(actualLabelPredictions)

[8, 6, 6, 6, 3, 9, 5, 6, 6, 3, 3, 8, 5, 9, 5, 3, 3, 5, 5, 8, 8, 8, 3, 8, 8, 8, 3, 6, 6, 9, 3, 5, 6, 6, 6, 8, 8, 9, 5, 9, 9, 8, 5, 6, 8, 3, 6, 3, 3, 3, 9, 9, 5, 6, 3, 9, 9, 8, 8, 3, 3, 9, 6, 9, 9, 5, 3, 3, 6, 5, 3, 9, 8, 3, 6, 6, 9, 8, 8, 9, 6, 3, 8, 6, 9, 8, 6, 5, 8, 3, 3, 9, 8, 9, 6, 9, 6, 5, 5, 9, 5, 8, 8, 8, 8, 3, 6, 5, 3, 8, 9, 8, 6, 9, 6, 6, 9, 8, 9, 9, 5, 9, 6, 6, 6, 3, 6, 3, 5, 5, 6, 3, 9, 5, 6, 3, 8, 9, 9, 8, 9, 5, 8, 8, 3, 6, 3, 3, 8, 3, 8, 3, 5, 6, 6, 3, 5, 5, 5, 3, 8, 8, 8, 3, 3, 5, 9, 5, 9, 9, 9, 5, 5, 8, 3, 8, 3, 3, 9, 9, 6, 3, 8, 5, 9, 6, 9, 3, 5, 9, 6, 6, 6, 3, 8, 3, 9, 5, 5, 8, 8, 8, 9, 3, 6, 6, 9, 3, 9, 6, 9, 6, 9, 5, 8, 3, 8, 8, 8, 8, 8, 5, 5, 8, 9, 6, 8, 8, 3, 5, 6, 9, 8, 3, 3, 6, 6, 6, 6, 6, 6, 6, 8, 9, 6, 8, 9, 5, 9, 5, 8, 5, 8, 5, 6, 9, 9, 5, 9, 5, 6, 9, 9, 6, 8, 6, 6, 8, 3, 9, 3, 6, 6, 9, 6, 8, 6, 5, 5, 6, 8, 8, 6, 8, 9, 6, 5, 6, 8, 5, 9, 8, 5, 8, 9, 9, 5, 8, 3, 3, 9, 8, 6, 8, 6, 3, 8, 9, 5, 9, 9, 3, 8, 5, 8, 8, 9, 5, 9, 8, 8, 9, 6, 5, 5, 9, 9, 9, 5, 3, 6, 9, 3, 