Skip to content
This repository

HTTPS clone URL

Subversion checkout URL

You can clone with HTTPS or Subversion.

Download ZIP
branch: master
Fetching contributors…

Octocat-spinner-32-eaf2f5

Cannot retrieve contributors at this time

file 354 lines (274 sloc) 13.298 kb
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353
"""
Django does not properly set up casts
"""

import django
from django.contrib.contenttypes.generic import GenericForeignKey
from django.contrib.contenttypes.models import ContentType
from django.db import connection, models
from django.db.models.query import QuerySet


def get_gfk_field(model):
    for field in model._meta.virtual_fields:
        if isinstance(field, GenericForeignKey):
            return field

    raise ValueError('Unable to find gfk field on %s' % model)

def normalize_qs_model(qs_or_model):
    if isinstance(qs_or_model, QuerySet):
        return qs_or_model
    return qs_or_model._default_manager.all()

def get_field_type(f):
    if django.VERSION < (1, 4):
        raw_type = f.db_type()
    else:
        raw_type = f.db_type(connection)
    if raw_type.lower().split()[0] in ('serial', 'integer', 'unsigned', 'bigint', 'smallint'):
        raw_type = 'integer'
    return raw_type

def prepare_query(qs_model, generic_qs_model, aggregator, gfk_field):
    qs = normalize_qs_model(qs_model)
    generic_qs = normalize_qs_model(generic_qs_model)
    
    model = qs.model
    generic_model = generic_qs.model
    
    if gfk_field is None:
        gfk_field = get_gfk_field(generic_model)
    
    content_type = ContentType.objects.get_for_model(model)
    rel_name = aggregator.lookup.split('__', 1)[0]
    
    try:
        generic_rel_descriptor = getattr(model, rel_name)
    except AttributeError:
        # missing the generic relation, so do fallback query
        return False
    
    rel_model = generic_rel_descriptor.field.rel.to
    if rel_model != generic_model:
        raise AttributeError('Model %s does not match the GenericRelation "%s" (%s)' % (
            generic_model, rel_name, rel_model,
        ))
    
    pk_field_type = get_field_type(model._meta.pk)
    gfk_field_type = get_field_type(generic_model._meta.get_field(gfk_field.fk_field))
    if pk_field_type != gfk_field_type:
        return False
    
    qs = qs.filter(**{
        '%s__%s' % (rel_name, gfk_field.ct_field): content_type,
        '%s__pk__in' % (rel_name): generic_qs.values('pk'),
    })
    return qs

def generic_annotate(qs_model, generic_qs_model, aggregator, gfk_field=None, alias='score'):
    """
Find blog entries with the most comments:
qs = generic_annotate(Entry.objects.public(), Comment.objects.public(), Count('comments__id'))
for entry in qs:
print entry.title, entry.score
Find the highest rated foods:
generic_annotate(Food, Rating, Avg('ratings__rating'), alias='avg')
for food in qs:
print food.name, '- average rating:', food.avg
.. note::
In both of the above examples it is assumed that a GenericRelation exists
on Entry to Comment (named "comments") and also on Food to Rating (named "ratings").
If a GenericRelation does *not* exist, the query will still return correct
results but the code path will be different as it will use the fallback method.
.. warning::
If the underlying column type differs between the qs_model's primary
key and the generic_qs_model's foreign key column, it will use the fallback
method, which can correctly CASTself.
:param qs_model: A model or a queryset of objects you want to perform
annotation on, e.g. blog entries
:param generic_qs_model: A model or queryset containing a GFK, e.g. comments
:param aggregator: an aggregation, from django.db.models, e.g. Count('id') or Avg('rating')
:param gfk_field: explicitly specify the field w/the gfk
:param alias: attribute name to use for annotation
"""
    prepared_query = prepare_query(qs_model, generic_qs_model, aggregator, gfk_field)
    if prepared_query is not False:
        return prepared_query.annotate(**{alias: aggregator})
    else:
        # need to fall back since CAST will be missing
        return fallback_generic_annotate(qs_model, generic_qs_model, aggregator, gfk_field, alias)


