Skip to content
This repository

HTTPS clone URL

Subversion checkout URL

You can clone with HTTPS or Subversion.

Download ZIP
Browse code

[multi-db] Added preliminary drop-table generation to django.db.backe…

…nds.ansi.sql.SchemaBuilder.

git-svn-id: http://code.djangoproject.com/svn/django/branches/multiple-db-support@3320 bcc190cf-cafb-0310-a4f2-bffc1f526a37
commit c0132e88f6cdb6dbaa843e9075ae97419b8709d9 1 parent 3e56234
JP authored July 11, 2006
77  django/db/backends/ansi/sql.py
@@ -43,8 +43,13 @@ class SchemaBuilder(object):
43 43
     or other constraints.
44 44
     """
45 45
     def __init__(self):
  46
+        # models that I have created
46 47
         self.models_already_seen = set()
47  
-
  48
+        # model references, keyed by the referrent model
  49
+        self.references = {}
  50
+        # table cache; set to short-circuit table lookups
  51
+        self.tables = None
  52
+        
48 53
     def get_create_table(self, model, style=None):
49 54
         """Construct and return the SQL expression(s) needed to create the
50 55
         table for the given model, and any constraints on that
@@ -218,8 +223,8 @@ def get_create_many_to_many(self, model, style=None):
218 223
     def get_drop_table(self, model, cascade=False, style=None):
219 224
         """Construct and return the SQL statment(s) needed to drop a model's
220 225
         table. If cascade is true, then output additional statments to drop any
221  
-        dependant man-many tables and drop any foreign keys that reference
222  
-        this table.
  226
+        many-to-many tables that this table created and any foreign keys that
  227
+        reference this table.
223 228
         """
224 229
         if style is None:
225 230
             style = default_style
@@ -227,16 +232,45 @@ def get_drop_table(self, model, cascade=False, style=None):
227 232
         info = opts.connection_info
228 233
         db_table = opts.db_table
229 234
         backend = info.backend
  235
+        qn = backend.quote_name
230 236
         output = []
231 237
         output.append(BoundStatement(
232 238
                 '%s %s;' % (style.SQL_KEYWORD('DROP TABLE'),
233  
-                            style.SQL_TABLE(backend.quote_name(db_table))),
  239
+                            style.SQL_TABLE(qn(db_table))),
234 240
                 info.connection))
235 241
 
236 242
         if cascade:
237  
-            # FIXME deal with my foreign keys, others that might have a foreign
238  
-            # key TO me, and many-many
239  
-            pass
  243
+            # deal with others that might have a foreign key TO me: alter
  244
+            # their tables to drop the constraint
  245
+            if backend.supports_constraints:
  246
+                references_to_delete = self.get_references()
  247
+                if model in references_to_delete:
  248
+                    for rel_class, f in references_to_delete[model]:
  249
+                        table = rel_class._meta.db_table
  250
+                        if not self.table_exists(info, table):
  251
+                            continue
  252
+                        col = f.column
  253
+                        r_table = opts.db_table
  254
+                        r_col = opts.get_field(f.rel.field_name).column
  255
+                        output.append(BoundStatement(
  256
+                            '%s %s %s %s;' % 
  257
+                            (style.SQL_KEYWORD('ALTER TABLE'),
  258
+                             style.SQL_TABLE(qn(table)),
  259
+                             style.SQL_KEYWORD(
  260
+                                        backend.get_drop_foreignkey_sql()),
  261
+                             style.SQL_FIELD(qn("%s_referencing_%s_%s" %
  262
+                                                (col, r_table, r_col)))),
  263
+                            info.connection))
  264
+                    del references_to_delete[model]
  265
+            # many to many: drop any many-many tables that are my
  266
+            # responsiblity
  267
+            for f in opts.many_to_many:
  268
+                if not isinstance(f.rel, models.GenericRel):
  269
+                    output.append(BoundStatement(
  270
+                            '%s %s;' %
  271
+                            (style.SQL_KEYWORD('DROP TABLE'),
  272
+                             style.SQL_TABLE(qn(f.m2m_db_table()))),
  273
+                            info.connection))
240 274
         # Reverse it, to deal with table dependencies.        
241 275
         output.reverse()
242 276
         return output
@@ -273,11 +307,36 @@ def get_initialdata(self, model):
273 307
     def get_initialdata_path(self, model):
274 308
         """Get the path from which to load sql initial data files for a model.
275 309
         """
276  
-        return os.path.normpath(os.path.join(os.path.dirname(models.get_app(model._meta.app_label).__file__), 'sql'))
  310
+        return os.path.normpath(os.path.join(os.path.dirname(
  311
+                    models.get_app(model._meta.app_label).__file__), 'sql'))
277 312
             
278 313
     def get_rel_data_type(self, f):
279 314
         return (f.get_internal_type() in ('AutoField', 'PositiveIntegerField',
280 315
                                           'PositiveSmallIntegerField')) \
