Skip to content
This repository

HTTPS clone URL

Subversion checkout URL

You can clone with HTTPS or Subversion.

Download ZIP
Browse code

queryset-refactor: Fixed up extra(select=...) calls with parameters s…

…o that the

parameters are substituted in correctly in all cases. This introduces an extra
argument to extra() for this purpose; no alternative there.

Also fixed values() to work if you don't specify *all* the extra select aliases
in the values() call.

Refs #3141.


git-svn-id: http://code.djangoproject.com/svn/django/branches/queryset-refactor@7340 bcc190cf-cafb-0310-a4f2-bffc1f526a37
  • Loading branch information...
commit 04da22633fcda983cb9ee69e63b2ebe99301b717 1 parent e2dfad1
Malcolm Tredinnick authored March 20, 2008
27  django/db/models/query.py
@@ -142,16 +142,16 @@ def iterator(self):
142 142
         else:
143 143
             requested = None
144 144
         max_depth = self.query.max_depth
145  
-        index_end = len(self.model._meta.fields)
146 145
         extra_select = self.query.extra_select.keys()
  146
+        index_start = len(extra_select)
147 147
         for row in self.query.results_iter():
148 148
             if fill_cache:
149  
-                obj, index_end = get_cached_row(self.model, row, 0, max_depth,
150  
-                        requested=requested)
  149
+                obj, _ = get_cached_row(self.model, row, index_start,
  150
+                        max_depth, requested=requested)
151 151
             else:
152  
-                obj = self.model(*row[:index_end])
  152
+                obj = self.model(*row[index_start:])
153 153
             for i, k in enumerate(extra_select):
154  
-                setattr(obj, k, row[index_end + i])
  154
+                setattr(obj, k, row[i])
155 155
             yield obj
156 156
 
157 157
     def count(self):
@@ -413,14 +413,14 @@ def distinct(self, true_or_false=True):
413 413
         return obj
414 414
 
415 415
     def extra(self, select=None, where=None, params=None, tables=None,
416  
-            order_by=None):
  416
+            order_by=None, select_params=None):
417 417
         """
418 418
         Add extra SQL fragments to the query.
419 419
         """
420 420
         assert self.query.can_filter(), \
421 421
                 "Cannot change a query once a slice has been taken"
422 422
         clone = self._clone()
423  
-        clone.query.add_extra(select, where, params, tables, order_by)
  423
+        clone.query.add_extra(select, select_params, where, params, tables, order_by)
424 424
         return clone
425 425
 
426 426
     def reverse(self):
@@ -475,9 +475,10 @@ def __iter__(self):
475 475
         return self.iterator()
476 476
 
477 477
     def iterator(self):
478  
-        self.field_names.extend([f for f in self.query.extra_select.keys()])
  478
+        self.query.trim_extra_select(self.extra_names)
  479
+        names = self.query.extra_select.keys() + self.field_names
479 480
         for row in self.query.results_iter():
480  
-            yield dict(zip(self.field_names, row))
  481
+            yield dict(zip(names, row))
481 482
 
482 483
     def _setup_query(self):
483 484
         """
@@ -487,6 +488,7 @@ def _setup_query(self):
487 488
         Called by the _clone() method after initialising the rest of the
488 489
         instance.
489 490
         """
  491
+        self.extra_names = []
490 492
         if self._fields:
491 493
             if not self.query.extra_select:
492 494
                 field_names = list(self._fields)
@@ -496,7 +498,9 @@ def _setup_query(self):
496 498
                 for f in self._fields:
497 499
                     if f in names:
498 500
                         field_names.append(f)
499  
-                    elif not self.query.extra_select.has_key(f):
  501
+                    elif self.query.extra_select.has_key(f):
  502
+                        self.extra_names.append(f)
  503
+                    else:
500 504
                         raise FieldDoesNotExist('%s has no field named %r'
501 505
                                 % (self.model._meta.object_name, f))
502 506
         else:
