Skip to content
This repository was archived by the owner on Feb 22, 2020. It is now read-only.

Commit 57cc95f

Browse files
committed
feat(encoder): add quantizer
1 parent bbf4283 commit 57cc95f

File tree

3 files changed

+18
-9
lines changed

3 files changed

+18
-9
lines changed

gnes/encoder/numeric/quantizer.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ def _get_centroids(self):
5959
endpoint=False, retstep=False, dtype=None)[1:]
6060
coordinates = np.tile(axis_point, (self.dim_per_byte, 1))
6161
elif self.partition_method == 'random':
62-
coordinates = np.random.randint(self.lower_bound, self.upper_bound,
62+
coordinates = np.random.uniform(self.lower_bound, self.upper_bound,
6363
size=[self.dim_per_byte, num_sample_per_dim])
6464
else:
6565
raise NotImplementedError
@@ -95,12 +95,12 @@ def _get_max_min_value(vecs):
9595

9696
def _check_bound(self, max_value, min_value):
9797
if self.upper_bound < max_value:
98-
self.logger.warning("upper bound (=%.3f) is smaller than max value of input data (=%.3f), you should choose"
98+
raise Warning("upper bound (=%.3f) is smaller than max value of input data (=%.3f), you should choose"
9999
"a bigger value for upper bound" % (self.upper_bound, max_value))
100100
if self.lower_bound > min_value:
101-
self.logger.warning("lower bound (=%.3f) is bigger than min value of input data (=%.3f), you should choose"
101+
raise Warning("lower bound (=%.3f) is bigger than min value of input data (=%.3f), you should choose"
102102
"a smaller value for lower bound" % (self.lower_bound, min_value))
103103
if (self.upper_bound-self.lower_bound) >= 10*(max_value - min_value):
104-
self.logger.warning("(upper bound - lower_bound) (=%.3f) is 10 times larger than (max value - min value) "
104+
raise Warning("(upper bound - lower_bound) (=%.3f) is 10 times larger than (max value - min value) "
105105
"(=%.3f) of data, maybe you should choose a suitable bound" %
106106
((self.upper_bound-self.lower_bound), (max_value - min_value)))

tests/test_quantizer_encoder.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,22 @@
77

88
class TestQuantizerEncoder(unittest.TestCase):
99
def setUp(self):
10-
self.vecs = np.random.randint(-150, 150, size=[1000, 160]).astype('float32')
1110
dirname = os.path.dirname(__file__)
1211
self.vanilla_quantizer_yaml = os.path.join(dirname, 'yaml', 'quantizer_encoder.yml')
1312

1413
def test_vanilla_quantizer(self):
1514
encoder = BaseNumericEncoder.load_yaml(self.vanilla_quantizer_yaml)
1615
encoder.train()
17-
out = encoder.encode(self.vecs)
18-
print(out.shape)
16+
17+
vecs_1 = np.random.uniform(-150, 150, size=[1000, 160]).astype('float32')
18+
out = encoder.encode(vecs_1)
19+
self.assertEqual(len(out.shape), 2)
20+
self.assertEqual(out.shape[0], 1000)
21+
self.assertEqual(out.shape[1], 16)
22+
23+
vecs_2 = np.random.uniform(-1, 1, size=[1000, 160]).astype('float32')
24+
self.assertRaises(Warning, encoder.encode, vecs_2)
25+
26+
vecs_3 = np.random.uniform(-1, 1000, size=[1000, 160]).astype('float32')
27+
self.assertRaises(Warning, encoder.encode, vecs_3)
1928

tests/yaml/quantizer_encoder.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
!QuantizerEncoder
22
parameters:
3-
upper_bound: 1000000
4-
lower_bound: -100
3+
upper_bound: 500
4+
lower_bound: -200
55
partition_method: 'random'
66
cluster_per_byte: 255
77
dim_per_byte: 10

0 commit comments

Comments
 (0)