281 316
                                           and 'IntegerField' \
282 317
                                           or f.get_internal_type()
283  
-        
  318
+    
  319
+    def get_references(self):
  320
+        """Fill (if needed) and return the reference cache.
  321
+        """
  322
+        if self.references:
  323
+            return self.references
  324
+        for klass in models.get_models():
  325
+            for f in klass._meta.fields:
  326
+                if f.rel:
  327
+                    self.references.setdefault(f.rel.to, []).append((klass, f))
  328
+        return self.references
  329
+
  330
+    def get_table_list(self, connection_info):
  331
+        """Get list of tables accessible via the connection described by
  332
+        connection_info.
  333
+        """
  334
+        if self.tables is not None:
  335
+            return self.tables
  336
+        cursor = info.connection.cursor()
  337
+        introspection = connection_info.get_introspection_module()
  338
+        return introspection.get_table_list(cursor)        
  339
+    
  340
+    def table_exists(self, connection_info, table):
  341
+        tables = self.get_table_list(connection_info)
  342
+        return table in tables
5  django/db/models/manager.py
@@ -159,9 +159,8 @@ def get_table_list(self):
159 159
         """Get list of tables accessible via my model's connection.
160 160
         """
161 161
         info = self.model._meta.connection_info
162  
-        cursor = info.connection.cursor()
163  
-        introspect = info.get_introspection_module()
164  
-        return introspect.get_table_list(cursor)
  162
+        builder = info.get_creation_module.builder()
  163
+        return builder.get_table_list(info)
165 164
     
166 165
 class ManagerDescriptor(object):
167 166
     # This class ensures managers aren't accessible via model instances.
109  tests/othertests/ansi_sql.py
... ...
@@ -1,52 +1,24 @@
1  
-# For Python 2.3
2  
-if not hasattr(__builtins__, 'set'):
3  
-    from sets import Set as set
4  
-
5 1
 """
6  
->>> from django.db import models
7 2
 >>> from django.db.backends.ansi import sql
8 3
 
9  
-# test models
10  
->>> class Car(models.Model):
11  
-...     make = models.CharField(maxlength=32)
12  
-...     model = models.CharField(maxlength=32)
13  
-...     year = models.IntegerField()
14  
-...     condition = models.CharField(maxlength=32)
15  
-...     
16  
-...     class Meta:
17  
-...         app_label = 'ansi_sql'
18  
-
19  
->>> class Collector(models.Model):
20  
-...     name = models.CharField(maxlength=32)
21  
-...     cars = models.ManyToManyField(Car)
22  
-...     
23  
-...     class Meta:
24  
-...         app_label = 'ansi_sql'
25  
-
26  
->>> class Mod(models.Model):
27  
-...     car = models.ForeignKey(Car)
28  
-...     part = models.CharField(maxlength=32, db_index=True)
29  
-...     description = models.TextField()
30  
-...     
31  
-...     class Meta:
32  
-...         app_label = 'ansi_sql'
  4
+# so we can test with a predicatable constraint setting
  5
+>>> real_cnst = Mod._meta.connection_info.backend.supports_constraints
  6
+>>> Mod._meta.connection_info.backend.supports_constraints = True
33 7
     
34 8
 # generate create sql
35 9
 >>> builder = sql.SchemaBuilder()
36 10
 >>> builder.get_create_table(Car)
37  
-([BoundStatement('CREATE TABLE "ansi_sql_car" (...);')], [])
  11
+([BoundStatement('CREATE TABLE "ansi_sql_car" (...);')], {})
38 12
 >>> builder.models_already_seen
39  
-[<class 'othertests.ansi_sql.Car'>]
  13
+Set([<class 'othertests.ansi_sql.Car'>])
40 14
 >>> builder.models_already_seen = set()
41 15
 
42 16
 # test that styles are used
43 17
 >>> builder.get_create_table(Car, style=mockstyle())
44  
-([BoundStatement('SQL_KEYWORD(CREATE TABLE) SQL_TABLE("ansi_sql_car") (...SQL_FIELD("id")...);')], [])
  18
+([BoundStatement('SQL_KEYWORD(CREATE TABLE) SQL_TABLE("ansi_sql_car") (...SQL_FIELD("id")...);')], {})
45 19
 
46 20
 # test pending relationships
47 21
 >>> builder.models_already_seen = set()
48  
->>> real_cnst = Mod._meta.connection_info.backend.supports_constraints
49  
->>> Mod._meta.connection_info.backend.supports_constraints = True
50 22
 >>> builder.get_create_table(Mod)
51 23
 ([BoundStatement('CREATE TABLE "ansi_sql_mod" (..."car_id" integer NOT NULL,...);')], {<class 'othertests.ansi_sql.Car'>: [BoundStatement('ALTER TABLE "ansi_sql_mod" ADD CONSTRAINT ... FOREIGN KEY ("car_id") REFERENCES "ansi_sql_car" ("id");')]})
52 24
 >>> builder.models_already_seen = set()
@@ -54,7 +26,6 @@
54 26
 ([BoundStatement('CREATE TABLE "ansi_sql_car" (...);')], {})
55 27
 >>> builder.get_create_table(Mod)
56 28
 ([BoundStatement('CREATE TABLE "ansi_sql_mod" (..."car_id" integer NOT NULL REFERENCES "ansi_sql_car" ("id"),...);')], {})
57  
->>> Mod._meta.connection_info.backend.supports_constraints = real_cnst
58 29
 
59 30
 # test many-many
60 31
 >>> builder.get_create_table(Collector)
@@ -75,16 +46,82 @@
75 46
 >>> builder.get_initialdata_path = othertests_sql
76 47
 >>> builder.get_initialdata(Car)
77 48
 [BoundStatement('insert into ansi_sql_car (...)...values (...);')]
  49
+
  50
+# test drop
  51
+>>> builder.get_drop_table(Mod)
  52
+[BoundStatement('DROP TABLE "ansi_sql_mod";')]
  53
+>>> builder.get_drop_table(Mod, cascade=True)
  54
+[BoundStatement('DROP TABLE "ansi_sql_mod";')]
  55
+>>> builder.get_drop_table(Car)
  56
+[BoundStatement('DROP TABLE "ansi_sql_car";')]
  57
+>>> builder.get_drop_table(Car, cascade=True)
  58
+[BoundStatement('DROP TABLE "ansi_sql_car";')]
  59
+
  60
+>>> builder.tables = ['ansi_sql_car', 'ansi_sql_mod', 'ansi_sql_collector']
  61
+>>> Mod._meta.connection_info.backend.supports_constraints = False
  62
+>>> builder.get_drop_table(Car, cascade=True)
  63
+[BoundStatement('DROP TABLE "ansi_sql_car";')]
  64
+>>> Mod._meta.connection_info.backend.supports_constraints = True
  65
+>>> builder.get_drop_table(Car, cascade=True)
  66
+[BoundStatement('ALTER TABLE "ansi_sql_mod" ...'), BoundStatement('DROP TABLE "ansi_sql_car";')]
  67
+>>> builder.get_drop_table(Collector)
  68
+[BoundStatement('DROP TABLE "ansi_sql_collector";')]
  69
+>>> builder.get_drop_table(Collector, cascade=True)
  70
+[BoundStatement('DROP TABLE "ansi_sql_collector_cars";'), BoundStatement('DROP TABLE "ansi_sql_collector";')]
  71
+>>> Mod._meta.connection_info.backend.supports_constraints = real_cnst
  72
+
78 73
 """