@@ -513,7 +517,8 @@ def _clone(self, klass=None, setup=False, **kwargs):
513 517
         """
514 518
         c = super(ValuesQuerySet, self)._clone(klass, **kwargs)
515 519
         c._fields = self._fields[:]
516  
-        c.field_names = self.field_names[:]
  520
+        c.field_names = self.field_names
  521
+        c.extra_names = self.extra_names
517 522
         if setup and hasattr(c, '_setup_query'):
518 523
             c._setup_query()
519 524
         return c
40  django/db/models/sql/query.py
@@ -73,6 +73,7 @@ def __init__(self, model, connection, where=WhereNode):
73 73
         # These are for extensions. The contents are more or less appended
74 74
         # verbatim to the appropriate clause.
75 75
         self.extra_select = {}  # Maps col_alias -> col_sql.
  76
+        self.extra_select_params = ()
76 77
         self.extra_tables = ()
77 78
         self.extra_where = ()
78 79
         self.extra_params = ()
@@ -150,6 +151,7 @@ def clone(self, klass=None, **kwargs):
150 151
         obj.select_related = self.select_related
151 152
         obj.max_depth = self.max_depth
152 153
         obj.extra_select = self.extra_select.copy()
  154
+        obj.extra_select_params = self.extra_select_params
153 155
         obj.extra_tables = self.extra_tables
154 156
         obj.extra_where = self.extra_where
155 157
         obj.extra_params = self.extra_params
@@ -214,6 +216,7 @@ def as_sql(self, with_limits=True):
214 216
         # get_from_clause() for details.
215 217
         from_, f_params = self.get_from_clause()
216 218
         where, w_params = self.where.as_sql(qn=self.quote_name_unless_alias)
  219
+        params = list(self.extra_select_params)
217 220
 
218 221
         result = ['SELECT']
219 222
         if self.distinct:
@@ -222,7 +225,7 @@ def as_sql(self, with_limits=True):
222 225
 
223 226
         result.append('FROM')
224 227
         result.extend(from_)
225  
-        params = list(f_params)
  228
+        params.extend(f_params)
226 229
 
227 230
         if where:
228 231
             result.append('WHERE %s' % where)
@@ -351,8 +354,8 @@ def get_columns(self):
351 354
         the model.
352 355
         """
353 356
         qn = self.quote_name_unless_alias
354  
-        result = []
355  
-        aliases = []
  357
+        result = ['(%s) AS %s' % (col, alias) for alias, col in self.extra_select.items()]
  358
+        aliases = self.extra_select.keys()
356 359
         if self.select:
357 360
             for col in self.select:
358 361
                 if isinstance(col, (list, tuple)):
@@ -364,12 +367,9 @@ def get_columns(self):
364 367
                     if hasattr(col, 'alias'):
365 368
                         aliases.append(col.alias)
366 369
         elif self.default_cols:
367  
-            result = self.get_default_columns(True)
368  
-            aliases = result[:]
369  
-
370  
-        result.extend(['(%s) AS %s' % (col, alias)
371  
-                for alias, col in self.extra_select.items()])
372  
-        aliases.extend(self.extra_select.keys())
  370
+            cols = self.get_default_columns(True)
  371
+            result.extend(cols)
  372
+            aliases.extend(cols)
373 373
 
374 374
         self._select_aliases = set(aliases)
375 375
         return result
@@ -403,9 +403,9 @@ def get_default_columns(self, as_str=False):
403 403
     def get_from_clause(self):
404 404
         """
405 405
         Returns a list of strings that are joined together to go after the
406  
-        "FROM" part of the query, as well as any extra parameters that need to
407  
-        be included. Sub-classes, can override this to create a from-clause via
408  
-        a "select", for example (e.g. CountQuery).
  406
+        "FROM" part of the query, as well as a list any extra parameters that
  407
+        need to be included. Sub-classes, can override this to create a
  408
+        from-clause via a "select", for example (e.g. CountQuery).
409 409
 
410 410
         This should only be called after any SQL construction methods that
411 411
         might change the tables we need. This means the select columns and
@@ -1253,6 +1253,7 @@ def add_count_column(self):
1253 1253
             self.distinct = False
1254 1254
         self.select = [select]
1255 1255
         self.extra_select = {}
  1256
+        self.extra_select_params = ()
1256 1257
 
1257 1258
     def add_select_related(self, fields):
1258 1259
         """
@@ -1267,7 +1268,7 @@ def add_select_related(self, fields):
1267 1268
                 d = d.setdefault(part, {})
1268 1269
         self.select_related = field_dict
1269 1270
 
1270  
-    def add_extra(self, select, where, params, tables, order_by):
  1271
