Skip to content
This repository

HTTPS clone URL

Subversion checkout URL

You can clone with HTTPS or Subversion.

Download ZIP
Browse code

Pass values through get_db_prep_save() in a QuerySet.update() call.

This removes a long-standing FIXME in the update() handling and allows for
greater flexibility in the values passed in. In particular, it brings updates
into line with saves for django.contrib.gis fields, so fixed #10411.

Thanks to Justin Bronn and Russell Keith-Magee for help with this patch.

git-svn-id: http://code.djangoproject.com/svn/django/trunk@10003 bcc190cf-cafb-0310-a4f2-bffc1f526a37
  • Loading branch information...
commit 35f934f5a7d9eb79a69f33de84854aa9a506b912 1 parent cee3173
Malcolm Tredinnick authored March 09, 2009
5  django/contrib/gis/db/backend/adaptor.py
@@ -8,7 +8,10 @@ def __init__(self, geom):
8 8
         self.srid = geom.srid
9 9
 
10 10
     def __eq__(self, other):
11  
-        return self.wkt == other.wkt and self.srid == other.srid 
  11
+        return self.wkt == other.wkt and self.srid == other.srid
12 12
 
13 13
     def __str__(self):
14 14
         return self.wkt
  15
+
  16
+    def prepare_database_save(self, unused):
  17
+        return self
3  django/contrib/gis/db/backend/postgis/adaptor.py
@@ -31,3 +31,6 @@ def getquoted(self):
31 31
         "Returns a properly quoted string for use in PostgreSQL/PostGIS."
32 32
         # Want to use WKB, so wrap with psycopg2 Binary() to quote properly.
33 33
         return "%s(%s, %s)" % (GEOM_FROM_WKB, Binary(self.wkb), self.srid or -1)
  34
+
  35
+    def prepare_database_save(self, unused):
  36
+        return self
18  django/contrib/gis/tests/geoapp/test_regress.py
... ...
@@ -0,0 +1,18 @@
  1
+import os, unittest
  2
+from django.contrib.gis.db.backend import SpatialBackend
  3
+from django.contrib.gis.tests.utils import no_mysql, no_oracle, no_postgis
  4
+from models import City
  5
+
  6
+class GeoRegressionTests(unittest.TestCase):
  7
+
  8
+    def test01_update(self):
  9
+        "Testing GeoQuerySet.update(), see #10411."
  10
+        pnt = City.objects.get(name='Pueblo').point
  11
+        bak = pnt.clone()
  12
+        pnt.y += 0.005
  13
+        pnt.x += 0.005
  14
+
  15
+        City.objects.filter(name='Pueblo').update(point=pnt)
  16
+        self.assertEqual(pnt, City.objects.get(name='Pueblo').point)
  17
+        City.objects.filter(name='Pueblo').update(point=bak)
  18
+        self.assertEqual(bak, City.objects.get(name='Pueblo').point)
55  django/contrib/gis/tests/geoapp/tests.py
@@ -13,7 +13,7 @@
13 13
 DISABLE = False
14 14
 
15 15
 class GeoModelTest(unittest.TestCase):
16  
-    
  16
+
17 17
     def test01_initial_sql(self):
18 18
         "Testing geographic initial SQL."
19 19
         if DISABLE: return
@@ -21,7 +21,7 @@ def test01_initial_sql(self):
21 21
             # Oracle doesn't allow strings longer than 4000 characters
22 22
             # in SQL files, and I'm stumped on how to use Oracle BFILE's
23 23
             # in PLSQL, so we set up the larger geometries manually, rather
24  
-            # than relying on the initial SQL. 
  24
+            # than relying on the initial SQL.
25 25
 
26 26
             # Routine for returning the path to the data files.
27 27
             data_dir = os.path.join(os.path.dirname(__file__), 'sql')
@@ -65,7 +65,7 @@ def test02_proxy(self):
65 65
         new = Point(5, 23)
66 66
         nullcity.point = new
67 67
 
68  
-        # Ensuring that the SRID is automatically set to that of the 
  68
+        # Ensuring that the SRID is automatically set to that of the
69 69
         #  field after assignment, but before saving.
70 70
         self.assertEqual(4326, nullcity.point.srid)
