Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Super ko #21

Open
wants to merge 18 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 28 additions & 15 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# About
An environment for the board game Go. It is implemented using OpenAI's Gym API.
An environment for the board game Go. It is implemented using OpenAI's Gym API.
It is also optimized to be as efficient as possible in order to efficiently train ML models.

# Installation
Expand All @@ -13,8 +13,10 @@ pip install -e .
### Coding example
```python
import gym
import gym_go

go_env = gym.make('gym_go:go-v0', size=7, komi=0, reward_method='real')
go_env = gym.make('go-v0', size=7, komi=0, reward_method='real')
go_env.reset()

first_action = (2,5)
second_action = (5,2)
Expand All @@ -23,7 +25,7 @@ go_env.render('terminal')
```

```
0 1 2 3 4 5 6
0 1 2 3 4 5 6
0 ╔═╤═╤═╤═╤═╤═╗
1 ╟─┼─┼─┼─┼─┼─╢
2 ╟─┼─┼─┼─┼─○─╢
Expand All @@ -41,7 +43,7 @@ go_env.render('terminal')
```

```
0 1 2 3 4 5 6
0 1 2 3 4 5 6
0 ╔═╤═╤═╤═╤═╤═╗
1 ╟─┼─┼─┼─┼─┼─╢
2 ╟─┼─┼─┼─┼─○─╢
Expand All @@ -62,23 +64,34 @@ python3 demo.py
![alt text](screenshots/human_ui.png)

### High level API
[GoEnv](gym_go/envs/go_env.py) defines the Gym environment for Go.
It contains the highest level API for basic Go usage.
[GoEnv](gym_go/envs/go_env.py) defines the Gym environment for Go.
It contains the highest level API for basic Go usage.

### Low level API
[GoGame](gym_go/gogame.py) is the set of low-level functions that defines all the game logic of Go.
`GoEnv`'s high level API is built on `GoGame`.
These sets of functions are intended for a more detailed and finetuned
These sets of functions are intended for a more detailed and finetuned
usage of Go.

# Scoring
We use Trump Taylor scoring, a simple area scoring, to determine the winner. A player's _area_ is defined as the number of empty points a
player's pieces surround plus the number of player's pieces on the board. The _winner_ is the player with the larger
We use Trump Taylor scoring, a simple area scoring, to determine the winner. A player's _area_ is defined as the number of empty points a
player's pieces surround plus the number of player's pieces on the board. The _winner_ is the player with the larger
area (a game is tied if both players have an equal amount of area on the board).

There is also support for `komi`, a bias score constant to balance the advantage of black going first.
There is also support for `komi`, a bias score constant to balance the advantage of black going first.
By default `komi` is set to 0.

# Ko and super ko
The game supports a simple implementation of the ko rule by default, which prevents single move take-back scenarios. In addition, an optional
super ko rule can be enabled when initializing the gym:

```python
go_env = gym.make('go-v0', size=7, super_ko=True)
```

This rule implements positional super ko by tracking play history, which catches repeating positions not detected by the regular ko rule
at the price of a performance overhead.

# Game ending
A game ends when both players pass consecutively

Expand All @@ -90,16 +103,16 @@ Reward methods are in _black_'s perspective
* `0` - Game is tied
* `1` - Black won
* `0` - Otherwise
* **Heuristic**: If the game is ongoing, the reward is `black area - white area`.
If black won, the reward is `BOARD_SIZE**2`.
* **Heuristic**: If the game is ongoing, the reward is `black area - white area`.
If black won, the reward is `BOARD_SIZE**2`.
If white won, the reward is `-BOARD_SIZE**2`.
If tied, the reward is `0`.

# State
The `state` object that is returned by the `reset` and `step` functions of the environment is a
`6 x BOARD_SIZE x BOARD_SIZE` numpy array. All values in the array are either `0` or `1`
The `state` object that is returned by the `reset` and `step` functions of the environment is a
`6 x BOARD_SIZE x BOARD_SIZE` numpy array. All values in the array are either `0` or `1`
* **First and second channel:** represent the black and white pieces respectively.
* **Third channel:** Indicator layer for whose turn it is
* **Third channel:** Indicator layer for whose turn it is
* **Fourth channel:** Invalid moves (including ko-protection) for the next action
* **Fifth channel:** Indicator layer for whether the previous move was a pass
* **Sixth channel:** Indicator layer for whether the game is over
Expand Down
4 changes: 3 additions & 1 deletion demo.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import argparse

import gym
import gym_go

# Arguments
parser = argparse.ArgumentParser(description='Demo Go Environment')
Expand All @@ -9,7 +10,8 @@
args = parser.parse_args()

# Initialize environment
go_env = gym.make('gym_go:go-v0', size=args.boardsize, komi=args.komi)
go_env = gym.make('go-v0', size=args.boardsize, komi=args.komi)
go_env.reset()

# Game loop
done = False
Expand Down
11 changes: 9 additions & 2 deletions gym_go/envs/go_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,9 @@ class GoEnv(gym.Env):
govars = govars
gogame = gogame

def __init__(self, size, komi=0, reward_method='real'):
def __init__(self, size, komi=0, super_ko=False, reward_method='real'):
'''
@param super_ko: whether to enable super-ko rule (history tracking)
@param reward_method: either 'heuristic' or 'real'
heuristic: gives # black pieces - # white pieces.
real: gives 0 for in-game move, 1 for winning, -1 for losing,
Expand All @@ -31,6 +32,7 @@ def __init__(self, size, komi=0, reward_method='real'):
self.size = size
self.komi = komi
self.state_ = gogame.init_state(size)
self.history = [] if super_ko else None
self.reward_method = RewardMethod(reward_method)
self.observation_space = gym.spaces.Box(np.float32(0), np.float32(govars.NUM_CHNLS),
shape=(govars.NUM_CHNLS, size, size))
Expand All @@ -43,6 +45,8 @@ def reset(self):
done, return state
'''
self.state_ = gogame.init_state(self.size)
if self.history is not None:
self.history = []
self.done = False
return np.copy(self.state_)

Expand All @@ -59,7 +63,10 @@ def step(self, action):
elif action is None:
action = self.size ** 2

self.state_ = gogame.next_state(self.state_, action, canonical=False)
self.old_state = self.state()
self.state_ = gogame.next_state(self.state_, action, canonical=False, history=self.history)
if self.history is not None:
self.history.append(self.old_state)
self.done = gogame.game_ended(self.state_)
return np.copy(self.state_), self.reward(), self.done, self.info()

Expand Down
26 changes: 13 additions & 13 deletions gym_go/gogame.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def batch_init_state(batch_size, board_size):
return batch_state


def next_state(state, action1d, canonical=False):
def next_state(state, action1d, canonical=False, history=None):
# Deep copy the state to modify
state = np.copy(state)

Expand Down Expand Up @@ -74,20 +74,20 @@ def next_state(state, action1d, canonical=False):
if len(killed_group) == 1:
ko_protect = killed_group[0]

# Update invalid moves
state[govars.INVD_CHNL] = state_utils.compute_invalid_moves(state, player, ko_protect)

# Switch turn
state_utils.set_turn(state)

# Update invalid moves
state[govars.INVD_CHNL] = state_utils.compute_invalid_moves(state, player, ko_protect, history)

if canonical:
# Set canonical form
state = canonical_form(state)

return state


def batch_next_states(batch_states, batch_action1d, canonical=False):
def batch_next_states(batch_states, batch_action1d, canonical=False, batch_histories=None):
# Deep copy the state to modify
batch_states = np.copy(batch_states)

Expand Down Expand Up @@ -136,13 +136,13 @@ def batch_next_states(batch_states, batch_action1d, canonical=False):
if len(killed_group) == 1:
batch_ko_protect[batch_non_pass[i]] = killed_group[0]

# Update invalid moves
batch_states[:, govars.INVD_CHNL] = state_utils.batch_compute_invalid_moves(batch_states, batch_players,
batch_ko_protect)

# Switch turn
state_utils.batch_set_turn(batch_states)

# Update invalid moves
batch_states[:, govars.INVD_CHNL] = state_utils.batch_compute_invalid_moves(batch_states, batch_players,
batch_ko_protect, batch_histories)

if canonical:
# Set canonical form
batch_states = batch_canonical_form(batch_states)
Expand All @@ -153,7 +153,7 @@ def batch_next_states(batch_states, batch_action1d, canonical=False):
def invalid_moves(state):
# return a fixed size binary vector
if game_ended(state):
return np.zeros(action_size(state))
return np.ones(action_size(state))
return np.append(state[govars.INVD_CHNL].flatten(), 0)


Expand Down Expand Up @@ -247,7 +247,7 @@ def turn(state):


def batch_turn(batch_state):
return np.max(batch_state[:, govars.TURN_CHNL], axis=(1, 2)).astype(np.int)
return np.max(batch_state[:, govars.TURN_CHNL], axis=(1, 2)).astype(int)


def liberties(state: np.ndarray):
Expand All @@ -258,7 +258,7 @@ def liberties(state: np.ndarray):
liberty_list = []
for player_pieces in [blacks, whites]:
liberties = ndimage.binary_dilation(player_pieces, state_utils.surround_struct)
liberties *= (1 - all_pieces).astype(np.bool)
liberties *= (1 - all_pieces).astype(bool)
liberty_list.append(liberties)

return liberty_list[0], liberty_list[1]
Expand All @@ -280,7 +280,7 @@ def areas(state):
all_pieces = np.sum(state[[govars.BLACK, govars.WHITE]], axis=0)
empties = 1 - all_pieces

empty_labels, num_empty_areas = ndimage.measurements.label(empties)
empty_labels, num_empty_areas = ndimage.label(empties)

black_area, white_area = np.sum(state[govars.BLACK]), np.sum(state[govars.WHITE])
for label in range(1, num_empty_areas + 1):
Expand Down
67 changes: 57 additions & 10 deletions gym_go/state_utils.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
import numpy as np
from scipy import ndimage
from scipy.ndimage import measurements

from gym_go import govars
from gym_go import govars, gogame

group_struct = np.array([[[0, 0, 0],
[0, 0, 0],
Expand All @@ -21,7 +20,7 @@
neighbor_deltas = np.array([[-1, 0], [1, 0], [0, -1], [0, 1]])


def compute_invalid_moves(state, player, ko_protect=None):
def compute_invalid_moves(state, player, ko_protect=None, history=None):
"""
Updates invalid moves in the OPPONENT's perspective
1.) Opponent cannot move at a location
Expand All @@ -42,11 +41,12 @@ def compute_invalid_moves(state, player, ko_protect=None):

# Setup invalid and valid arrays
possible_invalid_array = np.zeros(state.shape[1:])
super_ko_invalid_array = np.zeros(state.shape[1:])
definite_valids_array = np.zeros(state.shape[1:])

# Get all groups
all_own_groups, num_own_groups = measurements.label(state[player])
all_opp_groups, num_opp_groups = measurements.label(state[1 - player])
all_own_groups, num_own_groups = ndimage.label(state[player])
all_opp_groups, num_opp_groups = ndimage.label(state[1 - player])
expanded_own_groups = np.zeros((num_own_groups, *state.shape[1:]))
expanded_opp_groups = np.zeros((num_opp_groups, *state.shape[1:]))

Expand Down Expand Up @@ -80,10 +80,32 @@ def compute_invalid_moves(state, player, ko_protect=None):
# Ko-protection
if ko_protect is not None:
invalid_moves[ko_protect[0], ko_protect[1]] = 1

# Super ko-protection
if history is not None and len(history) > 0:
# Create a new state with updated invalid moves so we can calculate child moves
updated_state = np.copy(state)
updated_state[govars.INVD_CHNL] = (invalid_moves > 0)

children = gogame.children(updated_state)
board_size = np.prod(state.shape[1:])
children = children[:board_size]

trunc_history = np.array(history)[:, :2]
for action1d, child_state in enumerate(children):
# Skip children that don't represent a valid move
if (child_state[:2] == 0).all():
continue
if (trunc_history == child_state[:2]).all(axis=1).all(axis=1).all(axis=1).any():
action2d = action1d // state.shape[1:][0], action1d % state.shape[1:][1]
super_ko_invalid_array[action2d[0], action2d[1]] = 1

invalid_moves = invalid_moves + super_ko_invalid_array

return invalid_moves > 0


def batch_compute_invalid_moves(batch_state, batch_player, batch_ko_protect):
def batch_compute_invalid_moves(batch_state, batch_player, batch_ko_protect, batch_history=None):
"""
Updates invalid moves in the OPPONENT's perspective
1.) Opponent cannot move at a location
Expand All @@ -105,11 +127,12 @@ def batch_compute_invalid_moves(batch_state, batch_player, batch_ko_protect):

# Setup invalid and valid arrays
batch_possible_invalid_array = np.zeros(batch_state.shape[:1] + batch_state.shape[2:])
batch_super_ko_invalid_array = np.zeros(batch_state.shape[:1] + batch_state.shape[2:])
batch_definite_valids_array = np.zeros(batch_state.shape[:1] + batch_state.shape[2:])

# Get all groups
batch_all_own_groups, _ = measurements.label(batch_state[batch_idcs, batch_player], group_struct)
batch_all_opp_groups, _ = measurements.label(batch_state[batch_idcs, 1 - batch_player], group_struct)
batch_all_own_groups, _ = ndimage.label(batch_state[batch_idcs, batch_player], group_struct)
batch_all_opp_groups, _ = ndimage.label(batch_state[batch_idcs, 1 - batch_player], group_struct)

batch_data = enumerate(zip(batch_all_own_groups, batch_all_opp_groups, batch_empties))
for i, (all_own_groups, all_opp_groups, empties) in batch_data:
Expand Down Expand Up @@ -153,6 +176,30 @@ def batch_compute_invalid_moves(batch_state, batch_player, batch_ko_protect):
for i, ko_protect in enumerate(batch_ko_protect):
if ko_protect is not None:
invalid_moves[i, ko_protect[0], ko_protect[1]] = 1

# Super ko-protection
if batch_history is not None:
# Create a new state with updated invalid moves so we can calculate child moves
updated_states = np.copy(batch_state)
updated_states[:, govars.INVD_CHNL] = (invalid_moves > 0)

board_size = np.prod(batch_state.shape[2:])
batch_children = np.array(
[gogame.children(s)[:board_size] for s in updated_states]
)

trunc_history = batch_history[:, :, :2]
for i, state in enumerate(batch_state):
for action1d, child_state in enumerate(batch_children[i]):
# Skip children that don't represent a valid move
if (child_state[:2] == 0).all():
continue
if (trunc_history[i] == child_state[:2]).all(axis=1).all(axis=1).all(axis=1).any():
action2d = action1d // state.shape[1:][0], action1d % state.shape[1:][1]
batch_super_ko_invalid_array[i, action2d[0], action2d[1]] = 1

invalid_moves = invalid_moves + batch_super_ko_invalid_array

return invalid_moves > 0


Expand All @@ -163,7 +210,7 @@ def update_pieces(state, adj_locs, player):
all_pieces = np.sum(state[[govars.BLACK, govars.WHITE]], axis=0)
empties = 1 - all_pieces

all_opp_groups, _ = ndimage.measurements.label(state[opponent])
all_opp_groups, _ = ndimage.label(state[opponent])

# Go through opponent groups
all_adj_labels = all_opp_groups[adj_locs[:, 0], adj_locs[:, 1]]
Expand All @@ -187,7 +234,7 @@ def batch_update_pieces(batch_non_pass, batch_state, batch_adj_locs, batch_playe
batch_all_pieces = np.sum(batch_state[:, [govars.BLACK, govars.WHITE]], axis=1)
batch_empties = 1 - batch_all_pieces

batch_all_opp_groups, _ = ndimage.measurements.label(batch_state[batch_non_pass, batch_opponent],
batch_all_opp_groups, _ = ndimage.label(batch_state[batch_non_pass, batch_opponent],
group_struct)

batch_data = enumerate(zip(batch_all_opp_groups, batch_all_pieces, batch_empties, batch_adj_locs, batch_opponent))
Expand Down
Loading