Skip to content

Commit

Permalink
Better handling of .models. Thanks to zbyte64 for the report & Ho…
Browse files Browse the repository at this point in the history
…nzaKral for the original patch!
  • Loading branch information
toastdriven committed Jan 24, 2012
1 parent d374d34 commit 33e3217
Show file tree
Hide file tree
Showing 10 changed files with 98 additions and 64 deletions.
25 changes: 10 additions & 15 deletions haystack/backends/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ def clear(self, models=[], commit=True):
def search(self, query_string, sort_by=None, start_offset=0, end_offset=None,
fields='', highlight=False, facets=None, date_facets=None, query_facets=None,
narrow_queries=None, spelling_query=None, within=None,
dwithin=None, distance_point=None,
dwithin=None, distance_point=None, models=None,
limit_to_registered_models=None, result_class=None, **kwargs):
"""
Takes a query to search on and returns dictionary.
Expand Down Expand Up @@ -380,6 +380,9 @@ def build_params(self, spelling_query=None):
if self.fields:
kwargs['fields'] = self.fields

if self.models:
kwargs['models'] = self.models

return kwargs

def run(self, spelling_query=None, **kwargs):
Expand Down Expand Up @@ -408,6 +411,9 @@ def run_mlt(self, **kwargs):
'result_class': self.result_class,
}

if self.models:
search_kwargs['models'] = self.models

if kwargs:
search_kwargs.update(kwargs)

Expand Down Expand Up @@ -510,22 +516,11 @@ def build_query(self):
Interprets the collected query metadata and builds the final query to
be sent to the backend.
"""
query = self.query_filter.as_query_string(self.build_query_fragment)
final_query = self.query_filter.as_query_string(self.build_query_fragment)

if not query:
if not final_query:
# Match all.
query = self.matching_all_fragment()

if len(self.models):
models = sorted(['%s:%s.%s' % (DJANGO_CT, model._meta.app_label, model._meta.module_name) for model in self.models])
models_clause = ' OR '.join(models)

if query != self.matching_all_fragment():
final_query = '(%s) AND (%s)' % (query, models_clause)
else:
final_query = models_clause
else:
final_query = query
final_query = self.matching_all_fragment()

