Skip to content
Permalink
Browse files

fix(encoder): fix vald encoder and add unittest

  • Loading branch information...
Larryjianfeng committed Sep 9, 2019
1 parent f8e18d0 commit 1ba4e11cb7f18b97cb35faed61b7d82fb512cd84
Showing with 25 additions and 15 deletions.
  1. +25 −15 gnes/encoder/numeric/vlad.py
@@ -25,22 +25,31 @@
class VladEncoder(BaseNumericEncoder):
batch_size = 2048

def __init__(self, num_clusters: int, *args, **kwargs):
def __init__(self, num_clusters: int,
using_faiss_pred: True,
*args, **kwargs):
super().__init__(*args, **kwargs)
self.num_clusters = num_clusters
self.using_faiss_pred = using_faiss_pred
self.centroids = None
self.index_flat = None

def kmeans_train(self, vecs):
import faiss

kmeans = faiss.Kmeans(vecs.shape[1], self.num_clusters, niter=5, verbose=False)
kmeans.train(vecs)
self.centroids = kmeans.centroids
self.index_flat = faiss.IndexFlatL2(self.centroids.shape[1])
self.index_flat.add(self.centroids)

def kmeans_pred(self, vecs):
vecs = np.reshape(vecs, [vecs.shape[0], 1, 1, vecs.shape[1]])
dist = np.sum(np.square(vecs - self.centroids), -1)
return np.argmax(-dist, axis=-1).astype(np.int64)
if self.using_faiss_pred:
D, I = self.index_flat.search(vecs.astype(np.float32), 1)
return np.reshape(I, [-1])
else:
vecs = np.reshape(vecs, [vecs.shape[0], 1, 1, vecs.shape[1]])
dist = np.sum(np.square(vecs - self.centroids), -1)
return np.argmax(-dist, axis=-1).reshape([-1]).astype(np.int32)

@batching
def train(self, vecs: np.ndarray, *args, **kwargs):
@@ -52,24 +61,25 @@ def train(self, vecs: np.ndarray, *args, **kwargs):
@train_required
@batching
def encode(self, vecs: np.ndarray, *args, **kwargs) -> np.ndarray:
vecs_ = copy.deepcopy(vecs)
vecs_ = np.concatenate((list(vecs_[i] for i in range(len(vecs_)))), axis=0)

knn_output = self.kmeans_pred(vecs_)
knn_output = [knn_output[i:i + vecs.shape[1]] for i in range(0, len(knn_output), vecs.shape[1])]
knn_output = [self.kmeans_pred(vecs_) for vecs_ in vecs]

output = []
for chunk_count, chunk in enumerate(vecs):
res = np.zeros((self.centroids.shape[0], self.centroids.shape[1]))
for frame_count, frame in enumerate(chunk):
center_index = knn_output[chunk_count][frame_count][0]
center_index = knn_output[chunk_count][frame_count]
res[center_index] += (frame - self.centroids[center_index])
output.append(res)
res = res.reshape([-1])
output.append(res / np.sum(res**2)**0.5)

output = np.array(list(map(lambda x: x.reshape(1, -1), output)), dtype=np.float32)
output = np.squeeze(output, axis=1)
return output
return np.array(output, dtype=np.float32)

def _copy_from(self, x: 'VladEncoder') -> None:
self.num_clusters = x.num_clusters
self.centroids = x.centroids
self.using_faiss_pred = x.using_faiss_pred
if self.using_faiss_pred:
import faiss
self.index_flat = faiss.IndexFlatL2(self.centroids.shape[1])
self.index_flat.add(self.centroids)

0 comments on commit 1ba4e11

Please sign in to comment.
You can’t perform that action at this time.