Skip to content

Commit

Permalink
Correctly handle virtual related fields.
Browse files Browse the repository at this point in the history
  • Loading branch information
charettes committed Oct 21, 2013
1 parent dce7e00 commit 9e0fa5a
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 17 deletions.
36 changes: 22 additions & 14 deletions tenancy/models.py
Expand Up @@ -231,14 +231,17 @@ def __new__(cls, name, bases, attrs):
cls.references[model] = cls.reference(model, Meta, related_names)
opts = model._meta
# Validate related name of related fields.
for field in opts.local_fields:
if field.rel:
for field in (opts.local_fields + opts.virtual_fields):
rel = getattr(field, 'rel', None)
if rel:
cls.validate_related_name(field, field.rel.to, model)
# Replace and store the current `on_delete` value to
# make sure non-tenant models are not collected on
# deletion.
field.rel._on_delete = field.rel.on_delete
field.rel.on_delete = DO_NOTHING
on_delete = rel.on_delete
if on_delete is not DO_NOTHING:
rel._on_delete = on_delete
rel.on_delete = DO_NOTHING
for m2m in opts.local_many_to_many:
rel = m2m.rel
to = rel.to
Expand Down Expand Up @@ -389,12 +392,16 @@ def abstract_tenant_model_factory(self, tenant):
local_ptr = self._meta.parents[parent._for_tenant_model]
ptr.name = None
ptr.set_attributes_from_name(local_ptr.name)

# Add the local fields of this class
local_fields = self._meta.local_fields + self._meta.local_many_to_many
for local_field in local_fields:
field = copy.deepcopy(local_field)
rel = field.rel
# Add copy of the fields to cloak the inherited ones.
fields = (
copy.deepcopy(field) for field in (
self._meta.local_fields +
self._meta.local_many_to_many +
self._meta.virtual_fields
)
)
for field in fields:
rel = getattr(field, 'rel', None)
if rel:
# Make sure related fields pointing to tenant models are
# pointing to their tenant specific counterpart.
Expand All @@ -419,10 +426,11 @@ def abstract_tenant_model_factory(self, tenant):
if isinstance(field, models.ManyToManyField):
through = field.rel.through
rel.through = self.references[through].for_tenant(tenant)
else:
# Re-assign the correct `on_delete` that was swapped for
# `DO_NOTHING` to prevent non-tenant model collection.
rel.on_delete = rel._on_delete
# Re-assign the correct `on_delete` that was swapped for
# `DO_NOTHING` to prevent non-tenant model collection.
on_delete = getattr(rel, '_on_delete', None)
if on_delete:
rel.on_delete = on_delete
field.contribute_to_class(model, field.name)

return model
Expand Down
14 changes: 13 additions & 1 deletion tenancy/tests/models.py
Expand Up @@ -137,7 +137,7 @@ class TenantMeta:


class M2MSpecific(TenantModel):
related = models.ForeignKey('RelatedTenantModel')
related = models.ForeignKey('RelatedTenantModel', null=True)
specific = models.ForeignKey(
SpecificModel, related_name="%(app_label)s_%(class)s_related"
)
Expand Down Expand Up @@ -230,3 +230,15 @@ class MutableModelSubclass(MutableModel):

class NonMutableModel(TenantModel):
mutable_fk = models.ForeignKey(MutableModel, related_name='non_mutables')


try:
from django.db.models.fields.related import ForeignObject
except ImportError:
pass
else:
ForeignObject(
RelatedTenantModel, ['specific'], ['fk'], related_name='+'
).contribute_to_class(
M2MSpecific, 'specific_related_fk', virtual_only=True
)
17 changes: 15 additions & 2 deletions tenancy/tests/test_models.py
Expand Up @@ -6,9 +6,9 @@
import sys
# TODO: Remove when support for Python 2.6 is dropped
if sys.version_info >= (2, 7):
from unittest import skipIf
from unittest import skipIf, skipUnless
else:
from django.utils.unittest import skipIf
from django.utils.unittest import skipIf, skipUnless
import weakref

from django.contrib.contenttypes.models import ContentType
Expand Down Expand Up @@ -525,6 +525,19 @@ def test_signals(self):
]
)

@skipUnless(hasattr(django_models, 'ForeignObject'),
'Foreign object is not present is this version of Django.')
def test_virtual_foreign_object(self):
"""
Make sure a virtual foreign object pointing to a tenant specific
model is correctly handled.
"""
for tenant in Tenant.objects.all():
specific = tenant.specificmodels.create()
related = tenant.related_tenant_models.create(fk=specific)
m2m_specific = tenant.m2m_specifics.create(specific=specific)
self.assertEqual(m2m_specific.specific_related_fk, related)


class NonTenantModelTest(TransactionTestCase):
def test_fk_to_tenant(self):
Expand Down

0 comments on commit 9e0fa5a

Please sign in to comment.