Skip to content

Commit

Permalink
Check if object was closed:
Browse files Browse the repository at this point in the history
When accessing methods or properties of connection and cursor objects
we now check if the object or parent object was closed and raise an
exception.
  • Loading branch information
9EOR9 committed Jan 16, 2022
1 parent a839827 commit 811cc1c
Show file tree
Hide file tree
Showing 5 changed files with 92 additions and 8 deletions.
52 changes: 44 additions & 8 deletions mariadb/connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,10 @@ class Connection(mariadb._mariadb.connection):
Connections are created using the method mariadb.connect()
"""

def _check_closed(self):
if self._closed:
raise mariadb.ProgrammingError("Connection is closed")

def __init__(self, *args, **kwargs):
"""
Establishes a connection to a database server and returns a connection
Expand All @@ -50,7 +54,6 @@ def __init__(self, *args, **kwargs):
self.__last_used = 0
self.tpc_state= TPC_STATE.NONE
self._xid= None
self.__closed= None

autocommit= kwargs.pop("autocommit", False)
self._converter= kwargs.pop("converter", None)
Expand Down Expand Up @@ -109,24 +112,27 @@ def cursor(self, cursorclass=mariadb.cursors.Cursor, **kwargs):
If cursor_type is set to CURSOR_TYPE.READ_ONLY, a cursor is opened for
the statement invoked with cursors execute() method.
"""

self._check_closed()
cursor= cursorclass(self, **kwargs)
if not isinstance(cursor, mariadb._mariadb.cursor):
raise mariadb.ProgrammingError("%s is not an instance of mariadb.cursor" % cursor)
return cursor

def close(self):
self._check_closed()
if self._Connection__pool:
self._Connection__pool._close_connection(self)
else:
super().close()

def __enter__(self):
self._check_closed()
"Returns a copy of the connection."

return self

def __exit__(self, exc_type, exc_val, exc_tb):
self._check_closed()
"Closes connection."

self.close()
Expand All @@ -136,6 +142,7 @@ def commit(self):
Commit any pending transaction to the database.
"""

self._check_closed()
if self.tpc_state > TPC_STATE.NONE:
raise mariadb.ProgrammingError("commit() is not allowed if a TPC transaction is active")
self._execute_command("COMMIT")
Expand All @@ -151,6 +158,7 @@ def rollback(self):
or the storage engine does not support transactions."
"""

self._check_closed()
if self.tpc_state > TPC_STATE.NONE:
raise mariadb.ProgrammingError("rollback() is not allowed if a TPC transaction is active")
self._execute_command("ROLLBACK")
Expand All @@ -164,6 +172,7 @@ def kill(self, id: int):
The connection id must be retrieved by SHOW PROCESSLIST sql command.
"""

self._check_closed()
if not isinstance(id, int):
raise mariadb.ProgrammingError("id must be of type int.")
stmt= "KILL %s" % id
Expand All @@ -175,6 +184,7 @@ def begin(self):
Start a new transaction which can be committed by .commit() method,
or cancelled by .rollback() method.
"""
self._check_closed()
self._execute_command("BEGIN")
self._read_response()

Expand All @@ -185,6 +195,7 @@ def select_db(self, new_db: str):
The default database can also be obtained or changed by database attribute.
"""

self._check_closed()
self.database= new_db

def get_server_version(self):
Expand All @@ -200,6 +211,7 @@ def show_warnings(self):
Shows error, warning and note messages from last executed command.
"""

self._check_closed()
if (not self.warnings):
return None;

Expand Down Expand Up @@ -252,6 +264,7 @@ def tpc_begin(self, xid):
calls .commit() or .rollback() during an active TPC transaction.
"""

self._check_closed()
if type(xid).__name__ != "xid":
raise TypeError("argument 1 must be xid not %s", type(xid).__name__)
stmt= "XA BEGIN '%s','%s',%s" % (xid[1], xid[2], xid[0])
Expand Down Expand Up @@ -280,6 +293,7 @@ def tpc_commit(self, xid=None):
is intended for use in recovery."
"""

self._check_closed()
if not xid:
xid= self._xid

Expand Down Expand Up @@ -325,6 +339,7 @@ def tpc_prepare(self):
.tpc_commit() or .tpc_rollback() have been called.
"""

self._check_closed()
if self.tpc_state == TPC_STATE.NONE:
raise mariadb.ProgrammingError("Transaction not started.")
if self.tpc_state == TPC_STATE.PREPARE:
Expand Down Expand Up @@ -365,6 +380,7 @@ def tpc_rollback(self, xid=None):
.tpc_commit() or .tpc_rollback() have been called.
"""

self._check_closed()
if self.tpc_state == TPC_STATE.NONE:
raise mariadb.ProgrammingError("Transaction not started.")
if xid and type(xid).__name__ != "xid":
Expand Down Expand Up @@ -400,6 +416,7 @@ def tpc_recover(self):
tpc_commit(xid) or .tpc_rollback(xid).
"""

