Skip to content

Commit

Permalink
feat: add after_evolution method to search / evolution classes
Browse files Browse the repository at this point in the history
Issue #11
  • Loading branch information
morinim committed May 1, 2024
1 parent bb2af27 commit 869b69f
Show file tree
Hide file tree
Showing 8 changed files with 160 additions and 18 deletions.
67 changes: 67 additions & 0 deletions src/examples/symbolic_regression02.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
/*
* \remark This file is part of VITA.
*
* \copyright Copyright (C) 2024 EOS di Manlio Morini.
*
* \license
* This Source Code Form is subject to the terms of the Mozilla Public
* License, v. 2.0. If a copy of the MPL was not distributed with this file,
* You can obtain one at http://mozilla.org/MPL/2.0/
*
* \see
* https://github.com/morinim/ultra/wiki/symbolic_regression
*/

/* CHANGES IN THIS FILE MUST BE APPLIED TO THE LINKED WIKI PAGE */

#include "kernel/ultra.h"

int main()
{
using namespace ultra;

// DATA SAMPLE
// (the target function is `x + sin(x)`)
std::istringstream training(R"(
-9.456,-10.0
-8.989, -8.0
-5.721, -6.0
-3.243, -4.0
-2.909, -2.0
0.000, 0.0
2.909, 2.0
3.243, 4.0
5.721, 6.0
8.989, 8.0
)");

// READING INPUT DATA
src::problem prob(training);

// SETTING UP SYMBOLS
prob.insert<real::sin>();
prob.insert<real::cos>();
prob.insert<real::add>();
prob.insert<real::sub>();
prob.insert<real::div>();
prob.insert<real::mul>();

// SEARCHING
prob.params.evolution.generations = 50;
src::search s(prob);

// This is a callback function invoked at the end of every generation and
// useful to gather statistical data.
s.after_generation([](const auto &pop, const auto &)
{
for (const auto &i : pop)
std::cout << out::python_language << i << '\n';
std::cout << "--------------------------------------\n";
});

const auto result(s.run());

std::cout << "\nCANDIDATE SOLUTION\n"
<< out::c_language << result.best_individual
<< "\n\nFITNESS\n" << *result.best_measurements.fitness << '\n';
}
12 changes: 11 additions & 1 deletion src/kernel/evolution.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,10 @@
namespace ultra
{

template<Individual I, Fitness F>
using after_generation_callback_t =
std::function<void(const layered_population<I> &, const summary<I, F> &)>;

///
/// Progressively evolves a population of programs over a series of
/// generations.
Expand All @@ -40,11 +44,15 @@ class evolution
using individual_t = typename S::individual_t;
using fitness_t = typename S::fitness_t;

using after_generation_callback_t =
ultra::after_generation_callback_t<individual_t, fitness_t>;

explicit evolution(const S &);

summary<individual_t, fitness_t> run();

void set_shake_function(const std::function<bool(unsigned)> &);
evolution &after_generation(after_generation_callback_t);
evolution &shake_function(const std::function<bool(unsigned)> &);

[[nodiscard]] bool is_valid() const;

Expand All @@ -59,6 +67,8 @@ class evolution
layered_population<individual_t> pop_;
S es_;
std::function<bool(unsigned)> shake_ {};

after_generation_callback_t after_generation_callback_ {};
};

#include "kernel/evolution.tcc"
Expand Down
25 changes: 23 additions & 2 deletions src/kernel/evolution.tcc
Original file line number Diff line number Diff line change
Expand Up @@ -170,15 +170,33 @@ void evolution<S>::print(bool summary, std::chrono::milliseconds elapsed,
/// Sets the shake function.
///
/// \param[in] f the shaking function
/// \return a reference to *this* object (method chaining / fluent
/// interface)
///
/// The shake function is called every new generation and is used to alter
/// the environment of the evolution (i.e. it could change the points for a
/// symbolic regression problem, the examples for a classification task...).
///
template<Strategy S>
void evolution<S>::set_shake_function(const std::function<bool(unsigned)> &f)
evolution<S> &evolution<S>::shake_function(
const std::function<bool(unsigned)> &f)
{
shake_ = f;
return *this;
}

///
/// Sets a callback function called at the end of every generation.
///
/// \param[in] f callback function
/// \return a reference to *this* object (method chaining / fluent
/// interface)
///
template<Strategy S>
evolution<S> &evolution<S>::after_generation(after_generation_callback_t f)
{
after_generation_callback_ = std::move(f);
return *this;
}

/// The evolutionary core loop.
Expand Down Expand Up @@ -277,7 +295,10 @@ evolution<S>::run()

sum_.az = analyze(pop_, es_.evaluator());
log_evolution();
es_.after_generation(pop_, sum_);

es_.after_generation(pop_, sum_); // strategy-specific bookkeeping
if (after_generation_callback_)
after_generation_callback_(pop_, sum_);
}

