Skip to content
This repository

HTTPS clone URL

Subversion checkout URL

You can clone with HTTPS or Subversion.

Download ZIP
Browse code

Fixed #12460 -- Improved inspectdb handling of special field names

Thanks mihail lukin for the report and elijahr and kgibula for their
work on the patch.
  • Loading branch information...
commit 395c6083af93cc37c7d16ef4db451091841cefdc 1 parent 10d3207
Claude Paroz authored August 23, 2012
117  django/core/management/commands/inspectdb.py
... ...
@@ -1,3 +1,5 @@
  1
+from __future__ import unicode_literals
  2
+
1 3
 import keyword
2 4
 from optparse import make_option
3 5
 
@@ -31,6 +33,7 @@ def handle_inspection(self, options):
31 33
         table_name_filter = options.get('table_name_filter')
32 34
 
33 35
         table2model = lambda table_name: table_name.title().replace('_', '').replace(' ', '').replace('-', '')
  36
+        strip_prefix = lambda s: s.startswith("u'") and s[1:] or s
34 37
 
35 38
         cursor = connection.cursor()
36 39
         yield "# This is an auto-generated Django model module."
@@ -41,6 +44,7 @@ def handle_inspection(self, options):
41 44
         yield "#"
42 45
         yield "# Also note: You'll have to insert the output of 'django-admin.py sqlcustom [appname]'"
43 46
         yield "# into your database."
  47
+        yield "from __future__ import unicode_literals"
44 48
         yield ''
45 49
         yield 'from %s import models' % self.db_module
46 50
         yield ''
@@ -59,16 +63,19 @@ def handle_inspection(self, options):
59 63
                 indexes = connection.introspection.get_indexes(cursor, table_name)
60 64
             except NotImplementedError:
61 65
                 indexes = {}
  66
+            used_column_names = [] # Holds column names used in the table so far
62 67
             for i, row in enumerate(connection.introspection.get_table_description(cursor, table_name)):
63  
-                column_name = row[0]
64  
-                att_name = column_name.lower()
65 68
                 comment_notes = [] # Holds Field notes, to be displayed in a Python comment.
66 69
                 extra_params = {}  # Holds Field parameters such as 'db_column'.
  70
+                column_name = row[0]
  71
+                is_relation = i in relations
  72
+
  73
+                att_name, params, notes = self.normalize_col_name(
  74
+                    column_name, used_column_names, is_relation)
  75
+                extra_params.update(params)
  76
+                comment_notes.extend(notes)
67 77
 
68  
-                # If the column name can't be used verbatim as a Python
69  
-                # attribute, set the "db_column" for this Field.
70  
-                if ' ' in att_name or '-' in att_name or keyword.iskeyword(att_name) or column_name != att_name:
71  
-                    extra_params['db_column'] = column_name
  78
+                used_column_names.append(att_name)
72 79
 
73 80
                 # Add primary_key and unique, if necessary.
74 81
                 if column_name in indexes:
@@ -77,30 +84,12 @@ def handle_inspection(self, options):
77 84
                     elif indexes[column_name]['unique']:
78 85
                         extra_params['unique'] = True
79 86
 
80  
-                # Modify the field name to make it Python-compatible.
81  
-                if ' ' in att_name:
82  
-                    att_name = att_name.replace(' ', '_')
83  
-                    comment_notes.append('Field renamed to remove spaces.')
84  
-
85  
-                if '-' in att_name:
86  
-                    att_name = att_name.replace('-', '_')
87  
-                    comment_notes.append('Field renamed to remove dashes.')
88  
-
89  
-                if column_name != att_name:
90  
-                    comment_notes.append('Field name made lowercase.')
91  
-
92  
-                if i in relations:
  87
+                if is_relation:
93 88
                     rel_to = relations[i][1] == table_name and "'self'" or table2model(relations[i][1])
94  
-
95 89
                     if rel_to in known_models:
96 90
                         field_type = 'ForeignKey(%s' % rel_to
97 91
                     else:
98 92
                         field_type = "ForeignKey('%s'" % rel_to
99  
-
100  
-                    if att_name.endswith('_id'):
101  
-                        att_name = att_name[:-3]
102  
-                    else:
103  
-                        extra_params['db_column'] = column_name
104 93
                 else:
105 94
                     # Calling `get_field_type` to get the field type string and any
106 95
                     # additional paramters and notes.
@@ -110,16 +99,6 @@ def handle_inspection(self, options):
110 99
 
111 100
                     field_type += '('
112 101
 
