Skip to content

Commit

Permalink
feat: add support for classification to src::search
Browse files Browse the repository at this point in the history
Issue #11
  • Loading branch information
morinim committed May 1, 2024
1 parent 804b401 commit d3530ea
Show file tree
Hide file tree
Showing 4 changed files with 21 additions and 9 deletions.
4 changes: 2 additions & 2 deletions src/kernel/gp/src/evaluator.h
Original file line number Diff line number Diff line change
Expand Up @@ -258,8 +258,8 @@ class gaussian_evaluator : public evaluator<dataframe>
public:
explicit gaussian_evaluator(dataframe &);

[[nodiscard]] double operator()(const P &) override;
std::unique_ptr<basic_oracle> oracle(const P &) const;
[[nodiscard]] double operator()(const P &) const;
[[nodiscard]] std::unique_ptr<basic_oracle> oracle(const P &) const;
};

#include "kernel/gp/src/evaluator.tcc"
Expand Down
20 changes: 15 additions & 5 deletions src/kernel/gp/src/evaluator.tcc
Original file line number Diff line number Diff line change
Expand Up @@ -108,10 +108,9 @@ auto sum_of_errors_evaluator<P, ERRF, DAT>::fast(const P &prg) const
}

///
/// \param[in] prg program(individual/team) to be transformed in a lambda
/// function
/// \return the lambda function associated with `prg` (`nullptr` in case
/// of errors).
/// \param[in] prg program (individual/team) to be transformed in an oracle
/// \return the oracle associated with `prg` (`nullptr` in case of
/// errors).
///
template<IndividualOrTeam P, template<class> class ERRF, class DAT>
requires ErrorFunction<ERRF<P>, DAT>
Expand Down Expand Up @@ -259,7 +258,7 @@ gaussian_evaluator<P>::gaussian_evaluator(dataframe &d) : evaluator(d)
/// \return the fitness (greater is better, max is `0`)
///
template<IndividualOrTeam P>
double gaussian_evaluator<P>::operator()(const P &prg)
double gaussian_evaluator<P>::operator()(const P &prg) const
{
Expects(this->dat_->classes() >= 2);
basic_gaussian_oracle<P, false, false> oracle(prg, *this->dat_);
Expand Down Expand Up @@ -287,4 +286,15 @@ double gaussian_evaluator<P>::operator()(const P &prg)
return d;
}

///
/// \param[in] prg program (individual/team) to be transformed in an oracle
/// \return the oracle associated with `prg` (`nullptr` in case of
/// errors).
///
template<IndividualOrTeam P>
std::unique_ptr<basic_oracle> gaussian_evaluator<P>::oracle(const P &prg) const
{
return std::make_unique<gaussian_oracle<P>>(prg, *this->dat_);
}

#endif // include guard
2 changes: 2 additions & 0 deletions src/kernel/gp/src/search.h
Original file line number Diff line number Diff line change
Expand Up @@ -72,9 +72,11 @@ class search
public:
using after_generation_callback_t =
ultra::after_generation_callback_t<P, double>;

using individual_t = P;
using fitness_t = double;

using class_evaluator_t = gaussian_evaluator<P>;
using reg_evaluator_t = rmae_evaluator<P>;

search(problem &p, metric_flags m = metric_flags::nothing);
Expand Down
4 changes: 2 additions & 2 deletions src/kernel/gp/src/search.tcc
Original file line number Diff line number Diff line change
Expand Up @@ -366,7 +366,7 @@ search_stats<P, typename search<P>::fitness_t> search<P>::run(
});

if (prob_.classification())
return {};
return search_scheme.template operator()<class_evaluator_t>();
else
return search_scheme.template operator()<reg_evaluator_t>();
}
Expand All @@ -375,7 +375,7 @@ template<IndividualOrTeam P>
std::unique_ptr<basic_oracle> search<P>::oracle(const P &prg) const
{
if (prob_.classification())
return nullptr;
return class_evaluator_t(prob_.data()).oracle(prg);
else
return reg_evaluator_t(prob_.data()).oracle(prg);
}
Expand Down

0 comments on commit d3530ea

Please sign in to comment.