Skip to content

Commit

Permalink
First implementation for status callback
Browse files Browse the repository at this point in the history
  • Loading branch information
9EOR9 committed Aug 5, 2022
1 parent 09e5cad commit a9ad1fc
Show file tree
Hide file tree
Showing 5 changed files with 236 additions and 66 deletions.
45 changes: 38 additions & 7 deletions include/mariadb_python.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,9 @@
#include <docs/common.h>
#include <limits.h>

#define CHECK_TYPE(obj, type) \
(Py_TYPE((obj)) == type || PyType_IsSubtype(Py_TYPE((obj)), type))

#if PY_VERSION_HEX < 0x030900A4 && !defined(Py_SET_TYPE)
static inline void _Py_SET_TYPE(PyObject *ob, PyTypeObject *type)
{ ob->ob_type = type; }
Expand All @@ -55,8 +58,6 @@ typedef CRITICAL_SECTION pthread_mutex_t;
#include <limits.h>
#endif /* defined(_WIN32) */

#define CHECK_TYPE(obj, type) \
(Py_TYPE((obj)) == type || PyType_IsSubtype(Py_TYPE((obj)), type))

#ifndef MIN
#define MIN(a,b) (a) < (b) ? (a) : (b)
Expand Down Expand Up @@ -181,6 +182,7 @@ typedef struct st_parser {
/* PEP-249: Connection object */
typedef struct {
PyObject_HEAD
PyThreadState *thread_state;
MYSQL *mysql;
int open;
uint8_t is_buffered;
Expand All @@ -199,10 +201,13 @@ typedef struct {
uint8_t status;
uint8_t asynchronous;
struct timespec last_used;
PyThreadState *thread_state;
unsigned long thread_id;
char *server_info;
uint8_t closed;
#if MARIADB_PACKAGE_VERSION_ID > 30301
PyObject *status_callback;
#endif
PyObject *last_executed_stmt;
} MrdbConnection;

typedef struct {
Expand Down Expand Up @@ -275,7 +280,6 @@ typedef struct {
uint8_t fetched;
uint8_t closed;
uint8_t reprepare;
PyThreadState *thread_state;
enum enum_paramstyle paramstyle;
} MrdbCursor;

Expand Down Expand Up @@ -736,6 +740,33 @@ MrdbParser_parse(MrdbParser *p, uint8_t is_batch, char *errmsg, size_t errmsg_le

#endif /* __i386__ OR _WIN32 */

#ifdef _WIN32
//#define alloca _malloca
#endif
/* Due to callback functions we cannot use PY_BEGIN/END_ALLOW_THREADS */

#define MARIADB_BEGIN_ALLOW_THREADS(obj)\
{\
(obj)->thread_state= PyEval_SaveThread();\
}

#define MARIADB_END_ALLOW_THREADS(obj)\
if ((obj)->thread_state)\
{\
PyEval_RestoreThread((obj)->thread_state);\
(obj)->thread_state= NULL;\
}

#define MARIADB_UNBLOCK_THREADS(obj)\
{\
if ((obj)->thread_state)\
{\
_save= (obj)->thread_state;\
PyEval_RestoreThread(_save);\
(obj)->thread_state= NULL;\
}\
}

#define MARIADB_BLOCK_THREADS(obj)\
if (_save)\
{\
(obj)->thread_state= PyEval_SaveThread();\
_save= NULL;\
}
28 changes: 14 additions & 14 deletions mariadb/connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def __init__(self, *args, **kwargs):
Establishes a connection to a database server and returns a connection
object.
"""

self._last_executed_statement= None
self._socket= None
self.__in_use= 0
self.__pool = None
Expand Down Expand Up @@ -440,7 +440,7 @@ def database(self):
"""Get default database for connection."""

self._check_closed()
return self._mariadb_get_info(INFO.SCHEMA, str)
return self._mariadb_get_info(INFO.SCHEMA)

@database.setter
def database(self, schema):
Expand All @@ -462,7 +462,7 @@ def user(self):
"""
self._check_closed()

return self._mariadb_get_info(INFO.USER, str)
return self._mariadb_get_info(INFO.USER)

@property
def character_set(self):
Expand All @@ -479,42 +479,42 @@ def client_capabilities(self):
"""Client capability flags."""

self._check_closed()
return self._mariadb_get_info(INFO.CLIENT_CAPABILITIES, int)
return self._mariadb_get_info(INFO.CLIENT_CAPABILITIES)

@property
def server_capabilities(self):
"""Server capability flags."""

self._check_closed()
return self._mariadb_get_info(INFO.SERVER_CAPABILITIES, int)
return self._mariadb_get_info(INFO.SERVER_CAPABILITIES)

@property
def extended_server_capabilities(self):
"""Extended server capability flags (only for MariaDB database servers)."""

self._check_closed()
return self._mariadb_get_info(INFO.EXTENDED_SERVER_CAPABILITIES, int)
return self._mariadb_get_info(INFO.EXTENDED_SERVER_CAPABILITIES)

@property
def server_port(self):
"""Database server TCP/IP port. This value will be 0 in case of a unix socket connection."""

self._check_closed()
return self._mariadb_get_info(INFO.PORT, int)
return self._mariadb_get_info(INFO.PORT)

@property
def unix_socket(self):
"""Unix socket name."""

self._check_closed()
return self._mariadb_get_info(INFO.UNIX_SOCKET, str)
return self._mariadb_get_info(INFO.UNIX_SOCKET)

@property
def server_name(self):
"""Name or IP address of database server."""

self._check_closed()
return self._mariadb_get_info(INFO.HOST, str)
return self._mariadb_get_info(INFO.HOST)

@property
def collation(self):
Expand All @@ -527,21 +527,21 @@ def server_info(self):
"""Server version in alphanumerical format (str)"""

self._check_closed()
return self._mariadb_get_info(INFO.SERVER_VERSION, str)
return self._mariadb_get_info(INFO.SERVER_VERSION)

@property
def tls_cipher(self):
"""TLS cipher suite if a secure connection is used."""

self._check_closed()
return self._mariadb_get_info(INFO.SSL_CIPHER, str)
return self._mariadb_get_info(INFO.SSL_CIPHER)

@property
def tls_version(self):
"""TLS protocol version if a secure connection is used."""

self._check_closed()
return self._mariadb_get_info(INFO.TLS_VERSION, str)
return self._mariadb_get_info(INFO.TLS_VERSION)

@property
def server_status(self):
Expand All @@ -550,7 +550,7 @@ def server_status(self):
"""

self._check_closed()
return self._mariadb_get_info(INFO.SERVER_STATUS, int)
return self._mariadb_get_info(INFO.SERVER_STATUS)

@property
def server_version(self):
Expand All @@ -562,7 +562,7 @@ def server_version(self):
"""

self._check_closed()
return self._mariadb_get_info(INFO.SERVER_VERSION_ID, int)
return self._mariadb_get_info(INFO.SERVER_VERSION_ID)

@property
def server_version_info(self):
Expand Down
7 changes: 6 additions & 1 deletion mariadb/cursors.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,8 @@ def execute(self, statement: str, data: Sequence =(), buffered=None):

self.check_closed()

self.connection._last_executed_statement= statement

# Parse statement
do_parse= True
self._rowcount= 0
Expand Down Expand Up @@ -314,6 +316,8 @@ def executemany(self, statement, parameters):
if not parameters or not len(parameters):
raise mariadb.ProgrammingError("No data provided")

self.connection._last_executed_statement= statement

# clear pending results
if self.field_count:
self._clear_result()
Expand Down Expand Up @@ -373,7 +377,8 @@ def close(self):
The cursor will be unusable from this point forward; an Error (or subclass)
exception will be raised if any operation is attempted with the cursor."
"""
super().close()
if not self.connection.is_closed:
super().close()

def fetchone(self):
"""
Expand Down
Loading

0 comments on commit a9ad1fc

Please sign in to comment.