Skip to content

Commit

Permalink
Unified highlighting.
Browse files Browse the repository at this point in the history
  • Loading branch information
coady committed Dec 23, 2017
1 parent 122a7fe commit ffc51b9
Show file tree
Hide file tree
Showing 8 changed files with 54 additions and 143 deletions.
12 changes: 0 additions & 12 deletions docs/engine.rst
Original file line number Diff line number Diff line change
Expand Up @@ -196,18 +196,6 @@ SpanQuery

<SpanOrQuery: spanOr(spans)>

Highlighter
^^^^^^^^^^^^^
.. autoclass:: Highlighter
:show-inheritance:
:members:

FastVectorHighlighter
^^^^^^^^^^^^^^^^^^^^^
.. autoclass:: FastVectorHighlighter
:show-inheritance:
:members:

SpellParser
^^^^^^^^^^^^^
.. autoclass:: SpellParser
Expand Down
12 changes: 12 additions & 0 deletions lupyne/engine/analyzers.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from java.lang import Float
from java.util import HashMap
from org.apache.lucene import analysis, queryparser, util
from org.apache.lucene.search import uhighlight
from org.apache.pylucene.analysis import PythonAnalyzer, PythonTokenFilter
from org.apache.pylucene.queryparser.classic import PythonQueryParser
from six import string_types
Expand Down Expand Up @@ -160,3 +161,14 @@ def parse(self, query, field='', op='', parser=None, **attrs):
finally:
if isinstance(parser, PythonQueryParser):
parser.finalize()

@method
def highlight(self, query, field, content, count=1):
"""Return highlighted content.
:param query: lucene Query
:param field: field name
:param content: text
:param count: optional maximum number of passages
"""
return uhighlight.UnifiedHighlighter(None, self).highlightWithoutSearcher(field, query, content, count).toString()
16 changes: 14 additions & 2 deletions lupyne/engine/documents.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,13 +336,25 @@ def items(self):
"""Generate zipped ids and scores."""
return map(operator.attrgetter('doc', 'score'), self.scoredocs)

def highlights(self, query, **fields):
"""Generate highlighted fields for each hit.
:param query: lucene Query
:param field: mapping of fields to maxinum number of passages
"""
mapping = self.searcher.highlighter.highlightFields(list(fields), query, list(self.ids), list(fields.values()))
mapping = {field: lucene.JArray_string.cast_(mapping.get(field)) for field in fields}
return (dict(zip(mapping, values)) for values in zip(*mapping.values()))

def docvalues(self, field, type=None):
"""Return mappoing of docs to docvalues."""
"""Return mapping of docs to docvalues."""
return self.searcher.docvalues(field, type).select(self.ids)

def groupby(self, func, count=None, docs=None):
"""Return ordered list of `Hits`_ grouped by value of function applied to doc ids.
Optionally limit the number of groups and docs per group."""
Optionally limit the number of groups and docs per group.
"""
groups = collections.OrderedDict()
for scoredoc in self.scoredocs:
value = func(scoredoc.doc)
Expand Down
19 changes: 9 additions & 10 deletions lupyne/engine/indexers.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,11 @@
from java.io import File, IOException, StringReader
from java.util import Arrays, HashSet
from org.apache.lucene import analysis, document, index, queries, search, store, util
from org.apache.lucene.search import uhighlight
from six import string_types
from six.moves import filter, map, range, zip
from .analyzers import Analyzer
from .queries import suppress, Query, DocValues, Highlighter, FastVectorHighlighter, SpellParser
from .queries import suppress, Query, DocValues, SpellParser
from .documents import Field, Document, Hits, GroupingSearch
from .spatial import Distances
from ..utils import long, Atomic, SpellChecker
Expand Down Expand Up @@ -356,12 +357,10 @@ def parse(self, query, spellcheck=False, **kwargs):
kwargs['parser'], kwargs['searcher'] = SpellParser, self
return Analyzer.parse(self.analyzer, query, **kwargs)

def highlighter(self, query, field, **kwargs):
"""Return `Highlighter`_ or if applicable `FastVectorHighlighter`_ specific to searcher and query."""
query = self.parse(query, field=field)
fieldinfo = self.fieldinfos.get(field)
vector = fieldinfo and fieldinfo.hasVectors()
return (FastVectorHighlighter if vector else Highlighter)(self, query, field, **kwargs)
@property
def highlighter(self):
"""lucene UnifiedHighlighter"""
return uhighlight.UnifiedHighlighter(self, self.analyzer)

def count(self, *query, **options):
"""Return number of hits for given query or term.
Expand All @@ -374,7 +373,7 @@ def count(self, *query, **options):
query = self.parse(*query, **options) if query else Query.alldocs()
return super(IndexSearcher, self).count(query)