def generic_aggregate(qs_model, generic_qs_model, aggregator, gfk_field=None):
    """
Find total number of comments on blog entries:
generic_aggregate(Entry.objects.public(), Comment.objects.public(), Count('comments__id'))
Find the average rating for foods starting with 'a':
a_foods = Food.objects.filter(name__startswith='a')
generic_aggregate(a_foods, Rating, Avg('ratings__rating'))
.. note::
In both of the above examples it is assumed that a GenericRelation exists
on Entry to Comment (named "comments") and also on Food to Rating (named "ratings").
If a GenericRelation does *not* exist, the query will still return correct
results but the code path will be different as it will use the fallback method.
.. warning::
If the underlying column type differs between the qs_model's primary
key and the generic_qs_model's foreign key column, it will use the fallback
method, which can correctly CASTself.

:param qs_model: A model or a queryset of objects you want to perform
annotation on, e.g. blog entries
:param generic_qs_model: A model or queryset containing a GFK, e.g. comments
:param aggregator: an aggregation, from django.db.models, e.g. Count('id') or Avg('rating')
:param gfk_field: explicitly specify the field w/the gfk
"""
    prepared_query = prepare_query(qs_model, generic_qs_model, aggregator, gfk_field)
    if prepared_query is not False:
        return prepared_query.aggregate(aggregate=aggregator)['aggregate']
    else:
        # need to fall back since CAST will be missing
        return fallback_generic_aggregate(qs_model, generic_qs_model, aggregator, gfk_field)


def generic_filter(generic_qs_model, filter_qs_model, gfk_field=None):
    """
Only show me ratings made on foods that start with "a":
a_foods = Food.objects.filter(name__startswith='a')
generic_filter(Rating.objects.all(), a_foods)
Only show me comments from entries that are marked as public:
generic_filter(Comment.objects.public(), Entry.objects.public())
:param generic_qs_model: A model or queryset containing a GFK, e.g. comments
:param qs_model: A model or a queryset of objects you want to restrict the generic_qs to
:param gfk_field: explicitly specify the field w/the gfk
"""
    generic_qs = normalize_qs_model(generic_qs_model)
    filter_qs = normalize_qs_model(filter_qs_model)
    
    if not gfk_field:
        gfk_field = get_gfk_field(generic_qs.model)
    
    pk_field_type = get_field_type(filter_qs.model._meta.pk)
    gfk_field_type = get_field_type(generic_qs.model._meta.get_field(gfk_field.fk_field))
    if pk_field_type != gfk_field_type:
        return fallback_generic_filter(generic_qs, filter_qs, gfk_field)
    
    return generic_qs.filter(**{
        gfk_field.ct_field: ContentType.objects.get_for_model(filter_qs.model),
        '%s__in' % gfk_field.fk_field: filter_qs.values('pk'),
    })


###############################################################################
# fallback methods

def query_as_sql(query):
    if django.VERSION < (1, 2):
        return query.as_sql()
    else:
        return query.get_compiler(connection=connection).as_sql()

def query_as_nested_sql(query):
    if django.VERSION < (1, 2):
        return query.as_nested_sql()
    else:
        return query.get_compiler(connection=connection).as_nested_sql()

def gfk_expression(qs_model, gfk_field):
    # handle casting the GFK field if need be
    qn = connection.ops.quote_name
    
    pk_field_type = get_field_type(qs_model._meta.pk)
    gfk_field_type = get_field_type(gfk_field.model._meta.get_field(gfk_field.fk_field))
    if 'mysql' in connection.settings_dict['ENGINE'] and pk_field_type == 'integer':
        pk_field_type = 'unsigned'
    
    if pk_field_type != gfk_field_type:
        # cast the gfk to the pk type
        gfk_expr = "CAST(%s AS %s)" % (qn(gfk_field.fk_field), pk_field_type)
    else:
        gfk_expr = qn(gfk_field.fk_field) # the object_id field on the GFK
    
    return gfk_expr

