In [1]:
%jars ../../../jars/junit-jupiter-api-5.7.0.jar
%jars ../../../jars/opentest4j-1.2.0.jar
%jars ../../../jars/junit-platform-commons-1.7.1.jar
%jars ../../../jars/tribuo-clustering-hdbscan-4.3.0-SNAPSHOT-jar-with-dependencies.jar
%jars ../../../jars/opencsv-5.4.jar

In [2]:
import java.nio.file.Paths;
import java.util.*;
import java.util.stream.Collectors;
import static org.junit.jupiter.api.Assertions.*;
import com.opencsv.CSVReaderHeaderAware;

In [3]:
import org.tribuo.Dataset;
import org.tribuo.Feature;
import org.tribuo.MutableDataset;
import org.tribuo.Prediction;
import org.tribuo.clustering.ClusterID;
import org.tribuo.clustering.ClusteringFactory;
import org.tribuo.clustering.hdbscan.HdbscanModel;
import org.tribuo.clustering.hdbscan.HdbscanTrainer;
import org.tribuo.data.columnar.FieldProcessor;
import org.tribuo.data.columnar.ResponseProcessor;
import org.tribuo.data.columnar.RowProcessor;
import org.tribuo.data.columnar.processors.field.DoubleFieldProcessor;
import org.tribuo.data.columnar.processors.response.EmptyResponseProcessor;
import org.tribuo.data.csv.CSVDataSource;
import org.tribuo.data.csv.CSVLoader;
import org.tribuo.math.distance.DistanceType;
import org.tribuo.math.neighbour.NeighboursQueryFactoryType;
import org.tribuo.util.Util;

In [4]:
String train_csv_path = "../../../data/synthetic_data_training_scaled.csv";
Map<String, FieldProcessor> regexMappingProcessors = new HashMap<>(); 
var reader = new CSVReaderHeaderAware(new FileReader(train_csv_path));
var record = reader.readMap();  
for (Map.Entry<String, String> e : record.entrySet()) {
    regexMappingProcessors.put(e.getKey().toString(), new DoubleFieldProcessor(e.getKey().toString()));
}

RowProcessor<ClusterID> rowProcessor = new RowProcessor<>(new EmptyResponseProcessor<>(new ClusteringFactory()), regexMappingProcessors);
CSVDataSource<ClusterID> csvSource = new CSVDataSource<>(Paths.get(train_csv_path), rowProcessor, false);
// Dataset creation from CSV
Dataset<ClusterID> dataset = new MutableDataset<>(csvSource);

In [5]:
System.out.println(String.format("Data size = %d, number of features = %d",dataset.size(),dataset.getFeatureMap().size()));


Data size = 2500, number of features = 1013


In [6]:
// minClusterSize = 5
var trainer = new HdbscanTrainer(5,  DistanceType.L2, 5, 2, NeighboursQueryFactoryType.BRUTE_FORCE);
var model = trainer.train(dataset);

In [7]:
int i = 0;
var indexList = new ArrayList<Integer>();

for (Integer label: model.getClusterLabels()) {
    if (label != 0) {
        indexList.add(i);
        System.out.println("data point " + i + " has label " + label);
    }    
    i++;
}

if (indexList.isEmpty()) {
    System.out.println("No clusters identified");
}

No clusters identified


In [8]:
for (Integer index : indexList) {
    System.out.println("data point " + index + " has outlier score " + model.getOutlierScores().get(index));
}

if (indexList.isEmpty()) {
    System.out.println("No clusters identified");
}

No clusters identified


In [9]:
// minClusterSize = 3
trainer = new HdbscanTrainer(3,  DistanceType.L2, 5, 2, NeighboursQueryFactoryType.BRUTE_FORCE);
model = trainer.train(dataset);

In [10]:
System.out.println(model.getClusterLabels().size());

2500


In [11]:
i = 0;
indexList = new ArrayList<Integer>();

for (Integer label: model.getClusterLabels()) {
    if (label != 0) {
        indexList.add(i);
        System.out.println("data point " + i + " has label " + label);
    }    
    i++;
}

// System.out.println(model.getClusterLabels());

data point 55 has label 2
data point 89 has label 2
data point 660 has label 2
data point 726 has label 3
data point 898 has label 3
data point 2049 has label 3
data point 2089 has label 3
data point 2351 has label 3
data point 2404 has label 2
data point 2416 has label 3


In [12]:
for (Integer index : indexList) {
    System.out.println("data point " + index + " has outlier score " + model.getOutlierScores().get(index));
}

data point 55 has outlier score 0.0
data point 89 has outlier score 1.0298608123904263E-4
data point 660 has outlier score 0.0
data point 726 has outlier score 0.0
data point 898 has outlier score 0.00657807853522574
data point 2049 has outlier score 0.0
data point 2089 has outlier score 0.0051976652402137
data point 2351 has outlier score 0.006594760113519671
data point 2404 has outlier score 0.0
data point 2416 has outlier score 0.0


