diff --git a/src/model/preprocessing.rs b/src/model/preprocessing.rs index 822f29d..c7dc0ee 100644 --- a/src/model/preprocessing.rs +++ b/src/model/preprocessing.rs @@ -1,12 +1,14 @@ //! Utilities for data preprocessing. -use crate::settings::PreProcessing; +use crate::model::error::ModelError; +use crate::settings::{PreProcessing, SettingsError}; use crate::utils::features::{FeatureError, interaction_features, polynomial_features}; use smartcore::{ decomposition::{ pca::{PCA, PCAParameters}, svd::{SVD, SVDParameters}, }, + error::Failed, linalg::{ basic::arrays::{Array, Array2}, traits::{ @@ -52,72 +54,103 @@ where } } - /// Train preprocessing models based on settings. - pub fn train(&mut self, x: &InputArray, settings: &PreProcessing) { + /// Fit preprocessing state (if required) and return a transformed copy of the + /// training matrix. + pub fn fit_transform( + &mut self, + x: InputArray, + settings: &PreProcessing, + ) -> Result { + self.pca = None; + self.svd = None; match settings { + PreProcessing::None => Ok(x), + PreProcessing::AddInteractions => { + interaction_features(x).map_err(Self::feature_error_to_settings) + } + PreProcessing::AddPolynomial { order } => { + polynomial_features(x, *order).map_err(Self::feature_error_to_settings) + } PreProcessing::ReplaceWithPCA { number_of_components, - } => { - self.train_pca(x, *number_of_components); - } + } => self.fit_pca(&x, *number_of_components), PreProcessing::ReplaceWithSVD { number_of_components, - } => { - self.train_svd(x, *number_of_components); - } - _ => {} + } => self.fit_svd(&x, *number_of_components), } } - /// Apply preprocessing to data. + /// Apply preprocessing to inference data. pub fn preprocess( &self, x: InputArray, settings: &PreProcessing, - ) -> Result { - Ok(match settings { - PreProcessing::None => x, - PreProcessing::AddInteractions => interaction_features(x)?, - PreProcessing::AddPolynomial { order } => polynomial_features(x, *order)?, - PreProcessing::ReplaceWithPCA { - number_of_components: _, - } => self.pca_features(&x), - PreProcessing::ReplaceWithSVD { - number_of_components: _, - } => self.svd_features(&x), - }) + ) -> Result { + match settings { + PreProcessing::None => Ok(x), + PreProcessing::AddInteractions => { + interaction_features(x).map_err(Self::feature_error_to_model) + } + PreProcessing::AddPolynomial { order } => { + polynomial_features(x, *order).map_err(Self::feature_error_to_model) + } + PreProcessing::ReplaceWithPCA { .. } => self.pca_features(&x), + PreProcessing::ReplaceWithSVD { .. } => self.svd_features(&x), + } } - fn train_pca(&mut self, x: &InputArray, n: usize) { + fn fit_pca(&mut self, x: &InputArray, n: usize) -> Result { let pca = PCA::fit( x, PCAParameters::default() .with_n_components(n) .with_use_correlation_matrix(true), ) - .expect("Could not train PCA preprocessor"); + .map_err(|err| Self::failed_to_settings(&err))?; + let transformed = pca + .transform(x) + .map_err(|err| Self::failed_to_settings(&err))?; self.pca = Some(pca); + Ok(transformed) } - fn pca_features(&self, x: &InputArray) -> InputArray { - self.pca + fn pca_features(&self, x: &InputArray) -> Result { + let pca = self + .pca .as_ref() - .expect("PCA model not trained") - .transform(x) - .expect("Could not transform data using PCA") + .ok_or_else(|| ModelError::Inference("PCA model not trained".to_string()))?; + pca.transform(x) + .map_err(|err| ModelError::Inference(err.to_string())) } - fn train_svd(&mut self, x: &InputArray, n: usize) { + fn fit_svd(&mut self, x: &InputArray, n: usize) -> Result { let svd = SVD::fit(x, SVDParameters::default().with_n_components(n)) - .expect("Could not train SVD preprocessor"); + .map_err(|err| Self::failed_to_settings(&err))?; + let transformed = svd + .transform(x) + .map_err(|err| Self::failed_to_settings(&err))?; self.svd = Some(svd); + Ok(transformed) } - fn svd_features(&self, x: &InputArray) -> InputArray { - self.svd + fn svd_features(&self, x: &InputArray) -> Result { + let svd = self + .svd .as_ref() - .expect("SVD model not trained") - .transform(x) - .expect("Could not transform data using SVD") + .ok_or_else(|| ModelError::Inference("SVD model not trained".to_string()))?; + svd.transform(x) + .map_err(|err| ModelError::Inference(err.to_string())) + } + + fn feature_error_to_settings(err: FeatureError) -> SettingsError { + SettingsError::PreProcessingFailed(err.to_string()) + } + + fn feature_error_to_model(err: FeatureError) -> ModelError { + ModelError::Inference(err.to_string()) + } + + fn failed_to_settings(err: &Failed) -> SettingsError { + SettingsError::PreProcessingFailed(err.to_string()) } } diff --git a/src/model/supervised.rs b/src/model/supervised.rs index 198b098..78aaa98 100644 --- a/src/model/supervised.rs +++ b/src/model/supervised.rs @@ -11,13 +11,14 @@ use crate::model::{ preprocessing::Preprocessor, }; use crate::settings::{ - ClassificationSettings, FinalAlgorithm, Metric, RegressionSettings, SupervisedSettings, + ClassificationSettings, FinalAlgorithm, Metric, RegressionSettings, SettingsError, + SupervisedSettings, }; use comfy_table::{ Attribute, Cell, Table, modifiers::UTF8_SOLID_INNER_BORDERS, presets::UTF8_FULL, }; use humantime::format_duration; -use smartcore::error::Failed; +use smartcore::error::{Failed, FailedError}; use smartcore::linalg::{ basic::arrays::{Array, Array1, Array2, MutArrayView1}, traits::{ @@ -111,7 +112,9 @@ where { /// Settings for the model. pub settings: S, - /// Training features. + /// Original training features used to recompute preprocessing steps. + x_train_raw: InputArray, + /// Preprocessed training features fed to algorithms. x_train: InputArray, /// Training targets. y_train: OutputArray, @@ -136,8 +139,10 @@ where { /// Create a new supervised model. pub fn new(x: InputArray, y: OutputArray, settings: S) -> Self { + let x_train_raw = x.clone(); Self { settings, + x_train_raw, x_train: x, y_train: y, comparison: Vec::new(), @@ -152,8 +157,11 @@ where /// Returns [`Failed`] if cross-validation fails for any algorithm. pub fn train(&mut self) -> Result<(), Failed> { let sup = self.settings.supervised(); - self.preprocessor - .train(&self.x_train.clone(), &sup.preprocessing); + let raw = self.x_train_raw.clone(); + self.x_train = self + .preprocessor + .fit_transform(raw, &sup.preprocessing) + .map_err(|err| Self::preprocessing_failed(&err))?; for alg in ::all_algorithms(&self.settings) { let trained = alg.cross_validate_model(&self.x_train, &self.y_train, &self.settings)?; @@ -169,10 +177,7 @@ where /// Returns [`ModelError::NotTrained`] if no algorithm has been trained or if inference fails. pub fn predict(&self, x: InputArray) -> ModelResult { let sup = self.settings.supervised(); - let x = self - .preprocessor - .preprocess(x, &sup.preprocessing) - .map_err(|e| ModelError::Inference(e.to_string()))?; + let x = self.preprocessor.preprocess(x, &sup.preprocessing)?; match sup.final_model_approach { FinalAlgorithm::None => Err(ModelError::NotTrained), @@ -203,6 +208,10 @@ where self.comparison.reverse(); } } + + fn preprocessing_failed(err: &SettingsError) -> Failed { + Failed::because(FailedError::ParametersError, &err.to_string()) + } } impl Display for SupervisedModel diff --git a/src/settings/error.rs b/src/settings/error.rs index dd71ace..d6c02d8 100644 --- a/src/settings/error.rs +++ b/src/settings/error.rs @@ -5,12 +5,14 @@ use std::fmt::{Display, Formatter}; use super::Metric; /// Errors related to model settings. -#[derive(Debug, Clone, Copy, PartialEq, Eq)] +#[derive(Debug, Clone, PartialEq, Eq)] pub enum SettingsError { /// A required metric was not specified. MetricNotSet, /// The provided metric is not supported for the task. UnsupportedMetric(Metric), + /// Preprocessing configuration failed to run successfully. + PreProcessingFailed(String), } impl Display for SettingsError { @@ -18,6 +20,9 @@ impl Display for SettingsError { match self { Self::MetricNotSet => write!(f, "a metric must be set"), Self::UnsupportedMetric(m) => write!(f, "unsupported metric: {m}"), + Self::PreProcessingFailed(msg) => { + write!(f, "preprocessing configuration failed: {msg}") + } } } } diff --git a/tests/classification.rs b/tests/classification.rs index 127bc11..01e7e4a 100644 --- a/tests/classification.rs +++ b/tests/classification.rs @@ -5,7 +5,7 @@ use automl::algorithms::ClassificationAlgorithm; use automl::model::Algorithm; use automl::settings::{ BernoulliNBParameters, CategoricalNBParameters, ClassificationSettings, - MultinomialNBParameters, RandomForestClassifierParameters, SVCParameters, + MultinomialNBParameters, PreProcessing, RandomForestClassifierParameters, SVCParameters, }; use automl::{DenseMatrix, ModelError, SupervisedModel}; use classification_data::{ @@ -217,6 +217,32 @@ fn bernoulli_nb_rejects_non_binary_without_threshold() { ); } +#[test] +fn classification_pca_preprocessing_predicts() { + type Model = SupervisedModel< + ClassificationAlgorithm, Vec>, + ClassificationSettings, + DenseMatrix, + Vec, + >; + + let (x, y) = classification_testing_data(); + let settings = ClassificationSettings::default() + .with_svc_settings(SVCParameters::default()) + .with_preprocessing(PreProcessing::ReplaceWithPCA { + number_of_components: 2, + }); + + let mut model: Model = SupervisedModel::new(x, y, settings); + model.train().unwrap(); + + let predictions = model + .predict(DenseMatrix::from_2d_array(&[&[0.0, 0.0], &[1.0, 1.0]]).unwrap()) + .expect("PCA-preprocessed model should predict successfully"); + + assert_eq!(predictions.len(), 2); +} + #[test] fn invalid_alpha_returns_error() { // Arrange diff --git a/tests/regression.rs b/tests/regression.rs index 9e9e91e..716e9b9 100644 --- a/tests/regression.rs +++ b/tests/regression.rs @@ -4,7 +4,7 @@ mod regression_data; use automl::algorithms::RegressionAlgorithm; use automl::model::Algorithm; use automl::settings::{ - Distance, ExtraTreesRegressorParameters, KNNParameters, Kernel, SVRParameters, + Distance, ExtraTreesRegressorParameters, KNNParameters, Kernel, PreProcessing, SVRParameters, XGRegressorParameters, }; use automl::{DenseMatrix, RegressionSettings, SupervisedModel}; @@ -234,6 +234,36 @@ fn test_xgboost_skiplist_controls_algorithms() { )); } +#[test] +fn regression_polynomial_preprocessing_predicts() { + type Model = SupervisedModel< + RegressionAlgorithm, Vec>, + RegressionSettings, Vec>, + DenseMatrix, + Vec, + >; + + let (x, y) = regression_testing_data(); + let settings = RegressionSettings::default() + .with_preprocessing(PreProcessing::AddPolynomial { order: 2 }) + .only(&RegressionAlgorithm::default_knn_regressor()); + + let mut regressor: Model = SupervisedModel::new(x, y, settings); + regressor.train().unwrap(); + + let predictions = regressor + .predict( + DenseMatrix::from_2d_array(&[ + &[234.289, 235.6, 159.0, 107.608, 1947., 60.323], + &[259.426, 232.5, 145.6, 108.632, 1948., 61.122], + ]) + .unwrap(), + ) + .expect("Polynomial preprocessing should allow prediction"); + + assert_eq!(predictions.len(), 2); +} + fn test_from_settings(settings: RegressionSettings, Vec>) { // Set up the regressor settings and load data type Model = SupervisedModel<