113  
-                if keyword.iskeyword(att_name):
114  
-                    att_name += '_field'
115  
-                    comment_notes.append('Field renamed because it was a Python reserved word.')
116  
-
117  
-                if att_name[0].isdigit():
118  
-                    att_name = 'number_%s' % att_name
119  
-                    extra_params['db_column'] = six.text_type(column_name)
120  
-                    comment_notes.append("Field renamed because it wasn't a "
121  
-                        "valid Python identifier.")
122  
-
123 102
                 # Don't output 'id = meta.AutoField(primary_key=True)', because
124 103
                 # that's assumed if it doesn't exist.
125 104
                 if att_name == 'id' and field_type == 'AutoField(' and extra_params == {'primary_key': True}:
@@ -136,7 +115,9 @@ def handle_inspection(self, options):
136 115
                 if extra_params:
137 116
                     if not field_desc.endswith('('):
138 117
                         field_desc += ', '
139  
-                    field_desc += ', '.join(['%s=%r' % (k, v) for k, v in extra_params.items()])
  118
+                    field_desc += ', '.join([
  119
+                        '%s=%s' % (k, strip_prefix(repr(v)))
  120
+                        for k, v in extra_params.items()])
140 121
                 field_desc += ')'
141 122
                 if comment_notes:
142 123
                     field_desc += ' # ' + ' '.join(comment_notes)
@@ -144,6 +125,64 @@ def handle_inspection(self, options):
144 125
             for meta_line in self.get_meta(table_name):
145 126
                 yield meta_line
146 127
 
  128
+    def normalize_col_name(self, col_name, used_column_names, is_relation):
  129
+        """
  130
+        Modify the column name to make it Python-compatible as a field name
  131
+        """
  132
+        field_params = {}
  133
+        field_notes = []
  134
+
  135
+        new_name = col_name.lower()
  136
+        if new_name != col_name:
  137
+            field_notes.append('Field name made lowercase.')
  138
+
  139
+        if is_relation:
  140
+            if new_name.endswith('_id'):
  141
+                new_name = new_name[:-3]
  142
+            else:
  143
+                field_params['db_column'] = col_name
  144
+
  145
+        if ' ' in new_name:
  146
+            new_name = new_name.replace(' ', '_')
  147
+            field_notes.append('Field renamed to remove spaces.')
  148
+
  149
+        if '-' in new_name:
  150
+            new_name = new_name.replace('-', '_')
  151
+            field_notes.append('Field renamed to remove dashes.')
  152
+
  153
+        if new_name.find('__') >= 0:
  154
+            while new_name.find('__') >= 0:
  155
+                new_name = new_name.replace('__', '_')
  156
+            field_notes.append("Field renamed because it contained more than one '_' in a row.")
  157
+
  158
+        if new_name.startswith('_'):
  159
+            new_name = 'field%s' % new_name
  160
+            field_notes.append("Field renamed because it started with '_'.")
  161
+
  162
+        if new_name.endswith('_'):
  163
+            new_name = '%sfield' % new_name
  164
+            field_notes.append("Field renamed because it ended with '_'.")
  165
+
  166
+        if keyword.iskeyword(new_name):
  167
+            new_name += '_field'
  168
+            field_notes.append('Field renamed because it was a Python reserved word.')
  169
+
  170
+        if new_name[0].isdigit():
  171
+            new_name = 'number_%s' % new_name
  172
+            field_notes.append("Field renamed because it wasn't a valid Python identifier.")
  173
+
  174
+        if new_name in used_column_names:
  175
+            num = 0
  176
+            while '%s_%d' % (new_name, num) in used_column_names:
  177
+                num += 1
  178
+            new_name = '%s_%d' % (new_name, num)
  179
+            field_notes.append('Field renamed because of name conflict.')
  180
+
  181
+        if col_name != new_name and field_notes:
  182
+            field_params['db_column'] = col_name
  183
+
  184
+        return new_name, field_params, field_notes
  185
+
147 186
     def get_field_type(self, connection, table_name, row):
148 187
         """
149 188
         Given the database connection, the table name, and the cursor row
@@ -181,6 +220,6 @@ def get_meta(self, table_name):
181 220
         to construct the inner Meta class for the model corresponding
182 221
         to the given database table name.
183 222
         """
184  
-        return ['    class Meta:',
185  
-                '        db_table = %r' % table_name,
186  
-                '']
  223
+        return ["    class Meta:",
  224
+                "        db_table = '%s'" % table_name,
  225
+                ""]
6  tests/regressiontests/inspectdb/models.py
@@ -19,3 +19,9 @@ class DigitsInColumnName(models.Model):
19 19
     all_digits = models.CharField(max_length=11, db_column='123')
20 20
     leading_digit = models.CharField(max_length=11, db_column='4extra')
21 21
     leading_digits = models.CharField(max_length=11, db_column='45extra')
  22
+
  23
