Skip to content

Commit

Permalink
Fix unittest errors and Update encoder naming convention (#1617)
Browse files Browse the repository at this point in the history
* fix unittest
* change query encoder naming convention to python
  • Loading branch information
ArthurChen189 committed Sep 3, 2023
1 parent 7bd7fef commit fff033d
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 6 deletions.
6 changes: 4 additions & 2 deletions pyserini/search/lucene/_impact_searcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,9 +110,11 @@ def from_prebuilt_index(cls, prebuilt_index_name: str, query_encoder: Union[Quer
return cls(index_dir, query_encoder, min_idf, encoder_type, prebuilt_index_name=prebuilt_index_name)

def encode(self, query):
if self.encoder_type == 'pytorch':
if self.encoder_type == 'onnx':
encoded_query = self.object.encode_with_onnx(query)
elif self.encoder_type == 'pytorch':
encoded_query = self.query_encoder.encode(query)
else: raise ValueError(f'Invalid query encoder type: {type(query_encoder)} for encode')
else: raise ValueError(f'Invalid query encoder type: {type(self.query_encoder)} for encode')
return encoded_query

@staticmethod
Expand Down
8 changes: 4 additions & 4 deletions tests/test_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,10 +226,10 @@ def test_onnx_encode_unicoil(self):
temp_object = LuceneImpactSearcher(f'{self.index_dir}lucene9-index.cacm', 'SpladePlusPlusEnsembleDistil', encoder_type='onnx')

# this function will never be called in _impact_searcher, here to check quantization correctness
results = temp_object.object.encodeWithOnnx("here is a test")
self.assertAlmostEqual(results.get("here"), 156, delta=2e-4)
self.assertAlmostEqual(results.get("a"), 31, delta=2e-4)
self.assertAlmostEqual(results.get("test"), 149, delta=2e-4)
results = temp_object.encode("here is a test")
self.assertEqual(results.get("here"), 156)
self.assertEqual(results.get("a"), 31)
self.assertEqual(results.get("test"), 149)

temp_object.close()
del temp_object
Expand Down

0 comments on commit fff033d

Please sign in to comment.