Skip to content
Permalink
Browse files

fix(encoder): fix vlad to speed up centroids calculation

  • Loading branch information...
Larryjianfeng committed Sep 9, 2019
1 parent c62fa3f commit 654a5ba40a30ef51d57ab6ff0942c77d68d5a102
Showing with 8 additions and 4 deletions.
  1. +8 −4 gnes/encoder/numeric/vlad.py
@@ -24,7 +24,7 @@ class VladEncoder(BaseNumericEncoder):
batch_size = 2048

def __init__(self, num_clusters: int,
using_faiss_pred: bool=True,
using_faiss_pred: bool = False,
*args, **kwargs):
super().__init__(*args, **kwargs)
self.num_clusters = num_clusters
@@ -37,6 +37,8 @@ def kmeans_train(self, vecs):
kmeans = faiss.Kmeans(vecs.shape[1], self.num_clusters, niter=5, verbose=False)
kmeans.train(vecs)
self.centroids = kmeans.centroids
self.centroids_l2 = np.sum(self.centroids**2, axis=1).reshape([1, -1])
self.centroids_trans = np.transpose(self.centroids)
if self.using_faiss_pred:
self.faiss_index()

@@ -50,9 +52,9 @@ def kmeans_pred(self, vecs):
_, pred = self.index_flat.search(vecs.astype(np.float32), 1)
return np.reshape(pred, [-1])
else:
vecs = np.reshape(vecs, [vecs.shape[0], 1, 1, vecs.shape[1]])
dist = np.sum(np.square(vecs - self.centroids), -1)
return np.argmax(-dist, axis=-1).reshape([-1]).astype(np.int32)
vecs_l2 = np.sum(vecs**2, axis=1).reshape([-1, 1])
dist = vecs_l2 + self.centroids_l2 - 2 * np.matmul(vecs, self.centroids_trans)
return np.argmax(dist, axis=-1).reshape([-1]).astype(np.int32)

@batching
def train(self, vecs: np.ndarray, *args, **kwargs):
@@ -79,6 +81,8 @@ def encode(self, vecs: np.ndarray, *args, **kwargs) -> np.ndarray:
def _copy_from(self, x: 'VladEncoder') -> None:
self.num_clusters = x.num_clusters
self.centroids = x.centroids
self.centroids_l2 = x.centroids_l2
self.centroids_trans = np.transpose(self.centroids)
self.using_faiss_pred = x.using_faiss_pred
if self.using_faiss_pred:
self.faiss_index()

0 comments on commit 654a5ba

Please sign in to comment.
You can’t perform that action at this time.