diff --git a/ext/mysql2/client.c b/ext/mysql2/client.c index aa10baed3..b2ac283a4 100644 --- a/ext/mysql2/client.c +++ b/ext/mysql2/client.c @@ -133,7 +133,7 @@ static VALUE rb_raise_mysql2_error(mysql_client_wrapper *wrapper) { static void *nogvl_init(void *ptr) { MYSQL *client; - mysql_client_wrapper *wrapper = (mysql_client_wrapper *)ptr; + mysql_client_wrapper *wrapper = ptr; /* may initialize embedded server and read /etc/services off disk */ client = mysql_init(wrapper->client); @@ -224,7 +224,7 @@ static void *nogvl_close(void *ptr) { /* this is called during GC */ static void rb_mysql_client_free(void *ptr) { - mysql_client_wrapper *wrapper = (mysql_client_wrapper *)ptr; + mysql_client_wrapper *wrapper = ptr; decr_mysql2_client(wrapper); } @@ -437,10 +437,9 @@ static void *nogvl_read_query_result(void *ptr) { } static void *nogvl_do_result(void *ptr, char use_result) { - mysql_client_wrapper *wrapper; + mysql_client_wrapper *wrapper = ptr; MYSQL_RES *result; - wrapper = (mysql_client_wrapper *)ptr; if (use_result) { result = mysql_use_result(wrapper->client); } else { @@ -533,14 +532,13 @@ static VALUE disconnect_and_raise(VALUE self, VALUE error) { } static VALUE do_query(void *args) { - struct async_query_args *async_args; + struct async_query_args *async_args = args; struct timeval tv; - struct timeval* tvp; + struct timeval *tvp; long int sec; int retval; VALUE read_timeout; - async_args = (struct async_query_args *)args; read_timeout = rb_iv_get(async_args->self, "@read_timeout"); tvp = NULL; @@ -578,11 +576,9 @@ static VALUE do_query(void *args) { } #else static VALUE finish_and_mark_inactive(void *args) { - VALUE self; + VALUE self = args; MYSQL_RES *result; - self = (VALUE)args; - GET_CLIENT(self); if (!NIL_P(wrapper->active_thread)) { diff --git a/ext/mysql2/result.c b/ext/mysql2/result.c index 52820faec..040e9d526 100644 --- a/ext/mysql2/result.c +++ b/ext/mysql2/result.c @@ -48,7 +48,7 @@ static rb_encoding *binaryEncoding; #define MYSQL2_MIN_TIME 62171150401ULL #endif -#define GET_RESULT(obj) \ +#define GET_RESULT(self) \ mysql2_result_wrapper *wrapper; \ Data_Get_Struct(self, mysql2_result_wrapper, wrapper); @@ -91,16 +91,18 @@ static void rb_mysql_result_free_result(mysql2_result_wrapper * wrapper) { if (wrapper->resultFreed != 1) { if (wrapper->stmt_wrapper) { - mysql_stmt_free_result(wrapper->stmt_wrapper->stmt); - - /* MySQL BUG? If the statement handle was previously used, and so - * mysql_stmt_bind_result was called, and if that result set and bind buffers were freed, - * MySQL still thinks the result set buffer is available and will prefetch the - * first result in mysql_stmt_execute. This will corrupt or crash the program. - * By setting bind_result_done back to 0, we make MySQL think that a result set - * has never been bound to this statement handle before to prevent the prefetch. - */ - wrapper->stmt_wrapper->stmt->bind_result_done = 0; + if (!wrapper->stmt_wrapper->closed) { + mysql_stmt_free_result(wrapper->stmt_wrapper->stmt); + + /* MySQL BUG? If the statement handle was previously used, and so + * mysql_stmt_bind_result was called, and if that result set and bind buffers were freed, + * MySQL still thinks the result set buffer is available and will prefetch the + * first result in mysql_stmt_execute. This will corrupt or crash the program. + * By setting bind_result_done back to 0, we make MySQL think that a result set + * has never been bound to this statement handle before to prevent the prefetch. + */ + wrapper->stmt_wrapper->stmt->bind_result_done = 0; + } if (wrapper->result_buffers) { unsigned int i; @@ -855,6 +857,10 @@ static VALUE rb_mysql_result_each(int argc, VALUE * argv, VALUE self) { GET_RESULT(self); + if (wrapper->stmt_wrapper && wrapper->stmt_wrapper->closed) { + rb_raise(cMysql2Error, "Statement handle already closed"); + } + defaults = rb_iv_get(self, "@query_options"); Check_Type(defaults, T_HASH); if (rb_scan_args(argc, argv, "01&", &opts, &block) == 1) { diff --git a/ext/mysql2/statement.c b/ext/mysql2/statement.c index 3b83feaf2..fc61e9302 100644 --- a/ext/mysql2/statement.c +++ b/ext/mysql2/statement.c @@ -8,18 +8,19 @@ static VALUE intern_usec, intern_sec, intern_min, intern_hour, intern_day, inter #define GET_STATEMENT(self) \ mysql_stmt_wrapper *stmt_wrapper; \ Data_Get_Struct(self, mysql_stmt_wrapper, stmt_wrapper); \ - if (!stmt_wrapper->stmt) { rb_raise(cMysql2Error, "Invalid statement handle"); } + if (!stmt_wrapper->stmt) { rb_raise(cMysql2Error, "Invalid statement handle"); } \ + if (stmt_wrapper->closed) { rb_raise(cMysql2Error, "Statement handle already closed"); } static void rb_mysql_stmt_mark(void * ptr) { - mysql_stmt_wrapper* stmt_wrapper = (mysql_stmt_wrapper *)ptr; + mysql_stmt_wrapper *stmt_wrapper = ptr; if (!stmt_wrapper) return; rb_gc_mark(stmt_wrapper->client); } static void *nogvl_stmt_close(void * ptr) { - mysql_stmt_wrapper *stmt_wrapper = (mysql_stmt_wrapper *)ptr; + mysql_stmt_wrapper *stmt_wrapper = ptr; if (stmt_wrapper->stmt) { mysql_stmt_close(stmt_wrapper->stmt); stmt_wrapper->stmt = NULL; @@ -28,7 +29,7 @@ static void *nogvl_stmt_close(void * ptr) { } static void rb_mysql_stmt_free(void * ptr) { - mysql_stmt_wrapper* stmt_wrapper = (mysql_stmt_wrapper *)ptr; + mysql_stmt_wrapper *stmt_wrapper = ptr; decr_mysql2_stmt(stmt_wrapper); } @@ -93,7 +94,7 @@ static void *nogvl_prepare_statement(void *ptr) { } VALUE rb_mysql_stmt_new(VALUE rb_client, VALUE sql) { - mysql_stmt_wrapper* stmt_wrapper; + mysql_stmt_wrapper *stmt_wrapper; VALUE rb_stmt; #ifdef HAVE_RUBY_ENCODING_H rb_encoding *conn_enc; @@ -105,6 +106,7 @@ VALUE rb_mysql_stmt_new(VALUE rb_client, VALUE sql) { { stmt_wrapper->client = rb_client; stmt_wrapper->refcount = 1; + stmt_wrapper->closed = 0; stmt_wrapper->stmt = NULL; } @@ -461,6 +463,7 @@ static VALUE rb_mysql_stmt_affected_rows(VALUE self) { */ static VALUE rb_mysql_stmt_close(VALUE self) { GET_STATEMENT(self); + stmt_wrapper->closed = 1; rb_thread_call_without_gvl(nogvl_stmt_close, stmt_wrapper, RUBY_UBF_IO, 0); return Qnil; } diff --git a/ext/mysql2/statement.h b/ext/mysql2/statement.h index 0c1d54c55..63260aab0 100644 --- a/ext/mysql2/statement.h +++ b/ext/mysql2/statement.h @@ -7,6 +7,7 @@ typedef struct { VALUE client; MYSQL_STMT *stmt; int refcount; + int closed; } mysql_stmt_wrapper; void init_mysql2_statement(void);