sum_.elapsed = from_start.elapsed();
Expand Down
6 changes: 6 additions & 0 deletions src/kernel/gp/src/search.h
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,8 @@ template<IndividualOrTeam P = gp::individual>
class search
{
public:
using after_generation_callback_t =
ultra::after_generation_callback_t<P, double>;
using individual_t = P;
using fitness_t = double;

Expand All @@ -82,10 +84,14 @@ class search

std::unique_ptr<basic_oracle> oracle(const P &) const;

search &after_generation(after_generation_callback_t);

private:
// *** Private data members ***
problem &prob_; // problem we're working on
metric_flags metrics_; // metrics we have to calculate during the search

after_generation_callback_t after_generation_callback_ {};
};


Expand Down
32 changes: 24 additions & 8 deletions src/kernel/gp/src/search.tcc
Original file line number Diff line number Diff line change
Expand Up @@ -356,17 +356,19 @@ template<IndividualOrTeam P>
search_stats<P, typename search<P>::fitness_t> search<P>::run(
unsigned n, const model_measurements<fitness_t> &threshold)
{
if (prob_.classification())
const auto search_scheme([&]<Evaluator E>()
{
basic_search<alps_es, E> search(prob_, E(prob_.data()), metrics_);

search.after_generation(after_generation_callback_);

return search.run(n, threshold);
});

if (prob_.classification())
return {};
}
else
{
basic_search<alps_es, reg_evaluator_t> reg_search(
prob_, reg_evaluator_t(prob_.data()), metrics_);

return reg_search.run(n, threshold);
}
return search_scheme.template operator()<reg_evaluator_t>();
}

template<IndividualOrTeam P>
Expand All @@ -378,4 +380,18 @@ std::unique_ptr<basic_oracle> search<P>::oracle(const P &prg) const
return reg_evaluator_t(prob_.data()).oracle(prg);
}

///
/// Sets a callback function executed at the end of every generation.
///
/// \param[in] f callback function
/// \return a reference to *this* object (method chaining / fluent
/// interface)
///
template<IndividualOrTeam P>
search<P> &search<P>::after_generation(after_generation_callback_t f)
{
after_generation_callback_ = std::move(f);
return *this;
}

#endif // include guard
7 changes: 7 additions & 0 deletions src/kernel/search.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@ class search
public:
using individual_t = evaluator_individual_t<E>;
using fitness_t = evaluator_fitness_t<E>;
using after_generation_callback_t =
ultra::after_generation_callback_t<individual_t, fitness_t>;

search(problem &, E);

Expand All @@ -60,6 +62,8 @@ class search

template<class V, class... Args> search &validation_strategy(Args && ...);

search &after_generation(after_generation_callback_t);

[[nodiscard]] virtual bool is_valid() const;

protected:
Expand All @@ -73,6 +77,9 @@ class search

problem &prob_; // problem we're working on

// Callback functions.
after_generation_callback_t after_generation_callback_ {};

private:
// Template method of the `search::run` member function called exactly one
// time just before the first run.
Expand Down
17 changes: 16 additions & 1 deletion src/kernel/search.tcc
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,20 @@ void search<ES, E>::init()
load();
}

///
/// Sets a callback function executed at the end of every generation.
///
/// \param[in] f callback function
/// \return a reference to *this* object (method chaining / fluent
/// interface)
///
template<template<class> class ES, Evaluator E>
search<ES, E> &search<ES, E>::after_generation(after_generation_callback_t f)
{
after_generation_callback_ = std::move(f);
return *this;
}

///
/// Tries to tune search parameters for the current problem.
///
Expand Down Expand Up @@ -132,7 +146,8 @@ search<ES, E>::run(unsigned n, const model_measurements<fitness_t> &threshold)
vs_->training_setup(r);

evolution evo(es_);
evo.set_shake_function(shake);
evo.after_generation(after_generation_callback_);
evo.shake_function(shake);
const auto run_summary(evo.run());

vs_->validation_setup(r);
Expand Down
12 changes: 6 additions & 6 deletions src/test/evolution.cc
Original file line number Diff line number Diff line change
Expand Up @@ -53,12 +53,12 @@ TEST_CASE_FIXTURE(fixture1, "Shake function")
test_evaluator<gp::individual> eva(test_evaluator_type::realistic);

evolution evo(std_es(prob, eva));
evo.set_shake_function([i = 0](unsigned gen) mutable
{
CHECK(gen == i);
++i;
return true;
});
evo.shake_function([i = 0](unsigned gen) mutable
{
CHECK(gen == i);
++i;
return true;
});

evo.run();
}
Expand Down

0 comments on commit 869b69f

Please sign in to comment.