Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pick chosen move based on normal distribution LCB #2290

Merged
merged 11 commits into from Mar 26, 2019
2 changes: 2 additions & 0 deletions src/GTP.cpp
Expand Up @@ -89,6 +89,7 @@ float cfg_logconst;
float cfg_softmax_temp;
float cfg_fpu_reduction;
float cfg_fpu_root_reduction;
float cfg_ci_alpha;
std::string cfg_weightsfile;
std::string cfg_logfile;
FILE* cfg_logfile_handle;
Expand Down Expand Up @@ -347,6 +348,7 @@ void GTP::setup_default_parameters() {
cfg_resignpct = -1;
cfg_noise = false;
cfg_fpu_root_reduction = cfg_fpu_reduction;
cfg_ci_alpha = 1e-5f;
cfg_random_cnt = 0;
cfg_random_min_visits = 1;
cfg_random_temp = 1.0f;
Expand Down
1 change: 1 addition & 0 deletions src/GTP.h
Expand Up @@ -114,6 +114,7 @@ extern float cfg_logconst;
extern float cfg_softmax_temp;
extern float cfg_fpu_reduction;
extern float cfg_fpu_root_reduction;
extern float cfg_ci_alpha;
extern std::string cfg_logfile;
extern std::string cfg_weightsfile;
extern FILE* cfg_logfile_handle;
Expand Down
6 changes: 6 additions & 0 deletions src/Leela.cpp
Expand Up @@ -203,6 +203,7 @@ static void parse_commandline(int argc, char *argv[]) {
("logconst", po::value<float>())
("softmax_temp", po::value<float>())
("fpu_reduction", po::value<float>())
("ci_alpha", po::value<float>())
;
#endif
// These won't be shown, we use them to catch incorrect usage of the
Expand Down Expand Up @@ -278,6 +279,9 @@ static void parse_commandline(int argc, char *argv[]) {
if (vm.count("fpu_reduction")) {
cfg_fpu_reduction = vm["fpu_reduction"].as<float>();
}
if (vm.count("ci_alpha")) {
cfg_ci_alpha = vm["ci_alpha"].as<float>();
}
#endif

if (vm.count("logfile")) {
Expand Down Expand Up @@ -494,6 +498,8 @@ void init_global_objects() {
// improves reproducibility across platforms.
Random::get_Rng().seedrandom(cfg_rng_seed);

Utils::create_z_table();

initialize_network();
}

Expand Down
41 changes: 41 additions & 0 deletions src/UCTNode.cpp
Expand Up @@ -180,8 +180,16 @@ void UCTNode::virtual_loss_undo() {
}

void UCTNode::update(float eval) {
// Cache values to avoid race conditions.
auto old_eval = static_cast<float>(m_blackevals);
auto old_visits = static_cast<int>(m_visits);
auto old_delta = old_visits > 0 ? eval - old_eval / old_visits : 0.0f;
m_visits++;
accumulate_eval(eval);
auto new_delta = eval - (old_eval + eval) / (old_visits + 1);
// Welford's online algorithm for calculating variance.
auto delta = old_delta * new_delta;
atomic_add(m_squared_diff, delta);
}

bool UCTNode::has_children() const {
Expand All @@ -207,10 +215,32 @@ void UCTNode::set_policy(float policy) {
m_policy = policy;
}

float UCTNode::get_variance(float default_var) const {
Ttl marked this conversation as resolved.
Show resolved Hide resolved
return m_visits > 1 ? m_squared_diff / (m_visits - 1) : default_var;
}

float UCTNode::get_stddev(float default_stddev) const {
Ttl marked this conversation as resolved.
Show resolved Hide resolved
return m_visits > 1 ? std::sqrt(get_variance()) : default_stddev;
}

int UCTNode::get_visits() const {
return m_visits;
}

float UCTNode::get_lcb(int color) const {
// Lower confidence bound of winrate.
Ttl marked this conversation as resolved.
Show resolved Hide resolved
auto visits = get_visits();
if (visits < 2) {
return 0.0f;
}
auto mean = get_raw_eval(color);

auto stddev = std::sqrt(get_variance(1.0f) / visits);
Hersmunch marked this conversation as resolved.
Show resolved Hide resolved
auto z = cached_t_quantile(visits - 1);

return mean - z * stddev;
}

float UCTNode::get_raw_eval(int tomove, int virtual_loss) const {
auto visits = get_visits() + virtual_loss;
assert(visits > 0);
Expand Down Expand Up @@ -313,6 +343,17 @@ class NodeComp : public std::binary_function<UCTNodePointer&,
auto a_visit = a.get_visits();
auto b_visit = b.get_visits();

// Calculate the lower confidence bound for each node.
if (a_visit && b_visit) {
auto a_lcb = a.get_lcb(m_color);
auto b_lcb = b.get_lcb(m_color);
Ttl marked this conversation as resolved.
Show resolved Hide resolved

// Sort on lower confidence bounds
if (a_lcb != b_lcb) {
return a_lcb < b_lcb;
}
}
Ttl marked this conversation as resolved.
Show resolved Hide resolved

// if visits are not same, sort on visits
if (a_visit != b_visit) {
return a_visit < b_visit;
Expand Down
6 changes: 6 additions & 0 deletions src/UCTNode.h
Expand Up @@ -76,12 +76,15 @@ class UCTNode {
int get_visits() const;
float get_policy() const;
void set_policy(float policy);
float get_variance(float default_var = 0.0f) const;
float get_stddev(float default_stddev = 0.0f) const;
float get_eval(int tomove) const;
float get_raw_eval(int tomove, int virtual_loss = 0) const;
float get_net_eval(int tomove) const;
void virtual_loss();
void virtual_loss_undo();
void update(float eval);
float get_lcb(int color) const;

// Defined in UCTNodeRoot.cpp, only to be called on m_root in UCTSearch
void randomize_first_proportionally();
Expand Down Expand Up @@ -122,6 +125,9 @@ class UCTNode {
float m_policy;
// Original net eval for this node (not children).
float m_net_eval{0.0f};
// Initialize to prior of variance. Avoids accidental zero variances
// at low visits.
std::atomic<float> m_squared_diff{0.01f};
Ttl marked this conversation as resolved.
Show resolved Hide resolved
std::atomic<double> m_blackevals{0.0};
std::atomic<Status> m_status{ACTIVE};

Expand Down
12 changes: 12 additions & 0 deletions src/UCTNodePointer.cpp
Expand Up @@ -137,6 +137,18 @@ float UCTNodePointer::get_policy() const {
return read_policy(v);
}

float UCTNodePointer::get_stddev(float default_stddev) const {
Ttl marked this conversation as resolved.
Show resolved Hide resolved
auto v = m_data.load();
if (is_inflated(v)) return read_ptr(v)->get_stddev(default_stddev);
return default_stddev;
}

float UCTNodePointer::get_lcb(int color) const {
assert(is_inflated());
auto v = m_data.load();
return read_ptr(v)->get_lcb(color);
}

bool UCTNodePointer::active() const {
auto v = m_data.load();
if (is_inflated(v)) return read_ptr(v)->active();
Expand Down
2 changes: 2 additions & 0 deletions src/UCTNodePointer.h
Expand Up @@ -129,6 +129,8 @@ class UCTNodePointer {
bool valid() const;
int get_visits() const;
float get_policy() const;
float get_stddev(float default_stddev) const;
float get_lcb(int color) const;
bool active() const;
int get_move() const;
// this can only be called if it is an inflated pointer
Expand Down
28 changes: 20 additions & 8 deletions src/UCTSearch.cpp
Expand Up @@ -283,10 +283,11 @@ void UCTSearch::dump_stats(FastState & state, UCTNode & parent) {
tmpstate.play_move(node->get_move());
auto pv = move + " " + get_pv(tmpstate, *node);

myprintf("%4s -> %7d (V: %5.2f%%) (N: %5.2f%%) PV: %s\n",
myprintf("%4s -> %7d (V: %5.2f%%) (LCB: %5.2f%%) (N: %5.2f%%) PV: %s\n",
move.c_str(),
node->get_visits(),
node->get_visits() ? node->get_raw_eval(color)*100.0f : 0.0f,
std::max(0.0f, node->get_lcb(color) * 100.0f),
node->get_policy() * 100.0f,
pv.c_str());
}
Expand Down Expand Up @@ -615,28 +616,39 @@ int UCTSearch::est_playouts_left(int elapsed_centis, int time_for_move) const {
static_cast<int>(std::ceil(playout_rate * time_left)));
}

size_t UCTSearch::prune_noncontenders(int elapsed_centis, int time_for_move, bool prune) {
size_t UCTSearch::prune_noncontenders(int color, int elapsed_centis, int time_for_move, bool prune) {
auto lcb_max = 0.0f;
auto Nfirst = 0;
// There are no cases where the root's children vector gets modified
// during a multithreaded search, so it is safe to walk it here without
// taking the (root) node lock.
for (const auto& node : m_root->get_children()) {
if (node->valid()) {
Nfirst = std::max(Nfirst, node->get_visits());
const auto visits = node->get_visits();
if (visits > 0) {
lcb_max = std::max(lcb_max, node->get_lcb(color));
}
Nfirst = std::max(Nfirst, visits);
Ttl marked this conversation as resolved.
Show resolved Hide resolved
}
}
const auto min_required_visits =
Nfirst - est_playouts_left(elapsed_centis, time_for_move);
auto pruned_nodes = size_t{0};
for (const auto& node : m_root->get_children()) {
if (node->valid()) {
const auto visits = node->get_visits();
const auto has_enough_visits =
node->get_visits() >= min_required_visits;
visits >= min_required_visits;
// Avoid pruning moves that could have the best lower confidence
// bound.
const auto high_winrate = visits > 0 ?
node->get_raw_eval(color) >= lcb_max : false;
const auto prune_this_node = !(has_enough_visits || high_winrate);

if (prune) {
node->set_active(has_enough_visits);
node->set_active(!prune_this_node);
}
if (!has_enough_visits) {
if (prune_this_node) {
++pruned_nodes;
}
}
Expand All @@ -650,9 +662,10 @@ bool UCTSearch::have_alternate_moves(int elapsed_centis, int time_for_move) {
if (cfg_timemanage == TimeManagement::OFF) {
return true;
}
auto my_color = m_rootstate.get_to_move();
// For self play use. Disables pruning of non-contenders to not bias the training data.
auto prune = cfg_timemanage != TimeManagement::NO_PRUNING;
auto pruned = prune_noncontenders(elapsed_centis, time_for_move, prune);
auto pruned = prune_noncontenders(my_color, elapsed_centis, time_for_move, prune);
if (pruned < m_root->get_children().size() - 1) {
return true;
}
Expand All @@ -661,7 +674,6 @@ bool UCTSearch::have_alternate_moves(int elapsed_centis, int time_for_move) {
// which will cause Leela to quickly respond to obvious/forced moves.
// That comes at the cost of some playing strength as she now cannot
// think ahead about her next moves in the remaining time.
auto my_color = m_rootstate.get_to_move();
auto tc = m_rootstate.get_timecontrol();
if (!tc.can_accumulate_time(my_color)
|| m_maxplayouts < UCTSearch::UNLIMITED_PLAYOUTS) {
Expand Down
2 changes: 1 addition & 1 deletion src/UCTSearch.h
Expand Up @@ -126,7 +126,7 @@ class UCTSearch {
bool should_resign(passflag_t passflag, float besteval);
bool have_alternate_moves(int elapsed_centis, int time_for_move);
int est_playouts_left(int elapsed_centis, int time_for_move) const;
size_t prune_noncontenders(int elapsed_centis = 0, int time_for_move = 0,
size_t prune_noncontenders(int color, int elapsed_centis = 0, int time_for_move = 0,
bool prune = true);
bool stop_thinking(int elapsed_centis = 0, int time_for_move = 0) const;
int get_best_move(passflag_t passflag);
Expand Down
22 changes: 22 additions & 0 deletions src/Utils.cpp
Expand Up @@ -35,6 +35,7 @@
#include <cstdio>

#include <boost/filesystem.hpp>
#include <boost/math/distributions/students_t.hpp>

#ifdef _WIN32
#include <windows.h>
Expand All @@ -49,6 +50,27 @@

Utils::ThreadPool thread_pool;

std::array<float, z_entries> z_lookup;

void Utils::create_z_table() {
for (auto i = 1; i < z_entries + 1; i++) {
boost::math::students_t dist(i);
auto z = boost::math::quantile(boost::math::complement(dist, cfg_ci_alpha));
z_lookup[i - 1] = z;
}
}

float Utils::cached_t_quantile(int v) {
if (v < 1) {
return z_lookup[0];
}
if (v < z_entries) {
return z_lookup[v - 1];
} else {
Ttl marked this conversation as resolved.
Show resolved Hide resolved
return z_lookup[z_entries - 1];
}
}

bool Utils::input_pending() {
#ifdef HAVE_SELECT
fd_set read_fds;
Expand Down
5 changes: 5 additions & 0 deletions src/Utils.h
Expand Up @@ -40,6 +40,8 @@

extern Utils::ThreadPool thread_pool;

auto constexpr z_entries = 1000;
gcp marked this conversation as resolved.
Show resolved Hide resolved

namespace Utils {
void myprintf_error(const char *fmt, ...);
void myprintf(const char *fmt, ...);
Expand Down Expand Up @@ -67,6 +69,9 @@ namespace Utils {
size_t ceilMultiple(size_t a, size_t b);

const std::string leelaz_file(std::string file);

void create_z_table();
float cached_t_quantile(int v);
}

#endif