Skip to content

Commit

Permalink
add AggregateEvaluator
Browse files Browse the repository at this point in the history
  • Loading branch information
foolnotion committed Mar 31, 2023
1 parent b1d7c6b commit 88a15c3
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 5 deletions.
29 changes: 24 additions & 5 deletions include/operon/operators/evaluator.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ struct EvaluatorBase : public OperatorBase<Operon::Vector<Operon::Scalar>, Indiv
size_t budget_ = DefaultEvaluationBudget;
};

class UserDefinedEvaluator : public EvaluatorBase {
class OPERON_EXPORT UserDefinedEvaluator : public EvaluatorBase {
public:
UserDefinedEvaluator(Problem& problem, std::function<typename EvaluatorBase::ReturnType(Operon::RandomGenerator&, Operon::Individual&)> func)
: EvaluatorBase(problem)
Expand Down Expand Up @@ -179,7 +179,7 @@ class MultiEvaluator : public EvaluatorBase {
{
}

void Add(EvaluatorBase const& evaluator)
auto Add(EvaluatorBase const& evaluator)
{
evaluators_.push_back(std::ref(evaluator));
}
Expand Down Expand Up @@ -225,9 +225,29 @@ class MultiEvaluator : public EvaluatorBase {
std::vector<std::reference_wrapper<EvaluatorBase const>> evaluators_;
};

class OPERON_EXPORT AggregateEvaluator final : public EvaluatorBase {
public:
enum class AggregateType : int { Min, Max, Median, Mean, HarmonicMean, Sum };

explicit AggregateEvaluator(EvaluatorBase& evaluator)
: EvaluatorBase(evaluator.GetProblem()), evaluator_(evaluator)
{
}

auto SetAggregateType(AggregateType type) { aggtype_ = type; }
auto GetAggregateType() const { return aggtype_; }

auto
operator()(Operon::RandomGenerator& rng, Individual& ind, Operon::Span<Operon::Scalar> buf) const -> typename EvaluatorBase::ReturnType override;

private:
std::reference_wrapper<EvaluatorBase const> evaluator_;
AggregateType aggtype_{AggregateType::Mean};
};

// a couple of useful user-defined evaluators (mostly to avoid calling lambdas from python)
// TODO: think about a better design
class LengthEvaluator : public UserDefinedEvaluator {
class OPERON_EXPORT LengthEvaluator : public UserDefinedEvaluator {
public:
explicit LengthEvaluator(Operon::Problem& problem, size_t maxlength = 1)
: UserDefinedEvaluator(problem, [maxlength](Operon::RandomGenerator& /*unused*/, Operon::Individual& ind) {
Expand All @@ -237,7 +257,7 @@ class LengthEvaluator : public UserDefinedEvaluator {
}
};

class ShapeEvaluator : public UserDefinedEvaluator {
class OPERON_EXPORT ShapeEvaluator : public UserDefinedEvaluator {
public:
explicit ShapeEvaluator(Operon::Problem& problem)
: UserDefinedEvaluator(problem, [](Operon::RandomGenerator& /*unused*/, Operon::Individual& ind) {
Expand Down Expand Up @@ -331,6 +351,5 @@ class OPERON_EXPORT AkaikeInformationCriterionEvaluator final : public Evaluator
Operon::MSE mse_;
};


} // namespace Operon
#endif
35 changes: 35 additions & 0 deletions source/operators/evaluator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,41 @@ namespace Operon {
return EvaluatorBase::ReturnType { -distance / static_cast<Operon::Scalar>(sampleSize_) };
}

auto
AggregateEvaluator::operator()(Operon::RandomGenerator& rng, Individual& ind, Operon::Span<Operon::Scalar> buf) const -> typename EvaluatorBase::ReturnType
{
using vstat::univariate::accumulate;
auto f = evaluator_.get()(rng, ind, buf);
switch(aggtype_) {
case AggregateType::Min: {
return { *std::min_element(f.begin(), f.end()) };
}
case AggregateType::Max: {
return { *std::max_element(f.begin(), f.end()) };
}
case AggregateType::Median: {
auto const sz { std::ssize(f) };
auto const a = f.begin() + sz / 2;
std::nth_element(f.begin(), a, f.end());
if (sz % 2 == 0) {
auto const b = std::max_element(f.begin(), a);
return { (*a + *b) / 2 };
}
return { *a };
}
case AggregateType::Mean: {
return { static_cast<Operon::Scalar>(accumulate<Operon::Scalar>(f.begin(), f.end()).mean) };
}
case AggregateType::HarmonicMean: {
auto stats = accumulate<Operon::Scalar>(f.begin(), f.end(), [](auto x) { return 1/x; });
return { static_cast<Operon::Scalar>(stats.count / stats.sum) };
}
case AggregateType::Sum: {
return { static_cast<Operon::Scalar>(vstat::univariate::accumulate<Operon::Scalar>(f.begin(), f.end()).sum) };
}
}
}

auto MinimumDescriptionLengthEvaluator::operator()(Operon::RandomGenerator& rng, Individual& ind, Operon::Span<Operon::Scalar> buf) const -> typename EvaluatorBase::ReturnType {
// call the base method of the evaluator in order to optimize the coefficients
// this also returns the error which we are going to use
Expand Down

0 comments on commit 88a15c3

Please sign in to comment.