diff --git a/examples/minimal_clustering.rs b/examples/minimal_clustering.rs index fd4fc17..495cade 100644 --- a/examples/minimal_clustering.rs +++ b/examples/minimal_clustering.rs @@ -25,6 +25,11 @@ fn main() { let mut model = ClusteringModel::new(x.clone(), settings); model.train(); + // Evaluate clustering quality using the known ground-truth assignments + // for this fixture dataset. + let truth = vec![1_u8, 1, 2, 2]; + model.evaluate(&truth); + // Print trained results println!("{model}"); diff --git a/tests/clustering.rs b/tests/clustering.rs index 11b2fb7..e5e72b3 100644 --- a/tests/clustering.rs +++ b/tests/clustering.rs @@ -5,6 +5,7 @@ use automl::{ ClusteringModel, ModelError, metrics::ClusterMetrics, settings::{ClusteringAlgorithmName, ClusteringSettings}, + utils::load_csv_features, }; use clustering_data::clustering_testing_data; use smartcore::linalg::basic::matrix::DenseMatrix; @@ -92,6 +93,23 @@ fn clustering_model_display_shows_metrics() { assert!(output.contains("1.00")); } +#[test] +fn clustering_model_display_shows_metrics_for_fixture_csv() { + // Arrange + let x = load_csv_features("tests/fixtures/clustering_points.csv").unwrap(); + let mut model: ClusteringModel, Vec> = + ClusteringModel::new(x.clone(), ClusteringSettings::default().with_k(2)); + model.train(); + let truth = vec![1_u8, 1, 2, 2]; + model.evaluate(&truth); + + // Act + let output = format!("{model}"); + + // Assert + assert!(output.contains("1.00")); +} + #[test] fn clustering_model_display_shows_configured_algorithm_when_untrained() { // Arrange