+class UnderscoresInColumnName(models.Model):
  24
+    field = models.IntegerField(db_column='field')
  25
+    field_field_0 = models.IntegerField(db_column='Field_')
  26
+    field_field_1 = models.IntegerField(db_column='Field__')
  27
+    field_field_2 = models.IntegerField(db_column='__field')
37  tests/regressiontests/inspectdb/tests.py
... ...
@@ -1,3 +1,5 @@
  1
+from __future__ import unicode_literals
  2
+
1 3
 from django.core.management import call_command
2 4
 from django.test import TestCase, skipUnlessDBFeature
3 5
 from django.utils.six import StringIO
@@ -17,7 +19,6 @@ def test_stealth_table_name_filter_option(self):
17 19
         # the Django test suite, check that one of its tables hasn't been
18 20
         # inspected
19 21
         self.assertNotIn("class DjangoContentType(models.Model):", out.getvalue(), msg=error_message)
20  
-        out.close()
21 22
 
22 23
     @skipUnlessDBFeature('can_introspect_foreign_keys')
23 24
     def test_attribute_name_not_python_keyword(self):
@@ -27,15 +28,16 @@ def test_attribute_name_not_python_keyword(self):
27 28
         call_command('inspectdb',
28 29
                      table_name_filter=lambda tn:tn.startswith('inspectdb_'),
29 30
                      stdout=out)
  31
+        output = out.getvalue()
30 32
         error_message = "inspectdb generated an attribute name which is a python keyword"
31  
-        self.assertNotIn("from = models.ForeignKey(InspectdbPeople)", out.getvalue(), msg=error_message)
  33
+        self.assertNotIn("from = models.ForeignKey(InspectdbPeople)", output, msg=error_message)
32 34
         # As InspectdbPeople model is defined after InspectdbMessage, it should be quoted
33  
-        self.assertIn("from_field = models.ForeignKey('InspectdbPeople')", out.getvalue())
  35
+        self.assertIn("from_field = models.ForeignKey('InspectdbPeople', db_column='from_id')",
  36
+            output)
34 37
         self.assertIn("people_pk = models.ForeignKey(InspectdbPeople, primary_key=True)",
35  
-            out.getvalue())
  38
+            output)
36 39
         self.assertIn("people_unique = models.ForeignKey(InspectdbPeople, unique=True)",
37  
-            out.getvalue())
38  
-        out.close()
  40
+            output)
39 41
 
40 42
     def test_digits_column_name_introspection(self):
41 43
         """Introspection of column names consist/start with digits (#16536/#17676)"""
@@ -45,13 +47,24 @@ def test_digits_column_name_introspection(self):
45 47
         call_command('inspectdb',
46 48
                      table_name_filter=lambda tn:tn.startswith('inspectdb_'),
47 49
                      stdout=out)
  50
+        output = out.getvalue()
48 51
         error_message = "inspectdb generated a model field name which is a number"
49  
-        self.assertNotIn("    123 = models.CharField", out.getvalue(), msg=error_message)
50  
-        self.assertIn("number_123 = models.CharField", out.getvalue())
  52
+        self.assertNotIn("    123 = models.CharField", output, msg=error_message)
  53
+        self.assertIn("number_123 = models.CharField", output)
51 54
 
52 55
         error_message = "inspectdb generated a model field name which starts with a digit"
53  
-        self.assertNotIn("    4extra = models.CharField", out.getvalue(), msg=error_message)
54  
-        self.assertIn("number_4extra = models.CharField", out.getvalue())
  56
+        self.assertNotIn("    4extra = models.CharField", output, msg=error_message)
  57
+        self.assertIn("number_4extra = models.CharField", output)
  58
+
  59
+        self.assertNotIn("    45extra = models.CharField", output, msg=error_message)
  60
+        self.assertIn("number_45extra = models.CharField", output)
55 61
 
56  
-        self.assertNotIn("    45extra = models.CharField", out.getvalue(), msg=error_message)
57  
-        self.assertIn("number_45extra = models.CharField", out.getvalue())
  62
+    def test_underscores_column_name_introspection(self):
  63
+        """Introspection of column names containing underscores (#12460)"""
  64
+        out = StringIO()
  65
+        call_command('inspectdb', stdout=out)
  66
+        output = out.getvalue()
  67
+        self.assertIn("field = models.IntegerField()", output)
  68
+        self.assertIn("field_field = models.IntegerField(db_column='Field_')", output)
  69
+        self.assertIn("field_field_0 = models.IntegerField(db_column='Field__')", output)
  70
+        self.assertIn("field_field_1 = models.IntegerField(db_column='__field')", output)

0 notes on commit 395c608

Please sign in to comment.
Something went wrong with that request. Please try again.