71 71
         nullcity.save()
@@ -94,7 +94,7 @@ def test02_proxy(self):
94 94
 
95 95
         ns = State.objects.get(name='NullState')
96 96
         self.assertEqual(ply, ns.poly)
97  
-        
  97
+
98 98
         # Testing the `ogr` and `srs` lazy-geometry properties.
99 99
         if gdal.HAS_GDAL:
100 100
             self.assertEqual(True, isinstance(ns.poly.ogr, gdal.OGRGeometry))
@@ -120,7 +120,7 @@ def test03a_kml(self):
120 120
         qs = City.objects.all()
121 121
         self.assertRaises(TypeError, qs.kml, 'name')
122 122
 
123  
-        # The reference KML depends on the version of PostGIS used 
  123
+        # The reference KML depends on the version of PostGIS used
124 124
         # (the output stopped including altitude in 1.3.3).
125 125
         major, minor1, minor2 = SpatialBackend.version
126 126
         ref_kml1 = '<Point><coordinates>-104.609252,38.255001,0</coordinates></Point>'
@@ -204,8 +204,8 @@ def test06_make_line(self):
204 204
         self.assertRaises(TypeError, Country.objects.make_line)
205 205
         # Reference query:
206 206
         # SELECT AsText(ST_MakeLine(geoapp_city.point)) FROM geoapp_city;
207  
-        self.assertEqual(GEOSGeometry('LINESTRING(-95.363151 29.763374,-96.801611 32.782057,-97.521157 34.464642,174.783117 -41.315268,-104.609252 38.255001,-95.23506 38.971823,-87.650175 41.850385,-123.305196 48.462611)', srid=4326),
208  
-                         City.objects.make_line())
  207
+        ref_line = GEOSGeometry('LINESTRING(-95.363151 29.763374,-96.801611 32.782057,-97.521157 34.464642,174.783117 -41.315268,-104.609252 38.255001,-95.23506 38.971823,-87.650175 41.850385,-123.305196 48.462611)', srid=4326)
  208
+        self.assertEqual(ref_line, City.objects.make_line())
209 209
 
210 210
     def test09_disjoint(self):
211 211
         "Testing the `disjoint` lookup type."
@@ -227,7 +227,7 @@ def test10_contains_contained(self):
227 227
         if DISABLE: return
228 228
         # Getting Texas, yes we were a country -- once ;)
229 229
         texas = Country.objects.get(name='Texas')
230  
-        
  230
+
231 231
         # Seeing what cities are in Texas, should get Houston and Dallas,
232 232
         #  and Oklahoma City because 'contained' only checks on the
233 233
         #  _bounding box_ of the Geometries.
@@ -288,15 +288,15 @@ def test11_lookup_insert_transform(self):
288 288
         # `ST_Intersects`, so contains is used instead.
289 289
         nad_pnt = fromstr(nad_wkt, srid=nad_srid)
290 290
         if SpatialBackend.oracle:
291  
-            tx = Country.objects.get(mpoly__contains=nad_pnt) 
  291
+            tx = Country.objects.get(mpoly__contains=nad_pnt)
292 292
         else:
293 293
             tx = Country.objects.get(mpoly__intersects=nad_pnt)
294 294
         self.assertEqual('Texas', tx.name)
295  
-        
  295
+
296 296
         # Creating San Antonio.  Remember the Alamo.
297 297
         sa = City(name='San Antonio', point=nad_pnt)
298 298
         sa.save()
299  
-        
  299
+
300 300
         # Now verifying that San Antonio was transformed correctly
301 301
         sa = City.objects.get(name='San Antonio')
302 302
         self.assertAlmostEqual(wgs_pnt.x, sa.point.x, 6)
@@ -321,7 +321,7 @@ def test12_null_geometries(self):
321 321
         # Puerto Rico should be NULL (it's a commonwealth unincorporated territory)
322 322
         self.assertEqual(1, len(nullqs))
323 323
         self.assertEqual('Puerto Rico', nullqs[0].name)
324  
-        
  324
+
325 325
         # The valid states should be Colorado & Kansas
326 326
         self.assertEqual(2, len(validqs))
327 327
         state_names = [s.name for s in validqs]
