Skip to content
This repository was archived by the owner on Feb 22, 2020. It is now read-only.

Commit 2fd8dab

Browse files
author
Han Xiao
authored
style: minor fix on the styling
1 parent 57cc95f commit 2fd8dab

File tree

1 file changed

+7
-10
lines changed

1 file changed

+7
-10
lines changed

gnes/encoder/numeric/quantizer.py

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,7 @@ def __init__(self, dim_per_byte: int, cluster_per_byte: int = 255,
3636
self.upper_bound = upper_bound
3737
self.lower_bound = lower_bound
3838
self.partition_method = partition_method
39-
self.centroids = None
40-
self._get_centroids()
39+
self.centroids = self._get_centroids()
4140

4241
def _get_centroids(self):
4342
"""
@@ -52,7 +51,7 @@ def _get_centroids(self):
5251
if self.upper_bound < self.lower_bound:
5352
raise ValueError("upper bound is smaller than lower bound")
5453

55-
self.centroids = []
54+
centroids = []
5655
num_sample_per_dim = np.ceil(pow(self.num_clusters, 1 / self.dim_per_byte)).astype(np.uint8)
5756
if self.partition_method == 'average':
5857
axis_point = np.linspace(self.lower_bound, self.upper_bound, num=num_sample_per_dim+1,
@@ -65,16 +64,13 @@ def _get_centroids(self):
6564
raise NotImplementedError
6665

6766
for item in product(*coordinates):
68-
self.centroids.append(list(item))
69-
self.centroids = self.centroids[:self.num_clusters]
67+
centroids.append(list(item))
68+
return centroids[:self.num_clusters]
7069

7170
@batching
7271
def encode(self, vecs: np.ndarray, *args, **kwargs) -> np.ndarray:
72+
self._check_bound(vecs)
7373
num_bytes = self._get_num_bytes(vecs)
74-
max_value, min_value = self._get_max_min_value(vecs)
75-
76-
self._check_bound(max_value, min_value)
77-
7874
x = np.reshape(vecs, [vecs.shape[0], num_bytes, 1, self.dim_per_byte])
7975
x = np.sum(np.square(x - self.centroids), -1)
8076
# start from 1
@@ -93,7 +89,8 @@ def _get_num_bytes(self, vecs: np.ndarray):
9389
def _get_max_min_value(vecs):
9490
return np.amax(vecs, axis=None), np.amin(vecs, axis=None)
9591

96-
def _check_bound(self, max_value, min_value):
92+
def _check_bound(self, vecs):
93+
max_value, min_value = self._get_max_min_value(vecs)
9794
if self.upper_bound < max_value:
9895
raise Warning("upper bound (=%.3f) is smaller than max value of input data (=%.3f), you should choose"
9996
"a bigger value for upper bound" % (self.upper_bound, max_value))

0 commit comments

Comments
 (0)