diff --git a/tenancy/auth/backends.py b/tenancy/auth/backends.py index bb18fd4..b995484 100644 --- a/tenancy/auth/backends.py +++ b/tenancy/auth/backends.py @@ -29,7 +29,7 @@ def __init__(self): "`tenancy.middleware.GlobalTenantMiddleware` does " "just that." % attr_name ) - self.tenant_user_model = user_model.for_tenant(tenant) + self.tenant_user_model = tenant.models[user_model] def authenticate(self, username=None, password=None, **kwargs): if username is None: diff --git a/tenancy/forms.py b/tenancy/forms.py index e7cefb1..75bcf7d 100644 --- a/tenancy/forms.py +++ b/tenancy/forms.py @@ -10,7 +10,7 @@ def _get_tenant_model(tenant, model): raise ImproperlyConfigured( "%s must be an instance of TenantModelBase" % model.__name__ ) - return model.for_tenant(tenant) + return tenant.models[model] def tenant_modelform_factory(tenant, form): diff --git a/tenancy/management/commands/createsuperuser.py b/tenancy/management/commands/createsuperuser.py index 59a8752..f019344 100644 --- a/tenancy/management/commands/createsuperuser.py +++ b/tenancy/management/commands/createsuperuser.py @@ -31,7 +31,7 @@ def __init__(self): def handle(self, *args, **kwargs): tenant = kwargs.get('tenant') if tenant: - self.UserModel = self.UserModel.for_tenant(tenant) + self.UserModel = tenant.models[self.UserModel] elif self.tenant_auth_user_model: raise CommandError( "Since your swapped `AUTH_USER_MODEL` is tenant specific " diff --git a/tenancy/models.py b/tenancy/models.py index 0eca6af..f3499ea 100644 --- a/tenancy/models.py +++ b/tenancy/models.py @@ -18,7 +18,7 @@ from django.db.models.fields.related import add_lazy_relation from django.db.models.loading import get_model from django.dispatch.dispatcher import receiver -from django.utils.six import with_metaclass, string_types +from django.utils.six import itervalues, string_types, with_metaclass from django.utils.six.moves import copyreg from . import get_tenant_model @@ -29,7 +29,23 @@ receivers_for_model, remove_from_app_cache) -class TenantModelsCache(object): +class TenantModels(object): + __slots__ = ['references'] + + def __init__(self, tenant): + self.references = OrderedDict(( + (reference, reference.for_tenant(tenant)) + for reference in TenantModelBase.references + )) + + def __getitem__(self, key): + return self.references[key] + + def __iter__(self, **kwargs): + return itervalues(self.references, **kwargs) + + +class TenantModelsDescriptor(object): def contribute_to_class(self, cls, name): self.name = name setattr(cls, name, self) @@ -43,10 +59,7 @@ def __get__(self, instance, owner): try: models = instance.__dict__[self.name] except KeyError: - models = tuple( - reference.for_tenant(instance) - for reference in TenantModelBase.references - ) + models = TenantModels(instance) self.__set__(instance, models) return models @@ -86,7 +99,7 @@ def delete(self, *args, **kwargs): def natural_key(self): raise NotImplementedError - models = TenantModelsCache() + models = TenantModelsDescriptor() @contextmanager def as_global(self): @@ -142,7 +155,7 @@ def db_schema_table(tenant, db_table): class Reference(object): - __slots__ = ('model', 'bases', 'Meta', 'related_names') + __slots__ = ['model', 'bases', 'Meta', 'related_names'] def __init__(self, model, Meta, related_names=None): self.model = model @@ -551,7 +564,7 @@ def __pickle_tenant_model_base(model): class TenantModelDescriptor(object): - __slots__ = ('model',) + __slots__ = ['model'] def __init__(self, model): self.model = model @@ -559,7 +572,7 @@ def __init__(self, model): def __get__(self, tenant, owner): if not tenant: return self - return self.model.for_tenant(tenant)._default_manager + return tenant.models[self.model]._default_manager class TenantModel(with_metaclass(TenantModelBase, models.Model)): diff --git a/tenancy/tests/test_commands.py b/tenancy/tests/test_commands.py index 3457ff7..1ae5ad2 100644 --- a/tenancy/tests/test_commands.py +++ b/tenancy/tests/test_commands.py @@ -55,13 +55,15 @@ def test_verbosity(self): tenant = Tenant.objects.get(name='tenant') stdout.seek(0) connection = connections[tenant._state.db] - if connection.vendor == 'postgresql': - self.assertIn(tenant.db_schema, stdout.readline()) - for model in TenantModelBase.references: - self.assertIn(model._meta.object_name, stdout.readline()) - self.assertIn(model._meta.db_table, stdout.readline()) - self.assertIn('Installing indexes ...', stdout.readline()) - tenant.delete() + try: + if connection.vendor == 'postgresql': + self.assertIn(tenant.db_schema, stdout.readline()) + for model in TenantModelBase.references: + self.assertIn(model._meta.object_name, stdout.readline()) + self.assertIn(model._meta.db_table, stdout.readline()) + self.assertIn('Installing indexes ...', stdout.readline()) + finally: + tenant.delete() @setup_custom_tenant_user @mock_inputs(( diff --git a/tenancy/tests/test_models.py b/tenancy/tests/test_models.py index c5f54d1..2f94605 100644 --- a/tenancy/tests/test_models.py +++ b/tenancy/tests/test_models.py @@ -108,12 +108,21 @@ def test_model_garbage_collection(self): self.assertIsNone(model_wref()) -class TenantModelsCacheTest(TenancyTestCase): - def test_initialized_models(self): - """ - Make sure models are loaded upon model initialization. - """ - self.assertIn('models', self.tenant.__dict__) +class TenantModelsDescriptorTest(TenancyTestCase): + def setUp(self): + super(TenantModelsDescriptorTest, self).setUp() + Tenant.objects.clear_cache() + + def test_uncached_upon_tenant_initialization(self): + """Make sure models are not created upon model initialization.""" + tenant = Tenant.objects.get(pk=self.tenant.pk) + self.assertNotIn('models', tenant.__dict__) + + def test_created_models_upon_cache_access(self): + """Make sure all tenant models are created upon cache access.""" + tenant = Tenant.objects.get(pk=self.tenant.pk) + for reference in TenantModelBase.references: + tenant.models[reference] class TenantModelBaseTest(TenancyTestCase): diff --git a/tenancy/views.py b/tenancy/views.py index 1cfce7d..885f9fe 100644 --- a/tenancy/views.py +++ b/tenancy/views.py @@ -52,7 +52,8 @@ def get_model(self): ) def get_tenant_model(self): - return self.get_model().for_tenant(self.get_tenant()) + tenant = self.get_tenant() + return tenant.models[self.get_model()] def get_queryset(self): return self.get_tenant_model()._default_manager.all()