def fallback_generic_annotate(qs_model, generic_qs_model, aggregator, gfk_field=None, alias='score'):
    qs = normalize_qs_model(qs_model)
    generic_qs = normalize_qs_model(generic_qs_model)
    
    content_type = ContentType.objects.get_for_model(qs.model)
    
    qn = connection.ops.quote_name
    aggregate_field = aggregator.lookup
    
    # since the aggregate may contain a generic relation, strip it
    if '__' in aggregate_field:
        _, aggregate_field = aggregate_field.rsplit('__', 1)
    
    if gfk_field is None:
        gfk_field = get_gfk_field(generic_qs.model)
    
    # collect the params we'll be using
    params = (
        aggregator.name, # the function that's doing the aggregation
        qn(aggregate_field), # the field containing the value to aggregate
        qn(gfk_field.model._meta.db_table), # table holding gfk'd item info
        qn(gfk_field.ct_field + '_id'), # the content_type field on the GFK
        content_type.pk, # the content_type id we need to match
        gfk_expression(qs.model, gfk_field),
        qn(qs.model._meta.db_table), # the table and pk from the main
        qn(qs.model._meta.pk.name) # part of the query
    )
    
    sql_template = """
SELECT COALESCE(%s(%s), 0) AS aggregate_score
FROM %s
WHERE
%s=%s AND
%s=%s.%s"""
    
    extra = sql_template % params
    
    if generic_qs.query.where.children:
        generic_query = generic_qs.values_list('pk').query
        inner_query, inner_query_params = query_as_sql(generic_query)
        
        inner_params = (
            qn(generic_qs.model._meta.db_table),
            qn(generic_qs.model._meta.pk.name),
        )
        inner_start = ' AND %s.%s IN (' % inner_params
        inner_end = ')'
        extra = extra + inner_start + inner_query + inner_end
    else:
        inner_query_params = []

    return qs.extra(
        select={alias: extra},
        select_params=inner_query_params,
    )

def fallback_generic_aggregate(qs_model, generic_qs_model, aggregator, gfk_field=None):
    qs = normalize_qs_model(qs_model)
    generic_qs = normalize_qs_model(generic_qs_model)
    
    content_type = ContentType.objects.get_for_model(qs.model)
    
    qn = connection.ops.quote_name
    aggregate_field = aggregator.lookup
    
    # since the aggregate may contain a generic relation, strip it
    if '__' in aggregate_field:
        _, aggregate_field = aggregate_field.rsplit('__', 1)
    
    if gfk_field is None:
        gfk_field = get_gfk_field(generic_qs.model)
    
    qs = qs.values_list('pk') # just the pks
    query, query_params = query_as_nested_sql(qs.query)
    
    # collect the params we'll be using
    params = (
        aggregator.name, # the function that's doing the aggregation
        qn(aggregate_field), # the field containing the value to aggregate
        qn(gfk_field.model._meta.db_table), # table holding gfk'd item info
        qn(gfk_field.ct_field + '_id'), # the content_type field on the GFK
        content_type.pk, # the content_type id we need to match
        gfk_expression(qs.model, gfk_field), # the object_id field on the GFK
    )
    
    query_start = """
SELECT %s(%s) AS aggregate_score
FROM %s
WHERE
%s=%s AND
%s IN (
""" % params
    
    query_end = ")"
    
    if generic_qs.query.where.children:
        generic_query = generic_qs.values_list('pk').query
        inner_query, inner_query_params = query_as_sql(generic_query)
        
        query_params += inner_query_params
        
        inner_params = (
            qn(generic_qs.model._meta.pk.name),
        )
        inner_start = ' AND %s IN (' % inner_params
        inner_end = ')'
        query_end = query_end + inner_start + inner_query + inner_end
    
    # pass in the inner_query unmodified as we will use the cursor to handle
    # quoting the inner parameters correctly
    query = query_start + query + query_end
    
    cursor = connection.cursor()
    cursor.execute(query, query_params)
    row = cursor.fetchone()

    return row[0]

def fallback_generic_filter(generic_qs_model, filter_qs_model, gfk_field=None):
    generic_qs = normalize_qs_model(generic_qs_model)
    filter_qs = normalize_qs_model(filter_qs_model)
    
    if gfk_field is None:
        gfk_field = get_gfk_field(generic_qs.model)
    
    # get the contenttype of our filtered queryset, e.g. Business
    filter_model = filter_qs.model
    content_type = ContentType.objects.get_for_model(filter_model)
    
    # filter the generic queryset to only include items of the given ctype
    generic_qs = generic_qs.filter(**{gfk_field.ct_field: content_type})
    
    # just select the primary keys in the sub-select
    filtered_query = filter_qs.values_list('pk').query
    inner_query, inner_query_params = query_as_sql(filtered_query)
    
    where = '%s IN (%s)' % (
        gfk_expression(filter_model, gfk_field),
        inner_query,
    )
    
    return generic_qs.extra(
        where=(where,),
        params=inner_query_params
    )
Something went wrong with that request. Please try again.