@@ -338,18 +338,18 @@ def test13_left_right(self):
338 338
         "Testing the 'left' and 'right' lookup types."
339 339
         if DISABLE: return
340 340
         # Left: A << B => true if xmax(A) < xmin(B)
341  
-        # Right: A >> B => true if xmin(A) > xmax(B) 
  341
+        # Right: A >> B => true if xmin(A) > xmax(B)
342 342
         #  See: BOX2D_left() and BOX2D_right() in lwgeom_box2dfloat4.c in PostGIS source.
343  
-        
  343
+
344 344
         # Getting the borders for Colorado & Kansas
345 345
         co_border = State.objects.get(name='Colorado').poly
346 346
         ks_border = State.objects.get(name='Kansas').poly
347 347
 
348 348
         # Note: Wellington has an 'X' value of 174, so it will not be considered
349 349
         #  to the left of CO.
350  
-        
  350
+
351 351
         # These cities should be strictly to the right of the CO border.
352  
-        cities = ['Houston', 'Dallas', 'San Antonio', 'Oklahoma City', 
  352
+        cities = ['Houston', 'Dallas', 'San Antonio', 'Oklahoma City',
353 353
                   'Lawrence', 'Chicago', 'Wellington']
354 354
         qs = City.objects.filter(point__right=co_border)
355 355
         self.assertEqual(7, len(qs))
@@ -365,7 +365,7 @@ def test13_left_right(self):
365 365
         #  to the left of CO.
366 366
         vic = City.objects.get(point__left=co_border)
367 367
         self.assertEqual('Victoria', vic.name)
368  
-        
  368
+
369 369
         cities = ['Pueblo', 'Victoria']
370 370
         qs = City.objects.filter(point__left=ks_border)
371 371
         self.assertEqual(2, len(qs))
@@ -383,12 +383,12 @@ def test14_equals(self):
383 383
     def test15_relate(self):
384 384
         "Testing the 'relate' lookup type."
385 385
         if DISABLE: return
386  
-        # To make things more interesting, we will have our Texas reference point in 
  386
+        # To make things more interesting, we will have our Texas reference point in
387 387
         # different SRIDs.
388 388
         pnt1 = fromstr('POINT (649287.0363174 4177429.4494686)', srid=2847)
389 389
         pnt2 = fromstr('POINT(-98.4919715741052 29.4333344025053)', srid=4326)
390 390
 
391  
-        # Not passing in a geometry as first param shoud 
  391
+        # Not passing in a geometry as first param shoud
392 392
         # raise a type error when initializing the GeoQuerySet
393 393
         self.assertRaises(TypeError, Country.objects.filter, mpoly__relate=(23, 'foo'))
394 394
         # Making sure the right exception is raised for the given
@@ -440,7 +440,7 @@ def test17_unionagg(self):
440 440
         # Using `field_name` keyword argument in one query and specifying an
441 441
         # order in the other (which should not be used because this is
442 442
         # an aggregate method on a spatial column)
443  
-        u1 = qs.unionagg(field_name='point') 
  443
+        u1 = qs.unionagg(field_name='point')
444 444
         u2 = qs.order_by('name').unionagg()
445 445
         tol = 0.00001
446 446
         if SpatialBackend.oracle:
@@ -458,8 +458,8 @@ def test18_geometryfield(self):
458 458
         Feature(name='Point', geom=Point(1, 1)).save()
459 459
         Feature(name='LineString', geom=LineString((0, 0), (1, 1), (5, 5))).save()
460 460
         Feature(name='Polygon', geom=Polygon(LinearRing((0, 0), (0, 5), (5, 5), (5, 0), (0, 0)))).save()
461  
-        Feature(name='GeometryCollection', 
462  
-                geom=GeometryCollection(Point(2, 2), LineString((0, 0), (2, 2)), 
  461
+        Feature(name='GeometryCollection',
  462
+                geom=GeometryCollection(Point(2, 2), LineString((0, 0), (2, 2)),
463 463
                                         Polygon(LinearRing((0, 0), (0, 5), (5, 5), (5, 0), (0, 0))))).save()
464 464
 
465 465
         f_1 = Feature.objects.get(name='Point')
@@ -474,7 +474,7 @@ def test18_geometryfield(self):
474 474
         f_4 = Feature.objects.get(name='GeometryCollection')
475 475
         self.assertEqual(True, isinstance(f_4.geom, GeometryCollection))
476 476
         self.assertEqual(f_3.geom, f_4.geom[2])
477  
-    
  477
+
478 478
     def test19_centroid(self):
479 479
         "Testing the `centroid` GeoQuerySet method."
480 480
         if DISABLE: return
@@ -494,7 +494,7 @@ def test20_pointonsurface(self):
494 494
                    'Texas' : fromstr('POINT (-103.002434 36.500397)', srid=4326),
495 495
                    }
496 496
         elif SpatialBackend.postgis:
497  
-            # Using GEOSGeometry to compute the reference point on surface values 
  497
+            # Using GEOSGeometry to compute the reference point on surface values
498 498
             # -- since PostGIS also uses GEOS these should be the same.
499 499
             ref = {'New Zealand' : Country.objects.get(name='New Zealand').mpoly.point_on_surface,
500 500
                    'Texas' : Country.objects.get(name='Texas').mpoly.point_on_surface
@@ -533,7 +533,7 @@ def test23_numgeom(self):
533 533
         if DISABLE: return
534 534
         # Both 'countries' only have two geometries.
535 535
         for c in Country.objects.num_geom(): self.assertEqual(2, c.num_geom)
536  
-        for c in City.objects.filter(point__isnull=False).num_geom(): 
  536
+        for c in City.objects.filter(point__isnull=False).num_geom():
537 537
             # Oracle will return 1 for the number of geometries on non-collections,
538 538
             # whereas PostGIS will return None.
539 539
             if SpatialBackend.postgis: self.assertEqual(None, c.num_geom)
@@ -566,15 +566,18 @@ def test26_inherited_geofields(self):
566 566
         # All transformation SQL will need to be performed on the
567 567
         # _parent_ table.
568 568
         qs = PennsylvaniaCity.objects.transform(32128)
569  
-        
  569
+
570 570
         self.assertEqual(1, qs.count())
571 571
         for pc in qs: self.assertEqual(32128, pc.point.srid)
572 572
 
573 573
 from test_feeds import GeoFeedTest
  574
+from test_regress import GeoRegressionTests
574 575
 from test_sitemaps import GeoSitemapTest
  576
+
575 577
 def suite():
576 578
     s = unittest.TestSuite()
577 579
     s.addTest(unittest.makeSuite(GeoModelTest))
578 580
     s.addTest(unittest.makeSuite(GeoFeedTest))
579 581
     s.addTest(unittest.makeSuite(GeoSitemapTest))
  582
+    s.addTest(unittest.makeSuite(GeoRegressionTests))
580 583
     return s
17  django/contrib/gis/tests/geoapp/tests_mysql.py
@@ -8,7 +8,7 @@
8 8
 from django.core.exceptions import ImproperlyConfigured
9 9
 
10 10
 class GeoModelTest(unittest.TestCase):
11  
-    
  11
+
12 12
     def test01_initial_sql(self):
13 13
         "Testing geographic initial SQL."
14 14
         # Ensuring that data was loaded from initial SQL.
@@ -38,7 +38,7 @@ def test02_proxy(self):
38 38
         new = Point(5, 23)
39 39
         nullcity.point = new
40 40
 
41  
-        # Ensuring that the SRID is automatically set to that of the 
  41
+        # Ensuring that the SRID is automatically set to that of the
42 42
         #  field after assignment, but before saving.
43 43
         self.assertEqual(4326, nullcity.point.srid)
44 44
         nullcity.save()
@@ -67,7 +67,7 @@ def test02_proxy(self):
67 67
 
68 68
         ns = State.objects.get(name='NullState')
69 69
         self.assertEqual(ply, ns.poly)
70  
-        
  70
+
71 71
         # Testing the `ogr` and `srs` lazy-geometry properties.
72 72
         if gdal.HAS_GDAL:
73 73
             self.assertEqual(True, isinstance(ns.poly.ogr, gdal.OGRGeometry))
@@ -88,7 +88,7 @@ def test03_contains_contained(self):
88 88
         "Testing the 'contained', 'contains', and 'bbcontains' lookup types."
89 89
         # Getting Texas, yes we were a country -- once ;)
90 90
         texas = Country.objects.get(name='Texas')
91  
-        
  91
+
92 92
         # Seeing what cities are in Texas, should get Houston and Dallas,
93 93
         #  and Oklahoma City because MySQL 'within' only checks on the
94 94
         #  _bounding box_ of the Geometries.
@@ -146,8 +146,8 @@ def test06_geometryfield(self):
146 146
         f1 = Feature(name='Point', geom=Point(1, 1))
147 147
         f2 = Feature(name='LineString', geom=LineString((0, 0), (1, 1), (5, 5)))
148 148
         f3 = Feature(name='Polygon', geom=Polygon(LinearRing((0, 0), (0, 5), (5, 5), (5, 0), (0, 0))))
149  
-        f4 = Feature(name='GeometryCollection', 
150  
-                     geom=GeometryCollection(Point(2, 2), LineString((0, 0), (2, 2)), 
  149
+        f4 = Feature(name='GeometryCollection',
  150
+                     geom=GeometryCollection(Point(2, 2), LineString((0, 0), (2, 2)),
151 151
                                              Polygon(LinearRing((0, 0), (0, 5), (5, 5), (5, 0), (0, 0)))))
152 152
         f1.save()
153 153
         f2.save()
@@ -166,7 +166,7 @@ def test06_geometryfield(self):
166 166
         f_4 = Feature.objects.get(name='GeometryCollection')
167 167
         self.assertEqual(True, isinstance(f_4.geom, GeometryCollection))
168 168
         self.assertEqual(f_3.geom, f_4.geom[2])
169  
-    
  169
+
170 170
     def test07_mysql_limitations(self):
171 171
         "Testing that union(), kml(), gml() raise exceptions."
172 172
         self.assertRaises(ImproperlyConfigured, City.objects.union, Point(5, 23), field_name='point')
@@ -174,10 +174,13 @@ def test07_mysql_limitations(self):
174 174
         self.assertRaises(ImproperlyConfigured, Country.objects.all().gml, field_name='mpoly')
175 175
 
176 176
 from test_feeds import GeoFeedTest
  177
+from test_regress import GeoRegressionTests
177 178
 from test_sitemaps import GeoSitemapTest
  179
+
178 180
 def suite():
179 181
     s = unittest.TestSuite()
180 182
     s.addTest(unittest.makeSuite(GeoModelTest))
181 183
     s.addTest(unittest.makeSuite(GeoFeedTest))
182 184
     s.addTest(unittest.makeSuite(GeoSitemapTest))
  185
+    s.addTest(unittest.makeSuite(GeoRegressionTests))
183 186
     return s
2  django/db/models/base.py
@@ -499,6 +499,8 @@ def _get_next_or_previous_in_order(self, is_next):
499 499
             setattr(self, cachename, obj)
500 500
         return getattr(self, cachename)
501 501
 
  502
+    def prepare_database_save(self, unused):
  503
+        return self.pk
502 504
 
503 505
 
504 506
 ############################################
3  django/db/models/expressions.py
@@ -90,6 +90,9 @@ def __rand__(self, other):
90 90
     def __ror__(self, other):
91 91
         return self._combine(other, self.OR, True)
92 92
 
  93
+    def prepare_database_save(self, unused):
  94
+        return self
  95
+
93 96
 class F(ExpressionNode):
94 97
     """
95 98
     An expression representing the value of the given field.
7  django/db/models/sql/subqueries.py
@@ -239,9 +239,10 @@ def add_update_fields(self, values_seq):
239 239
         """
240 240
         from django.db.models.base import Model
241 241
         for field, model, val in values_seq:
242  
-            # FIXME: Some sort of db_prep_* is probably more appropriate here.
243  
-            if field.rel and isinstance(val, Model):
244  
-                val = val.pk
  242
+            if hasattr(val, 'prepare_database_save'):
  243
+                val = val.prepare_database_save(field)
  244
+            else:
  245
+                val = field.get_db_prep_save(val)
245 246
 
246 247
             # Getting the placeholder for the field.
247 248
             if hasattr(field, 'get_placeholder'):

0 notes on commit 35f934f

Please sign in to comment.
Something went wrong with that request. Please try again.