Skip to content

Commit

Permalink
Look for symmetrical position in cache.
Browse files Browse the repository at this point in the history
  • Loading branch information
TFiFiE committed Apr 30, 2018
1 parent 8fc9165 commit c91fffb
Show file tree
Hide file tree
Showing 10 changed files with 74 additions and 54 deletions.
10 changes: 10 additions & 0 deletions src/FastBoard.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,16 @@ void FastBoard::reset_board(int size) {
m_parent[MAXSQ] = MAXSQ;
m_libs[MAXSQ] = 16384; /* we will subtract from this */
m_next[MAXSQ] = MAXSQ;

for (auto symmetry = 0; symmetry < 8; ++symmetry) {
m_symmetry_idx[symmetry][0] = 0; // Make sure the special value for the lack of ko stays unchanged.
for (auto y = 0; y < size; ++y) {
for (auto x = 0; x < size; ++x) {
const auto newvtx = get_symmetry({x, y}, symmetry, size);
m_symmetry_idx[symmetry][get_vertex(x, y)] = get_vertex(newvtx.first, newvtx.second);
}
}
}
}

bool FastBoard::is_suicide(int i, int color) const {
Expand Down
4 changes: 3 additions & 1 deletion src/FastBoard.h
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ class FastBoard {
int get_boardsize(void) const;
square_t get_square(int x, int y) const;
square_t get_square(int vertex) const ;
int get_vertex(int i, int j) const;
int get_vertex(int x, int y) const;
void set_square(int x, int y, square_t content);
void set_square(int vertex, square_t content);
std::pair<int, int> get_xy(int vertex) const;
Expand Down Expand Up @@ -125,6 +125,8 @@ class FastBoard {
int m_boardsize;
int m_squaresize;

std::array<std::array<unsigned short, MAXSQ>, 8> m_symmetry_idx;

int calc_reach_color(int color) const;

int count_neighbours(const int color, const int i) const;
Expand Down
4 changes: 4 additions & 0 deletions src/FastState.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,10 @@ class FastState {
void display_state();
std::string move_to_text(int move);

std::uint64_t get_hash(const int symmetry) const {
return board.calc_hash(m_komove, symmetry);
}

FullBoard board;

float m_komi;
Expand Down
19 changes: 8 additions & 11 deletions src/FullBoard.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ int FullBoard::remove_string(int i) {
return removed;
}

std::uint64_t FullBoard::calc_ko_hash(void) {
std::uint64_t FullBoard::calc_ko_hash() const {
auto res = Zobrist::zobrist_empty;

for (int i = 0; i < m_maxsq; i++) {
Expand All @@ -65,16 +65,15 @@ std::uint64_t FullBoard::calc_ko_hash(void) {
}

/* Tromp-Taylor has positional superko */
m_ko_hash = res;
return res;
}

std::uint64_t FullBoard::calc_hash(int komove) {
std::uint64_t FullBoard::calc_hash(const int komove, const int symmetry) const {
auto res = Zobrist::zobrist_empty;

for (int i = 0; i < m_maxsq; i++) {
if (m_square[i] != INVAL) {
res ^= Zobrist::zobrist[m_square[i]][i];
res ^= Zobrist::zobrist[m_square[i]][m_symmetry_idx[symmetry][i]];
}
}

Expand All @@ -86,18 +85,16 @@ std::uint64_t FullBoard::calc_hash(int komove) {
res ^= Zobrist::zobrist_blacktomove;
}

res ^= Zobrist::zobrist_ko[komove];

m_hash = res;
res ^= Zobrist::zobrist_ko[m_symmetry_idx[symmetry][komove]];

return res;
}

std::uint64_t FullBoard::get_hash(void) const {
std::uint64_t FullBoard::get_hash() const {
return m_hash;
}

std::uint64_t FullBoard::get_ko_hash(void) const {
std::uint64_t FullBoard::get_ko_hash() const {
return m_ko_hash;
}

Expand Down Expand Up @@ -191,6 +188,6 @@ void FullBoard::display_board(int lastmove) {
void FullBoard::reset_board(int size) {
FastBoard::reset_board(size);

calc_hash();
calc_ko_hash();
m_hash = calc_hash();
m_ko_hash = calc_ko_hash();
}
8 changes: 4 additions & 4 deletions src/FullBoard.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,10 @@ class FullBoard : public FastBoard {
int remove_string(int i);
int update_board(const int color, const int i);

std::uint64_t calc_hash(int komove = 0);
std::uint64_t calc_ko_hash(void);
std::uint64_t get_hash(void) const;
std::uint64_t get_ko_hash(void) const;
std::uint64_t calc_hash(const int komove = 0, const int symmetry = 0) const;
std::uint64_t calc_ko_hash() const;
std::uint64_t get_hash() const;
std::uint64_t get_ko_hash() const;
void set_to_move(int tomove);

void reset_board(int size);
Expand Down
4 changes: 3 additions & 1 deletion src/Leela.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -325,7 +325,9 @@ void init_global_objects() {

void benchmark(GameState& game) {
game.set_timecontrol(0, 1, 0, 0); // Set infinite time.
game.play_textmove("b", "q16");
game.play_textmove("b", "r16");
game.play_textmove("w", "d4");
game.play_textmove("b", "c3");
auto search = std::make_unique<UCTSearch>(game);
game.set_to_move(FastBoard::WHITE);
search->think(FastBoard::WHITE);
Expand Down
52 changes: 16 additions & 36 deletions src/Network.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -340,7 +340,9 @@ void Network::initialize() {
// Prepare symmetry table
for(auto s = 0; s < 8; s++) {
for(auto v = 0; v < BOARD_SQUARES; v++) {
symmetry_nn_idx_table[s][v] = get_nn_idx_symmetry(v, s);
const auto newvtx = get_symmetry({v % BOARD_SIZE, v / BOARD_SIZE}, s);
symmetry_nn_idx_table[s][v] = (newvtx.second * BOARD_SIZE) + newvtx.first;
assert(symmetry_nn_idx_table[s][v] >= 0 && symmetry_nn_idx_table[s][v] < BOARD_SQUARES);
}
}

Expand Down Expand Up @@ -868,8 +870,19 @@ Network::Netresult Network::get_scored_moves(

if (!skip_cache) {
// See if we already have this in the cache.
if (NNCache::get_NNCache().lookup(state->board.get_hash(), result)) {
return result;
for (auto symmetry = 0; symmetry < 8; ++symmetry) {
const auto hash = state->get_hash(symmetry);
assert(symmetry != 0 || hash == state->board.get_hash());
if (NNCache::get_NNCache().lookup(hash, result)) {
decltype(result.policy) corrected_policy;
corrected_policy.reserve(BOARD_SQUARES);
for (auto idx = size_t{0}; idx < BOARD_SQUARES; ++idx) {
const auto sym_idx = symmetry_nn_idx_table[symmetry][idx];
corrected_policy.emplace_back(result.policy[sym_idx]);
}
result.policy = std::move(corrected_policy);
return result;
}
}
}

Expand Down Expand Up @@ -1070,36 +1083,3 @@ void Network::gather_features(const GameState* const state, NNPlanes & planes) {
planes[white_offset + h]);
}
}

int Network::get_nn_idx_symmetry(const int vertex, int symmetry) {
assert(vertex >= 0 && vertex < BOARD_SQUARES);
assert(symmetry >= 0 && symmetry < 8);
auto x = vertex % BOARD_SIZE;
auto y = vertex / BOARD_SIZE;
int newx;
int newy;

if (symmetry >= 4) {
std::swap(x, y);
symmetry -= 4;
}

if (symmetry == 0) {
newx = x;
newy = y;
} else if (symmetry == 1) {
newx = x;
newy = BOARD_SIZE - y - 1;
} else if (symmetry == 2) {
newx = BOARD_SIZE - x - 1;
newy = y;
} else {
assert(symmetry == 3);
newx = BOARD_SIZE - x - 1;
newy = BOARD_SIZE - y - 1;
}

const auto newvtx = (newy * BOARD_SIZE) + newx;
assert(newvtx >= 0 && newvtx < BOARD_SQUARES);
return newvtx;
}
1 change: 0 additions & 1 deletion src/Network.h
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,6 @@ class Network {
static void winograd_sgemm(const std::vector<float>& U,
const std::vector<float>& V,
std::vector<float>& M, const int C, const int K);
static int get_nn_idx_symmetry(const int vertex, int symmetry);
static void fill_input_plane_pair(
const FullBoard& board, BoardPlane& black, BoardPlane& white);
static Netresult get_scored_moves_internal(
Expand Down
25 changes: 25 additions & 0 deletions src/Utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -144,3 +144,28 @@ size_t Utils::ceilMultiple(size_t a, size_t b) {
auto ret = a + (b - a % b);
return ret;
}

std::pair<int, int> Utils::get_symmetry(const std::pair<int, int>& vertex, const int symmetry, const int board_size) {
assert(vertex.first >= 0 && vertex.first < board_size);
assert(vertex.second >= 0 && vertex.second < board_size);
assert(symmetry >= 0 && symmetry < 8);
auto x = vertex.first;
auto y = vertex.second;

if ((symmetry & 4) != 0) {
std::swap(x, y);
}

if ((symmetry & 2) != 0) {
x = board_size - x - 1;
}

if ((symmetry & 1) != 0) {
y = board_size - y - 1;
}

assert(x >= 0 && x < board_size);
assert(y >= 0 && y < board_size);
assert(symmetry != 0 || vertex == std::make_pair(x, y));
return {x, y};
}
1 change: 1 addition & 0 deletions src/Utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ namespace Utils {
}

size_t ceilMultiple(size_t a, size_t b);
std::pair<int, int> get_symmetry(const std::pair<int, int>& vertex, const int symmetry, const int board_size = BOARD_SIZE);
}

#endif

0 comments on commit c91fffb

Please sign in to comment.