diff --git a/django_tenants/postgresql_backend/base.py b/django_tenants/postgresql_backend/base.py index 69f0f67b..c833e537 100644 --- a/django_tenants/postgresql_backend/base.py +++ b/django_tenants/postgresql_backend/base.py @@ -74,17 +74,6 @@ def set_tenant(self, tenant, include_public=True): self.include_public_schema = include_public self.set_settings_schema(self.schema_name) self.search_path_set = False - - def set_schema(self, schema_name, include_public=True): - """ - Main API method to current database schema, - but it does not actually modify the db connection. - """ - self.tenant = FakeTenant(schema_name=schema_name) - self.schema_name = schema_name - self.include_public_schema = include_public - self.set_settings_schema(schema_name) - self.search_path_set = False # Content type can no longer be cached as public and tenant schemas # have different models. If someone wants to change this, the cache # needs to be separated between public and shared schemas. If this @@ -94,14 +83,18 @@ def set_schema(self, schema_name, include_public=True): # wrong model will be fetched. ContentType.objects.clear_cache() + def set_schema(self, schema_name, include_public=True): + """ + Main API method to current database schema, + but it does not actually modify the db connection. + """ + self.set_tenant(FakeTenant(schema_name=schema_name), include_public) + def set_schema_to_public(self): """ Instructs to stay in the common 'public' schema. """ - self.tenant = FakeTenant(schema_name=get_public_schema_name()) - self.schema_name = get_public_schema_name() - self.set_settings_schema(self.schema_name) - self.search_path_set = False + self.set_tenant(FakeTenant(schema_name=get_public_schema_name())) def set_settings_schema(self, schema_name): self.settings_dict['SCHEMA'] = schema_name diff --git a/django_tenants/utils.py b/django_tenants/utils.py index 31c56424..65f937e3 100644 --- a/django_tenants/utils.py +++ b/django_tenants/utils.py @@ -81,21 +81,9 @@ def __exit__(self, *exc): self.connection.set_tenant(self.previous_tenant) -class tenant_context(ContextDecorator): +class tenant_context(schema_context): def __init__(self, *args, **kwargs): - self.tenant = args[0] - super().__init__() - - def __enter__(self): - self.connection = connections[get_tenant_database_alias()] - self.previous_tenant = connection.tenant - self.connection.set_tenant(self.tenant) - - def __exit__(self, *exc): - if self.previous_tenant is None: - self.connection.set_schema_to_public() - else: - self.connection.set_tenant(self.previous_tenant) + super().__init__(args[0].schema_name, **kwargs) def clean_tenant_url(url_string):