Permalink
Browse files

Clean up sgf replay functionality

  • Loading branch information...
1 parent 89783fd commit 52f2bb678c5a7949dc37efe447d54dc248b10c78 @brilee committed Nov 26, 2016
Showing with 51 additions and 57 deletions.
  1. +2 −3 load_data_sets.py
  2. +41 −45 sgf_wrapper.py
  3. +8 −9 tests/test_sgf_wrapper.py
View
@@ -7,7 +7,7 @@
from features import DEFAULT_FEATURES
import go
-import sgf_wrapper
+from sgf_wrapper import replay_sgf
import utils
# Number of data points to store in a chunk on disk
@@ -51,8 +51,7 @@ def find_sgf_files(*dataset_dirs):
def get_positions_from_sgf(file):
with open(file) as f:
- sgf = sgf_wrapper.SgfWrapper(f.read())
- for position_w_context in sgf.get_main_branch():
+ for position_w_context in replay_sgf(f.read()):
if position_w_context.is_usable():
yield position_w_context
View
@@ -15,6 +15,25 @@
from utils import parse_sgf_coords as pc
import sgf
+class GameMetadata(namedtuple("GameMetadata", "result handicap board_size")):
+ pass
+
+class PositionWithContext(namedtuple("SgfPosition", "position next_move metadata")):
+ '''
+ Wrapper around go.Position.
+ Stores a position, the move that came next, and the eventual result.
+ '''
+ def is_usable(self):
+ return all([
+ self.position is not None,
+ self.next_move is not None,
+ self.metadata.result != "Void",
+ self.metadata.handicap <= 4,
+ ])
+
+ def __str__(self):
+ return str(self.position) + '\nNext move: {} Result: {}'.format(self.next_move, self.result)
+
def sgf_prop(value_list):
'Converts raw sgf library output to sensible value'
if value_list is None:
@@ -67,52 +86,29 @@ def maybe_correct_next(pos, next_node):
('W' in next_node.properties and not pos.to_play == go.WHITE)):
pos.flip_playerturn(mutate=True)
-class GameMetadata(namedtuple("GameMetadata", "result handicap board_size")):
- pass
-
-class SgfWrapper(object):
+def replay_sgf(sgf_contents):
'''
- Wrapper for sgf files, exposing contents as go.Position instances
+ Wrapper for sgf files, exposing contents as position_w_context instances
with open(filename) as f:
- sgf = sgf_wrapper.SgfWrapper(f.read())
- for position, move, result in sgf.get_main_branch():
- print(position)
+ for position_w_context in replay_sgf(f.read()):
+ print(position_w_context.position)
'''
+ collection = sgf.parse(sgf_contents)
+ game = collection.children[0]
+ props = game.root.properties
+ assert int(sgf_prop(props.get('GM', ['1']))) == 1, "Not a Go SGF!"
+ komi = float(sgf_prop(props.get('KM')))
+ metadata = GameMetadata(
+ result=sgf_prop(props.get('RE')),
+ handicap=int(sgf_prop(props.get('HA', [0]))),
+ board_size=int(sgf_prop(props.get('SZ'))))
+ go.set_board_size(metadata.board_size)
- def __init__(self, file_contents):
- self.collection = sgf.parse(file_contents)
- self.game = self.collection.children[0]
- props = self.game.root.properties
- assert int(sgf_prop(props.get('GM', ['1']))) == 1, "Not a Go SGF!"
- self.komi = float(sgf_prop(props.get('KM')))
- self.metadata = GameMetadata(
- result=sgf_prop(props.get('RE')),
- handicap=int(sgf_prop(props.get('HA', [0]))),
- board_size=int(sgf_prop(props.get('SZ'))))
- go.set_board_size(self.metadata.board_size)
-
- def get_main_branch(self):
- pos = Position(komi=self.komi)
- current_node = self.game.root
- while pos is not None and current_node is not None:
- pos = handle_node(pos, current_node)
- maybe_correct_next(pos, current_node.next)
- next_move = get_next_move(current_node)
- yield PositionWithContext(pos, next_move, self.metadata)
- current_node = current_node.next
-
-class PositionWithContext(namedtuple("SgfPosition", "position next_move metadata")):
- '''
- Wrapper around go.Position.
- Stores a position, the move that came next, and the eventual result.
- '''
- def is_usable(self):
- return all([
- self.position is not None,
- self.next_move is not None,
- self.metadata.result != "Void",
- self.metadata.handicap <= 4,
- ])
-
- def __str__(self):
- return str(self.position) + '\nNext move: {} Result: {}'.format(self.next_move, self.result)
+ pos = Position(komi=komi)
+ current_node = game.root
+ while pos is not None and current_node is not None:
+ pos = handle_node(pos, current_node)
+ maybe_correct_next(pos, current_node.next)
+ next_move = get_next_move(current_node)
+ yield PositionWithContext(pos, next_move, metadata)
+ current_node = current_node.next
@@ -1,5 +1,5 @@
import go
-import sgf_wrapper
+from sgf_wrapper import replay_sgf
import unittest
from utils import parse_kgs_coords as pc
@@ -12,10 +12,11 @@
class TestSgfWrapper(GoPositionTestCase):
def test_sgf_props(self):
- sgf = sgf_wrapper.SgfWrapper(CHINESE_HANDICAP_SGF)
- self.assertEqual(sgf.metadata.result, 'B+39.50')
- self.assertEqual(sgf.metadata.board_size, 9)
- self.assertEqual(sgf.komi, 5.5)
+ sgf_replayer = replay_sgf(CHINESE_HANDICAP_SGF)
+ initial = next(sgf_replayer)
+ self.assertEqual(initial.metadata.result, 'B+39.50')
+ self.assertEqual(initial.metadata.board_size, 9)
+ self.assertEqual(initial.position.komi, 5.5)
def test_japanese_handicap_handling(self):
intermediate_board = load_board('''
@@ -57,8 +58,7 @@ def test_japanese_handicap_handling(self):
to_play=go.WHITE,
)
- sgf = sgf_wrapper.SgfWrapper(JAPANESE_HANDICAP_SGF)
- positions_w_context = list(sgf.get_main_branch())
+ positions_w_context = list(replay_sgf(JAPANESE_HANDICAP_SGF))
self.assertEqualPositions(intermediate_position, positions_w_context[1].position)
self.assertEqualPositions(final_position, positions_w_context[-1].position)
@@ -102,8 +102,7 @@ def test_chinese_handicap_handling(self):
recent=(pc('E9'), pc('F9')),
to_play=go.WHITE
)
- sgf = sgf_wrapper.SgfWrapper(CHINESE_HANDICAP_SGF)
- positions_w_context = list(sgf.get_main_branch())
+ positions_w_context = list(replay_sgf(CHINESE_HANDICAP_SGF))
self.assertEqualPositions(intermediate_position, positions_w_context[1].position)
self.assertEqual(positions_w_context[1].next_move, pc('C3'))
self.assertEqualPositions(final_position, positions_w_context[-1].position)

0 comments on commit 52f2bb6

Please sign in to comment.