Skip to content
Permalink
Browse files

fix(encoder): fix vald in numeric encoder

  • Loading branch information...
Larryjianfeng committed Sep 9, 2019
1 parent 2c9d5ce commit f8e18d067722a9454e13a978cf7be41ea7241ed3
Showing with 26 additions and 0 deletions.
  1. +26 −0 tests/test_vlad.py
@@ -0,0 +1,26 @@
import os
import unittest
import numpy as np
from gnes.encoder.numeric.vlad import VladEncoder


class TestVladEncoder(unittest.TestCase):
def setUp(self):
self.mock_train_data = np.random.random([200, 128])
self.mock_eval_data = np.random.random([2, 2, 128])
self.dump_path = os.path.join(os.path.dirname(__file__), 'vlad.bin')

def tearDown(self):
if os.path.exists(self.dump_path):
os.remove(self.dump_path)

def test_vlad_train(self):
model = VladEncoder(20)
model.train(self.mock_train_data)
self.assertEqual(model.centroids.shape, (20, 128))
model.dump(self.dump_path)

def test_vlad_encode(self):
model = VladEncoder.load(self.dump_path)
v = model.encode(self.mock_eval_data)
self.assertEqual(v.shape, (2, 2560))

0 comments on commit f8e18d0

Please sign in to comment.
You can’t perform that action at this time.