Skip to content
Permalink
Browse files

fix(encoder): fix vlad unittest

  • Loading branch information...
Larryjianfeng committed Sep 9, 2019
1 parent ddf13ff commit ffc822b39e39dad05fbfa84f28a4844a31d3e785
Showing with 19 additions and 10 deletions.
  1. +17 −8 gnes/encoder/numeric/vlad.py
  2. +2 −2 tests/test_vlad.py
@@ -14,8 +14,6 @@
# limitations under the License.


import copy

import numpy as np

from ..base import BaseNumericEncoder
@@ -39,6 +37,11 @@ def kmeans_train(self, vecs):
kmeans = faiss.Kmeans(vecs.shape[1], self.num_clusters, niter=5, verbose=False)
kmeans.train(vecs)
self.centroids = kmeans.centroids
if self.using_faiss_pred:
self.faiss_index()

def faiss_index(self):
import faiss
self.index_flat = faiss.IndexFlatL2(self.centroids.shape[1])
self.index_flat.add(self.centroids)

@@ -53,10 +56,9 @@ def kmeans_pred(self, vecs):

@batching
def train(self, vecs: np.ndarray, *args, **kwargs):
vecs = vecs.reshape([-1, vecs.shape[-1]])
assert len(vecs) > self.num_clusters, 'number of data should be larger than number of clusters'
vecs_ = copy.deepcopy(vecs)
vecs_ = np.concatenate((list(vecs_[i] for i in range(len(vecs_)))), axis=0)
self.kmeans_train(vecs_)
self.kmeans_train(vecs)

@train_required
@batching
@@ -79,7 +81,14 @@ def _copy_from(self, x: 'VladEncoder') -> None:
self.centroids = x.centroids
self.using_faiss_pred = x.using_faiss_pred
if self.using_faiss_pred:
import faiss
self.index_flat = faiss.IndexFlatL2(self.centroids.shape[1])
self.index_flat.add(self.centroids)
self.faiss_index()

def __setstate__(self, state):
super().__setstate__(state)
if self.using_faiss_pred:
self.faiss_index()

def __getstate__(self):
state = super().__getstate__()
del state['index_flat']
return state
@@ -6,8 +6,8 @@

class TestVladEncoder(unittest.TestCase):
def setUp(self):
self.mock_train_data = np.random.random([200, 128])
self.mock_eval_data = np.random.random([2, 2, 128])
self.mock_train_data = np.random.random([1, 200, 128]).astype(np.float32)
self.mock_eval_data = np.random.random([2, 2, 128]).astype(np.float32)
self.dump_path = os.path.join(os.path.dirname(__file__), 'vlad.bin')

def tearDown(self):

0 comments on commit ffc822b

Please sign in to comment.
You can’t perform that action at this time.