def collector(self, query, count=None, sort=None, reverse=False, scores=False, maxscore=False):
def collector(self, count=None, sort=None, reverse=False, scores=False, maxscore=False):
if count is None:
return search.CachingCollector.create(True, float('inf'))
count = min(count, self.maxDoc() or 1)
Expand All @@ -401,7 +400,7 @@ def search(self, query=None, count=None, sort=None, reverse=False, scores=False,
:param parser: :meth:`Analyzer.parse` options
"""
query = Query.alldocs() if query is None else self.parse(query, **parser)
cache = collector = self.collector(query, count, sort, reverse, scores, maxscore)
cache = collector = self.collector(count, sort, reverse, scores, maxscore)
counter = search.TimeLimitingCollector.getGlobalCounter()
results = collector if timeout is None else search.TimeLimitingCollector(collector, counter, long(timeout * 1000))
with suppress(search.TimeLimitingCollector.TimeExceededException):
Expand All @@ -410,7 +409,7 @@ def search(self, query=None, count=None, sort=None, reverse=False, scores=False,
if isinstance(cache, search.CachingCollector):
collector = search.TotalHitCountCollector()
cache.replay(collector)
collector = self.collector(query, collector.totalHits or 1, sort, reverse, scores, maxscore)
collector = self.collector(collector.totalHits or 1, sort, reverse, scores, maxscore)
cache.replay(collector)
topdocs = collector.topDocs()
stats = (topdocs.totalHits, topdocs.maxScore) * (timeout is None)
Expand Down
69 changes: 3 additions & 66 deletions lupyne/engine/queries.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,12 @@
import contextlib
import lucene
from java.lang import Integer
from java.util import Arrays, HashSet
from java.util import Arrays
from org.apache.lucene import index, search, util
from org.apache.lucene.search import highlight, spans, vectorhighlight
from org.apache.lucene.search import spans
from org.apache.pylucene.queryparser.classic import PythonQueryParser
from six import string_types
from six.moves import filter, map, range
from six.moves import map, range
from ..utils import method


Expand Down Expand Up @@ -251,69 +251,6 @@ def __getitem__(self, id):
return tuple(self.type(self.docvalues.lookupOrd(ord)) for ord in ords)


class Highlighter(highlight.Highlighter):
"""Inherited lucene Highlighter with stored analysis options.
:param searcher: `IndexSearcher`_ used for analysis, scoring, and optionally text retrieval
:param query: lucene Query
:param field: field name of text
:param terms: highlight any matching term in query regardless of position
:param fields: highlight matching terms from any field
:param tag: optional html tag name
:param formatter: optional lucene Formatter
:param encoder: optional lucene Encoder
"""
def __init__(self, searcher, query, field, terms=False, fields=False, tag='', formatter=None, encoder=None):
if tag:
formatter = highlight.SimpleHTMLFormatter('<{}>'.format(tag), '</{}>'.format(tag))
scorer = (highlight.QueryTermScorer if terms else highlight.QueryScorer)(query, *(searcher.indexReader, field) * (not fields))
highlight.Highlighter.__init__(self, *filter(None, [formatter, encoder, scorer]))
self.searcher, self.field = searcher, field
self.selector = HashSet(Arrays.asList([field]))

def fragments(self, doc, count=1):
"""Return highlighted text fragments.
:param doc: text string or doc id to be highlighted
:param count: maximum number of fragments
"""
if not isinstance(doc, string_types):
doc = self.searcher.doc(doc, self.selector)[self.field]
return doc and list(self.getBestFragments(self.searcher.analyzer, self.field, doc, count))


class FastVectorHighlighter(vectorhighlight.FastVectorHighlighter):
"""Inherited lucene FastVectorHighlighter with stored query.
Fields must be stored and have term vectors with offsets and positions.
:param searcher: `IndexSearcher`_ with stored term vectors
:param query: lucene Query
:param field: field name of text
:param terms: highlight any matching term in query regardless of position
:param fields: highlight matching terms from any field
:param tag: optional html tag name
:param fragListBuilder: optional lucene FragListBuilder
:param fragmentsBuilder: optional lucene FragmentsBuilder
"""
def __init__(self, searcher, query, field, terms=False, fields=False, tag='', fragListBuilder=None, fragmentsBuilder=None):
if tag:
fragmentsBuilder = vectorhighlight.SimpleFragmentsBuilder(['<{}>'.format(tag)], ['</{}>'.format(tag)])
args = fragListBuilder or vectorhighlight.SimpleFragListBuilder(), fragmentsBuilder or vectorhighlight.SimpleFragmentsBuilder()
vectorhighlight.FastVectorHighlighter.__init__(self, not terms, not fields, *args)
self.searcher, self.field = searcher, field
self.query = self.getFieldQuery(query)

def fragments(self, id, count=1, size=100):
"""Return highlighted text fragments.
:param id: document id
:param count: maximum number of fragments
:param size: maximum number of characters in fragment
"""
return list(self.getBestFragments(self.query, self.searcher.indexReader, id, self.field, size, count))


class SpellParser(PythonQueryParser):
"""Inherited lucene QueryParser which corrects spelling.
Expand Down
18 changes: 7 additions & 11 deletions lupyne/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -398,10 +398,9 @@ def search(self, q=None, count=None, start=0, fields=None, sort=None, facets='',
.. versionchanged:: 1.6 grouping searches use count and start options
&hl=\ *chars*,... &hl.count=1&hl.tag=strong&hl.enable=[fields|terms]
&hl=\ *chars*,... &hl.count=1
| stored fields to return highlighted
| optional maximum fragment count and html tag name
| optionally enable matching any field or any term
| optional maximum fragment count
&mlt=\ *int*\ &mlt.fields=\ *chars*,... &mlt.\ *chars*\ =...,
| doc index (or id without a query) to find MoreLikeThis
Expand Down Expand Up @@ -455,25 +454,22 @@ def search(self, q=None, count=None, start=0, fields=None, sort=None, facets='',
hits = searcher.search(q, sort=sort, count=count, timeout=timeout, **scores)
groups = engine.documents.Groups(searcher, [hits[start:]], hits.count, hits.maxscore)
result = {'query': q and str(q), 'count': groups.count, 'maxscore': groups.maxscore}
tag, enable = options.get('hl.tag', 'strong'), options.get('hl.enable', '')
hlcount = options.get('hl.count', 1)
if hl:
hl = {name: searcher.highlighter(q, name, terms='terms' in enable, fields='fields' in enable, tag=tag) for name in hl}
fields, multi, docvalues = parse.fields(searcher, fields, **options)
if fields is None:
fields = {}
else:
groups.select(*itertools.chain(fields, multi))
hl = dict.fromkeys(hl, options.get('hl.count', 1))
result['groups'] = []
for hits in groups:
docs = []
for hit in hits:
highlights = hits.highlights(q, **hl) if hl else ([{}] * len(hits))
for hit, highlight in zip(hits, highlights):
doc = hit.dict(*multi, **fields)
with HTTPError(TypeError):
doc.update((name, docvalues[name][hit.id]) for name in docvalues)
fragments = (hl[name].fragments(hit.id, hlcount) for name in hl) # pragma: no branch
if hl:
doc['__highlights__'] = {name: value for name, value in zip(hl, fragments) if value is not None}
if highlight:
doc['__highlights__'] = highlight
docs.append(doc)
result['groups'].append({'docs': docs, 'count': hits.count, 'value': getattr(hits, 'value', None)})
if not group:
Expand Down
34 changes: 7 additions & 27 deletions tests/test_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import pytest
import lucene
from org.apache.lucene import analysis, document, search, store, util
from org.apache.lucene.search import highlight, vectorhighlight
from six.moves import map
from lupyne import engine
from lupyne.utils import long, suppress
Expand Down Expand Up @@ -525,32 +524,13 @@ def test_highlighting(constitution):
if 'amendment' in doc:
indexer.add(text=doc['text'])
indexer.commit()
highlighter = indexer.highlighter('persons', 'text')
for id in indexer:
fragments = highlighter.fragments(id)
assert len(fragments) == ('persons' in indexer[id]['text'])
assert all('<b>persons</b>' in fragment.lower() for fragment in fragments)
id = 3
text = indexer[id]['text']
query = '"persons, houses, papers"'
highlighter = indexer.highlighter(query, '', terms=True, fields=True, formatter=highlight.SimpleHTMLFormatter('*', '*'))
fragments = highlighter.fragments(text, count=3)
assert len(fragments) == 2 and fragments[0].count('*') == 2 * 3 and '*persons*' in fragments[1]
highlighter = indexer.highlighter(query, '', terms=True)
highlighter.textFragmenter = highlight.SimpleFragmenter(200)
fragment, = highlighter.fragments(text, count=3)
assert len(fragment) > len(text) and fragment.count('<B>persons</B>') == 2
fragment, = indexer.highlighter(query, 'text', tag='em').fragments(id, count=3)
assert len(fragment) < len(text) and fragment.index('<em>persons') < fragment.index('papers</em>')
fragment, = indexer.highlighter(query, 'text').fragments(id)
assert fragment.count('<b>') == fragment.count('</b>') == 1
highlighter = indexer.highlighter(query, 'text', fragListBuilder=vectorhighlight.SingleFragListBuilder())
text, = highlighter.fragments(id)
assert fragment in text and len(text) > len(fragment)
query = indexer.parse(query, field=text)
highlighter = engine.queries.Highlighter(indexer.indexSearcher, query, 'text', terms=True, fields=True, tag='em')
fragment, = highlighter.fragments(id)
assert fragment and '<em>' in fragment
query = Q.term('text', 'right')
assert engine.Analyzer.highlight(indexer.analyzer, query, 'text', "word right word") == "word <b>right</b> word"
hits = indexer.search(query)
highlights = list(hits.highlights(query, text=1))
assert len(hits) == len(highlights)
for highlight in highlights:
assert '<b>right</b>' in highlight.pop('text') and not highlight


def test_nrt():
Expand Down
17 changes: 2 additions & 15 deletions tests/test_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,23 +157,10 @@ def test_search(resource):

def test_highlights(resource):
doc, = resource.search(q='amendment:1', hl='amendment', fields='article')['docs']
assert doc['__highlights__'] == {'amendment': ['<strong>1</strong>']}
doc, = resource.search(q='amendment:1', hl='amendment,article', **{'hl.count': 2, 'hl.tag': 'em'})['docs']
assert doc['__highlights__'] == {'amendment': ['<em>1</em>']}
assert doc['__highlights__'] == {'amendment': '<b>1</b>'}
result = resource.search(q='text:1', hl='amendment,article')
highlights = [doc['__highlights__'] for doc in result['docs']]
assert all(highlight and not any(highlight.values()) for highlight in highlights)
result = resource.search(q='text:1', hl='article', **{'hl.enable': 'fields'})
highlights = [doc['__highlights__'] for doc in result['docs']]
highlight, = [highlight['article'] for highlight in highlights if highlight.get('article')]
assert highlight == ['<strong>1</strong>']
result = resource.search(q='text:"section 1"', hl='amendment,article', **{'hl.enable': 'fields'})
highlights = [doc['__highlights__'] for doc in result['docs']]
assert all(highlight and not any(highlight.values()) for highlight in highlights)
result = resource.search(q='text:"section 1"', hl='amendment,article', **{'hl.enable': ['fields', 'terms']})
highlights = [doc['__highlights__'] for doc in result['docs']]
highlight, = [highlight['article'] for highlight in highlights if highlight.get('article')]
assert highlight == ['<strong>1</strong>']
assert all(highlight['amendment'] or highlight['article'] for highlight in highlights)
result = resource.search(mlt=0)
assert result['count'] == 25 and set(result['query'].split()) == {'text:united', 'text:states'}
result = resource.search(q='amendment:2', mlt=0, **{'mlt.fields': 'text', 'mlt.minTermFreq': 1, 'mlt.minWordLen': 6})
Expand Down

0 comments on commit ffc51b9

Please sign in to comment.