Skip to content

Commit

Permalink
Supports RawSQL.
Browse files Browse the repository at this point in the history
  • Loading branch information
BertrandBordage committed Jun 3, 2017
1 parent d0ee580 commit 4522106
Showing 1 changed file with 23 additions and 12 deletions.
35 changes: 23 additions & 12 deletions cachalot/utils.py
Expand Up @@ -9,8 +9,10 @@

from django import VERSION as django_version
from django.db import connections
from django.db.models.expressions import RawSQL
from django.db.models.sql import Query
from django.db.models.sql.where import ExtraWhere, SubqueryConstraint
from django.db.models.sql.where import (
ExtraWhere, SubqueryConstraint, WhereNode)
from django.utils.module_loading import import_string
from django.utils.six import text_type, binary_type

Expand All @@ -22,6 +24,10 @@ class UncachableQuery(Exception):
pass


class IsRawQuery(Exception):
pass


TUPLE_OR_LIST = {tuple, list}

CACHABLE_PARAM_TYPES = {
Expand Down Expand Up @@ -112,25 +118,30 @@ def _get_tables_from_sql(connection, lowercased_sql):

def _find_subqueries(children):
for child in children:
if child.__class__ is SubqueryConstraint:
child_class = child.__class__
if child_class is WhereNode:
for grand_child in _find_subqueries(child.children):
yield grand_child
elif child_class is SubqueryConstraint:
if child.query_object.__class__ is Query:
yield child.query_object
else:
yield child.query_object.query
elif child_class is ExtraWhere:
raise IsRawQuery
else:
rhs = None
if hasattr(child, 'rhs'):
rhs = child.rhs
rhs_class = rhs.__class__
if rhs_class is RawSQL:
raise IsRawQuery
if rhs_class is Query:
yield rhs
elif hasattr(rhs, 'query'):
yield rhs.query
elif rhs_class in UNCACHABLE_FUNCS:
raise UncachableQuery
if hasattr(child, 'children'):
for grand_child in _find_subqueries(child.children):
yield grand_child


def is_cachable(table):
Expand Down Expand Up @@ -161,16 +172,16 @@ def _get_tables(db_alias, query):
and not cachalot_settings.CACHALOT_CACHE_RANDOM):
raise UncachableQuery

if query.extra_select or getattr(query, 'subquery', False) \
or any(c.__class__ is ExtraWhere for c in query.where.children):
sql = query.get_compiler(db_alias).as_sql()[0].lower()
tables = _get_tables_from_sql(connections[db_alias], sql)
else:
try:
if query.extra_select or getattr(query, 'subquery', False):
raise IsRawQuery
tables = set(query.table_map)
tables.add(query.get_meta().db_table)
subquery_constraints = _find_subqueries(query.where.children)
for subquery in subquery_constraints:
for subquery in _find_subqueries(query.where.children):
tables.update(_get_tables(db_alias, subquery))
except IsRawQuery:
sql = query.get_compiler(db_alias).as_sql()[0].lower()
tables = _get_tables_from_sql(connections[db_alias], sql)

if not are_all_cachable(tables):
raise UncachableQuery
Expand Down

0 comments on commit 4522106

Please sign in to comment.