Permalink
Browse files

rewrite onehot methods for small improvement

  • Loading branch information...
1 parent 014a8b3 commit 3f09a5c3431a0b259bd7310d8d3c9737ff514fa7 @brilee committed Oct 24, 2016
Showing with 51 additions and 16 deletions.
  1. +9 −3 features.py
  2. +7 −8 load_data_sets.py
  3. +2 −5 policy.py
  4. +12 −0 tests/test_datasets.py
  5. +21 −0 utils.py
View
@@ -20,15 +20,21 @@
import numpy as np
import go
+from utils import product
# Resolution/truncation limit for one-hot features
P = 8
def make_onehot(feature, planes):
onehot_features = np.zeros(feature.shape + (planes,), dtype=np.uint8)
- for i in range(planes - 1):
- onehot_features[:, :, i] = (feature == i+1)
- onehot_features[:, :, planes-1] = (feature >= planes)
+ capped = np.minimum(feature, planes)
+ onehot_index_offsets = np.arange(0, product(onehot_features.shape), planes) + capped.ravel()
+ # A 0 is encoded as [0,0,0,0], not [1,0,0,0], so we'll
+ # filter out any offsets that are a multiple of $planes
+ # A 1 is encoded as [1,0,0,0], not [0,1,0,0], so subtract 1 from offsets
+ nonzero_elements = (capped != 0).ravel()
+ nonzero_index_offsets = onehot_index_offsets[nonzero_elements] - 1
+ onehot_features.ravel()[nonzero_index_offsets] = 1
return onehot_features
def planes(num_planes):
View
@@ -34,13 +34,12 @@ def iter_chunks(chunk_size, iterator):
else:
break
-def make_onehot(dense_labels, num_classes):
- dense_labels = np.fromiter(dense_labels, dtype=np.int16)
- num_labels = dense_labels.shape[0]
- index_offset = np.arange(num_labels) * num_classes
- labels_one_hot = np.zeros((num_labels, num_classes), dtype=np.int16)
- labels_one_hot.flat[index_offset + dense_labels.ravel()] = 1
- return labels_one_hot
+def make_onehot(coords):
+ num_positions = len(coords)
+ output = np.zeros([num_positions, go.N ** 2], dtype=np.uint8)
+ for i, coord in enumerate(coords):
+ output[i, utils.flatten_coords(coord)] = 1
+ return output
def find_sgf_files(*dataset_dirs):
for dataset_dir in dataset_dirs:
@@ -106,7 +105,7 @@ def get_batch(self, batch_size):
def from_positions_w_context(positions_w_context, is_test=False):
positions, next_moves, results = zip(*positions_w_context)
extracted_features = extract_features(positions)
- encoded_moves = make_onehot(map(utils.flatten_coords, next_moves), go.N ** 2)
+ encoded_moves = make_onehot(next_moves)
return DataSet(extracted_features, encoded_moves, results, is_test=is_test)
def write(self, filename):
View
@@ -51,15 +51,12 @@ def set_up_network(self):
y = tf.placeholder(tf.float32, shape=[None, go.N ** 2])
#convenience functions for initializing weights and biases
- # http://neuralnetworksanddeeplearning.com/chap3.html#weight_initialization
- def _product(numbers):
- return functools.reduce(operator.mul, numbers)
-
def _weight_variable(shape, name):
# If shape is [5, 5, 20, 32], then each of the 32 output planes
# has 5 * 5 * 20 inputs.
- number_inputs_added = _product(shape[:-1])
+ number_inputs_added = utils.product(shape[:-1])
stddev = 1 / math.sqrt(number_inputs_added)
+ # http://neuralnetworksanddeeplearning.com/chap3.html#weight_initialization
return tf.Variable(tf.truncated_normal(shape, stddev=stddev), name=name)
def _conv2d(x, W):
@@ -1,5 +1,7 @@
+import numpy as np
import os
from test_utils import GoPositionTestCase
+import go
import load_data_sets
TEST_DIR = os.path.dirname(os.path.realpath(__file__))
@@ -26,3 +28,13 @@ def test_dataset_serialization(self):
self.assertEqual(dataset.next_moves.shape, recovered.next_moves.shape)
self.assertEqualNPArray(dataset.next_moves, recovered.next_moves)
self.assertEqualNPArray(dataset.pos_features, recovered.pos_features)
+
+class TestDataSetHelpers(GoPositionTestCase):
+ def test_onehot(self):
+ go.set_board_size(9)
+ coords = [(1, 2), (3, 4)]
+ expected = np.zeros([2, 81], dtype=np.uint8)
+ expected[0, 11] = 1
+ expected[1, 31] = 1
+ onehot = load_data_sets.make_onehot(coords)
+ self.assertEqualNPArray(onehot, expected)
View
@@ -1,3 +1,6 @@
+from collections import defaultdict
+import time
+import functools, operator
import gtp
import go
@@ -40,3 +43,21 @@ def unparse_pygtp_coords(c):
if c is None:
return gtp.PASS
return c[1] + 1, go.N - c[0]
+
+def product(numbers):
+ return functools.reduce(operator.mul, numbers)
+
+
+class timer(object):
+ all_times = defaultdict(float)
+ def __init__(self, label):
+ self.label = label
+ def __enter__(self):
+ self.tick = time.time()
+ def __exit__(self, type, value, traceback):
+ self.tock = time.time()
+ self.all_times[self.label] += self.tock - self.tick
+ @classmethod
+ def print_times(cls):
+ for k, v in cls.all_times.items():
+ print("%s: %.3f" % (k, v))

0 comments on commit 3f09a5c

Please sign in to comment.