+    def add_extra(self, select, select_params, where, params, tables, order_by):
1271 1272
         """
1272 1273
         Adds data to the various extra_* attributes for user-created additions
1273 1274
         to the query.
@@ -1279,6 +1280,8 @@ def add_extra(self, select, where, params, tables, order_by):
1279 1280
                     not isinstance(self.extra_select, SortedDict)):
1280 1281
                 self.extra_select = SortedDict(self.extra_select)
1281 1282
             self.extra_select.update(select)
  1283
+        if select_params:
  1284
+            self.extra_select_params += tuple(select_params)
1282 1285
         if where:
1283 1286
             self.extra_where += tuple(where)
1284 1287
         if params:
@@ -1288,6 +1291,17 @@ def add_extra(self, select, where, params, tables, order_by):
1288 1291
         if order_by:
1289 1292
             self.extra_order_by = order_by
1290 1293
 
  1294
+    def trim_extra_select(self, names):
  1295
+        """
  1296
+        Removes any aliases in the extra_select dictionary that aren't in
  1297
+        'names'.
  1298
+
  1299
+        This is needed if we are selecting certain values that don't incldue
  1300
+        all of the extra_select names.
  1301
+        """
  1302
+        for key in set(self.extra_select).difference(set(names)):
  1303
+            del self.extra_select[key]
  1304
+
1291 1305
     def set_start(self, start):
1292 1306
         """
1293 1307
         Sets the table from which to start joining. The start position is
54  docs/db-api.txt
@@ -841,8 +841,9 @@ You can only refer to ``ForeignKey`` relations in the list of fields passed to
841 841
 list of fields and the ``depth`` parameter in the same ``select_related()``
842 842
 call, since they are conflicting options.
843 843
 
844  
-``extra(select=None, where=None, params=None, tables=None, order_by=None)``
845  
-~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
  844
+``extra(select=None, where=None, params=None, tables=None, order_by=None,
  845
+select_params=None)``
  846
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
846 847
 
847 848
 Sometimes, the Django query syntax by itself can't easily express a complex
848 849
 ``WHERE`` clause. For these edge cases, Django provides the ``extra()``
@@ -901,31 +902,18 @@ of the arguments is required, but you should use at least one of them.
901 902
 
902 903
     **New in Django development version**
903 904
     In some rare cases, you might wish to pass parameters to the SQL fragments
904  
-    in ``extra(select=...)```. Since the ``params`` attribute is a sequence
905  
-    and the ``select`` attribute is a dictionary, some care is required so
906  
-    that the parameters are matched up correctly with the extra select pieces.
907  
-    Firstly, in this situation, you should use a
908  
-    ``django.utils.datastructures.SortedDict`` for the ``select`` value, not
909  
-    just a normal Python dictionary. Secondly, make sure that your parameters
910  
-    for the ``select`` come first in the list and that you have not passed any
911  
-    parameters to an earlier ``extra()`` call for this queryset.
  905
+    in ``extra(select=...)```. For this purpose, use the ``select_params``
  906
+    parameter. Since ``select_params`` is a sequence and the ``select``
  907
+    attribute is a dictionary, some care is required so that the parameters
  908
+    are matched up correctly with the extra select pieces.  In this situation,
  909
+    you should use a ``django.utils.datastructures.SortedDict`` for the
  910
+    ``select`` value, not just a normal Python dictionary.
912 911
 
913  
-    This will work::
  912
+    This will work, for example::
914 913
 
915 914
         Blog.objects.extra(
916 915
             select=SortedDict(('a', '%s'), ('b', '%s')),
917  
-            params=('one', 'two'))
918  
-
919  
-    ... while this won't::
920  
-
921  
-        # Will not work!
922  
-        Blog.objects.extra(where=['foo=%s'], params=('bar',)).extra(
923  
-            select=SortedDict(('a', '%s'), ('b', '%s')),
924  
-            params=('one', 'two'))
925  
-
926  
-    In the second example, the earlier ``params`` usage will mess up the later
927  
-    one. So always put your extra select pieces in the first ``extra()`` call
928  
-    if you need to use parameters in them.
  916
+            select_params=('one', 'two'))
929 917
 
930 918
 ``where`` / ``tables``
931 919
     You can define explicit SQL ``WHERE`` clauses -- perhaps to perform
@@ -965,19 +953,18 @@ of the arguments is required, but you should use at least one of them.
965 953
     time).
966 954
 
967 955
 ``params``
