Skip to content

Commit

Permalink
PyLucene 8 supported.
Browse files Browse the repository at this point in the history
  • Loading branch information
coady committed Jun 23, 2019
1 parent 9d6ef1b commit fcf60c9
Show file tree
Hide file tree
Showing 7 changed files with 56 additions and 37 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ Optional server extras:
# Changes
dev
* PyLucene >=7.7 required
* PyLucene 8 supported

2.2
* PyLucene 7.6 supported
Expand Down
4 changes: 1 addition & 3 deletions lupyne/engine/analyzers.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,9 +106,7 @@ def __init__(self, tokenizer, *filters):
@classmethod
def standard(cls, *filters):
"""Return equivalent of StandardAnalyzer with additional filters."""
def stop(tokens):
return analysis.StopFilter(tokens, analysis.standard.StandardAnalyzer.STOP_WORDS_SET)
return cls(analysis.standard.StandardTokenizer, analysis.standard.StandardFilter, analysis.LowerCaseFilter, stop, *filters)
return cls(analysis.standard.StandardTokenizer, analysis.LowerCaseFilter, *filters)

@classmethod
def whitespace(cls, *filters):
Expand Down
9 changes: 6 additions & 3 deletions lupyne/engine/documents.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,11 +280,14 @@ class Hits(object):
:param searcher: `IndexSearcher`_ which can retrieve documents
:param scoredocs: lucene ScoreDocs
:param count: total number of hits
:param count: total number of hits; float indicates estimate
:param fields: optional field selectors
"""
def __init__(self, searcher, scoredocs, count=None, fields=None):
def __init__(self, searcher, scoredocs, count=0, fields=None):
self.searcher, self.scoredocs = searcher, scoredocs
if hasattr(count, 'relation'): # pragma: no cover
cls = int if count.relation == search.TotalHits.Relation.EQUAL_TO else float
count = cls(count.value)
self.count, self.fields = count, fields

def select(self, *fields):
Expand Down Expand Up @@ -369,7 +372,7 @@ class Groups(object):
"""Sequence of grouped `Hits`_."""
select = Hits.__dict__['select']

def __init__(self, searcher, groupdocs, count=None, fields=None):
def __init__(self, searcher, groupdocs, count=0, fields=None):
self.searcher, self.groupdocs = searcher, groupdocs
self.count, self.fields = count, fields

Expand Down
41 changes: 28 additions & 13 deletions lupyne/engine/indexers.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
from .documents import Field, Document, Hits, GroupingSearch
from .utils import long, suppress, Atomic, SpellChecker

LU7 = lucene.VERSION < '8'


class closing(set):
"""Manage lifespan of registered objects, similar to contextlib.closing."""
Expand Down Expand Up @@ -102,7 +104,8 @@ def __iter__(self):

@property
def bits(self):
return index.MultiFields.getLiveDocs(self.indexReader)
cls = index.MultiFields if LU7 else index.MultiBits
return cls.getLiveDocs(self.indexReader)

@property
def directory(self):
Expand Down Expand Up @@ -132,7 +135,8 @@ def segments(self):
@property
def fieldinfos(self):
"""mapping of field names to lucene FieldInfos"""
fieldinfos = index.MultiFields.getMergedFieldInfos(self.indexReader)
cls = index.MultiFields if LU7 else index.FieldInfos
fieldinfos = cls.getMergedFieldInfos(self.indexReader)
return {fieldinfo.name: fieldinfo for fieldinfo in fieldinfos.iterator()}

def suggest(self, name, value, count=1, **attrs):
Expand Down Expand Up @@ -205,7 +209,8 @@ def terms(self, name, value='', stop='', counts=False, distance=0, prefix=0):
:param distance: maximum edit distance for fuzzy terms
:param prefix: prefix length for fuzzy terms
"""
terms = index.MultiFields.getTerms(self.indexReader, name)
cls = index.MultiFields if LU7 else index.MultiTerms
terms = cls.getTerms(self.indexReader, name)
if not terms:
return iter([])
term, termsenum = index.Term(name, value), terms.iterator()
Expand All @@ -222,13 +227,15 @@ def terms(self, name, value='', stop='', counts=False, distance=0, prefix=0):

def docs(self, name, value, counts=False):
"""Generate doc ids which contain given term, optionally with frequency counts."""
docsenum = index.MultiFields.getTermDocsEnum(self.indexReader, name, util.BytesRef(value))
func = index.MultiFields.getTermDocsEnum if LU7 else index.MultiTerms.getTermPostingsEnum
docsenum = func(self.indexReader, name, util.BytesRef(value))
docs = iter(docsenum.nextDoc, index.PostingsEnum.NO_MORE_DOCS) if docsenum else ()
return ((doc, docsenum.freq()) for doc in docs) if counts else iter(docs)

