@@ -36,8 +36,7 @@ def __init__(self, dim_per_byte: int, cluster_per_byte: int = 255,
36
36
self .upper_bound = upper_bound
37
37
self .lower_bound = lower_bound
38
38
self .partition_method = partition_method
39
- self .centroids = None
40
- self ._get_centroids ()
39
+ self .centroids = self ._get_centroids ()
41
40
42
41
def _get_centroids (self ):
43
42
"""
@@ -52,7 +51,7 @@ def _get_centroids(self):
52
51
if self .upper_bound < self .lower_bound :
53
52
raise ValueError ("upper bound is smaller than lower bound" )
54
53
55
- self . centroids = []
54
+ centroids = []
56
55
num_sample_per_dim = np .ceil (pow (self .num_clusters , 1 / self .dim_per_byte )).astype (np .uint8 )
57
56
if self .partition_method == 'average' :
58
57
axis_point = np .linspace (self .lower_bound , self .upper_bound , num = num_sample_per_dim + 1 ,
@@ -65,16 +64,13 @@ def _get_centroids(self):
65
64
raise NotImplementedError
66
65
67
66
for item in product (* coordinates ):
68
- self . centroids .append (list (item ))
69
- self . centroids = self . centroids [:self .num_clusters ]
67
+ centroids .append (list (item ))
68
+ return centroids [:self .num_clusters ]
70
69
71
70
@batching
72
71
def encode (self , vecs : np .ndarray , * args , ** kwargs ) -> np .ndarray :
72
+ self ._check_bound (vecs )
73
73
num_bytes = self ._get_num_bytes (vecs )
74
- max_value , min_value = self ._get_max_min_value (vecs )
75
-
76
- self ._check_bound (max_value , min_value )
77
-
78
74
x = np .reshape (vecs , [vecs .shape [0 ], num_bytes , 1 , self .dim_per_byte ])
79
75
x = np .sum (np .square (x - self .centroids ), - 1 )
80
76
# start from 1
@@ -93,7 +89,8 @@ def _get_num_bytes(self, vecs: np.ndarray):
93
89
def _get_max_min_value (vecs ):
94
90
return np .amax (vecs , axis = None ), np .amin (vecs , axis = None )
95
91
96
- def _check_bound (self , max_value , min_value ):
92
+ def _check_bound (self , vecs ):
93
+ max_value , min_value = self ._get_max_min_value (vecs )
97
94
if self .upper_bound < max_value :
98
95
raise Warning ("upper bound (=%.3f) is smaller than max value of input data (=%.3f), you should choose"
99
96
"a bigger value for upper bound" % (self .upper_bound , max_value ))
0 commit comments