if self.boost:
boost_list = []
Expand Down
22 changes: 16 additions & 6 deletions haystack/backends/elasticsearch_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,7 @@ def clear(self, models=[], commit=True):
def search(self, query_string, sort_by=None, start_offset=0, end_offset=None,
fields='', highlight=False, facets=None, date_facets=None, query_facets=None,
narrow_queries=None, spelling_query=None, within=None,
dwithin=None, distance_point=None,
dwithin=None, distance_point=None, models=None,
limit_to_registered_models=None, result_class=None, **kwargs):
if len(query_string) == 0:
return {
Expand Down Expand Up @@ -374,13 +374,20 @@ def search(self, query_string, sort_by=None, start_offset=0, end_offset=None,
if limit_to_registered_models is None:
limit_to_registered_models = getattr(settings, 'HAYSTACK_LIMIT_TO_REGISTERED_MODELS', True)

if limit_to_registered_models:
if models and len(models):
model_choices = sorted(['%s.%s' % (model._meta.app_label, model._meta.module_name) for model in models])
elif limit_to_registered_models:
# Using narrow queries, limit the results to only models handled
# with the current routers.
registered_models = self.build_models_list()
model_choices = self.build_models_list()
else:
model_choices = []

if len(model_choices) > 0:
if narrow_queries is None:
narrow_queries = set()

if len(registered_models) > 0:
narrow_queries.add('%s:(%s)' % (DJANGO_CT, ' OR '.join(registered_models)))
narrow_queries.add('%s:(%s)' % (DJANGO_CT, ' OR '.join(model_choices)))

if narrow_queries:
kwargs['query'].setdefault('filtered', {})
Expand Down Expand Up @@ -456,7 +463,7 @@ def search(self, query_string, sort_by=None, start_offset=0, end_offset=None,
return self._process_results(raw_results, highlight=highlight, result_class=result_class)

def more_like_this(self, model_instance, additional_query_string=None,
start_offset=0, end_offset=None,
start_offset=0, end_offset=None, models=None,
limit_to_registered_models=None, result_class=None, **kwargs):
from haystack import connections

Expand Down Expand Up @@ -779,6 +786,9 @@ def run(self, spelling_query=None, **kwargs):
if self.fields:
search_kwargs['fields'] = self.fields

if self.models:
search_kwargs['models'] = self.models

if spelling_query:
search_kwargs['spelling_query'] = spelling_query

Expand Down
2 changes: 1 addition & 1 deletion haystack/backends/simple_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def clear(self, models=[], commit=True):
def search(self, query_string, sort_by=None, start_offset=0, end_offset=None,
fields='', highlight=False, facets=None, date_facets=None, query_facets=None,
narrow_queries=None, spelling_query=None, within=None,
dwithin=None, distance_point=None,
dwithin=None, distance_point=None, models=None,
limit_to_registered_models=None, result_class=None, **kwargs):
hits = 0
results = []
Expand Down
35 changes: 23 additions & 12 deletions haystack/backends/solr_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def clear(self, models=[], commit=True):
def search(self, query_string, sort_by=None, start_offset=0, end_offset=None,
fields='', highlight=False, facets=None, date_facets=None, query_facets=None,
narrow_queries=None, spelling_query=None, within=None,
dwithin=None, distance_point=None,
dwithin=None, distance_point=None, models=None,
limit_to_registered_models=None, result_class=None, **kwargs):
if len(query_string) == 0:
return {
Expand Down Expand Up @@ -192,16 +192,20 @@ def search(self, query_string, sort_by=None, start_offset=0, end_offset=None,
if limit_to_registered_models is None:
limit_to_registered_models = getattr(settings, 'HAYSTACK_LIMIT_TO_REGISTERED_MODELS', True)

if limit_to_registered_models:
if models and len(models):
model_choices = sorted(['%s.%s' % (model._meta.app_label, model._meta.module_name) for model in models])
elif limit_to_registered_models:
# Using narrow queries, limit the results to only models handled
# with the current routers.
model_choices = self.build_models_list()
else:
model_choices = []

if len(model_choices) > 0:
if narrow_queries is None:
narrow_queries = set()

registered_models = self.build_models_list()

if len(registered_models) > 0:
narrow_queries.add('%s:(%s)' % (DJANGO_CT, ' OR '.join(registered_models)))
narrow_queries.add('%s:(%s)' % (DJANGO_CT, ' OR '.join(model_choices)))

if narrow_queries is not None:
kwargs['fq'] = list(narrow_queries)
Expand Down Expand Up @@ -245,7 +249,7 @@ def search(self, query_string, sort_by=None, start_offset=0, end_offset=None,
return self._process_results(raw_results, highlight=highlight, result_class=result_class, distance_point=distance_point)

def more_like_this(self, model_instance, additional_query_string=None,
start_offset=0, end_offset=None,
start_offset=0, end_offset=None, models=None,
limit_to_registered_models=None, result_class=None, **kwargs):
from haystack import connections

Expand All @@ -272,16 +276,20 @@ def more_like_this(self, model_instance, additional_query_string=None,
if limit_to_registered_models is None:
limit_to_registered_models = getattr(settings, 'HAYSTACK_LIMIT_TO_REGISTERED_MODELS', True)

if limit_to_registered_models:
if models and len(models):
model_choices = sorted(['%s.%s' % (model._meta.app_label, model._meta.module_name) for model in models])
elif limit_to_registered_models:
# Using narrow queries, limit the results to only models handled
# with the current routers.
model_choices = self.build_models_list()
else:
model_choices = []

if len(model_choices) > 0:
if narrow_queries is None:
narrow_queries = set()

registered_models = self.build_models_list()

if len(registered_models) > 0:
narrow_queries.add('%s:(%s)' % (DJANGO_CT, ' OR '.join(registered_models)))
narrow_queries.add('%s:(%s)' % (DJANGO_CT, ' OR '.join(model_choices)))

if additional_query_string:
narrow_queries.add(additional_query_string)
Expand Down Expand Up @@ -630,6 +638,9 @@ def run(self, spelling_query=None, **kwargs):
if self.fields:
search_kwargs['fields'] = self.fields

if self.models:
search_kwargs['models'] = self.models

if spelling_query:
search_kwargs['spelling_query'] = spelling_query

Expand Down
52 changes: 36 additions & 16 deletions haystack/backends/whoosh_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,7 @@ def optimize(self):
def search(self, query_string, sort_by=None, start_offset=0, end_offset=None,
fields='', highlight=False, facets=None, date_facets=None, query_facets=None,
narrow_queries=None, spelling_query=None, within=None,
dwithin=None, distance_point=None,
dwithin=None, distance_point=None, models=None,
limit_to_registered_models=None, result_class=None, **kwargs):
if not self.setup_complete:
self.setup()
Expand Down Expand Up @@ -324,16 +324,20 @@ def search(self, query_string, sort_by=None, start_offset=0, end_offset=None,
if limit_to_registered_models is None:
limit_to_registered_models = getattr(settings, 'HAYSTACK_LIMIT_TO_REGISTERED_MODELS', True)

if limit_to_registered_models:
if models and len(models):
model_choices = sorted(['%s.%s' % (model._meta.app_label, model._meta.module_name) for model in models])
elif limit_to_registered_models:
# Using narrow queries, limit the results to only models handled
# with the current routers.
model_choices = self.build_models_list()
else:
model_choices = []

if len(model_choices) > 0:
if narrow_queries is None:
narrow_queries = set()

registered_models = self.build_models_list()

if len(registered_models) > 0:
narrow_queries.add(' OR '.join(['%s:%s' % (DJANGO_CT, rm) for rm in registered_models]))
narrow_queries.add(' OR '.join(['%s:%s' % (DJANGO_CT, rm) for rm in model_choices]))

narrow_searcher = None

Expand All @@ -344,6 +348,12 @@ def search(self, query_string, sort_by=None, start_offset=0, end_offset=None,
for nq in narrow_queries:
recent_narrowed_results = narrow_searcher.search(self.parser.parse(force_unicode(nq)))

if len(recent_narrowed_results) <= 0:
return {
'results': [],
'hits': 0,
}

if narrowed_results:
narrowed_results.filter(recent_narrowed_results)
else:
Expand All @@ -370,7 +380,7 @@ def search(self, query_string, sort_by=None, start_offset=0, end_offset=None,
raw_results = searcher.search(parsed_query, limit=end_offset, sortedby=sort_by, reverse=reverse)

# Handle the case where the results have been narrowed.
if narrowed_results:
if narrowed_results is not None:
raw_results.filter(narrowed_results)

# Determine the page.
Expand Down Expand Up @@ -425,7 +435,7 @@ def search(self, query_string, sort_by=None, start_offset=0, end_offset=None,
}

def more_like_this(self, model_instance, additional_query_string=None,
start_offset=0, end_offset=None,
start_offset=0, end_offset=None, models=None,
limit_to_registered_models=None, result_class=None, **kwargs):
if not self.setup_complete:
self.setup()
Expand All @@ -444,16 +454,20 @@ def more_like_this(self, model_instance, additional_query_string=None,
if limit_to_registered_models is None:
limit_to_registered_models = getattr(settings, 'HAYSTACK_LIMIT_TO_REGISTERED_MODELS', True)

if limit_to_registered_models:
# Using narrow queries, limit the results to only models registered
# with the current site.
if models and len(models):
model_choices = sorted(['%s.%s' % (model._meta.app_label, model._meta.module_name) for model in models])
elif limit_to_registered_models:
# Using narrow queries, limit the results to only models handled
# with the current routers.
model_choices = self.build_models_list()
else:
model_choices = []

if len(model_choices) > 0:
if narrow_queries is None:
narrow_queries = set()

registered_models = self.build_models_list()

if len(registered_models) > 0:
narrow_queries.add(' OR '.join(['%s:%s' % (DJANGO_CT, rm) for rm in registered_models]))
narrow_queries.add(' OR '.join(['%s:%s' % (DJANGO_CT, rm) for rm in model_choices]))

if additional_query_string and additional_query_string != '*':
narrow_queries.add(additional_query_string)
Expand All @@ -467,6 +481,12 @@ def more_like_this(self, model_instance, additional_query_string=None,
for nq in narrow_queries:
recent_narrowed_results = narrow_searcher.search(self.parser.parse(force_unicode(nq)))

if len(recent_narrowed_results) <= 0:
return {
'results': [],
'hits': 0,
}

if narrowed_results:
narrowed_results.filter(recent_narrowed_results)
else:
Expand Down Expand Up @@ -507,7 +527,7 @@ def more_like_this(self, model_instance, additional_query_string=None,
raw_results = results[0].more_like_this(field_name, top=end_offset)

# Handle the case where the results have been narrowed.
if narrowed_results and hasattr(raw_results, 'filter'):
if narrowed_results is not None and hasattr(raw_results, 'filter'):
raw_results.filter(narrowed_results)

try:
Expand Down
4 changes: 2 additions & 2 deletions tests/elasticsearch_tests/tests/elasticsearch_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,10 +109,10 @@ def test_clean(self):
def test_build_query_with_models(self):
self.sq.add_filter(SQ(content='hello'))
self.sq.add_model(MockModel)
self.assertEqual(self.sq.build_query(), '(hello) AND (django_ct:core.mockmodel)')
self.assertEqual(self.sq.build_query(), 'hello')

self.sq.add_model(AnotherMockModel)
self.assertEqual(self.sq.build_query(), u'(hello) AND (django_ct:core.anothermockmodel OR django_ct:core.mockmodel)')
self.assertEqual(self.sq.build_query(), u'hello')

def test_set_result_class(self):
# Assert that we're defaulting to ``SearchResult``.
Expand Down
4 changes: 2 additions & 2 deletions tests/overrides/tests/altered_internal_names.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,10 @@ def test_altered_names(self):

sq.add_filter(SQ(content='hello'))
sq.add_model(MockModel)
self.assertEqual(sq.build_query(), u'(hello) AND (my_django_ct:core.mockmodel)')
self.assertEqual(sq.build_query(), u'hello')

sq.add_model(AnotherMockModel)
self.assertEqual(sq.build_query(), u'(hello) AND (my_django_ct:core.anothermockmodel OR my_django_ct:core.mockmodel)')
self.assertEqual(sq.build_query(), u'hello')

def test_solr_schema(self):
command = Command()
Expand Down
4 changes: 2 additions & 2 deletions tests/solr_tests/tests/solr_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,10 +113,10 @@ def test_clean(self):
def test_build_query_with_models(self):
self.sq.add_filter(SQ(content='hello'))
self.sq.add_model(MockModel)
self.assertEqual(self.sq.build_query(), '(hello) AND (django_ct:core.mockmodel)')
self.assertEqual(self.sq.build_query(), 'hello')

self.sq.add_model(AnotherMockModel)
self.assertEqual(self.sq.build_query(), u'(hello) AND (django_ct:core.anothermockmodel OR django_ct:core.mockmodel)')
self.assertEqual(self.sq.build_query(), u'hello')

def test_set_result_class(self):
# Assert that we're defaulting to ``SearchResult``.
Expand Down
10 changes: 4 additions & 6 deletions tests/whoosh_tests/tests/whoosh_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -474,15 +474,13 @@ def test_boost(self):
self.sb.update(self.wmmi, self.sample_objs)
self.raw_whoosh = self.raw_whoosh.refresh()
searcher = self.raw_whoosh.searcher()
self.assertEqual(len(searcher.search(self.parser.parse(u'*'), limit=1000)), 4)
self.assertEqual(len(searcher.search(self.parser.parse(u'*'), limit=1000)), 2)

results = SearchQuerySet().filter(SQ(author='daniel') | SQ(editor='daniel'))

self.assertEqual([result.id for result in results], [
'core.afourthmockmodel.1',
'core.afourthmockmodel.3',
'core.afourthmockmodel.2',
'core.afourthmockmodel.4'
])
self.assertEqual(results[0].boost, 1.1)

Expand Down Expand Up @@ -648,7 +646,7 @@ def test_various_searchquerysets(self):
self.assertEqual(len(sqs), 0)

sqs = self.sqs.models(MockModel)
self.assertEqual(sqs.query.build_query(), u'django_ct:core.mockmodel')
self.assertEqual(sqs.query.build_query(), u'*')
self.assertEqual(len(sqs), 3)

def test_all_regression(self):
Expand Down Expand Up @@ -813,11 +811,11 @@ def test_searchquerysets_with_models(self):
self.assertEqual(len(sqs), 25)

sqs = self.sqs.models(MockModel)
self.assertEqual(sqs.query.build_query(), u'django_ct:core.mockmodel')
self.assertEqual(sqs.query.build_query(), u'*')
self.assertEqual(len(sqs), 23)

sqs = self.sqs.models(AnotherMockModel)
self.assertEqual(sqs.query.build_query(), u'django_ct:core.anothermockmodel')
self.assertEqual(sqs.query.build_query(), u'*')
self.assertEqual(len(sqs), 2)


Expand Down
Loading

0 comments on commit 33e3217

Please sign in to comment.