Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,17 @@ will perform a comparison of classifier models using cross-validation. Printing

You can then perform inference using the best model with the `predict` method.

## Cookbook

Explore the `automl::cookbook` module for copy-pastable examples that mirror
real-world workflows:

- `cargo run --example breast_cancer_csv` – load the Wisconsin Diagnostic
Breast Cancer dataset from CSV, standardize features, and compare tuned
classifiers.
- `cargo run --example diabetes_regression` – impute, scale, and train
regression models for the diabetes progression dataset.

## Preprocessing pipelines

`automl` now supports composable preprocessing pipelines so you can build
Expand Down
60 changes: 60 additions & 0 deletions examples/breast_cancer_csv.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
#![allow(clippy::needless_doctest_main)]
//! Real-world breast cancer classification example.
//!
//! This example trains a multi-model comparison on the Wisconsin Diagnostic
//! Breast Cancer dataset bundled with the repository. The workflow shows how to
//! load a CSV file, wire up a preprocessing pipeline, and customize the
//! algorithms that participate in the comparison.
//!
//! Run with:
//!
//! ```bash
//! cargo run --example breast_cancer_csv
//! ```

#[path = "../tests/fixtures/breast_cancer_dataset.rs"]
mod breast_cancer_dataset;

use std::error::Error;

use automl::settings::{
ClassificationSettings, FinalAlgorithm, PreprocessingPipeline, PreprocessingStep,
RandomForestClassifierParameters, StandardizeParams,
};
use automl::{ClassificationModel, DenseMatrix};
use breast_cancer_dataset::load_breast_cancer_dataset;

fn main() -> Result<(), Box<dyn Error>> {
let (features, targets) = load_breast_cancer_dataset()?;

let preprocessing = PreprocessingPipeline::new()
.add_step(PreprocessingStep::Standardize(StandardizeParams::default()));

let settings = ClassificationSettings::default()
.with_number_of_folds(5)
.shuffle_data(true)
.with_final_model(FinalAlgorithm::Best)
.with_preprocessing(preprocessing)
.with_random_forest_classifier_settings(
RandomForestClassifierParameters::default()
.with_n_trees(200)
.with_max_depth(8)
.with_min_samples_split(4)
.with_min_samples_leaf(2),
);

let mut model = ClassificationModel::new(features, targets, settings);
model.train()?;

println!("{model}");

let example_patient = DenseMatrix::from_2d_vec(&vec![vec![
13.540, 14.360, 87.460, 566.300, 0.097, 0.052, 0.024, 0.015, 0.153, 0.055, 0.284, 0.915,
2.376, 23.420, 0.005, 0.013, 0.010, 0.005, 0.018, 0.002, 14.230, 17.730, 91.760, 618.800,
0.118, 0.115, 0.068, 0.025, 0.210, 0.062,
]])?;
let predictions = model.predict(example_patient)?;
println!("Predicted class for the evaluation patient: {predictions:?}");

Ok(())
}
81 changes: 81 additions & 0 deletions examples/diabetes_regression.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
#![allow(clippy::needless_doctest_main)]
//! Real-world diabetes progression regression example.
//!
//! The diabetes dataset includes 10 physiological measurements for 442
//! individuals. This example demonstrates how to configure a preprocessing
//! pipeline, tighten algorithm hyperparameters, and evaluate the models via
//! cross-validation before using the best regressor for inference.
//!
//! Run with:
//!
//! ```bash
//! cargo run --example diabetes_regression
//! ```

#[path = "../tests/fixtures/diabetes_dataset.rs"]
mod diabetes_dataset;

use std::error::Error;

use automl::settings::{
ColumnSelector, FinalAlgorithm, ImputeParams, ImputeStrategy, Kernel, PreprocessingPipeline,
PreprocessingStep, RandomForestRegressorParameters, RegressionSettings, SVRParameters,
ScaleParams, ScaleStrategy, StandardizeParams,
};
use automl::{DenseMatrix, RegressionModel};
use diabetes_dataset::load_diabetes_dataset;

fn main() -> Result<(), Box<dyn Error>> {
let (features, targets) = load_diabetes_dataset()?;

let preprocessing = PreprocessingPipeline::new()
.add_step(PreprocessingStep::Impute(ImputeParams {
strategy: ImputeStrategy::Median,
selector: ColumnSelector::All,
}))
.add_step(PreprocessingStep::Scale(ScaleParams {
selector: ColumnSelector::All,
strategy: ScaleStrategy::Standard(StandardizeParams::default()),
}));

let settings = RegressionSettings::default()
.with_number_of_folds(8)
.shuffle_data(true)
.with_final_model(FinalAlgorithm::Best)
.with_preprocessing(preprocessing)
.with_random_forest_regressor_settings(
RandomForestRegressorParameters::default()
.with_n_trees(250)
.with_max_depth(6)
.with_min_samples_leaf(2)
.with_min_samples_split(4),
)
.with_svr_settings(
SVRParameters::default()
.with_c(12.5)
.with_eps(0.05)
.with_kernel(Kernel::RBF(0.35)),
);

let mut model = RegressionModel::new(features, targets, settings);
model.train()?;

println!("{model}");

let evaluation_visit = DenseMatrix::from_2d_vec(&vec![vec![
0.038_075_906,
0.050_680_119,
0.061_696_207,
0.021_872_355,
-0.044_223_498,
-0.034_820_763,
-0.043_400_846,
-0.002_592_262,
0.019_908_421,
-0.017_646_125,
]])?;
let predicted_progression = model.predict(evaluation_visit)?;
println!("Predicted disease progression: {predicted_progression:?}");

Ok(())
}
17 changes: 17 additions & 0 deletions src/cookbook.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,20 @@
//! ```rust,ignore
#![doc = include_str!("../examples/maximal_regression.rs")]
//! ```
//!
//! ## Wisconsin Breast Cancer Classification
//!
//! Demonstrates loading data from `data/breast_cancer.csv`, standardizing every
//! feature, and customizing the random forest search space before running the
//! leaderboard comparison.
//! ```rust,ignore
#![doc = include_str!("../examples/breast_cancer_csv.rs")]
//! ```
//!
//! ## Diabetes Progression Regression
//!
//! Shows how to impute, standardize, and tune regression algorithms on the
//! diabetes dataset that ships with the repository.
//! ```rust,ignore
#![doc = include_str!("../examples/diabetes_regression.rs")]
//! ```
61 changes: 61 additions & 0 deletions tests/fixtures/breast_cancer_dataset.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
use std::error::Error;
use std::path::Path;