79 74
 import os
  75
+from django.db import models
  76
+from django.core.management import install
  77
+
  78
+# For Python 2.3
  79
+if not hasattr(__builtins__, 'set'):
  80
+    from sets import Set as set
  81
+
  82
+
  83
+# test models
  84
+class Car(models.Model):
  85
+    make = models.CharField(maxlength=32)
  86
+    model = models.CharField(maxlength=32)
  87
+    year = models.IntegerField()
  88
+    condition = models.CharField(maxlength=32)
  89
+    
  90
+    class Meta:
  91
+        app_label = 'ansi_sql'
  92
+
  93
+        
  94
+class Collector(models.Model):
  95
+    name = models.CharField(maxlength=32)
  96
+    cars = models.ManyToManyField(Car)
  97
+    
  98
+    class Meta:
  99
+        app_label = 'ansi_sql'
  100
+
  101
+        
  102
+class Mod(models.Model):
  103
+    car = models.ForeignKey(Car)
  104
+    part = models.CharField(maxlength=32, db_index=True)
  105
+    description = models.TextField()
  106
+    
  107
+    class Meta:
  108
+        app_label = 'ansi_sql'
  109
+
80 110
 
81  
-# mock style that wraps text in STYLE(text), for testing
82 111
 class mockstyle:
  112
+    """mock style that wraps text in STYLE(text), for testing"""
83 113
     def __getattr__(self, attr):
84 114
         if attr in ('ERROR', 'ERROR_OUTPUT', 'SQL_FIELD', 'SQL_COLTYPE',
85 115
                     'SQL_KEYWORD', 'SQL_TABLE'):
86 116
             return lambda text: "%s(%s)" % (attr, text)
87 117
 
  118
+        
88 119
 def othertests_sql(mod):
89 120
     """Look in othertests/sql for sql initialdata"""
90 121
     return os.path.normpath(os.path.join(os.path.dirname(__file__), 'sql'))
  122
+
  123
+
  124
+# install my stuff
  125
+Car.objects.install()
  126
+Collector.objects.install()
  127
+Mod.objects.install()

0 notes on commit c0132e8

Please sign in to comment.
Something went wrong with that request. Please try again.