diff --git a/src/model/clustering.rs b/src/model/clustering.rs
index 01ab082..85e4e0c 100644
--- a/src/model/clustering.rs
+++ b/src/model/clustering.rs
@@ -11,6 +11,7 @@ use comfy_table::{
};
use smartcore::linalg::basic::arrays::{Array1, Array2};
use smartcore::numbers::{basenum::Number, floatnum::FloatNumber, realnum::RealNumber};
+use std::collections::BTreeSet;
use std::fmt::{Display, Formatter};
/// Trains clustering models
@@ -53,8 +54,9 @@ where
for algorithm_name in self.settings.selected_algorithms() {
let algorithm = ClusteringAlgorithm::from_name(algorithm_name);
let fitted = algorithm.fit(&self.x_train, &self.settings);
- self.trained_algorithms
- .push(TrainedClusteringAlgorithm::new(algorithm_name, fitted));
+ let mut trained = TrainedClusteringAlgorithm::new(algorithm_name, fitted);
+ trained.compute_baseline(&self.x_train, &self.settings);
+ self.trained_algorithms.push(trained);
}
}
@@ -132,6 +134,8 @@ where
table.apply_modifier(UTF8_SOLID_INNER_BORDERS);
table.set_header(vec![
Cell::new("Model").add_attribute(Attribute::Bold),
+ Cell::new("Clusters").add_attribute(Attribute::Bold),
+ Cell::new("Noise").add_attribute(Attribute::Bold),
Cell::new("Homogeneity").add_attribute(Attribute::Bold),
Cell::new("Completeness").add_attribute(Attribute::Bold),
Cell::new("V-Measure").add_attribute(Attribute::Bold),
@@ -144,6 +148,8 @@ where
"-".to_string(),
"-".to_string(),
"-".to_string(),
+ "-".to_string(),
+ "-".to_string(),
]);
}
} else {
@@ -156,6 +162,22 @@ where
}
}
+/// Aggregate cluster statistics that do not require ground-truth labels.
+#[derive(Debug, Clone, Copy)]
+struct ClusterBaseline {
+ cluster_count: usize,
+ noise_count: usize,
+}
+
+impl ClusterBaseline {
+ const fn new(cluster_count: usize, noise_count: usize) -> Self {
+ Self {
+ cluster_count,
+ noise_count,
+ }
+ }
+}
+
/// Trained clustering algorithm with optional metrics.
struct TrainedClusteringAlgorithm
where
@@ -167,6 +189,7 @@ where
algorithm_name: ClusteringAlgorithmName,
algorithm: ClusteringAlgorithm,
metrics: Option>,
+ baseline: Option,
}
impl
@@ -185,6 +208,7 @@ where
algorithm_name,
algorithm,
metrics: None,
+ baseline: None,
}
}
@@ -192,6 +216,27 @@ where
self.algorithm.predict(x, settings)
}
+ fn compute_baseline(&mut self, x: &InputArray, settings: &ClusteringSettings) {
+ let Ok(predictions) = self.predict(x, settings) else {
+ self.baseline = None;
+ return;
+ };
+
+ let mut unique_clusters: BTreeSet = BTreeSet::new();
+ let mut noise_count = 0_usize;
+
+ for label in predictions.iterator(0) {
+ let value = *label;
+ if self.algorithm_name == ClusteringAlgorithmName::DBSCAN && value == CLUSTER::zero() {
+ noise_count += 1;
+ } else {
+ unique_clusters.insert(value);
+ }
+ }
+
+ self.baseline = Some(ClusterBaseline::new(unique_clusters.len(), noise_count));
+ }
+
fn display_row(&self) -> Vec {
let (homogeneity, completeness, v_measure) = if let Some(scores) = &self.metrics {
let format_score = |s: Option| match s {
@@ -207,8 +252,19 @@ where
("-".to_string(), "-".to_string(), "-".to_string())
};
+ let (clusters, noise) = if let Some(baseline) = &self.baseline {
+ (
+ baseline.cluster_count.to_string(),
+ baseline.noise_count.to_string(),
+ )
+ } else {
+ ("-".to_string(), "-".to_string())
+ };
+
vec![
self.algorithm_name.to_string(),
+ clusters,
+ noise,
homogeneity,
completeness,
v_measure,
diff --git a/tests/clustering.rs b/tests/clustering.rs
index e5e72b3..ce17bd4 100644
--- a/tests/clustering.rs
+++ b/tests/clustering.rs
@@ -89,6 +89,8 @@ fn clustering_model_display_shows_metrics() {
assert!(output.contains("KMeans"));
assert!(output.contains("Agglomerative"));
assert!(output.contains("DBSCAN"));
+ assert!(output.contains("Clusters"));
+ assert!(output.contains("Noise"));
assert!(output.contains("V-Measure"));
assert!(output.contains("1.00"));
}
@@ -126,6 +128,23 @@ fn clustering_model_display_shows_configured_algorithm_when_untrained() {
assert!(output.contains("Homogeneity"));
}
+#[test]
+fn clustering_model_display_shows_baseline_without_ground_truth() {
+ // Arrange
+ let x = clustering_testing_data();
+ let mut model: ClusteringModel, Vec> =
+ ClusteringModel::new(x.clone(), ClusteringSettings::default().with_k(2));
+ model.train();
+
+ // Act
+ let output = format!("{model}");
+
+ // Assert
+ assert!(output.contains("Clusters"));
+ assert!(output.contains("Noise"));
+ assert!(output.contains('2'));
+}
+
#[test]
fn clustering_model_display_clears_metrics_after_retraining() {
// Arrange