use csv::ReaderBuilder;
use smartcore::linalg::basic::matrix::DenseMatrix;

type CsvRows = (Vec<Vec<f64>>, Vec<String>);
type CsvResult = Result<CsvRows, Box<dyn Error>>;

fn load_feature_rows<P: AsRef<Path>>(path: P) -> CsvResult {
let mut reader = ReaderBuilder::new().has_headers(true).from_path(path)?;
let mut features = Vec::new();
let mut targets = Vec::new();

for record in reader.records() {
let record = record?;
let record_len = record.len();
if record_len < 2 {
return Err("dataset requires at least one feature and a target column".into());
}
let feature_len = record_len - 1;
let mut row = Vec::with_capacity(feature_len);
for value in record.iter().take(feature_len) {
row.push(value.parse()?);
}
let target_value = record
.get(feature_len)
.ok_or("dataset missing target column")?;
features.push(row);
targets.push(target_value.to_string());
}

Ok((features, targets))
}

fn parse_label(raw: &str) -> Result<u32, Box<dyn Error>> {
let numeric: f64 = raw.parse()?;
if (numeric - 1.0).abs() < f64::EPSILON {
Ok(1)
} else if numeric.abs() < f64::EPSILON {
Ok(0)
} else {
Err("unexpected label".into())
}
}

/// Load the Wisconsin Diagnostic Breast Cancer dataset from `data/breast_cancer.csv`.
///
/// # Errors
///
/// Returns an error if the CSV file cannot be read or parsed into numeric data.
pub fn load_breast_cancer_dataset() -> Result<(DenseMatrix<f64>, Vec<u32>), Box<dyn Error>> {
let (feature_rows, raw_targets) = load_feature_rows("data/breast_cancer.csv")?;
let features = DenseMatrix::from_2d_vec(&feature_rows)?;
let targets = raw_targets
.into_iter()
.map(|value| parse_label(&value))
.collect::<Result<Vec<_>, Box<dyn Error>>>()?;

Ok((features, targets))
}
50 changes: 50 additions & 0 deletions tests/fixtures/diabetes_dataset.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
use std::error::Error;
use std::path::Path;

use csv::ReaderBuilder;
use smartcore::linalg::basic::matrix::DenseMatrix;

type CsvRows = (Vec<Vec<f64>>, Vec<String>);
type CsvResult = Result<CsvRows, Box<dyn Error>>;

fn load_feature_rows<P: AsRef<Path>>(path: P) -> CsvResult {
let mut reader = ReaderBuilder::new().has_headers(true).from_path(path)?;
let mut features = Vec::new();
let mut targets = Vec::new();

for record in reader.records() {
let record = record?;
let record_len = record.len();
if record_len < 2 {
return Err("dataset requires at least one feature and a target column".into());
}
let feature_len = record_len - 1;
let mut row = Vec::with_capacity(feature_len);
for value in record.iter().take(feature_len) {
row.push(value.parse()?);
}
let target_value = record
.get(feature_len)
.ok_or("dataset missing target column")?;
features.push(row);
targets.push(target_value.to_string());
}

Ok((features, targets))
}

/// Load the diabetes progression dataset from `data/diabetes.csv`.
///
/// # Errors
///
/// Returns an error if the CSV file cannot be read or parsed into numeric data.
pub fn load_diabetes_dataset() -> Result<(DenseMatrix<f64>, Vec<f64>), Box<dyn Error>> {
let (feature_rows, raw_targets) = load_feature_rows("data/diabetes.csv")?;
let features = DenseMatrix::from_2d_vec(&feature_rows)?;
let targets = raw_targets
.into_iter()
.map(|value| -> Result<f64, Box<dyn Error>> { Ok(value.parse()?) })
.collect::<Result<Vec<_>, Box<dyn Error>>>()?;

Ok((features, targets))
}
29 changes: 29 additions & 0 deletions tests/real_world_datasets.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
#[path = "fixtures/breast_cancer_dataset.rs"]
mod breast_cancer_dataset;
#[path = "fixtures/diabetes_dataset.rs"]
mod diabetes_dataset;

use breast_cancer_dataset::load_breast_cancer_dataset;
use diabetes_dataset::load_diabetes_dataset;
use smartcore::linalg::basic::arrays::Array;

#[test]
fn breast_cancer_dataset_has_expected_shape() {
let (x, y) = load_breast_cancer_dataset().expect("dataset should load");
let (rows, cols) = x.shape();
assert_eq!(rows, 569);
assert_eq!(cols, 30);
assert_eq!(y.len(), rows);
let positives = y.iter().filter(|label| **label == 1).count();
assert_eq!(positives, 212);
}

#[test]
fn diabetes_dataset_has_expected_shape() {
let (x, y) = load_diabetes_dataset().expect("dataset should load");
let (rows, cols) = x.shape();
assert_eq!(rows, 442);
assert_eq!(cols, 10);
assert_eq!(y.len(), rows);
assert!(y.iter().all(|value| value.is_finite()));
}