Skip to content
This repository has been archived by the owner on Jun 16, 2021. It is now read-only.

Commit

Permalink
Switch to a less bulky way of annotating features with their planes
Browse files Browse the repository at this point in the history
  • Loading branch information
brilee committed Jul 26, 2016
1 parent 2d3f1ef commit 3504375
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 57 deletions.
87 changes: 36 additions & 51 deletions features.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,61 +28,46 @@ def make_onehot(feature, planes):
onehot_features[:, :, planes-1] = (feature >= planes)
return onehot_features

class FeatureExtractor(object):
def __init__(self, features):
self.features = features
self.planes = sum(f.planes for f in features)

def extract(self, position):
return np.concatenate([feature.extract(position) for feature in self.features], axis=2)

class Feature(object):
planes = 1

@staticmethod
def extract(position):
return np.zeros([go.N, go.N, Feature.planes], dtype=np.float32)
def planes(num_planes):
def deco(f):
f.planes = num_planes
return f
return deco

class StoneColorFeature(Feature):
planes = 3
@planes(3)
def stone_color_feature(position):
board = position.board
features = np.zeros([go.N, go.N, 3], dtype=np.float32)
features[board == go.BLACK, 0] = 1
features[board == go.WHITE, 1] = 1
features[board == go.EMPTY, 2] = 1
return features

@staticmethod
def extract(position):
board = position.board
features = np.zeros([go.N, go.N, 3], dtype=np.float32)
features[board == go.BLACK, 0] = 1
features[board == go.WHITE, 1] = 1
features[board == go.EMPTY, 2] = 1
return features

class RecentMoveFeature(Feature):
planes = 8
@planes(8)
def recent_move_feature(position):
p = 8
onehot_features = np.zeros([go.N, go.N, p], dtype=np.float32)
for i, move in enumerate(reversed(position.recent[-p:])):
if move is not None:
onehot_features[move[0], move[1], i] = 1
return onehot_features

@staticmethod
def extract(position):
p = RecentMoveFeature.planes
onehot_features = np.zeros([go.N, go.N, p], dtype=np.float32)
for i, move in enumerate(position.recent[-1:-1 - p:-1]):
if move is not None:
onehot_features[move[0], move[1], i] = 1
return onehot_features
@planes(8)
def liberty_feature(position):
features = np.zeros([go.N, go.N], dtype=np.float32)
for g in itertools.chain(*position.groups):
libs = len(g.liberties)
for s in g.stones:
features[s] = libs
return make_onehot(features, 8)


class FeatureExtractor(object):
def __init__(self, features):
self.features = features
self.planes = sum(f.planes for f in features)

class LibertyFeature(Feature):
'''
From the AlphaGo paper:
Each integer feature value is split into multiple 19 × 19 planes of binary values (one-hot encoding). For example, separate binary feature planes are used to represent whether an intersection has 1 liberty, 2 liberties,..., >=8 liberties.
'''
planes = 8

@staticmethod
def extract(position):
features = np.zeros([go.N, go.N], dtype=np.float32)
for g in itertools.chain(*position.groups):
libs = len(g.liberties)
for s in g.stones:
features[s] = libs
return make_onehot(features, LibertyFeature.planes)
def extract(self, position):
return np.concatenate([feature(position) for feature in self.features], axis=2)

DEFAULT_FEATURES = FeatureExtractor([StoneColorFeature, LibertyFeature, RecentMoveFeature])
DEFAULT_FEATURES = FeatureExtractor([stone_color_feature, liberty_feature, recent_move_feature])
12 changes: 6 additions & 6 deletions tests/test_features.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@

class TestFeatureExtraction(GoPositionTestCase):
def test_stone_color_feature(self):
f = features.StoneColorFeature.extract(TEST_POSITION)
f = features.stone_color_feature(TEST_POSITION)
self.assertEqual(f.shape, (9, 9, 3))
# plane 0 is B
self.assertEqual(f[0, 1, 0], 1)
Expand All @@ -39,8 +39,8 @@ def test_stone_color_feature(self):
self.assertEqual(f[0, 5, 1], 0)

def test_liberty_feature(self):
f = features.LibertyFeature.extract(TEST_POSITION)
self.assertEqual(f.shape, (9, 9, features.LibertyFeature.planes))
f = features.liberty_feature(TEST_POSITION)
self.assertEqual(f.shape, (9, 9, features.liberty_feature.planes))

self.assertEqual(f[0, 0, 0], 0)
# the stone at 0, 1 has 3 liberties.
Expand All @@ -53,8 +53,8 @@ def test_liberty_feature(self):
self.assertEqual(f[1, 0, 7], 1)

def test_recent_moves_feature(self):
f = features.RecentMoveFeature.extract(TEST_POSITION)
self.assertEqual(f.shape, (9, 9, features.RecentMoveFeature.planes))
f = features.recent_move_feature(TEST_POSITION)
self.assertEqual(f.shape, (9, 9, features.recent_move_feature.planes))
# most recent move at (1, 0)
self.assertEqual(f[1, 0, 0], 1)
self.assertEqual(f[1, 0, 3], 0)
Expand All @@ -65,4 +65,4 @@ def test_recent_moves_feature(self):
self.assertEqual(f[0, 1, 2], 1)
# no more older moves
self.assertEqualNPArray(f[:, :, 3], np.zeros([9, 9]))
self.assertEqualNPArray(f[:, :, features.RecentMoveFeature.planes - 1], np.zeros([9, 9]))
self.assertEqualNPArray(f[:, :, features.recent_move_feature.planes - 1], np.zeros([9, 9]))

0 comments on commit 3504375

Please sign in to comment.