Skip to content
This repository

HTTPS clone URL

Subversion checkout URL

You can clone with HTTPS or Subversion.

Download ZIP
Browse code

Reordered methods in database wrappers.

* Grouped related methods together -- with banner comments :/
* Described which methods are intended to be implemented in backends.
* Added docstrings.
* Used the same order in all wrappers.
  • Loading branch information...
commit d63e55039d42c2ba8bb6c8f8d133ede085b60969 1 parent c5a25c2
Aymeric Augustin authored March 02, 2013
285  django/db/backends/__init__.py
@@ -40,20 +40,24 @@ def __init__(self, settings_dict, alias=DEFAULT_DB_ALIAS,
40 40
         self.alias = alias
41 41
         self.use_debug_cursor = None
42 42
 
43  
-        # Transaction related attributes
44  
-        self.transaction_state = []
  43
+        # Savepoint management related attributes
45 44
         self.savepoint_state = 0
  45
+
  46
+        # Transaction management related attributes
  47
+        self.transaction_state = []
46 48
         # Tracks if the connection is believed to be in transaction. This is
47 49
         # set somewhat aggressively, as the DBAPI doesn't make it easy to
48 50
         # deduce if the connection is in transaction or not.
49 51
         self._dirty = False
50  
-        self._thread_ident = thread.get_ident()
51  
-        self.allow_thread_sharing = allow_thread_sharing
52 52
 
53 53
         # Connection termination related attributes
54 54
         self.close_at = None
55 55
         self.errors_occurred = False
56 56
 
  57
+        # Thread-safety related attributes
  58
+        self.allow_thread_sharing = allow_thread_sharing
  59
+        self._thread_ident = thread.get_ident()
  60
+
57 61
     def __eq__(self, other):
58 62
         return self.alias == other.alias
59 63
 
@@ -63,21 +67,26 @@ def __ne__(self, other):
63 67
     def __hash__(self):
64 68
         return hash(self.alias)
65 69
 
66  
-    def wrap_database_errors(self):
67  
-        return DatabaseErrorWrapper(self)
  70
+    ##### Backend-specific methods for creating connections and cursors #####
68 71
 
69 72
     def get_connection_params(self):
  73
+        """Returns a dict of parameters suitable for get_new_connection."""
70 74
         raise NotImplementedError
71 75
 
72 76
     def get_new_connection(self, conn_params):
  77
+        """Opens a connection to the database."""
73 78
         raise NotImplementedError
74 79
 
75 80
     def init_connection_state(self):
  81
+        """Initializes the database connection settings."""
76 82
         raise NotImplementedError
77 83
 
78 84
     def create_cursor(self):
  85
+        """Creates a cursor. Assumes that a connection is established."""
79 86
         raise NotImplementedError
80 87
 
  88
+    ##### Backend-specific wrappers for PEP-249 connection methods #####
  89
+
81 90
     def _cursor(self):
82 91
         with self.wrap_database_errors():
83 92
             if self.connection is None:
@@ -107,20 +116,48 @@ def _close(self):
107 116
             with self.wrap_database_errors():
108 117
                 return self.connection.close()
109 118
 
110  
-    def _enter_transaction_management(self, managed):
  119
+    ##### Generic wrappers for PEP-249 connection methods #####
  120
+
  121
+    def cursor(self):
111 122
         """
112  
-        A hook for backend-specific changes required when entering manual
113  
-        transaction handling.
  123
+        Creates a cursor, opening a connection if necessary.
114 124
         """
115  
-        pass
  125
+        self.validate_thread_sharing()
  126
+        if (self.use_debug_cursor or
  127
+            (self.use_debug_cursor is None and settings.DEBUG)):
  128
+            cursor = self.make_debug_cursor(self._cursor())
  129
+        else:
  130
+            cursor = util.CursorWrapper(self._cursor(), self)
  131
+        return cursor
116 132
 
117  
-    def _leave_transaction_management(self, managed):
  133
+    def commit(self):
118 134
         """
119  
-        A hook for backend-specific changes required when leaving manual
120  
-        transaction handling. Will usually be implemented only when
121  
-        _enter_transaction_management() is also required.
  135
+        Does the commit itself and resets the dirty flag.
122 136
         """
123  
-        pass
  137
+        self.validate_thread_sharing()
  138
+        self._commit()
  139
+        self.set_clean()
  140
+
  141
+    def rollback(self):
  142
+        """
  143
+        Does the rollback itself and resets the dirty flag.
  144
+        """
  145
+        self.validate_thread_sharing()
  146
+        self._rollback()
  147
+        self.set_clean()
  148
+
  149
+    def close(self):
  150
+        """
  151
+        Closes the connection to the database.
  152
+        """
  153
+        self.validate_thread_sharing()
  154
+        try:
  155
+            self._close()
  156
+        finally:
  157
+            self.connection = None
  158
+        self.set_clean()
  159
+
  160
+    ##### Backend-specific savepoint management methods #####
124 161
 
125 162
     def _savepoint(self, sid):
126 163
         if not self.features.uses_savepoints:
@@ -137,15 +174,65 @@ def _savepoint_commit(self, sid):
137 174
             return
138 175
         self.cursor().execute(self.ops.savepoint_commit_sql(sid))
139 176
 
140  
-    def abort(self):
  177
+    ##### Generic savepoint management methods #####
  178
+
  179
+    def savepoint(self):
141 180
         """
142  
-        Roll back any ongoing transaction and clean the transaction state
143  
-        stack.
  181
+        Creates a savepoint (if supported and required by the backend) inside the
  182
+        current transaction. Returns an identifier for the savepoint that will be
  183
+        used for the subsequent rollback or commit.
144 184
         """
145  
-        if self._dirty:
146  
-            self.rollback()
147  
-        while self.transaction_state:
148  
-            self.leave_transaction_management()
  185
+        thread_ident = thread.get_ident()
  186
+
  187
+        self.savepoint_state += 1
  188
+
  189
+        tid = str(thread_ident).replace('-', '')
  190
+        sid = "s%s_x%d" % (tid, self.savepoint_state)
  191
+        self._savepoint(sid)
  192
+        return sid
  193
+
  194
+    def savepoint_rollback(self, sid):
  195
+        """
  196
+        Rolls back the most recent savepoint (if one exists). Does nothing if
  197
+        savepoints are not supported.
  198
+        """
  199
+        self.validate_thread_sharing()
  200
+        if self.savepoint_state:
  201
+            self._savepoint_rollback(sid)
  202
+
  203
+    def savepoint_commit(self, sid):
  204
+        """
  205
+        Commits the most recent savepoint (if one exists). Does nothing if
  206
+        savepoints are not supported.
  207
+        """
  208
+        self.validate_thread_sharing()
  209
+        if self.savepoint_state:
  210
+            self._savepoint_commit(sid)
  211
+
  212
+    def clean_savepoints(self):
  213
+        """
  214
+        Resets the counter used to generate unique savepoint ids in this thread.
  215
+        """
  216
+        self.savepoint_state = 0
  217
+
  218
+    ##### Backend-specific transaction management methods #####
  219
+
  220
+    def _enter_transaction_management(self, managed):
  221
+        """
  222
+        A hook for backend-specific changes required when entering manual
  223
+        transaction handling.
  224
+        """
  225
+        pass
  226
+
  227
+    def _leave_transaction_management(self, managed):
  228
+        """
  229
+        A hook for backend-specific changes required when leaving manual
  230
+        transaction handling. Will usually be implemented only when
  231
+        _enter_transaction_management() is also required.
  232
+        """
  233
+        pass
  234
+
  235
+    ##### Generic transaction management methods #####
149 236
 
150 237
     def enter_transaction_management(self, managed=True):
151 238
         """
@@ -185,20 +272,15 @@ def leave_transaction_management(self):
185 272
             raise TransactionManagementError(
186 273
                 "Transaction managed block ended with pending COMMIT/ROLLBACK")
187 274
 
188  
-    def validate_thread_sharing(self):
  275
+    def abort(self):
189 276
         """
190  
-        Validates that the connection isn't accessed by another thread than the
191  
-        one which originally created it, unless the connection was explicitly
192  
-        authorized to be shared between threads (via the `allow_thread_sharing`
193  
-        property). Raises an exception if the validation fails.
  277
+        Roll back any ongoing transaction and clean the transaction state
  278
+        stack.
194 279
         """
195  
-        if (not self.allow_thread_sharing
196  
-            and self._thread_ident != thread.get_ident()):
197  
-                raise DatabaseError("DatabaseWrapper objects created in a "
198  
-                    "thread can only be used in that same thread. The object "
199  
-                    "with alias '%s' was created in thread id %s and this is "
200  
-                    "thread id %s."
201  
-                    % (self.alias, self._thread_ident, thread.get_ident()))
  280
+        if self._dirty:
  281
+            self.rollback()
  282
+        while self.transaction_state:
  283
+            self.leave_transaction_management()
202 284
 
203 285
     def is_dirty(self):
204 286
         """
@@ -224,12 +306,6 @@ def set_clean(self):
224 306
         self._dirty = False
225 307
         self.clean_savepoints()
226 308
 
227  
-    def clean_savepoints(self):
228  
-        """
229  
-        Resets the counter used to generate unique savepoint ids in this thread.
230  
-        """
231  
-        self.savepoint_state = 0
232  
-
233 309
     def is_managed(self):
234 310
         """
235 311
         Checks whether the transaction manager is in manual or in auto state.
@@ -275,57 +351,13 @@ def rollback_unless_managed(self):
275 351
         else:
276 352
             self.set_dirty()
277 353
 
278  
-    def commit(self):
279  
-        """
280  
-        Does the commit itself and resets the dirty flag.
281  
-        """
282  
-        self.validate_thread_sharing()
283  
-        self._commit()
284  
-        self.set_clean()
285  
-
286  
-    def rollback(self):
287  
-        """
288  
-        This function does the rollback itself and resets the dirty flag.
289  
-        """
290  
-        self.validate_thread_sharing()
291  
-        self._rollback()
292  
-        self.set_clean()
293  
-
294  
-    def savepoint(self):
295  
-        """
296  
-        Creates a savepoint (if supported and required by the backend) inside the
297  
-        current transaction. Returns an identifier for the savepoint that will be
298  
-        used for the subsequent rollback or commit.
299  
-        """
300  
-        thread_ident = thread.get_ident()
301  
-
302  
-        self.savepoint_state += 1
303  
-
304  
-        tid = str(thread_ident).replace('-', '')
305  
-        sid = "s%s_x%d" % (tid, self.savepoint_state)
306  
-        self._savepoint(sid)
307  
-        return sid
308  
-
309  
-    def savepoint_rollback(self, sid):
310  
-        """
311  
-        Rolls back the most recent savepoint (if one exists). Does nothing if
312  
-        savepoints are not supported.
313  
-        """
314  
-        self.validate_thread_sharing()
315  
-        if self.savepoint_state:
316  
-            self._savepoint_rollback(sid)
317  
-
318  
-    def savepoint_commit(self, sid):
319  
-        """
320  
-        Commits the most recent savepoint (if one exists). Does nothing if
321  
-        savepoints are not supported.
322  
-        """
323  
-        self.validate_thread_sharing()
324  
-        if self.savepoint_state:
325  
-            self._savepoint_commit(sid)
  354
+    ##### Foreign key constraints checks handling #####
326 355
 
327 356
     @contextmanager
328 357
     def constraint_checks_disabled(self):
  358
+        """
  359
+        Context manager that disables foreign key constraint checking.
  360
+        """
329 361
         disabled = self.disable_constraint_checking()
330 362
         try:
331 363
             yield
@@ -335,33 +367,40 @@ def constraint_checks_disabled(self):
335 367
 
336 368
     def disable_constraint_checking(self):
337 369
         """
338  
-        Backends can implement as needed to temporarily disable foreign key constraint
339  
-        checking.
  370
+        Backends can implement as needed to temporarily disable foreign key
  371
+        constraint checking.
340 372
         """
341 373
         pass
342 374
 
343 375
     def enable_constraint_checking(self):
344 376
         """
345  
-        Backends can implement as needed to re-enable foreign key constraint checking.
  377
+        Backends can implement as needed to re-enable foreign key constraint
  378
+        checking.
346 379
         """
347 380
         pass
348 381
 
349 382
     def check_constraints(self, table_names=None):
350 383
         """
351  
-        Backends can override this method if they can apply constraint checking (e.g. via "SET CONSTRAINTS
352  
-        ALL IMMEDIATE"). Should raise an IntegrityError if any invalid foreign key references are encountered.
  384
+        Backends can override this method if they can apply constraint
  385
+        checking (e.g. via "SET CONSTRAINTS ALL IMMEDIATE"). Should raise an
  386
+        IntegrityError if any invalid foreign key references are encountered.
353 387
         """
354 388
         pass
355 389
 
356  
-    def close(self):
357  
-        self.validate_thread_sharing()
358  
-        try:
359  
-            self._close()
360  
-        finally:
361  
-            self.connection = None
362  
-        self.set_clean()
  390
+    ##### Connection termination handling #####
  391
+
  392
+    def is_usable(self):
  393
+        """
  394
+        Tests if the database connection is usable.
  395
+        This function may assume that self.connection is not None.
  396
+        """
  397
+        raise NotImplementedError
363 398
 
364 399
     def close_if_unusable_or_obsolete(self):
  400
+        """
  401
+        Closes the current connection if unrecoverable errors have occurred,
  402
+        or if it outlived its maximum age.
  403
+        """
365 404
         if self.connection is not None:
366 405
             if self.errors_occurred:
367 406
                 if self.is_usable():
@@ -373,30 +412,45 @@ def close_if_unusable_or_obsolete(self):
373 412
                 self.close()
374 413
                 return
375 414
 
376  
-    def is_usable(self):
377  
-        """
378  
-        Test if the database connection is usable.
  415
+    ##### Thread safety handling #####
379 416
 
380  
-        This function may assume that self.connection is not None.
  417
+    def validate_thread_sharing(self):
381 418
         """
382  
-        raise NotImplementedError
  419
+        Validates that the connection isn't accessed by another thread than the
  420
+        one which originally created it, unless the connection was explicitly
  421
+        authorized to be shared between threads (via the `allow_thread_sharing`
  422
+        property). Raises an exception if the validation fails.
  423
+        """
  424
+        if not (self.allow_thread_sharing
  425
+                or self._thread_ident == thread.get_ident()):
  426
+            raise DatabaseError("DatabaseWrapper objects created in a "
  427
+                "thread can only be used in that same thread. The object "
  428
+                "with alias '%s' was created in thread id %s and this is "
  429
+                "thread id %s."
  430
+                % (self.alias, self._thread_ident, thread.get_ident()))
383 431
 
384  
-    def cursor(self):
385  
-        self.validate_thread_sharing()
386  
-        if (self.use_debug_cursor or
387  
-            (self.use_debug_cursor is None and settings.DEBUG)):
388  
-            cursor = self.make_debug_cursor(self._cursor())
389  
-        else:
390  
-            cursor = util.CursorWrapper(self._cursor(), self)
391  
-        return cursor
  432
+    ##### Miscellaneous #####
  433
+
  434
+    def wrap_database_errors(self):
  435
+        """
  436
+        Context manager and decorator that re-throws backend-specific database
  437
+        exceptions using Django's common wrappers.
  438
+        """
  439
+        return DatabaseErrorWrapper(self)
392 440
 
393 441
     def make_debug_cursor(self, cursor):
  442
+        """
  443
+        Creates a cursor that logs all queries in self.queries.
  444
+        """
394 445
         return util.CursorDebugWrapper(cursor, self)
395 446
 
396 447
     @contextmanager
397 448
     def temporary_connection(self):
398  
-        # Ensure a connection is established, and avoid leaving a dangling
399  
-        # connection, for operations outside of the request-response cycle.
  449
+        """
  450
+        Context manager that ensures that a connection is established, and
  451
+        if it opened one, closes it to avoid leaving a dangling connection.
  452
+        This is useful for operations outside of the request-response cycle.
  453
+        """
400 454
         must_close = self.connection is None
401 455
         cursor = self.cursor()
402 456
         try:
@@ -406,6 +460,7 @@ def temporary_connection(self):
406 460
             if must_close:
407 461
                 self.close()
408 462
 
  463
+
409 464
 class BaseDatabaseFeatures(object):
410 465
     allows_group_by_pk = False
411 466
     # True if django.db.backend.utils.typecast_timestamp is used on values
14  django/db/backends/dummy/base.py
@@ -48,19 +48,19 @@ class DatabaseWrapper(BaseDatabaseWrapper):
48 48
     # implementations. Anything that tries to actually
49 49
     # do something raises complain; anything that tries
50 50
     # to rollback or undo something raises ignore.
  51
+    _cursor = complain
51 52
     _commit = complain
52 53
     _rollback = ignore
53  
-    enter_transaction_management = complain
54  
-    leave_transaction_management = ignore
  54
+    _close = ignore
  55
+    _savepoint = ignore
  56
+    _savepoint_commit = complain
  57
+    _savepoint_rollback = ignore
  58
+    _enter_transaction_management = complain
  59
+    _leave_transaction_management = ignore
55 60
     set_dirty = complain
56 61
     set_clean = complain
57 62
     commit_unless_managed = complain
58 63
     rollback_unless_managed = ignore
59  
-    savepoint = ignore
60  
-    savepoint_commit = complain
61  
-    savepoint_rollback = ignore
62  
-    close = ignore
63  
-    cursor = complain
64 64
 
65 65
     def __init__(self, *args, **kwargs):
66 66
         super(DatabaseWrapper, self).__init__(*args, **kwargs)
34  django/db/backends/mysql/base.py
@@ -439,29 +439,12 @@ def create_cursor(self):
439 439
         cursor = self.connection.cursor()
440 440
         return CursorWrapper(cursor)
441 441
 
442  
-    def is_usable(self):
443  
-        try:
444  
-            self.connection.ping()
445  
-        except DatabaseError:
446  
-            return False
447  
-        else:
448  
-            return True
449  
-
450 442
     def _rollback(self):
451 443
         try:
452 444
             BaseDatabaseWrapper._rollback(self)
453 445
         except Database.NotSupportedError:
454 446
             pass
455 447
 
456  
-    @cached_property
457  
-    def mysql_version(self):
458  
-        with self.temporary_connection():
459  
-            server_info = self.connection.get_server_info()
460  
-        match = server_version_re.match(server_info)
461  
-        if not match:
462  
-            raise Exception('Unable to determine MySQL version from version string %r' % server_info)
463  
-        return tuple([int(x) for x in match.groups()])
464  
-
465 448
     def disable_constraint_checking(self):
466 449
         """
467 450
         Disables foreign key checks, primarily for use in adding rows with forward references. Always returns True,
@@ -510,3 +493,20 @@ def check_constraints(self, table_names=None):
510 493
                         % (table_name, bad_row[0],
511 494
                         table_name, column_name, bad_row[1],
512 495
                         referenced_table_name, referenced_column_name))
  496
+
  497
+    def is_usable(self):
  498
+        try:
  499
+            self.connection.ping()
  500
+        except DatabaseError:
  501
+            return False
  502
+        else:
  503
+            return True
  504
+
  505
+    @cached_property
  506
+    def mysql_version(self):
  507
+        with self.temporary_connection():
  508
+            server_info = self.connection.get_server_info()
  509
+        match = server_version_re.match(server_info)
  510
+        if not match:
  511
+            raise Exception('Unable to determine MySQL version from version string %r' % server_info)
  512
+        return tuple([int(x) for x in match.groups()])
52  django/db/backends/oracle/base.py
@@ -515,14 +515,6 @@ def __init__(self, *args, **kwargs):
515 515
         self.introspection = DatabaseIntrospection(self)
516 516
         self.validation = BaseDatabaseValidation(self)
517 517
 
518  
-    def check_constraints(self, table_names=None):
519  
-        """
520  
-        To check constraints, we set constraints to immediate. Then, when, we're done we must ensure they
521  
-        are returned to deferred.
522  
-        """
523  
-        self.cursor().execute('SET CONSTRAINTS ALL IMMEDIATE')
524  
-        self.cursor().execute('SET CONSTRAINTS ALL DEFERRED')
525  
-
526 518
     def _connect_string(self):
527 519
         settings_dict = self.settings_dict
528 520
         if not settings_dict['HOST'].strip():
@@ -536,9 +528,6 @@ def _connect_string(self):
536 528
         return "%s/%s@%s" % (settings_dict['USER'],
537 529
                              settings_dict['PASSWORD'], dsn)
538 530
 
539  
-    def create_cursor(self):
540  
-        return FormatStylePlaceholderCursor(self.connection)
541  
-
542 531
     def get_connection_params(self):
543 532
         conn_params = self.settings_dict['OPTIONS'].copy()
544 533
         if 'use_returning_into' in conn_params:
@@ -598,21 +587,8 @@ def init_connection_state(self):
598 587
             # stmtcachesize is available only in 4.3.2 and up.
599 588
             pass
600 589
 
601  
-    def is_usable(self):
602  
-        try:
603  
-            if hasattr(self.connection, 'ping'):    # Oracle 10g R2 and higher
604  
-                self.connection.ping()
605  
-            else:
606  
-                # Use a cx_Oracle cursor directly, bypassing Django's utilities.
607  
-                self.connection.cursor().execute("SELECT 1 FROM DUAL")
608  
-        except DatabaseError:
609  
-            return False
610  
-        else:
611  
-            return True
612  
-
613  
-    # Oracle doesn't support savepoint commits.  Ignore them.
614  
-    def _savepoint_commit(self, sid):
615  
-        pass
  590
+    def create_cursor(self):
  591
+        return FormatStylePlaceholderCursor(self.connection)
616 592
 
617 593
     def _commit(self):
618 594
         if self.connection is not None:
@@ -632,6 +608,30 @@ def _commit(self):
632 608
                     six.reraise(utils.IntegrityError, utils.IntegrityError(*tuple(e.args)), sys.exc_info()[2])
633 609
                 raise
634 610
 
  611
+    # Oracle doesn't support savepoint commits.  Ignore them.
  612
+    def _savepoint_commit(self, sid):
  613
+        pass
  614
+
  615
+    def check_constraints(self, table_names=None):
  616
+        """
  617
+        To check constraints, we set constraints to immediate. Then, when, we're done we must ensure they
  618
+        are returned to deferred.
  619
+        """
  620
+        self.cursor().execute('SET CONSTRAINTS ALL IMMEDIATE')
  621
+        self.cursor().execute('SET CONSTRAINTS ALL DEFERRED')
  622
+
  623
+    def is_usable(self):
  624
+        try:
  625
+            if hasattr(self.connection, 'ping'):    # Oracle 10g R2 and higher
  626
+                self.connection.ping()
  627
+            else:
  628
+                # Use a cx_Oracle cursor directly, bypassing Django's utilities.
  629
+                self.connection.cursor().execute("SELECT 1 FROM DUAL")
  630
+        except DatabaseError:
  631
+            return False
  632
+        else:
  633
+            return True
  634
+
635 635
     @cached_property
636 636
     def oracle_version(self):
637 637
         with self.temporary_connection():
82  django/db/backends/postgresql_psycopg2/base.py
@@ -91,40 +91,6 @@ def __init__(self, *args, **kwargs):
91 91
         self.introspection = DatabaseIntrospection(self)
92 92
         self.validation = BaseDatabaseValidation(self)
93 93
 
94  
-    def check_constraints(self, table_names=None):
95  
-        """
96  
-        To check constraints, we set constraints to immediate. Then, when, we're done we must ensure they
97  
-        are returned to deferred.
98  
-        """
99  
-        self.cursor().execute('SET CONSTRAINTS ALL IMMEDIATE')
100  
-        self.cursor().execute('SET CONSTRAINTS ALL DEFERRED')
101  
-
102  
-    def close(self):
103  
-        self.validate_thread_sharing()
104  
-        if self.connection is None:
105  
-            return
106  
-
107  
-        try:
108  
-            self.connection.close()
109  
-            self.connection = None
110  
-        except Database.Error:
111  
-            # In some cases (database restart, network connection lost etc...)
112  
-            # the connection to the database is lost without giving Django a
113  
-            # notification. If we don't set self.connection to None, the error
114  
-            # will occur a every request.
115  
-            self.connection = None
116  
-            logger.warning('psycopg2 error while closing the connection.',
117  
-                exc_info=sys.exc_info()
118  
-            )
119  
-            raise
120  
-        finally:
121  
-            self.set_clean()
122  
-
123  
-    @cached_property
124  
-    def pg_version(self):
125  
-        with self.temporary_connection():
126  
-            return get_version(self.connection)
127  
-
128 94
     def get_connection_params(self):
129 95
         settings_dict = self.settings_dict
130 96
         if not settings_dict['NAME']:
@@ -177,14 +143,26 @@ def create_cursor(self):
177 143
         cursor.tzinfo_factory = utc_tzinfo_factory if settings.USE_TZ else None
178 144
         return cursor
179 145
 
180  
-    def is_usable(self):
  146
+    def close(self):
  147
+        self.validate_thread_sharing()
  148
+        if self.connection is None:
  149
+            return
  150
+
181 151
         try:
182  
-            # Use a psycopg cursor directly, bypassing Django's utilities.
183  
-            self.connection.cursor().execute("SELECT 1")
184  
-        except DatabaseError:
185  
-            return False
186  
-        else:
187  
-            return True
  152
+            self.connection.close()
  153
+            self.connection = None
  154
+        except Database.Error:
  155
+            # In some cases (database restart, network connection lost etc...)
  156
+            # the connection to the database is lost without giving Django a
  157
+            # notification. If we don't set self.connection to None, the error
  158
+            # will occur a every request.
  159
+            self.connection = None
  160
+            logger.warning('psycopg2 error while closing the connection.',
  161
+                exc_info=sys.exc_info()
  162
+            )
  163
+            raise
  164
+        finally:
  165
+            self.set_clean()
188 166
 
189 167
     def _enter_transaction_management(self, managed):
190 168
         """
@@ -222,3 +200,25 @@ def set_dirty(self):
222 200
         if ((self.transaction_state and self.transaction_state[-1]) or
223 201
                 not self.features.uses_autocommit):
224 202
             super(DatabaseWrapper, self).set_dirty()
  203
+
  204
+    def check_constraints(self, table_names=None):
  205
+        """
  206
+        To check constraints, we set constraints to immediate. Then, when, we're done we must ensure they
  207
+        are returned to deferred.
  208
+        """
  209
+        self.cursor().execute('SET CONSTRAINTS ALL IMMEDIATE')
  210
+        self.cursor().execute('SET CONSTRAINTS ALL DEFERRED')
  211
+
  212
+    def is_usable(self):
  213
+        try:
  214
+            # Use a psycopg cursor directly, bypassing Django's utilities.
  215
+            self.connection.cursor().execute("SELECT 1")
  216
+        except DatabaseError:
  217
+            return False
  218
+        else:
  219
+            return True
  220
+
  221
+    @cached_property
  222
+    def pg_version(self):
  223
+        with self.temporary_connection():
  224
+            return get_version(self.connection)
19  django/db/backends/sqlite3/base.py
@@ -347,8 +347,13 @@ def init_connection_state(self):
347 347
     def create_cursor(self):
348 348
         return self.connection.cursor(factory=SQLiteCursorWrapper)
349 349
 
350  
-    def is_usable(self):
351  
-        return True
  350
+    def close(self):
  351
+        self.validate_thread_sharing()
  352
+        # If database is in memory, closing the connection destroys the
  353
+        # database. To prevent accidental data loss, ignore close requests on
  354
+        # an in-memory db.
  355
+        if self.settings_dict['NAME'] != ":memory:":
  356
+            BaseDatabaseWrapper.close(self)
352 357
 
353 358
     def check_constraints(self, table_names=None):
354 359
         """
@@ -384,13 +389,9 @@ def check_constraints(self, table_names=None):
384 389
                         % (table_name, bad_row[0], table_name, column_name, bad_row[1],
385 390
                         referenced_table_name, referenced_column_name))
386 391
 
387  
-    def close(self):
388  
-        self.validate_thread_sharing()
389  
-        # If database is in memory, closing the connection destroys the
390  
-        # database. To prevent accidental data loss, ignore close requests on
391  
-        # an in-memory db.
392  
-        if self.settings_dict['NAME'] != ":memory:":
393  
-            BaseDatabaseWrapper.close(self)
  392
+    def is_usable(self):
  393
+        return True
  394
+
394 395
 
395 396
 FORMAT_QMARK_REGEX = re.compile(r'(?<!%)%s')
396 397
 

0 notes on commit d63e550

Please sign in to comment.
Something went wrong with that request. Please try again.