From d3530ea75d762038911abaf298a4b44bc80c534f Mon Sep 17 00:00:00 2001 From: Manlio Morini Date: Wed, 1 May 2024 19:18:05 +0200 Subject: [PATCH] feat: add support for classification to `src::search` Issue #11 --- src/kernel/gp/src/evaluator.h | 4 ++-- src/kernel/gp/src/evaluator.tcc | 20 +++++++++++++++----- src/kernel/gp/src/search.h | 2 ++ src/kernel/gp/src/search.tcc | 4 ++-- 4 files changed, 21 insertions(+), 9 deletions(-) diff --git a/src/kernel/gp/src/evaluator.h b/src/kernel/gp/src/evaluator.h index dcdcbf7..536692d 100644 --- a/src/kernel/gp/src/evaluator.h +++ b/src/kernel/gp/src/evaluator.h @@ -258,8 +258,8 @@ class gaussian_evaluator : public evaluator public: explicit gaussian_evaluator(dataframe &); - [[nodiscard]] double operator()(const P &) override; - std::unique_ptr oracle(const P &) const; + [[nodiscard]] double operator()(const P &) const; + [[nodiscard]] std::unique_ptr oracle(const P &) const; }; #include "kernel/gp/src/evaluator.tcc" diff --git a/src/kernel/gp/src/evaluator.tcc b/src/kernel/gp/src/evaluator.tcc index c234525..418a677 100644 --- a/src/kernel/gp/src/evaluator.tcc +++ b/src/kernel/gp/src/evaluator.tcc @@ -108,10 +108,9 @@ auto sum_of_errors_evaluator::fast(const P &prg) const } /// -/// \param[in] prg program(individual/team) to be transformed in a lambda -/// function -/// \return the lambda function associated with `prg` (`nullptr` in case -/// of errors). +/// \param[in] prg program (individual/team) to be transformed in an oracle +/// \return the oracle associated with `prg` (`nullptr` in case of +/// errors). /// template class ERRF, class DAT> requires ErrorFunction, DAT> @@ -259,7 +258,7 @@ gaussian_evaluator

::gaussian_evaluator(dataframe &d) : evaluator(d) /// \return the fitness (greater is better, max is `0`) /// template -double gaussian_evaluator

::operator()(const P &prg) +double gaussian_evaluator

::operator()(const P &prg) const { Expects(this->dat_->classes() >= 2); basic_gaussian_oracle oracle(prg, *this->dat_); @@ -287,4 +286,15 @@ double gaussian_evaluator

::operator()(const P &prg) return d; } +/// +/// \param[in] prg program (individual/team) to be transformed in an oracle +/// \return the oracle associated with `prg` (`nullptr` in case of +/// errors). +/// +template +std::unique_ptr gaussian_evaluator

::oracle(const P &prg) const +{ + return std::make_unique>(prg, *this->dat_); +} + #endif // include guard diff --git a/src/kernel/gp/src/search.h b/src/kernel/gp/src/search.h index 75b9b4b..ab21399 100644 --- a/src/kernel/gp/src/search.h +++ b/src/kernel/gp/src/search.h @@ -72,9 +72,11 @@ class search public: using after_generation_callback_t = ultra::after_generation_callback_t; + using individual_t = P; using fitness_t = double; + using class_evaluator_t = gaussian_evaluator

; using reg_evaluator_t = rmae_evaluator

; search(problem &p, metric_flags m = metric_flags::nothing); diff --git a/src/kernel/gp/src/search.tcc b/src/kernel/gp/src/search.tcc index e57b536..f0df144 100644 --- a/src/kernel/gp/src/search.tcc +++ b/src/kernel/gp/src/search.tcc @@ -366,7 +366,7 @@ search_stats::fitness_t> search

::run( }); if (prob_.classification()) - return {}; + return search_scheme.template operator()(); else return search_scheme.template operator()(); } @@ -375,7 +375,7 @@ template std::unique_ptr search

::oracle(const P &prg) const { if (prob_.classification()) - return nullptr; + return class_evaluator_t(prob_.data()).oracle(prg); else return reg_evaluator_t(prob_.data()).oracle(prg); }