Skip to content

Commit

Permalink
Created observer functions to avoid need to have printing configurati…
Browse files Browse the repository at this point in the history
…on in strategic LQRE
  • Loading branch information
tturocy committed May 17, 2024
1 parent 1af14f1 commit 8228c80
Show file tree
Hide file tree
Showing 4 changed files with 83 additions and 112 deletions.
9 changes: 2 additions & 7 deletions src/pygambit/nash.h
Original file line number Diff line number Diff line change
Expand Up @@ -72,10 +72,7 @@ std::shared_ptr<LogitQREMixedStrategyProfile> LogitStrategyAtLambdaHelper(const
StrategicQREPathTracer alg;
alg.SetMaxDecel(p_maxAccel);
alg.SetStepsize(p_firstStep);
NullBuffer null_buffer;
std::ostream null_stream(&null_buffer);
return make_shared<LogitQREMixedStrategyProfile>(
alg.SolveAtLambda(start, null_stream, p_lambda, 1.0));
return make_shared<LogitQREMixedStrategyProfile>(alg.SolveAtLambda(start, p_lambda, 1.0));
}

List<LogitQREMixedStrategyProfile> logit_principal_branch(const Game &p_game, double p_maxregret,
Expand All @@ -85,7 +82,5 @@ List<LogitQREMixedStrategyProfile> logit_principal_branch(const Game &p_game, do
StrategicQREPathTracer alg;
alg.SetMaxDecel(p_maxAccel);
alg.SetStepsize(p_firstStep);
NullBuffer null_buffer;
std::ostream null_stream(&null_buffer);
return alg.TraceStrategicPath(start, null_stream, p_maxregret, 1.0);
return alg.TraceStrategicPath(start, p_maxregret, 1.0);
}
113 changes: 36 additions & 77 deletions src/solvers/logit/nfglogit.cc
Original file line number Diff line number Diff line change
Expand Up @@ -169,8 +169,8 @@ void GetJacobian(const Game &p_game, const Vector<double> &p_point, Matrix<doubl

class TracingCallbackFunction {
public:
TracingCallbackFunction(std::ostream &p_stream, const Game &p_game, int p_decimals)
: m_stream(p_stream), m_game(p_game), m_decimals(p_decimals)
TracingCallbackFunction(const Game &p_game, MixedStrategyObserverFunctionType p_observer)
: m_game(p_game), m_observer(p_observer)
{
}
~TracingCallbackFunction() = default;
Expand All @@ -179,111 +179,71 @@ class TracingCallbackFunction {
const List<LogitQREMixedStrategyProfile> &GetProfiles() const { return m_profiles; }

private:
std::ostream &m_stream;
Game m_game;
int m_decimals;
MixedStrategyObserverFunctionType m_observer;
List<LogitQREMixedStrategyProfile> m_profiles;
};

void TracingCallbackFunction::AppendPoint(const Vector<double> &p_point)
{
m_stream.setf(std::ios::fixed);
m_stream << std::setprecision(m_decimals) << p_point.back();
m_stream.unsetf(std::ios::fixed);
MixedStrategyProfile<double> profile(m_game->NewMixedStrategyProfile(0.0));
for (int i = 1; i < p_point.Length(); i++) {
profile[i] = exp(p_point[i]);
m_stream << "," << std::setprecision(m_decimals) << profile[i];
}
m_stream << std::endl;
m_profiles.push_back(LogitQREMixedStrategyProfile(profile, p_point.back(), 0.0));
MixedStrategyProfile<double> profile(PointToProfile(m_game, p_point));
m_profiles.push_back(LogitQREMixedStrategyProfile(profile, p_point.back(), 1.0));
m_observer(m_profiles.back());
}

class EstimatorCallbackFunction {
public:
EstimatorCallbackFunction(std::ostream &p_stream, const Game &p_game,
const Vector<double> &p_frequencies, int p_decimals);
EstimatorCallbackFunction(const Game &p_game, const Vector<double> &p_frequencies,
MixedStrategyObserverFunctionType p_observer);
~EstimatorCallbackFunction() = default;

void EvaluatePoint(const Vector<double> &p_point);

LogitQREMixedStrategyProfile GetMaximizer() const
{
return {m_bestProfile, m_bestLambda, m_maxlogL};
}
void PrintMaximizer() const;
LogitQREMixedStrategyProfile GetMaximizer() const { return m_bestProfile; }

private:
void PrintProfile(const MixedStrategyProfile<double> &, double) const;

std::ostream &m_stream;
Game m_game;
const Vector<double> &m_frequencies;
int m_decimals;
MixedStrategyProfile<double> m_bestProfile;
double m_bestLambda{0.0};
double m_maxlogL;
MixedStrategyObserverFunctionType m_observer;
LogitQREMixedStrategyProfile m_bestProfile;
};

EstimatorCallbackFunction::EstimatorCallbackFunction(std::ostream &p_stream, const Game &p_game,
EstimatorCallbackFunction::EstimatorCallbackFunction(const Game &p_game,
const Vector<double> &p_frequencies,
int p_decimals)
: m_stream(p_stream), m_game(p_game), m_frequencies(p_frequencies), m_decimals(p_decimals),
m_bestProfile(p_game->NewMixedStrategyProfile(0.0)),
m_maxlogL(LogLike(p_frequencies, static_cast<const Vector<double> &>(m_bestProfile)))
{
}

void EstimatorCallbackFunction::PrintProfile(const MixedStrategyProfile<double> &p_profile,
double p_logL) const
MixedStrategyObserverFunctionType p_observer)
: m_game(p_game), m_frequencies(p_frequencies), m_observer(p_observer),
m_bestProfile(p_game->NewMixedStrategyProfile(0.0), 0.0,
LogLike(p_frequencies, static_cast<const Vector<double> &>(
p_game->NewMixedStrategyProfile(0.0))))
{
for (size_t i = 1; i <= p_profile.MixedProfileLength(); i++) {
m_stream << "," << std::setprecision(m_decimals) << p_profile[i];
}
m_stream.setf(std::ios::fixed);
m_stream << "," << std::setprecision(m_decimals) << p_logL;
m_stream.unsetf(std::ios::fixed);
}

void EstimatorCallbackFunction::PrintMaximizer() const
void EstimatorCallbackFunction::EvaluatePoint(const Vector<double> &p_point)
{
m_stream.setf(std::ios::fixed);
m_stream << std::setprecision(m_decimals) << m_bestLambda;
m_stream.unsetf(std::ios::fixed);
PrintProfile(m_bestProfile, m_maxlogL);
m_stream << std::endl;
}

void EstimatorCallbackFunction::EvaluatePoint(const Vector<double> &x)
{
m_stream.setf(std::ios::fixed);
m_stream << std::setprecision(m_decimals) << x.back();
m_stream.unsetf(std::ios::fixed);
MixedStrategyProfile<double> profile = PointToProfile(m_game, x);
double logL = LogLike(m_frequencies, static_cast<const Vector<double> &>(profile));
PrintProfile(profile, logL);
m_stream << std::endl;
if (logL > m_maxlogL) {
m_maxlogL = logL;
m_bestLambda = x.back();
m_bestProfile = profile;
MixedStrategyProfile<double> profile(PointToProfile(m_game, p_point));
auto qre = LogitQREMixedStrategyProfile(
profile, p_point.back(),
LogLike(m_frequencies, static_cast<const Vector<double> &>(profile)));
m_observer(qre);
if (qre.GetLogLike() > m_bestProfile.GetLogLike()) {
m_bestProfile = qre;
}
}

} // namespace

List<LogitQREMixedStrategyProfile>
StrategicQREPathTracer::TraceStrategicPath(const LogitQREMixedStrategyProfile &p_start,
std::ostream &p_stream, double p_regret,
double p_omega) const
double p_regret, double p_omega,
MixedStrategyObserverFunctionType p_observer) const
{
double scale = p_start.GetGame()->GetMaxPayoff() - p_start.GetGame()->GetMinPayoff();
if (scale != 0.0) {
p_regret *= scale;
}

Vector<double> x(ProfileToPoint(p_start));
TracingCallbackFunction callback(p_stream, p_start.GetGame(), m_decimals);
TracingCallbackFunction callback(p_start.GetGame(), p_observer);
TracePath([&p_start](const Vector<double> &p_point,
Vector<double> &p_lhs) { GetValue(p_start.GetGame(), p_point, p_lhs); },
[&p_start](const Vector<double> &p_point, Matrix<double> &p_jac) {
Expand All @@ -299,11 +259,11 @@ StrategicQREPathTracer::TraceStrategicPath(const LogitQREMixedStrategyProfile &p

LogitQREMixedStrategyProfile
StrategicQREPathTracer::SolveAtLambda(const LogitQREMixedStrategyProfile &p_start,
std::ostream &p_stream, double p_targetLambda,
double p_omega) const
double p_targetLambda, double p_omega,
MixedStrategyObserverFunctionType p_observer) const
{
Vector<double> x(ProfileToPoint(p_start));
TracingCallbackFunction callback(p_stream, p_start.GetGame(), m_decimals);
TracingCallbackFunction callback(p_start.GetGame(), p_observer);
TracePath([&p_start](const Vector<double> &p_point,
Vector<double> &p_lhs) { GetValue(p_start.GetGame(), p_point, p_lhs); },
[&p_start](const Vector<double> &p_point, Matrix<double> &p_jac) {
Expand All @@ -317,10 +277,9 @@ StrategicQREPathTracer::SolveAtLambda(const LogitQREMixedStrategyProfile &p_star
return callback.GetProfiles().back();
}

LogitQREMixedStrategyProfile
StrategicQREEstimator::Estimate(const LogitQREMixedStrategyProfile &p_start,
const MixedStrategyProfile<double> &p_frequencies,
std::ostream &p_stream, double p_maxLambda, double p_omega)
LogitQREMixedStrategyProfile StrategicQREEstimator::Estimate(
const LogitQREMixedStrategyProfile &p_start, const MixedStrategyProfile<double> &p_frequencies,
double p_maxLambda, double p_omega, MixedStrategyObserverFunctionType p_observer) const
{
if (p_start.GetGame() != p_frequencies.GetGame()) {
throw MismatchException();
Expand All @@ -329,7 +288,7 @@ StrategicQREEstimator::Estimate(const LogitQREMixedStrategyProfile &p_start,
Vector<double> x(ProfileToPoint(p_start));
Vector<double> freq_vector(static_cast<const Vector<double> &>(p_frequencies));
EstimatorCallbackFunction callback(
p_stream, p_start.GetGame(), static_cast<const Vector<double> &>(p_frequencies), m_decimals);
p_start.GetGame(), static_cast<const Vector<double> &>(p_frequencies), p_observer);
while (x.back() < p_maxLambda) {
TracePath(
[&p_start](const Vector<double> &p_point, Vector<double> &p_lhs) {
Expand All @@ -347,7 +306,7 @@ StrategicQREEstimator::Estimate(const LogitQREMixedStrategyProfile &p_start,
return DiffLogLike(freq_vector, p_tangent);
});
}
callback.PrintMaximizer();
p_observer(callback.GetMaximizer());
return callback.GetMaximizer();
}

Expand Down
42 changes: 19 additions & 23 deletions src/solvers/logit/nfglogit.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,39 +49,38 @@ class LogitQREMixedStrategyProfile {
double operator[](int i) const { return m_profile[i]; }

private:
const MixedStrategyProfile<double> m_profile;
MixedStrategyProfile<double> m_profile;
double m_lambda;
double m_logLike;
};

using MixedStrategyObserverFunctionType =
std::function<void(const LogitQREMixedStrategyProfile &)>;

inline void NullMixedStrategyObserver(const LogitQREMixedStrategyProfile &) {}

class StrategicQREPathTracer : public PathTracer {
public:
StrategicQREPathTracer() : m_decimals(6) {}
StrategicQREPathTracer() = default;
~StrategicQREPathTracer() override = default;

List<LogitQREMixedStrategyProfile>
TraceStrategicPath(const LogitQREMixedStrategyProfile &p_start, std::ostream &p_logStream,
double p_maxregret, double p_omega) const;
LogitQREMixedStrategyProfile SolveAtLambda(const LogitQREMixedStrategyProfile &p_start,
std::ostream &p_logStream, double p_targetLambda,
double p_omega) const;

void SetDecimals(int p_decimals) { m_decimals = p_decimals; }
int GetDecimals() const { return m_decimals; }

protected:
int m_decimals;
List<LogitQREMixedStrategyProfile> TraceStrategicPath(
const LogitQREMixedStrategyProfile &p_start, double p_maxregret, double p_omega,
MixedStrategyObserverFunctionType p_observer = NullMixedStrategyObserver) const;
LogitQREMixedStrategyProfile
SolveAtLambda(const LogitQREMixedStrategyProfile &p_start, double p_targetLambda, double p_omega,
MixedStrategyObserverFunctionType p_observer = NullMixedStrategyObserver) const;
};

class StrategicQREEstimator : public StrategicQREPathTracer {
public:
StrategicQREEstimator() = default;
~StrategicQREEstimator() override = default;

LogitQREMixedStrategyProfile Estimate(const LogitQREMixedStrategyProfile &p_start,
const MixedStrategyProfile<double> &p_frequencies,
std::ostream &p_logStream, double p_maxLambda,
double p_omega);
LogitQREMixedStrategyProfile
Estimate(const LogitQREMixedStrategyProfile &p_start,
const MixedStrategyProfile<double> &p_frequencies, double p_maxLambda, double p_omega,
MixedStrategyObserverFunctionType p_observer = NullMixedStrategyObserver) const;
};

inline LogitQREMixedStrategyProfile
Expand All @@ -92,8 +91,7 @@ LogitStrategyEstimate(const MixedStrategyProfile<double> &p_frequencies, double
StrategicQREEstimator alg;
alg.SetMaxDecel(p_maxAccel);
alg.SetStepsize(p_firstStep);
std::ostringstream ostream;
return alg.Estimate(start, p_frequencies, ostream, 1000000.0, 1.0);
return alg.Estimate(start, p_frequencies, 1000000.0, 1.0);
}

inline List<MixedStrategyProfile<double>> LogitStrategySolve(const Game &p_game, double p_regret,
Expand All @@ -102,9 +100,7 @@ inline List<MixedStrategyProfile<double>> LogitStrategySolve(const Game &p_game,
StrategicQREPathTracer tracer;
tracer.SetMaxDecel(p_maxAccel);
tracer.SetStepsize(p_firstStep);
std::ostringstream ostream;
auto result =
tracer.TraceStrategicPath(LogitQREMixedStrategyProfile(p_game), ostream, p_regret, 1.0);
auto result = tracer.TraceStrategicPath(LogitQREMixedStrategyProfile(p_game), p_regret, 1.0);
auto ret = List<MixedStrategyProfile<double>>();
if (!result.empty()) {
ret.push_back(result.back().GetProfile());
Expand Down
31 changes: 26 additions & 5 deletions src/tools/logit/logit.cc
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,23 @@ bool ReadProfile(std::istream &p_stream, MixedStrategyProfile<double> &p_profile
return true;
}

void PrintProfile(std::ostream &p_stream, int p_decimals,
const LogitQREMixedStrategyProfile &p_profile)
{
p_stream.setf(std::ios::fixed);
p_stream << std::setprecision(p_decimals) << p_profile.GetLambda();
p_stream.unsetf(std::ios::fixed);
for (size_t i = 1; i <= p_profile.MixedProfileLength(); i++) {
p_stream << "," << std::setprecision(p_decimals) << p_profile[i];
}
if (p_profile.GetLogLike() <= 0.0) {
p_stream.setf(std::ios::fixed);
p_stream << "," << std::setprecision(p_decimals) << p_profile.GetLogLike();
p_stream.unsetf(std::ios::fixed);
}
p_stream << std::endl;
}

int main(int argc, char *argv[])
{
opterr = 0;
Expand Down Expand Up @@ -174,26 +191,30 @@ int main(int argc, char *argv[])
std::ifstream mleData(mleFile.c_str());
ReadProfile(mleData, frequencies);

auto printer = [decimals](const LogitQREMixedStrategyProfile &p) {
PrintProfile(std::cout, decimals, p);
};
LogitQREMixedStrategyProfile start(game);
StrategicQREEstimator tracer;
tracer.SetMaxDecel(maxDecel);
tracer.SetStepsize(hStart);
tracer.SetDecimals(decimals);
tracer.Estimate(start, frequencies, std::cout, maxLambda, 1.0);
tracer.Estimate(start, frequencies, maxLambda, 1.0, printer);
return 0;
}

if (!game->IsTree() || useStrategic) {
auto printer = [decimals](const LogitQREMixedStrategyProfile &p) {
PrintProfile(std::cout, decimals, p);
};
LogitQREMixedStrategyProfile start(game);
StrategicQREPathTracer tracer;
tracer.SetMaxDecel(maxDecel);
tracer.SetStepsize(hStart);
tracer.SetDecimals(decimals);
if (targetLambda > 0.0) {
tracer.SolveAtLambda(start, std::cout, targetLambda, 1.0);
tracer.SolveAtLambda(start, targetLambda, 1.0, printer);
}
else {
tracer.TraceStrategicPath(start, std::cout, maxregret, 1.0);
tracer.TraceStrategicPath(start, maxregret, 1.0, printer);
}
}
else {
Expand Down

0 comments on commit 8228c80

Please sign in to comment.