Skip to content

Commit

Permalink
Added a AddField operation.
Browse files Browse the repository at this point in the history
  • Loading branch information
charettes committed Jan 8, 2016
1 parent a64d758 commit 692e664
Show file tree
Hide file tree
Showing 5 changed files with 161 additions and 3 deletions.
18 changes: 16 additions & 2 deletions tenancy/operations.py
Expand Up @@ -46,10 +46,14 @@ def tenant_operation(self, tenant_model, operation, app_label, schema_editor, fr
tenant_from_state = self.create_tenant_project_state(tenant, from_state, connection)
tenant_to_state = self.create_tenant_project_state(tenant, to_state, connection)
if connection.vendor == 'postgresql':
cursor.execute("SET search_path = %s, public" % tenant.db_schema)
sql = "SET search_path = %s, public" % tenant.db_schema
cursor.execute(sql)
schema_editor.deferred_sql.append(sql)
operation(app_label, schema_editor, tenant_from_state, tenant_to_state)
if connection.vendor == 'postgresql':
cursor.execute('RESET search_path')
sql = 'RESET search_path'
cursor.execute(sql)
schema_editor.deferred_sql.append(sql)

def database_forwards(self, app_label, schema_editor, from_state, to_state):
tenant_model = self.get_tenant_model(app_label, from_state, to_state)
Expand Down Expand Up @@ -90,3 +94,13 @@ class AlterUniqueTogether(TenantModelOperation, operations.AlterUniqueTogether):

class AlterIndexTogether(TenantModelOperation, operations.AlterIndexTogether):
database_backwards = operations.AlterIndexTogether.database_backwards


class TenantModelFieldOperation(TenantModelOperation):
def get_operation_model_state(self, app_label, from_state, to_state):
# XXX: Use self.model_name_lower when dropping support for Django 1.7
return from_state.models[app_label, self.model_name.lower()]


class AddFied(TenantModelFieldOperation, operations.AddField):
pass
103 changes: 102 additions & 1 deletion tests/test_operations.py
Expand Up @@ -29,9 +29,30 @@ def get_tenant_table_names(self, tenant):
cursor.execute('RESET search_path')
return table_names

def get_tenant_table_columns(self, tenant, table_name):
cursor = connection.cursor()
if connection.vendor == 'postgresql':
cursor.execute("SET search_path = %s" % tenant.db_schema)
tenant_table_name = self.get_tenant_table_name(tenant, table_name)
columns = connection.introspection.get_table_description(cursor, tenant_table_name)
if connection.vendor == 'postgresql':
cursor.execute('RESET search_path')
return columns

def get_tenant_table_constraints(self, tenant, table_name):
cursor = connection.cursor()
if connection.vendor == 'postgresql':
cursor.execute("SET search_path = %s" % tenant.db_schema)
tenant_table_name = self.get_tenant_table_name(tenant, table_name)
return connection.introspection.get_constraints(connection.cursor(), tenant_table_name)
constraints = connection.introspection.get_constraints(cursor, tenant_table_name)
if connection.vendor == 'postgresql':
cursor.execute('RESET search_path')
return constraints

def get_column_constraints(self, constraints, column):
return {
name: details for name, details in constraints.items() if details['columns'] == [column]
}

def assertTenantTableExists(self, tenant, table_name):
table_name = self.get_tenant_table_name(tenant, table_name)
Expand Down Expand Up @@ -158,3 +179,83 @@ def test_alter_index_together(self):
break
else:
self.fail('Missing index.')

@override_settings(MIGRATION_MODULES={'tests': 'tests.test_operations_migrations.add_field'})
def test_add_field(self):
call_command('migrate', 'tests', '0001', interactive=False, stdout=StringIO())
for tenant in Tenant.objects.all():
self.assertTenantTableExists(tenant, 'tests_addfield')
self.assertEqual(len(self.get_tenant_table_columns(tenant, 'tests_addfield')), 1)
call_command('migrate', 'tests', interactive=False, stdout=StringIO())
for tenant in Tenant.objects.all():
columns = self.get_tenant_table_columns(tenant, 'tests_addfield')
constraints = self.get_tenant_table_constraints(tenant, 'tests_addfield')
self.assertEqual(columns[1][0], 'charfield')
self.assertEqual(self.get_column_constraints(constraints, 'charfield').values(), [{
'index': connection.vendor != 'postgresql',
'primary_key': False,
# The get_constraints() method doesn't correctly set `foreign_key`
# to `False` on PostgreSQL.
'foreign_key': None if connection.vendor == 'postgresql' else False,
'unique': True,
'check': False,
'columns': ['charfield'],
}])
self.assertEqual(columns[2][0], 'textfield')
self.assertEqual(self.get_column_constraints(constraints, 'textfield').values(), [{
'index': True,
'primary_key': False,
# The get_constraints() method doesn't correctly set `foreign_key`
# to `False` on PostgreSQL.
'foreign_key': None if connection.vendor == 'postgresql' else False,
'unique': False,
'check': False,
'columns': ['textfield'],
}])
self.assertEqual(columns[3][0], 'positiveintegerfield')
if connection.vendor == 'postgresql':
self.assertEqual(self.get_column_constraints(constraints, 'positiveintegerfield').values(), [{
'index': False,
'primary_key': False,
# The get_constraints() method doesn't correctly set `foreign_key`
# to `False` on PostgreSQL.
'foreign_key': None if connection.vendor == 'postgresql' else False,
'unique': False,
'check': True,
'columns': ['positiveintegerfield'],
}])
self.assertEqual(columns[4][0], 'foreign_key_id')
foreign_key_constraints = self.get_column_constraints(constraints, 'foreign_key_id').values()
expected_index = {
'index': True,
'primary_key': False,
# The get_constraints() method doesn't correctly set `foreign_key`
# to `False` on PostgreSQL.
'foreign_key': None if connection.vendor == 'postgresql' else False,
'unique': False,
'check': False,
'columns': ['foreign_key_id'],
}
for constraint in foreign_key_constraints:
if constraint == expected_index:
break
else:
self.fail('Missing fk index.')
if connection.vendor == 'postgresql':
expected_fk = {
'index': False,
'primary_key': False,
'foreign_key': ('tests_addfield', 'id'),
'unique': False,
'check': False,
'columns': ['foreign_key_id'],
}
for constraint in foreign_key_constraints:
if constraint == expected_fk:
break
else:
self.fail('Missing fk.')
call_command('migrate', 'tests', '0001', interactive=False, stdout=StringIO())
for tenant in Tenant.objects.all():
self.assertTenantTableExists(tenant, 'tests_addfield')
self.assertEqual(len(self.get_tenant_table_columns(tenant, 'tests_addfield')), 1)
23 changes: 23 additions & 0 deletions tests/test_operations_migrations/add_field/0001_create_model.py
@@ -0,0 +1,23 @@
# -*- coding: utf-8 -*-
from __future__ import unicode_literals

from django.db import migrations, models

from tenancy.models import Managed
from tenancy.operations import CreateModel


class Migration(migrations.Migration):

operations = [
CreateModel(
name='AddField',
fields=[
('id', models.AutoField(verbose_name='ID', serialize=False, auto_created=True, primary_key=True)),
],
bases=(models.Model,),
options={
'managed': Managed('tenancy.Tenant'),
}
),
]
20 changes: 20 additions & 0 deletions tests/test_operations_migrations/add_field/0002_add_fields.py
@@ -0,0 +1,20 @@
# -*- coding: utf-8 -*-
from __future__ import unicode_literals

from django.db import migrations, models

from tenancy.operations import AddFied


class Migration(migrations.Migration):

dependencies = [
('tests', '0001_create_model'),
]

operations = [
AddFied('AddField', 'charfield', models.CharField(max_length=100, default='', unique=True)),
AddFied('AddField', 'textfield', models.TextField(db_index=True)),
AddFied('AddField', 'positiveintegerfield', models.PositiveIntegerField(default=0)),
AddFied('AddField', 'foreign_key', models.ForeignKey('tests.AddField', null=True, on_delete=models.CASCADE)),
]
Empty file.

0 comments on commit 692e664

Please sign in to comment.