diff --git a/src/model/clustering.rs b/src/model/clustering.rs index 01ab082..85e4e0c 100644 --- a/src/model/clustering.rs +++ b/src/model/clustering.rs @@ -11,6 +11,7 @@ use comfy_table::{ }; use smartcore::linalg::basic::arrays::{Array1, Array2}; use smartcore::numbers::{basenum::Number, floatnum::FloatNumber, realnum::RealNumber}; +use std::collections::BTreeSet; use std::fmt::{Display, Formatter}; /// Trains clustering models @@ -53,8 +54,9 @@ where for algorithm_name in self.settings.selected_algorithms() { let algorithm = ClusteringAlgorithm::from_name(algorithm_name); let fitted = algorithm.fit(&self.x_train, &self.settings); - self.trained_algorithms - .push(TrainedClusteringAlgorithm::new(algorithm_name, fitted)); + let mut trained = TrainedClusteringAlgorithm::new(algorithm_name, fitted); + trained.compute_baseline(&self.x_train, &self.settings); + self.trained_algorithms.push(trained); } } @@ -132,6 +134,8 @@ where table.apply_modifier(UTF8_SOLID_INNER_BORDERS); table.set_header(vec![ Cell::new("Model").add_attribute(Attribute::Bold), + Cell::new("Clusters").add_attribute(Attribute::Bold), + Cell::new("Noise").add_attribute(Attribute::Bold), Cell::new("Homogeneity").add_attribute(Attribute::Bold), Cell::new("Completeness").add_attribute(Attribute::Bold), Cell::new("V-Measure").add_attribute(Attribute::Bold), @@ -144,6 +148,8 @@ where "-".to_string(), "-".to_string(), "-".to_string(), + "-".to_string(), + "-".to_string(), ]); } } else { @@ -156,6 +162,22 @@ where } } +/// Aggregate cluster statistics that do not require ground-truth labels. +#[derive(Debug, Clone, Copy)] +struct ClusterBaseline { + cluster_count: usize, + noise_count: usize, +} + +impl ClusterBaseline { + const fn new(cluster_count: usize, noise_count: usize) -> Self { + Self { + cluster_count, + noise_count, + } + } +} + /// Trained clustering algorithm with optional metrics. struct TrainedClusteringAlgorithm where @@ -167,6 +189,7 @@ where algorithm_name: ClusteringAlgorithmName, algorithm: ClusteringAlgorithm, metrics: Option>, + baseline: Option, } impl @@ -185,6 +208,7 @@ where algorithm_name, algorithm, metrics: None, + baseline: None, } } @@ -192,6 +216,27 @@ where self.algorithm.predict(x, settings) } + fn compute_baseline(&mut self, x: &InputArray, settings: &ClusteringSettings) { + let Ok(predictions) = self.predict(x, settings) else { + self.baseline = None; + return; + }; + + let mut unique_clusters: BTreeSet = BTreeSet::new(); + let mut noise_count = 0_usize; + + for label in predictions.iterator(0) { + let value = *label; + if self.algorithm_name == ClusteringAlgorithmName::DBSCAN && value == CLUSTER::zero() { + noise_count += 1; + } else { + unique_clusters.insert(value); + } + } + + self.baseline = Some(ClusterBaseline::new(unique_clusters.len(), noise_count)); + } + fn display_row(&self) -> Vec { let (homogeneity, completeness, v_measure) = if let Some(scores) = &self.metrics { let format_score = |s: Option| match s { @@ -207,8 +252,19 @@ where ("-".to_string(), "-".to_string(), "-".to_string()) }; + let (clusters, noise) = if let Some(baseline) = &self.baseline { + ( + baseline.cluster_count.to_string(), + baseline.noise_count.to_string(), + ) + } else { + ("-".to_string(), "-".to_string()) + }; + vec![ self.algorithm_name.to_string(), + clusters, + noise, homogeneity, completeness, v_measure, diff --git a/tests/clustering.rs b/tests/clustering.rs index e5e72b3..ce17bd4 100644 --- a/tests/clustering.rs +++ b/tests/clustering.rs @@ -89,6 +89,8 @@ fn clustering_model_display_shows_metrics() { assert!(output.contains("KMeans")); assert!(output.contains("Agglomerative")); assert!(output.contains("DBSCAN")); + assert!(output.contains("Clusters")); + assert!(output.contains("Noise")); assert!(output.contains("V-Measure")); assert!(output.contains("1.00")); } @@ -126,6 +128,23 @@ fn clustering_model_display_shows_configured_algorithm_when_untrained() { assert!(output.contains("Homogeneity")); } +#[test] +fn clustering_model_display_shows_baseline_without_ground_truth() { + // Arrange + let x = clustering_testing_data(); + let mut model: ClusteringModel, Vec> = + ClusteringModel::new(x.clone(), ClusteringSettings::default().with_k(2)); + model.train(); + + // Act + let output = format!("{model}"); + + // Assert + assert!(output.contains("Clusters")); + assert!(output.contains("Noise")); + assert!(output.contains('2')); +} + #[test] fn clustering_model_display_clears_metrics_after_retraining() { // Arrange