Skip to content
This repository

HTTPS clone URL

Subversion checkout URL

You can clone with HTTPS or Subversion.

Download ZIP
Browse code

Smarter handling of reverse lookups

  • Loading branch information...
commit a6657a9766027e143693d74d1791bea3a21db489 1 parent 987175f
Charles Leifer authored July 12, 2012
144  genericm2m/genericm2m_tests/tests.py
@@ -12,18 +12,18 @@ def setUp(self):
12 12
         self.pizza = Food.objects.create(name='pizza')
13 13
         self.sandwich = Food.objects.create(name='sandwich')
14 14
         self.cereal = Food.objects.create(name='cereal')
15  
-        
  15
+
16 16
         self.soda = Beverage.objects.create(name='soda')
17 17
         self.beer = Beverage.objects.create(name='beer')
18 18
         self.milk = Beverage.objects.create(name='milk')
19  
-        
  19
+
20 20
         self.mario = Person.objects.create(name='mario')
21 21
         self.sam = Person.objects.create(name='sam')
22 22
         self.chocula = Person.objects.create(name='chocula')
23  
-        
  23
+
24 24
         self.table = Boring.objects.create(name='table')
25 25
         self.chair = Boring.objects.create(name='chair')
26  
-    
  26
+
27 27
     def assertRelatedEqual(self, rel_qs, tups, from_field='parent',
28 28
                            to_field='object'):
