Skip to content

Commit

Permalink
feat: add oracle member function to search class
Browse files Browse the repository at this point in the history
Issue #11
  • Loading branch information
morinim committed Apr 28, 2024
1 parent 29b49c0 commit d402ad4
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 5 deletions.
7 changes: 6 additions & 1 deletion src/kernel/gp/src/search.h
Original file line number Diff line number Diff line change
Expand Up @@ -73,9 +73,14 @@ class search
using individual_t = P;
using fitness_t = double;

using reg_evaluator_t = rmae_evaluator<P>;

search(problem &p, metric_flags m = metric_flags::nothing);

search_stats<P, fitness_t> run(unsigned = 1);
search_stats<P, fitness_t> run(unsigned = 1,
const model_measurements<fitness_t> & = {});

std::unique_ptr<basic_oracle> oracle(const P &) const;

private:
// *** Private data members ***
Expand Down
18 changes: 14 additions & 4 deletions src/kernel/gp/src/search.tcc
Original file line number Diff line number Diff line change
Expand Up @@ -353,19 +353,29 @@ search<P>::search(problem &p, metric_flags m) : prob_(p), metrics_(m)
}

template<IndividualOrTeam P>
search_stats<P, typename search<P>::fitness_t> search<P>::run(unsigned n)
search_stats<P, typename search<P>::fitness_t> search<P>::run(
unsigned n, const model_measurements<fitness_t> &threshold)
{
if (prob_.classification())
{
return {};
}
else
{
basic_search<alps_es, rmae_evaluator<P>> reg_search(
prob_, rmae_evaluator<P>(prob_.data()), metrics_);
basic_search<alps_es, reg_evaluator_t> reg_search(
prob_, reg_evaluator_t(prob_.data()), metrics_);

return reg_search.run(n);
return reg_search.run(n, threshold);
}
}

template<IndividualOrTeam P>
std::unique_ptr<basic_oracle> search<P>::oracle(const P &prg) const
{
if (prob_.classification())
return nullptr;
else
return reg_evaluator_t(prob_.data()).oracle(prg);
}

#endif // include guard
2 changes: 2 additions & 0 deletions src/kernel/search.tcc
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,8 @@ void search<ES, E>::tune_parameters()
Ensures(prob_.params.is_valid(true));
}

///
/// Executes a given number of evolutionary-runs possibly saving good runs.
///
/// \param[in] n number of runs
/// \param[in] threshold used to identify successfully learned (matched,
Expand Down

0 comments on commit d402ad4

Please sign in to comment.