Permalink
Browse files

Hooking up the policy network to main.py, along with load/save/train

  • Loading branch information...
1 parent 797df65 commit ac532f8ee44d044696a5d6142c3070194a73d3cb @brilee committed Jul 1, 2016
Showing with 69 additions and 33 deletions.
  1. +9 −1 README.md
  2. +1 −1 go.py
  3. +10 −6 load_data_sets.py
  4. +31 −11 main.py
  5. +2 −14 policy.py
  6. +1 −0 requirements.txt
  7. +15 −0 strategies.py
View
@@ -15,5 +15,13 @@ MuGo uses the GTP protocol, and you can use any gtp-compliant program with it.
For example, to play against MuGo using gogui, you can do:
```
-gogui-twogtp -black 'python main.py random' -white 'gogui-display' -size 9 -komi 7.5 -verbose -auto
+gogui-twogtp -black 'python main.py gtp policy --read-file=/tmp/mymodel' -white 'gogui-display' -size 19 -komi 7.5 -verbose -auto
```
+
+Training
+========
+To train, run
+```
+python main.py train --read-file=/tmp/savedmodel --save-file=/tmp/savedmodel --epochs=10 data/kgs_data data/pro_data
+```
+where `data/kgs/data` and `data/pro_data` are directories of sgf files to be used for training.
View
@@ -303,4 +303,4 @@ def score(self):
return np.count_nonzero(working_board == BLACK) - np.count_nonzero(working_board == WHITE) - self.komi
-set_board_size(9)
+set_board_size(19)
View
@@ -1,6 +1,7 @@
from collections import namedtuple
import os
import numpy as np
+import sys
from features import DEFAULT_FEATURES
import go
@@ -15,10 +16,10 @@ def make_onehot(dense_labels, num_classes):
labels_one_hot.flat[index_offset + dense_labels.ravel()] = 1
return labels_one_hot
-def load_sgf_positions(*dataset_names):
- for dataset in dataset_names:
- dataset_dir = os.path.join(os.getcwd(), 'data', dataset)
- dataset_files = [os.path.join(dataset_dir, name) for name in os.listdir(dataset_dir)]
+def load_sgf_positions(*dataset_dirs):
+ for dataset_dir in dataset_dirs:
+ full_dir = os.path.join(os.getcwd(), dataset_dir)
+ dataset_files = [os.path.join(full_dir, name) for name in os.listdir(full_dir)]
all_datafiles = filter(os.path.isfile, dataset_files)
for file in all_datafiles:
with open(file) as f:
@@ -68,9 +69,12 @@ def get_batch(self, batch_size):
DataSets = namedtuple("DataSets", "test validation training input_planes")
-def load_data_sets(*dataset_names, feature_extractor=DEFAULT_FEATURES):
- positions_w_context = list(load_sgf_positions(*dataset_names))
+def load_data_sets(*dataset_dirs, feature_extractor=DEFAULT_FEATURES):
+ print("Extracting positions from sgfs...", file=sys.stderr)
+ positions_w_context = list(load_sgf_positions(*dataset_dirs))
+ print("Partitioning %s positions into test, validation, training datasets" % len(positions_w_context))
test, validation, training = partition_sets(positions_w_context)
+ print("Processing positions to extract features")
datasets = []
for dataset in (test, validation, training):
positions, next_moves, results = zip(*dataset)
View
@@ -1,11 +1,23 @@
import argparse
+import argh
import sys
-import gtp
+import gtp as gtp_lib
-from strategies import RandomPlayer
+from features import DEFAULT_FEATURES
+from strategies import RandomPlayer, PolicyNetworkBestMovePlayer
+from policy import PolicyNetwork
+from load_data_sets import load_data_sets
-def run_gtp(strategy):
- gtp_engine = gtp.Engine(strategy())
+def gtp(strategy, read_file=None):
+ if strategy == 'random':
+ instance = RandomPlayer()
+ elif strategy == 'policy':
+ policy_network = PolicyNetwork(DEFAULT_FEATURES.planes)
+ policy_network.initialize_variables(read_file)
+ instance = PolicyNetworkBestMovePlayer(policy_network)
+ else:
+ sys.stderr.write("Unknown strategy")
+ gtp_engine = gtp_lib.Engine(instance)
sys.stderr.write("GTP engine ready\n")
sys.stderr.flush()
while not gtp_engine.disconnect:
@@ -21,13 +33,21 @@ def run_gtp(strategy):
sys.stdout.write(engine_reply)
sys.stdout.flush()
-strategies = {
- 'random': RandomPlayer,
-}
+def train(read_file=None, save_file=None, epochs=10, *data_sets):
+ processed_data = load_data_sets(*data_sets)
+ n = PolicyNetwork(processed_data.input_planes)
+ n.initialize_variables(read_file)
+ for i in range(epochs):
+ n.train(processed_data.training)
+ n.check_accuracy(processed_data.test)
+ if save_file is not None:
+ n.save_variables(save_file)
+ print("Finished training. New model saved to %s" % save_file, file=sys.stderr)
+
+
parser = argparse.ArgumentParser()
-parser.add_argument('strategy', choices=strategies.keys())
+argh.add_commands(parser, [gtp, train])
+
if __name__ == '__main__':
- args = parser.parse_args()
- strategy = strategies[args.strategy]
- run_gtp(strategy)
+ argh.dispatch(parser)
View
@@ -24,12 +24,9 @@
import features
import go
-from load_data_sets import load_data_sets
-
-kgs = load_data_sets("kgs-micro")
class PolicyNetwork(object):
- def __init__(self, num_input_planes, k=64, num_int_conv_layers=3):
+ def __init__(self, num_input_planes, k=32, num_int_conv_layers=3):
self.num_input_planes = num_input_planes
self.k = k
self.num_int_conv_layers = num_int_conv_layers
@@ -78,7 +75,6 @@ def conv2d(x, W):
def initialize_variables(self, save_file=None):
self.session = tf.Session()
- # put loading functionality here
if save_file is None:
self.session.run(tf.initialize_all_variables())
else:
@@ -98,17 +94,9 @@ def train(self, training_data, batch_size=16):
def run(self, position):
processed_position = features.DEFAULT_FEATURES.extract(position)
- return self.session.run(self.output, feed_dict={self.x: processed_position[None, :]})
+ return self.session.run(self.output, feed_dict={self.x: processed_position[None, :]})[0]
def check_accuracy(self, test_data):
test_accuracy = self.session.run(self.accuracy, feed_dict={self.x: test_data.input, self.y: test_data.labels})
print("Test data accuracy: %g" % test_accuracy)
-n = PolicyNetwork(kgs.input_planes)
-n.initialize_variables("/tmp/mymodel")
-# for i in range(10):
-# n.train(kgs.training)
-# n.check_accuracy(kgs.test)
-n.check_accuracy(kgs.test)
-#n.save_variables("/tmp/mymodel")
-# best_moves = [unparse_kgs_coords(utils.unflatten_coords(c)) for c in sorted(range(361), key=lambda f: suggestion_probs[0, f])]
View
@@ -1,3 +1,4 @@
+argh==0.26.2
numpy==1.11.0
protobuf==3.0.0b2
pygtp==0.2
View
@@ -4,6 +4,7 @@
import go
import utils
+import policy, features
class GtpInterface(object):
def __init__(self):
@@ -46,3 +47,17 @@ def suggest_move(self, position):
class RandomPlayer(GtpInterface):
def suggest_move(self, position):
return random.choice(position.possible_moves())
+
+class PolicyNetworkBestMovePlayer(GtpInterface):
+ def __init__(self, network):
+ super().__init__()
+ self.network = network
+
+ def suggest_move(self, position):
+ probabilities = self.network.run(position)
+ move_probabilities = {
+ utils.unflatten_coords(x): probabilities[x]
+ for x in range(361)
+ }
+ best_move = max(move_probabilities.keys(), key=lambda k: move_probabilities[k])
+ return best_move

0 comments on commit ac532f8

Please sign in to comment.