29 29
         rel_tup = [
@@ -31,7 +31,7 @@ def assertRelatedEqual(self, rel_qs, tups, from_field='parent',
31 31
             for rel_obj in rel_qs
32 32
         ]
33 33
         self.assertEqual(rel_tup, list(tups))
34  
-    
  34
+
35 35
     def test_connect(self):
36 36
         """
37 37
         Connect model instances to various other model instances, then query
@@ -40,44 +40,44 @@ def test_connect(self):
40 40
         self.pizza.related.connect(self.soda)
41 41
         self.pizza.related.connect(self.beer)
42 42
         self.pizza.related.connect(self.mario)
43  
-        
  43
+
44 44
         self.soda.related.connect(self.pizza)
45 45
         self.soda.related.connect(self.beer)
46  
-        
  46
+
47 47
         related = self.pizza.related.all()
48 48
         self.assertRelatedEqual(related, (
49 49
             (self.pizza, self.mario),
50 50
             (self.pizza, self.beer),
51 51
             (self.pizza, self.soda),
52 52
         ))
53  
-        
  53
+
54 54
         self.sandwich.related.connect(self.soda)
55 55
         self.sandwich.related.connect(self.milk)
56  
-        
  56
+
57 57
         related = self.sandwich.related.all()
58 58
         self.assertRelatedEqual(related, (
59 59
             (self.sandwich, self.milk),
60 60
             (self.sandwich, self.soda),
61 61
         ))
62  
-        
  62
+
63 63
         related = self.cereal.related.all()
64 64
         self.assertRelatedEqual(related, ())
65  
-        
  65
+
66 66
         related = self.soda.related.all()
67 67
         self.assertRelatedEqual(related, (
68 68
             (self.soda, self.beer),
69 69
             (self.soda, self.pizza),
70 70
         ))
71  
-        
  71
+
72 72
         self.sandwich.related.connect(self.table)
73  
-        
  73
+
74 74
         related = self.sandwich.related.all()
75 75
         self.assertRelatedEqual(related, (
76 76
             (self.sandwich, self.table),
77 77
             (self.sandwich, self.milk),
78 78
             (self.sandwich, self.soda),
79 79
         ))
80  
-    
  80
+
81 81
     def test_related_to(self):
82 82
         """
83 83
         Check the back-side of the double-GFK, note: this only works on objects
@@ -91,24 +91,24 @@ def test_related_to(self):
91 91
         self.sandwich.related.connect(self.milk)
92 92
         self.mario.related.connect(self.soda)
93 93
         self.soda.related.connect(self.pizza)
94  
-        
  94
+
95 95
         related = self.soda.related.related_to()
96 96
         self.assertRelatedEqual(related, (
97 97
             (self.mario, self.soda),
98 98
             (self.sandwich, self.soda),
99 99
             (self.pizza, self.soda),
100 100
         ))
101  
-        
  101
+
102 102
         related = self.beer.related.related_to()
103 103
         self.assertRelatedEqual(related, (
104 104
             (self.pizza, self.beer),
105 105
         ))
106  
-        
  106
+
107 107
         related = self.milk.related.related_to()
108 108
         self.assertRelatedEqual(related, (
109 109
             (self.sandwich, self.milk),
110 110
         ))
111  
-        
  111
+
112 112
         related = self.pizza.related.related_to()
113 113
         self.assertRelatedEqual(related, (
114 114
             (self.soda, self.pizza),
@@ -124,15 +124,15 @@ def test_manager_methods(self):
124 124
         # connect pizza to soda and grab the newly-created RelatedObject
125 125
         self.pizza.related.connect(self.soda)
126 126
         rel_obj = RelatedObject.objects.all()[0]
127  
-        
  127
+
128 128
         # connect cereal to milk (this is just to make sure that anything
129 129
         # modified on one Food object doesn't affect another Food object
130 130
         self.cereal.related.connect(self.milk)
131  
-        
  131
+
132 132
         # create a new RelatedObject but do not save it yet -- note that it does
133 133
         # not have `parent_object` set
134 134
         new_rel_obj = RelatedObject(object=self.beer)
135  
-        
  135
+
136 136
         # add this related object to pizza, parent_object gets set and it will
137 137
         # show up in the queryset as expected
138 138
         self.pizza.related.add(new_rel_obj)
@@ -140,26 +140,26 @@ def test_manager_methods(self):
140 140
             (self.pizza, self.beer),
141 141
             (self.pizza, self.soda),
142 142
         ))
143  
-        
  143
+
144 144
         # remove the original RelatedObject `rel_obj`, which was the connection
145 145
         # from pizza -> soda
146 146
         self.pizza.related.remove(rel_obj)
147 147
         self.assertRelatedEqual(self.pizza.related.all(), (
148 148
             (self.pizza, self.beer),
149 149
         ))
150  
-        
  150
+
151 151
         # make sure clearing pizza's related queryset works
152 152
         self.pizza.related.clear()
153 153
         self.assertRelatedEqual(self.pizza.related.all(), ())
154  
-        
  154
+
155 155
         # make sure clearing the pizza objects didn't affect cereal
156 156
         self.assertRelatedEqual(self.cereal.related.all(), (
157 157
             (self.cereal, self.milk),
158 158
         ))
159  
-        
  159
+
160 160
         # there should be just one row in the table
161 161
         self.assertEqual(RelatedObject.objects.count(), 1)
162  
-    
  162
+
163 163
     def test_model_level(self):
164 164
         """
165 165
         The RelatedObjectsDescriptor can work at the class-level as well and
@@ -169,25 +169,25 @@ def test_model_level(self):
169 169
         """
170 170
         self.pizza.related.connect(self.beer)
171 171
         self.cereal.related.connect(self.milk)
172  
-        
  172
+
173 173
         self.mario.related.connect(self.pizza)
174 174
         self.sam.related.connect(self.beer)
175 175
         self.soda.related.connect(self.pizza)
176  
-        
  176
+
177 177
         self.assertRelatedEqual(Food.related.all(), (
178 178
             (self.cereal, self.milk),
179 179
             (self.pizza, self.beer),
180 180
         ))
181  
-        
  181
+
182 182
         self.assertRelatedEqual(Beverage.related.all(), (
183 183
             (self.soda, self.pizza),
184 184
         ))
185  
-        
  185
+
186 186
         self.assertRelatedEqual(Person.related.all(), (
187 187
             (self.sam, self.beer),
188 188
             (self.mario, self.pizza),
189 189
         ))
190  
-    
  190
+
191 191
     def test_custom_connect(self):
192 192
         """
193 193
         Mimic the test_connect() method, but instead use the custom descriptor,
@@ -195,22 +195,22 @@ def test_custom_connect(self):
195 195
         """
196 196
         self.pizza.related_beverages.connect(self.soda)
197 197
         self.pizza.related_beverages.connect(self.beer)
198  
-        
  198
+
199 199
         related = self.pizza.related_beverages.all()
200 200
         self.assertRelatedEqual(related, (
201 201
             (self.pizza, self.beer),
202 202
             (self.pizza, self.soda),
203 203
         ), 'food', 'beverage')
204  
-        
  204
+
205 205
         self.sandwich.related_beverages.connect(self.soda)
206 206
         self.sandwich.related_beverages.connect(self.milk)
207  
-        
  207
+
208 208
         related = self.sandwich.related_beverages.all()
209 209
         self.assertRelatedEqual(related, (
210 210
             (self.sandwich, self.milk),
211 211
             (self.sandwich, self.soda),
212 212
         ), 'food', 'beverage')
213  
-        
  213
+
214 214
         related = self.cereal.related_beverages.all()
215 215
         self.assertRelatedEqual(related, ())
216 216
 
@@ -221,32 +221,32 @@ def test_custom_model_manager(self):
221 221
         """
222 222
         self.pizza.related_beverages.connect(self.soda)
223 223
         rel_obj = RelatedBeverage.objects.all()[0] # grab the new related obj
224  
-        
  224
+
225 225
         self.cereal.related_beverages.connect(self.milk)
226  
-        
  226
+
227 227
         new_rel_obj = RelatedBeverage(beverage=self.beer)
228  
-        
  228
+
229 229
         self.pizza.related_beverages.add(new_rel_obj)
230 230
         self.assertRelatedEqual(self.pizza.related_beverages.all(), (
231 231
             (self.pizza, self.beer),
232 232
             (self.pizza, self.soda),
233 233
         ), 'food', 'beverage')
234  
-        
  234
+
235 235
         self.pizza.related_beverages.remove(rel_obj)
236 236
         self.assertRelatedEqual(self.pizza.related_beverages.all(), (
237 237
             (self.pizza, self.beer),
238 238
         ), 'food', 'beverage')
239  
-        
  239
+
240 240
         self.pizza.related_beverages.clear()
241 241
         self.assertRelatedEqual(self.pizza.related_beverages.all(), ())
242  
-        
  242
+
243 243
         # make sure clearing the pizza objects didn't affect cereal
244 244
         self.assertRelatedEqual(self.cereal.related_beverages.all(), (
245 245
             (self.cereal, self.milk),
246 246
         ), 'food', 'beverage')
247  
-        
  247
+
248 248
         self.assertEqual(RelatedBeverage.objects.count(), 1)
249  
-    
  249
+
250 250
     def test_custom_model_level(self):
251 251
         """
252 252
         And lastly, test that the custom descriptor/through-model work as
@@ -256,14 +256,14 @@ def test_custom_model_level(self):
256 256
         self.pizza.related_beverages.connect(self.beer)
257 257
         self.sandwich.related_beverages.connect(self.soda)
258 258
         self.cereal.related_beverages.connect(self.milk)
259  
-        
  259
+
260 260
         self.assertRelatedEqual(Food.related_beverages.all(), (
261 261
             (self.cereal, self.milk),
262 262
             (self.sandwich, self.soda),
263 263
             (self.pizza, self.beer),
264 264
             (self.pizza, self.soda),
265 265
         ), 'food', 'beverage')
266  
-    
  266
+
267 267
     def test_generic_traversal(self):
268 268
         """
269 269
         Ensure that the RelatedObjectsDescriptor returns a GFKOptimizedQuerySet
@@ -273,18 +273,30 @@ def test_generic_traversal(self):
273 273
         self.pizza.related.connect(self.beer)
274 274
         self.pizza.related.connect(self.soda)
275 275
         self.pizza.related.connect(self.mario)
276  
-        
  276
+
277 277
         # the manager returns instances of GFKOptimizedQuerySet
278 278
         related = self.pizza.related.all()
279 279
         self.assertEqual(type(related), GFKOptimizedQuerySet)
280  
-        
  280
+
281 281
         # check the queryset is using the right field
282 282
         self.assertEqual(related.get_gfk().name, 'object')
283  
-        
  283
+
284 284
         # the custom queryset's optimized lookup works correctly
285 285
         objects = related.generic_objects()
286 286
         self.assertEqual(objects, [self.mario, self.soda, self.beer])
287  
-    
  287
+
  288
+        # check the reverse does not hold, documenting existing behavior since
  289
+        # it looks at only the "default" manager on the back-side
  290
+        related = self.soda.related.related_to()
  291
+        self.assertEqual(type(related), GFKOptimizedQuerySet)
  292
+
  293
+        # check the queryset is using the right field
  294
+        self.assertEqual(related.get_gfk().name, 'parent')
  295
+
  296
+        # the custom queryset's optimized lookup works correctly
  297
+        objects = related.generic_objects()
  298
+        self.assertEqual(objects, [self.pizza])
  299
+
288 300
     def test_filtering(self):
289 301
         """
290 302
         Check that filtering on RelatedObject fields (or through model fields)
@@ -293,23 +305,23 @@ def test_filtering(self):
293 305
         self.pizza.related.connect(self.beer, alias='bud lite')
294 306
         self.pizza.related.connect(self.soda, alias='pepsi')
295 307
         self.pizza.related.connect(self.mario)
296  
-        
  308
+
297 309
         rel_qs = self.pizza.related.filter(alias='bud lite')
298 310
         self.assertRelatedEqual(rel_qs, (
299 311
             (self.pizza, self.beer),
300 312
         ))
301  
-        
  313
+
302 314
         rel_qs = self.pizza.related.filter(object_type=ContentType.objects.get_for_model(Beverage))
303 315
         self.assertRelatedEqual(rel_qs, (
304 316
             (self.pizza, self.soda),
305 317
             (self.pizza, self.beer),
306 318
         ))
307  
-        
  319
+
308 320
         rel_qs = self.beer.related.related_to().filter(alias='bud lite')
309 321
         self.assertRelatedEqual(rel_qs, (
310 322
             (self.pizza, self.beer),
311 323
         ))
312  
-        
  324
+
313 325
     def test_custom_model_using_gfks(self):
314 326
         """
315 327
         Check that using a custom through model with GFKs works as expected
@@ -318,57 +330,57 @@ def test_custom_model_using_gfks(self):
318 330
         self.note_a = Note.objects.create(content='a')
319 331
         self.note_b = Note.objects.create(content='b')
320 332
         self.note_c = Note.objects.create(content='c')
321  
-        
  333
+
322 334
         self.note_a.related.connect(self.pizza)
323 335
         self.note_a.related.connect(self.note_b)
324  
-        
  336
+
325 337
         # create some notes with custom attributes
326 338
         self.note_b.related.connect(self.cereal, alias='cereal note', description='lucky charms!')
327 339
         self.note_b.related.connect(self.milk, alias='milk note', description='goes good with cereal')
328  
-        
  340
+
329 341
         # ensure that the queryset is using the correct model and automatically
330 342
         # determines that a GFKOptimizedQuerySet can be used
331 343
         queryset = self.note_a.related.all()
332 344
         self.assertEqual(queryset.model, AnotherRelatedObject)
333 345
         self.assertTrue(isinstance(queryset, GFKOptimizedQuerySet))
334  
-        
  346
+
335 347
         related_a = self.note_a.related.all()
336 348
         self.assertRelatedEqual(related_a, (
337 349
             (self.note_a, self.pizza),
338 350
             (self.note_a, self.note_b),
339 351
         ))
340  
-        
  352
+
341 353
         related_b = self.note_b.related.all()
342 354
         self.assertRelatedEqual(related_b, (
343 355
             (self.note_b, self.cereal),
344 356
             (self.note_b, self.milk),
345 357
         ))
346  
-        
  358
+
347 359
         cereal_rel, milk_rel = related_b
348  
-        
  360
+
349 361
         # check that the custom attributes were saved correctly
350 362
         self.assertEqual(cereal_rel.alias, 'cereal note')
351 363
         self.assertEqual(cereal_rel.description, 'lucky charms!')
352  
-        
  364
+
353 365
         self.assertEqual(milk_rel.alias, 'milk note')
354 366
         self.assertEqual(milk_rel.description, 'goes good with cereal')
355  
-        
  367
+
356 368
         # check that we can filter on fields as expected
357 369
         self.assertRelatedEqual(self.note_b.related.filter(alias='cereal note'), (
358 370
             (self.note_b, self.cereal),
359 371
         ))
360  
-        
  372
+
361 373
         related_c = self.note_c.related.all()
362 374
         self.assertRelatedEqual(related_c, ())
363  
-        
  375
+
364 376
         # lastly, check that the GFKOptimizedQuerySet returns the expected
365 377
         # results when doing the optimized lookup
366 378
         self.assertEqual(related_a.generic_objects(), [
367 379
             self.pizza, self.note_b
368 380
         ])
369  
-        
  381
+
370 382
         self.assertEqual(related_b.generic_objects(), [
371 383
             self.cereal, self.milk
372 384
         ])
373  
-        
  385
+
374 386
         self.assertEqual(related_c.generic_objects(), [])
74  genericm2m/models.py
... ...
@@ -1,6 +1,7 @@
1 1
 from django.contrib.contenttypes.generic import GenericForeignKey
2 2
 from django.contrib.contenttypes.models import ContentType
3 3
 from django.db import models
  4
+from django.db.models import Q
4 5
 from django.db.models.query import QuerySet
5 6
 
6 7
 
@@ -8,67 +9,68 @@ class GFKOptimizedQuerySet(QuerySet):
8 9
     def __init__(self, *args, **kwargs):
9 10
         # pop the gfk_field from the kwargs if its passed in explicitly
10 11
         self._gfk_field = kwargs.pop('gfk_field', None)
11  
-        
  12
+
12 13
         # call the parent class' initializer
13 14
         super(GFKOptimizedQuerySet, self).__init__(*args, **kwargs)
14  
-    
  15
+
15 16
     def _clone(self, *args, **kwargs):
16 17
         clone = super(GFKOptimizedQuerySet, self)._clone(*args, **kwargs)
17 18
         clone._gfk_field = self._gfk_field
18 19
         return clone
19  
-    
  20
+
20 21
     def get_gfk(self):
21 22
         if not self._gfk_field:
22 23
             for field in self.model._meta.virtual_fields:
23 24
                 if isinstance(field, GenericForeignKey):
24 25
                     self._gfk_field = field
25 26
                     break
26  
-        
  27
+
27 28
         return self._gfk_field
28  
-    
  29
+
29 30
     def generic_objects(self):
30 31
         clone = self._clone()
31  
-        
  32
+
32 33
         ctypes_and_fks = {}
33  
-        
  34
+
34 35
         gfk_field = self.get_gfk()
35 36
         ctype_field = '%s_id' % gfk_field.ct_field
36 37
         fk_field = gfk_field.fk_field
37  
-        
  38
+
38 39
         for obj in clone:
39 40
             ctype = ContentType.objects.get_for_id(getattr(obj, ctype_field))
40 41
             obj_id = getattr(obj, fk_field)
41  
-            
  42
+
42 43
             ctypes_and_fks.setdefault(ctype, [])
43 44
             ctypes_and_fks[ctype].append(obj_id)
44  
-        
  45
+
45 46
         gfk_objects = {}
46 47
         for ctype, obj_ids in ctypes_and_fks.items():
47 48
             gfk_objects[ctype.pk] = ctype.model_class()._default_manager.in_bulk(obj_ids)
48  
-        
  49
+
49 50
         obj_list = []
50 51
         for obj in clone:
51 52
             obj_list.append(gfk_objects[getattr(obj, ctype_field)][getattr(obj, fk_field)])
52  
-        
  53
+
53 54
         return obj_list
54 55
 
55 56
 
56 57
 class RelatedObjectsDescriptor(object):
57  
-    def __init__(self, model=None, from_field='parent', to_field='object'):
  58
+    def __init__(self, model=None, from_field='parent', to_field='object', symmetrical=False):
58 59
         self.related_model = model or RelatedObject
59 60
         self.from_field = self.get_related_model_field(from_field)
60 61
         self.to_field = self.get_related_model_field(to_field)
61  
-    
  62
+        self.symmetrical = symmetrical
  63
+
62 64
     def get_related_model_field(self, field_name):
63 65
         opts = self.related_model._meta
64 66
         for virtual_field in opts.virtual_fields:
65 67
             if virtual_field.name == field_name:
66 68
                 return virtual_field
67 69
         return opts.get_field(field_name)
68  
-    
  70
+
69 71
     def is_gfk(self, field):
70 72
         return isinstance(field, GenericForeignKey)
71  
-    
  73
+
72 74
     def get_query_for_field(self, instance, field):
73 75
         if self.is_gfk(field):
74 76
             ctype = ContentType.objects.get_for_model(instance)
@@ -78,20 +80,20 @@ def get_query_for_field(self, instance, field):
78 80
             }
79 81
         elif isinstance(instance, field.rel.to):
80 82
             return {field.name: instance}
81  
-        
  83
+
82 84
         raise TypeError('Unable to query %s with %s' % (field, instance))
83  
-    
  85
+
84 86
     def get_query_from(self, instance):
85 87
         return self.get_query_for_field(instance, self.from_field)
86  
-    
  88
+
87 89
     def get_query_to(self, instance):
88 90
         return self.get_query_for_field(instance, self.to_field)
89  
-    
  91
+
90 92
     def contribute_to_class(self, cls, name):
91 93
         self.name = name
92 94
         self.model_class = cls
93 95
         setattr(cls, self.name, self)
94  
-    
  96
+
95 97
     def __get__(self, instance, cls=None):
96 98
         if instance is None:
97 99
             return self
@@ -110,15 +112,20 @@ def delete_manager(self, instance):
110 112
         return self.create_manager(instance,
111 113
                 self.related_model._base_manager.__class__)
112 114
 
113  
-    def create_manager(self, instance, superclass):
  115
+    def create_manager(self, instance, superclass, cf_from=True):
114 116
         rel_obj = self
115  
-        core_filters = self.get_query_from(instance)
116  
-        uses_gfk = self.is_gfk(self.to_field)
  117
+        if cf_from:
  118
+            core_filters = self.get_query_from(instance)
  119
+            rel_field = self.to_field
  120
+        else:
  121
+            core_filters = self.get_query_to(instance)
  122
+            rel_field = self.from_field
  123
+        uses_gfk = self.is_gfk(rel_field)
117 124
 
118 125
         class RelatedManager(superclass):
119 126
             def get_query_set(self):
120 127
                 if uses_gfk:
121  
-                    qs = GFKOptimizedQuerySet(self.model, gfk_field=rel_obj.to_field)
  128
+                    qs = GFKOptimizedQuerySet(self.model, gfk_field=rel_field)
122 129
                     return qs.filter(**(core_filters))
123 130
                 else:
124 131
                     return superclass.get_query_set(self).filter(**(core_filters))
@@ -155,14 +162,15 @@ def remove(self, *objs):
155 162
             def clear(self):
156 163
                 self.all().delete()
157 164
             clear.alters_data = True
158  
-            
  165
+
159 166
             def connect(self, obj, **kwargs):
160 167
                 kwargs.update(rel_obj.get_query_to(obj))
161 168
                 connection, created = self.get_or_create(**kwargs)
162 169
                 return connection
163  
-            
  170
+
164 171
             def related_to(self):
165  
-                return rel_obj.related_model._default_manager.filter(
  172
+                mgr = rel_obj.create_manager(instance, superclass, False)
  173
+                return mgr.filter(
166 174
                     **rel_obj.get_query_to(instance)
167 175
                 )
168 176
 
@@ -171,7 +179,7 @@ def related_to(self):
171 179
         manager.model = self.related_model
172 180
 
173 181
         return manager
174  
-    
  182
+
175 183
     def all(self):
176 184
         if self.is_gfk(self.from_field):
177 185
             ctype = ContentType.objects.get_for_model(self.model_class)
@@ -195,10 +203,10 @@ class BaseGFKRelatedObject(models.Model):
195 203
     object_type = models.ForeignKey(ContentType, related_name="related_%(class)s")
196 204
     object_id = models.IntegerField(db_index=True)
197 205
     object = GenericForeignKey(ct_field="object_type", fk_field="object_id")
198  
-    
  206
+
199 207
     class Meta:
200 208
         abstract = True
201  
-    
  209
+
202 210
 
203 211
 class RelatedObject(BaseGFKRelatedObject):
204 212
     """
@@ -208,9 +216,9 @@ class RelatedObject(BaseGFKRelatedObject):
208 216
     """
209 217
     alias = models.CharField(max_length=255, blank=True)
210 218
     creation_date = models.DateTimeField(auto_now_add=True)
211  
-    
  219
+
212 220
     class Meta:
213 221
         ordering = ('-creation_date',)
214  
-    
  222
+
215 223
     def __unicode__(self):
216 224
         return '%s related to %s ("%s")' % (self.parent, self.object, self.alias)

0 notes on commit a6657a9

Please sign in to comment.
Something went wrong with that request. Please try again.