def positions(self, name, value, payloads=False, offsets=False):
"""Generate doc ids and positions which contain given term, optionally with offsets, or only ones with payloads."""
docsenum = index.MultiFields.getTermPositionsEnum(self.indexReader, name, util.BytesRef(value))
func = index.MultiFields.getTermPositionsEnum if LU7 else index.MultiTerms.getTermPostingsEnum
docsenum = func(self.indexReader, name, util.BytesRef(value))
for doc in (iter(docsenum.nextDoc, index.PostingsEnum.NO_MORE_DOCS) if docsenum else ()):
positions = (docsenum.nextPosition() for _ in range(docsenum.freq()))
if payloads:
Expand Down Expand Up @@ -332,7 +339,8 @@ def spans(self, query, positions=False):
:param positions: optionally include slice positions instead of counts
"""
offset = 0
weight = query.createWeight(self, False, 1.0)
scores = False if LU7 else search.ScoreMode.COMPLETE_NO_SCORES
weight = query.createWeight(self, scores, 1.0)
postings = search.spans.SpanWeight.Postings.POSITIONS
for reader in self.readers:
try:
Expand Down Expand Up @@ -371,19 +379,22 @@ def count(self, *query, **options):
query = self.parse(*query, **options) if query else Query.alldocs()
return super(IndexSearcher, self).count(query)

def collector(self, count=None, sort=None, reverse=False, scores=False):
def collector(self, count=None, sort=None, reverse=False, scores=False, mincount=1000):
if count is None:
return search.CachingCollector.create(True, float('inf'))
count = min(count, self.maxDoc() or 1)
mincount = max(count, mincount)
args = [] if LU7 else [mincount]
if sort is None:
return search.TopScoreDocCollector.create(count)
return search.TopScoreDocCollector.create(count, *args)
if isinstance(sort, string_types):
sort = self.sortfield(sort, reverse=reverse)
if not isinstance(sort, search.Sort):
sort = search.Sort(sort)
return search.TopFieldCollector.create(sort, count, True, scores, False)
args = [True, scores, False] if LU7 else [mincount]
return search.TopFieldCollector.create(sort, count, *args)

def search(self, query=None, count=None, sort=None, reverse=False, scores=False, timeout=None, **parser):
def search(self, query=None, count=None, sort=None, reverse=False, scores=False, mincount=1000, timeout=None, **parser):
"""Run query and return `Hits`_.
.. versionchanged:: 1.4 sort param for lucene only; use Hits.sorted with a callable
Expand All @@ -394,11 +405,12 @@ def search(self, query=None, count=None, sort=None, reverse=False, scores=False,
:param sort: lucene Sort parameters
:param reverse: reverse flag used with sort
:param scores: compute scores for candidate results when sorting
:param mincount: total hit count accuracy threshold
:param timeout: stop search after elapsed number of seconds
:param parser: :meth:`Analyzer.parse` options
"""
query = Query.alldocs() if query is None else self.parse(query, **parser)
cache = collector = self.collector(count, sort, reverse, scores)
cache = collector = self.collector(count, sort, reverse, scores, mincount)
counter = search.TimeLimitingCollector.getGlobalCounter()
results = collector if timeout is None else search.TimeLimitingCollector(collector, counter, long(timeout * 1000))
with suppress(search.TimeLimitingCollector.TimeExceededException):
Expand All @@ -407,10 +419,13 @@ def search(self, query=None, count=None, sort=None, reverse=False, scores=False,
if isinstance(cache, search.CachingCollector):
collector = search.TotalHitCountCollector()
cache.replay(collector)
collector = self.collector(collector.totalHits or 1, sort, reverse, scores)
count = collector.totalHits or 1
collector = self.collector(count, sort, reverse, scores, count)
cache.replay(collector)
topdocs = collector.topDocs()
return Hits(self, topdocs.scoreDocs, topdocs.totalHits if timeout is None else None)
if not LU7 and scores: # pragma: no cover
search.TopFieldCollector.populateScores(topdocs.scoreDocs, self, query)
return Hits(self, topdocs.scoreDocs, topdocs.totalHits)

def facets(self, query, *fields, **query_map):
"""Return mapping of document counts for the intersection with each facet.
Expand Down
11 changes: 6 additions & 5 deletions lupyne/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -374,7 +374,7 @@ def search(self, q=None, count: int = None, start: int = 0, fields: multi = None
&q=\ *chars*\ &q.type=[term|prefix|wildcard]&q.spellcheck=true&q.\ *chars*\ =...,
query, optional type to skip parsing, spellcheck, and parser settings: q.field, q.op,...
&count=\ *int*\ &start=0
&count=\ *int*\ &start=0&count.min=1000
maximum number of docs to return and offset to start at
&fields=\ *chars*,... &fields.multi=\ *chars*,... &fields.docvalues=\ *chars*\ [:*chars*],...
Expand Down Expand Up @@ -434,16 +434,17 @@ def search(self, q=None, count: int = None, start: int = 0, fields: multi = None
start = count = 1
scores = 'sort.scores' in options
gcount = options.get('group.count', 1)
mincount = options.get('count.min', 1000)
if ':' in group:
hits = searcher.search(q, sort=sort, timeout=timeout, scores=scores)
hits = searcher.search(q, sort=sort, timeout=timeout, scores=scores, mincount=mincount)
name, docvalues = parse.docvalues(searcher, group)
with HTTPError(TypeError):
groups = hits.groupby(docvalues.select(hits.ids).__getitem__, count=count, docs=gcount)
groups.groupdocs = groups.groupdocs[start:]
elif group:
groups = searcher.groupby(group, q, count, start=start, sort=sort, groupDocsLimit=gcount, includeScores=scores)
groups = searcher.groupby(group, q, count, start=start, sort=sort, groupDocsLimit=gcount)
else:
hits = searcher.search(q, sort=sort, count=count, timeout=timeout, scores=scores)
hits = searcher.search(q, sort=sort, count=count, timeout=timeout, scores=scores, mincount=mincount)
groups = engine.documents.Groups(searcher, [hits[start:]], hits.count)
result = {'query': q and str(q), 'count': groups.count}
fields, multi, docvalues = parse.fields(searcher, fields, **options)
Expand Down Expand Up @@ -480,7 +481,7 @@ def search(self, q=None, count: int = None, start: int = 0, fields: multi = None
facets[name] = {term: counts[term] for term in heapq.nlargest(options['facets.count'], counts, key=counts.__getitem__)}
return result
search.__annotations__.update({'fields.multi': multi, 'fields.docvalues': multi, 'facets.count': int, 'facets.min': int,
'group.count': int, 'hl.count': int, 'mlt.fields': multi})
'group.count': int, 'hl.count': int, 'mlt.fields': multi, 'count.min': int})

@cherrypy.expose
@cherrypy.tools.params()
Expand Down
13 changes: 7 additions & 6 deletions tests/test_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,13 +159,14 @@ def test_searcher(tempdir, fields, constitution):
assert set(map(type, hits.ids)) == {int} and set(map(type, hits.scores)) == {float}
assert hits.maxscore == next(hits.scores)
ids = list(hits.ids)
hits = indexer.search('people', count=5, field='text')
hits = indexer.search('people', count=5, mincount=5, field='text')
assert list(hits.ids) == ids[:len(hits)]
assert len(hits) == 5 and hits.count == 8
assert not any(map(math.isnan, hits.scores))
assert hits.maxscore == next(hits.scores)
hits = indexer.search('text:people', count=5, sort=search.Sort.INDEXORDER)
hits = indexer.search('text:people', count=5, sort=search.Sort.INDEXORDER, scores=True)
assert sorted(hits.ids) == list(hits.ids)
assert all(score > 0 for score in hits.scores)
hit, = indexer.search('freedom', field='text')
assert hit['amendment'] == '1'
assert sorted(hit.dict()) == ['__id__', '__score__', 'amendment', 'date']
Expand All @@ -186,10 +187,10 @@ def test_searcher(tempdir, fields, constitution):
assert dict(indexer.positionvector(id, 'text', offsets=True))['persons'] == [(46, 53), (301, 308)]
analyzer = analysis.core.WhitespaceAnalyzer()
query = indexer.morelikethis(0, analyzer=analyzer)
assert set(str(query).split()) == {'text:united', 'text:states'}
assert {'text:united', 'text:states'} <= set(str(query).split())
assert str(indexer.morelikethis(0, 'article', analyzer=analyzer)) == ''
query = indexer.morelikethis(0, minDocFreq=3, analyzer=analyzer)
assert set(str(query).split()) == {'text:establish', 'text:united', 'text:states'}
assert {'text:establish', 'text:united', 'text:states'} <= set(str(query).split())
assert str(indexer.morelikethis('jury', 'text', minDocFreq=4, minTermFreq=1, analyzer=analyzer)) == 'text:jury'
assert str(indexer.morelikethis('jury', 'article', analyzer=analyzer)) == ''

Expand All @@ -201,7 +202,7 @@ def test_spellcheck(fields, constitution):
indexer.add(doc)
indexer.commit()
assert indexer.complete('missing', '') == []
assert indexer.complete('text', '')[:8] == ['shall', 'states', 'any', 'have', 'united', 'congress', 'state', 'constitution']
assert {'shall', 'states'} <= set(indexer.complete('text', '')[:8])
assert indexer.complete('text', 'con')[:2] == ['congress', 'constitution']
assert indexer.complete('text', 'congress') == indexer.complete('text', 'con', count=1) == ['congress']
assert indexer.complete('text', 'congresses') == []
Expand Down Expand Up @@ -378,7 +379,7 @@ def test_grouping(tempdir, indexer, zipcodes):
assert len(grouping) == len(list(grouping)) > 100
assert set(grouping) > set(facets)
hits = indexer.search(query, timeout=-1)
assert not hits and hits.count is None and math.isnan(hits.maxscore)
assert not hits and not hits.count and math.isnan(hits.maxscore)
hits = indexer.search(query, timeout=10)
assert len(hits) == hits.count == indexer.count(query) and hits.maxscore == 1.0
directory = store.RAMDirectory()
Expand Down
14 changes: 7 additions & 7 deletions tests/test_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def test_docs(resource):
doc = resource.docs('0', **{'fields.vector.counts': 'text'})
resource.client.patch('docs/amendment/1', {'amendment': '1'}).status_code == http.client.CONFLICT
assert not resource.patch('docs/amendment/1')
assert sorted(term for term, count in doc['text'].items() if count > 1) == ['establish', 'states', 'united']
assert {term for term, count in doc['text'].items() if count > 1} >= {'establish', 'states', 'united'}
assert resource.client.put('docs/article/0', {'article': '-1'}).status_code == http.client.BAD_REQUEST
assert resource.delete('docs/article/0') is None

Expand All @@ -95,7 +95,7 @@ def test_terms(resource):
assert resource.terms('text/right~1') == ['eight', 'right', 'rights']
assert resource.terms('text/right~') == ['eight', 'high', 'right', 'rights']
assert resource.terms('text/right~?count=3') == []
assert resource.terms('text/write~?count=5') == ['writs', 'writ', 'written']
assert resource.terms('text/write~?count=3') == ['writs', 'writ', 'written']
docs = resource.terms('text/people/docs')
assert resource.terms('text/people') == len(docs) == 8
counts = dict(resource.terms('text/people/docs/counts'))
Expand All @@ -117,7 +117,7 @@ def test_search(resource):
assert result['count'] == 1 and result['query'] == 'article:Preamble*'
result = resource.search(q='text:"We the People"', **{'q.phraseSlop': 3})
assert result['count'] == 1
assert result['query'] == 'text:"we ? people"~3'
assert result['query'].startswith('text:"we ') and result['query'].endswith(' people"~3')
doc, = result['docs']
assert sorted(doc) == ['__id__', '__score__', 'article']
assert doc['article'] == 'Preamble' and doc['__id__'] >= 0 and 0 < doc['__score__']
Expand Down Expand Up @@ -156,14 +156,14 @@ def test_highlights(resource):
highlights = [doc['__highlights__'] for doc in result['docs']]
assert all(highlight['amendment'] or highlight['article'] for highlight in highlights)
result = resource.search(mlt=0)
assert result['count'] == 25 and set(result['query'].split()) == {'text:united', 'text:states'}
assert result['count'] >= 25 and set(result['query'].split()) >= {'text:united', 'text:states'}
result = resource.search(q='amendment:2', mlt=0, **{'mlt.fields': 'text', 'mlt.minTermFreq': 1, 'mlt.minWordLen': 6})
assert result['count'] == 11 and set(result['query'].split()) == {'text:necessary', 'text:people'}
assert [doc['amendment'] for doc in result['docs'][:3]] == ['2', '9', '10']
result = resource.search(q='text:people', count=1, timeout=-1)
assert result == {'query': 'text:people', 'count': None, 'docs': []}
assert result == {'query': 'text:people', 'count': 0, 'docs': []}
result = resource.search(q='text:people', timeout=0.01)
assert result['count'] in (None, 8)
assert result['count'] in (0, 8)
result = resource.search(q='+text:right +text:people')
assert result['count'] == 4
assert resource.search(q='hello', **{'q.field': 'body.title^2.0'})['query'] == '(body.title:hello)^2.0'
Expand Down Expand Up @@ -199,7 +199,7 @@ def test_facets(tempdir, servers, zipcodes):
writer.add(zipcode=int(doc['zipcode']), location='{}.{}'.format(doc['county'], doc['city']))
writer.commit()
assert resource.post('update') == resource().popitem()[1] == len(writer)
result = resource.search(count=0, facets='county')
result = resource.search(count=0, facets='county', **{'count.min': 10000})
facets = result['facets']['county']
assert result['count'] == sum(facets.values()) and 'Los Angeles' in facets
result = resource.search(q='Los Angeles', count=0, facets='county.city', **{'q.type': 'term', 'q.field': 'county'})
Expand Down

0 comments on commit fcf60c9

Please sign in to comment.