self._check_closed()
cursor= self.cursor()
cursor.execute("XA RECOVER")
result= cursor.fetchall()
Expand All @@ -410,17 +427,19 @@ def tpc_recover(self):
def database(self):
"""Get default database for connection."""

self._check_closed()
return self._mariadb_get_info(INFO.SCHEMA, str)

@database.setter
def database(self, schema):
"""Set default database."""
"""Set default database."""
self._check_closed()

try:
self._execute_command("USE %s" % str(schema))
self._read_response()
except:
raise
try:
self._execute_command("USE %s" % str(schema))
self._read_response()
except:
raise

@property
def user(self):
Expand All @@ -429,6 +448,7 @@ def user(self):
string if it can't be determined, e.g. when using socket
authentication.
"""
self._check_closed()

return self._mariadb_get_info(INFO.USER, str)

Expand All @@ -446,36 +466,42 @@ def character_set(self):
def client_capabilities(self):
"""Client capability flags."""

self._check_closed()
return self._mariadb_get_info(INFO.CLIENT_CAPABILITIES, int)

@property
def server_capabilities(self):
"""Server capability flags."""

self._check_closed()
return self._mariadb_get_info(INFO.SERVER_CAPABILITIES, int)

@property
def extended_server_capabilities(self):
"""Extended server capability flags (only for MariaDB database servers)."""

self._check_closed()
return self._mariadb_get_info(INFO.EXTENDED_SERVER_CAPABILITIES, int)

@property
def server_port(self):
"""Database server TCP/IP port. This value will be 0 in case of a unix socket connection."""

self._check_closed()
return self._mariadb_get_info(INFO.PORT, int)

@property
def unix_socket(self):
"""Unix socket name."""

self._check_closed()
return self._mariadb_get_info(INFO.UNIX_SOCKET, str)

@property
def server_name(self):
"""Name or IP address of database server."""

self._check_closed()
return self._mariadb_get_info(INFO.HOST, str)

@property
Expand All @@ -488,18 +514,21 @@ def collation(self):
def server_info(self):
"""Server version in alphanumerical format (str)"""

self._check_closed()
return self._mariadb_get_info(INFO.SERVER_VERSION, str)

@property
def tls_cipher(self):
"""TLS cipher suite if a secure connection is used."""

self._check_closed()
return self._mariadb_get_info(INFO.SSL_CIPHER, str)

@property
def tls_version(self):
"""TLS protocol version if a secure connection is used."""

self._check_closed()
return self._mariadb_get_info(INFO.TLS_VERSION, str)

@property
Expand All @@ -508,6 +537,7 @@ def server_status(self):
Return server status flags
"""

self._check_closed()
return self._mariadb_get_info(INFO.SERVER_STATUS, int)

@property
Expand All @@ -519,6 +549,7 @@ def server_version(self):
VERSION_MAJOR * 10000 + VERSION_MINOR * 100 + VERSION_PATCH
"""

self._check_closed()
return self._mariadb_get_info(INFO.SERVER_VERSION_ID, int)

@property
Expand All @@ -527,6 +558,7 @@ def server_version_info(self):
Returns numeric version of connected database server in tuple format.
"""

self._check_closed()
version= self.server_version
return (int(version / 10000), int((version % 10000) / 100), version % 100)

Expand All @@ -542,10 +574,12 @@ def autocommit(self):
By default autocommit mode is set to False."
"""

self._check_closed()
return bool(self.server_status & STATUS.AUTOCOMMIT)

@autocommit.setter
def autocommit(self, mode):
self._check_closed()
if bool(mode) == self.autocommit:
return
try:
Expand Down Expand Up @@ -576,6 +610,7 @@ def open(self):
non processed pending result sets.
"""

self._check_closed()
try:
self.ping()
except:
Expand All @@ -591,4 +626,5 @@ def thread_id(self):
Alias for connection_id
"""

self._check_closed()
return self.connection_id
2 changes: 2 additions & 0 deletions mariadb/constants/INFO.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,3 +36,5 @@
SERVER_CAPABILITIES= 31
EXTENDED_SERVER_CAPABILITIES= 32
CLIENT_CAPABILITIES= 33
BYTES_READ= 34
BYTES_SENT= 35
Loading

0 comments on commit 811cc1c

Please sign in to comment.