Permalink
Browse files

Execute policy network entirely on CPU for playback purposes

  • Loading branch information...
1 parent ad08d3c commit 9bdbe80913f8c421949f4af3ab8a1e9d43464a3f @brilee committed Oct 25, 2016
Showing with 9 additions and 5 deletions.
  1. +3 −3 main.py
  2. +6 −2 policy.py
View
@@ -17,11 +17,11 @@ def gtp(strategy, read_file=None):
if strategy == 'random':
instance = RandomPlayer()
elif strategy == 'policy':
- policy_network = PolicyNetwork(DEFAULT_FEATURES.planes)
+ policy_network = PolicyNetwork(DEFAULT_FEATURES.planes, use_cpu=True)
policy_network.initialize_variables(read_file)
instance = PolicyNetworkBestMovePlayer(policy_network)
elif strategy == 'mcts':
- policy_network = PolicyNetwork(DEFAULT_FEATURES.planes)
+ policy_network = PolicyNetwork(DEFAULT_FEATURES.planes, use_cpu=True)
policy_network.initialize_variables(read_file)
instance = MCTS(policy_network)
else:
@@ -50,7 +50,7 @@ def preprocess(*data_sets, processed_dir="processed_data"):
process_raw_data(*data_sets, processed_dir=processed_dir)
-def train(processed_dir, read_file=None, save_file=None, epochs=10, logdir=None, checkpoint_freq=1000):
+def train(processed_dir, read_file=None, save_file=None, epochs=10, logdir=None, checkpoint_freq=10000):
test_dataset = DataSet.read(os.path.join(processed_dir, "test.chunk.gz"))
train_chunk_files = [os.path.join(processed_dir, fname)
for fname in os.listdir(processed_dir)
View
@@ -31,7 +31,7 @@
EPSILON = 1e-35
class PolicyNetwork(object):
- def __init__(self, num_input_planes, k=32, num_int_conv_layers=3):
+ def __init__(self, num_input_planes, k=32, num_int_conv_layers=3, use_cpu=False):
self.num_input_planes = num_input_planes
self.k = k
self.num_int_conv_layers = num_int_conv_layers
@@ -40,7 +40,11 @@ def __init__(self, num_input_planes, k=32, num_int_conv_layers=3):
self.test_stats = StatisticsCollector()
self.training_stats = StatisticsCollector()
self.session = tf.Session()
- self.set_up_network()
+ if use_cpu:
+ with tf.device("/cpu:0"):
+ self.set_up_network()
+ else:
+ self.set_up_network()
def set_up_network(self):
# a global_step variable allows epoch counts to persist through multiple training sessions

0 comments on commit 9bdbe80

Please sign in to comment.