In [13]:
System.out.println(model.getOutlierScores());

[0.04370600067141073, 0.06170271994717813, 0.05768754110232288, 0.06358733474269485, 0.06612283965439336, 0.02848406936645631, 0.0371003336836222, 0.054453997475938265, 0.05516287583571433, 0.0764498067109608, 0.05032605184197203, 0.05608685893458165, 0.01973009974812201, 0.07275131026065895, 0.11013072258910428, 0.06228760084146823, 0.05631215604250661, 0.04364332739642929, 0.04340325890380048, 0.04165607664698834, 0.040744553560097985, 0.04768510432447437, 0.055211605690372534, 0.04502316721526034, 0.06355652195450401, 0.04528713029125653, 0.03959533939608528, 0.02970385789368246, 0.052385172238500366, 0.06274974089341445, 0.05380670678084487, 0.038948483224989894, 0.06444395366012967, 0.05446837408845395, 0.07967502420749273, 0.02309926520681327, 0.052219636442236705, 0.05809830622702428, 0.03855252589648728, 0.0477262191439326, 0.07578383508343112, 0.03280480662965857, 0.03634470495576447, 0.04148909855864047, 0.05005897852943153, 0.06516458069501574, 0.026702597075294765, 0.067816

680054287, 0.027514415039316176, 0.03565335774997347, 0.05247169428814302, 0.04420276363800646, 0.06193081021128066, 0.05077844505882212, 0.036424639507253054, 0.040981664700412646, 0.050428849631916406, 0.032110812662432875, 0.09659951572715475, 0.03982050390803227, 0.07901583940327528, 0.04437223759321529, 0.04217653329103288, 0.06237848380644506, 0.09190718193971659, 0.05811433599989668, 0.05561518924261166, 0.056961971461188754, 0.033106105927066865, 0.05676476294466015, 0.03269639714114858, 0.043616484856089155, 0.044695783341390904, 0.03122480642515102, 0.050540829984592306, 0.05877091408236079, 0.013943982237902608, 0.06271477828175454, 0.03131736091035531, 0.04975527507686073, 0.026815327455338545, 0.045856655161370385, 0.05565185828847463, 0.05154137837086992, 0.01800840510586632, 0.06264013568395765, 0.060013590418363294, 0.024319682991367997, 0.04465344847790642, 0.02131729331190213, 0.035017615845195915, 0.05539991925873067, 0.06017722086408772, 0.025624625730331085, 0.0520

587860754, 0.06135356850796725, 0.05278078242384043, 0.04519458155839484, 0.05939504119727601, 0.0475818171474065, 0.0541575674439867, 0.06292491309099879, 0.05922504590336475, 0.06768320451532839, 0.045969240626757224, 0.05298992581934592, 0.021862235621462278, 0.04601107308724062, 0.051128516659990075, 0.0533778072817197, 0.03809623347135549, 0.04156944183246136, 0.044816645122416454, 0.013240365834809809, 0.03739020108349822, 0.05516418469331241, 0.07092677657426427, 0.06486845545679776, 0.04046266016589706, 0.045377416247933344, 0.024903079074797385, 0.03224519743194709, 0.031345800280183966, 0.046623349720296026, 0.03776963118740895, 0.047878950003297005, 0.053520096297947184, 0.07503811043104858, 0.06821840005371149, 0.010167221068000387, 0.05925871271051619, 0.031104651452736443, 0.03192528811031348, 0.059281962105048325, 0.023427805316783124, 0.06351596144920968, 0.0693635346719026, 0.045080394590996, 0.08595852088739242, 0.03561607159953717, 0.07297256212362657, 0.081717854874

5571327935, 0.056976177012074136, 0.04818521377146301, 0.05747089173438902, 0.026634737356925342, 0.03500958641902685, 0.03700109678323538, 0.027368953216127645, 0.06585089465891802, 0.06946096327215634, 0.01904610399434592, 0.056697845542403424, 0.02563565679405444, 0.050146626664409166, 0.08753791976740577, 0.05146034943286859, 0.0548894780382152, 0.07284677840454612, 0.06927867512887187, 0.059125767507222426, 0.034773697463826414, 0.06459703580155518, 0.03213289088638471, 0.07317457751943135, 0.03677497974513133, 0.05389870329138435, 0.04423144899253462, 0.04853306740367025, 0.048859824649504224, 0.006594760113519671, 0.056226491918556554, 0.03181894602988633, 0.05118893248421286, 0.08034995182306781, 0.06288758342009682, 0.07038087940526905, 0.04896340716066816, 0.0670284940339354, 0.06781046312229022, 0.050786151535848645, 0.04178427441135957, 0.04827317203814796, 0.058437440363382165, 0.04860672655998144, 0.036875517368259336, 0.03135806874784586, 0.0602039586349391, 0.0535391749