968  
-    The ``select`` and ``where`` parameters described above may use standard
969  
-    Python database string placeholders -- ``'%s'`` to indicate parameters the
970  
-    database engine should automatically quote. The ``params`` argument is a
971  
-    list of any extra parameters to be substituted.
  956
+    The ``where`` parameter described above may use standard Python database
  957
+    string placeholders -- ``'%s'`` to indicate parameters the database engine
  958
+    should automatically quote. The ``params`` argument is a list of any extra
  959
+    parameters to be substituted.
972 960
 
973 961
     Example::
974 962
 
975 963
         Entry.objects.extra(where=['headline=%s'], params=['Lennon'])
976 964
 
977  
-    Always use ``params`` instead of embedding values directly into ``select``
978  
-    or ``where`` because ``params`` will ensure values are quoted correctly
979  
-    according to your particular backend. (For example, quotes will be escaped
980  
-    correctly.)
  965
+    Always use ``params`` instead of embedding values directly into ``where``
  966
+    because ``params`` will ensure values are quoted correctly according to
  967
+    your particular backend. (For example, quotes will be escaped correctly.)
981 968
 
982 969
     Bad::
983 970
 
@@ -987,8 +974,9 @@ of the arguments is required, but you should use at least one of them.
987 974
 
988 975
         Entry.objects.extra(where=['headline=%s'], params=['Lennon'])
989 976
 
990  
-    The combined number of placeholders in the list of strings for ``select``
991  
-    or ``where`` should equal the number of values in the ``params`` list.
  977
+**New in Django development version** The ``select_params`` argument to
  978
+``extra()`` is new. Previously, you could attempt to pass parameters for
  979
+``select`` in the ``params`` argument, but it worked very unreliably.
992 980
 
993 981
 QuerySet methods that do not return QuerySets
994 982
 ---------------------------------------------
13  tests/regressiontests/queries/models.py
@@ -282,6 +282,10 @@ class Meta:
282 282
 >>> xx.save()
283 283
 >>> Item.objects.exclude(name='two').values('creator', 'name').distinct().count()
284 284
 4
  285
+>>> Item.objects.exclude(name='two').extra(select={'foo': '%s'}, select_params=(1,)).values('creator', 'name', 'foo').distinct().count()
  286
+4
  287
+>>> Item.objects.exclude(name='two').extra(select={'foo': '%s'}, select_params=(1,)).values('creator', 'name').distinct().count()
  288
+4
285 289
 >>> xx.delete()
286 290
 
287 291
 Bug #2253
@@ -386,6 +390,8 @@ class Meta:
386 390
 Bug #3141
387 391
 >>> Author.objects.extra(select={'foo': '1'}).count()
388 392
 4
  393
+>>> Author.objects.extra(select={'foo': '%s'}, select_params=(1,)).count()
  394
+4
389 395
 
390 396
 Bug #2400
391 397
 >>> Author.objects.filter(item__isnull=True)
@@ -462,6 +468,11 @@ class Meta:
462 468
 >>> qs.extra(order_by=('-good', 'id'))
463 469
 [<Ranking: 3: a1>, <Ranking: 2: a2>, <Ranking: 1: a3>]
464 470
 
  471
+# Despite having some extra aliases in the query, we can still omit them in a
  472
+# values() query.
  473
+>>> qs.values('id', 'rank').order_by('id')
  474
+[{'id': 1, 'rank': 2}, {'id': 2, 'rank': 1}, {'id': 3, 'rank': 3}]
  475
+
465 476
 Bugs #2874, #3002
466 477
 >>> qs = Item.objects.select_related().order_by('note__note', 'name')
467 478
 >>> list(qs)
@@ -533,7 +544,7 @@ class Meta:
533 544
 # This slightly odd comparison works aorund the fact that PostgreSQL will
534 545
 # return 'one' and 'two' as strings, not Unicode objects. It's a side-effect of
535 546
 # using constants here and not a real concern.
536  
->>> d = Item.objects.extra(select=SortedDict(s), params=params).values('a', 'b')[0]
  547
+>>> d = Item.objects.extra(select=SortedDict(s), select_params=params).values('a', 'b')[0]
537 548
 >>> d == {'a': u'one', 'b': u'two'}
538 549
 True
539 550
 

0 notes on commit 04da226

Please sign in to comment.
Something went wrong with that request. Please try again.