Skip to content

Commit

Permalink
Accept query options on Statement#execute
Browse files Browse the repository at this point in the history
  • Loading branch information
sodabrew committed Nov 28, 2017
1 parent 245c3d1 commit 6709b2b
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 58 deletions.
26 changes: 19 additions & 7 deletions ext/mysql2/statement.c
Expand Up @@ -2,7 +2,7 @@

VALUE cMysql2Statement;
extern VALUE mMysql2, cMysql2Error, cBigDecimal, cDateTime, cDate;
static VALUE sym_stream, intern_new_with_args, intern_each, intern_to_s;
static VALUE sym_stream, intern_new_with_args, intern_each, intern_to_s, intern_merge_bang;
static VALUE intern_sec_fraction, intern_usec, intern_sec, intern_min, intern_hour, intern_day, intern_month, intern_year;

#define GET_STATEMENT(self) \
Expand Down Expand Up @@ -184,7 +184,7 @@ static void set_buffer_for_string(MYSQL_BIND* bind_buffer, unsigned long *length
* the buffer is a Ruby string pointer and not our memory to manage.
*/
#define FREE_BINDS \
for (i = 0; i < argc; i++) { \
for (i = 0; i < c; i++) { \
if (bind_buffers[i].buffer && NIL_P(params_enc[i])) { \
xfree(bind_buffers[i].buffer); \
} \
Expand Down Expand Up @@ -248,8 +248,10 @@ static VALUE rb_mysql_stmt_execute(int argc, VALUE *argv, VALUE self) {
unsigned long *length_buffers = NULL;
unsigned long bind_count;
long i;
int c;
MYSQL_STMT *stmt;
MYSQL_RES *metadata;
VALUE opts;
VALUE current;
VALUE resultObj;
VALUE *params_enc;
Expand All @@ -261,22 +263,25 @@ static VALUE rb_mysql_stmt_execute(int argc, VALUE *argv, VALUE self) {

conn_enc = rb_to_encoding(wrapper->encoding);

/* Scratch space for string encoding exports, allocate on the stack. */
params_enc = alloca(sizeof(VALUE) * argc);
// Get count of ordinary arguments, and extract hash opts/keyword arguments
c = rb_scan_args(argc, argv, "*:", NULL, &opts);

// Scratch space for string encoding exports, allocate on the stack
params_enc = alloca(sizeof(VALUE) * c);

stmt = stmt_wrapper->stmt;

bind_count = mysql_stmt_param_count(stmt);
if (argc != (long)bind_count) {
rb_raise(cMysql2Error, "Bind parameter count (%ld) doesn't match number of arguments (%d)", bind_count, argc);
if (c != (long)bind_count) {
rb_raise(cMysql2Error, "Bind parameter count (%ld) doesn't match number of arguments (%d)", bind_count, c);
}

// setup any bind variables in the query
if (bind_count > 0) {
bind_buffers = xcalloc(bind_count, sizeof(MYSQL_BIND));
length_buffers = xcalloc(bind_count, sizeof(unsigned long));

for (i = 0; i < argc; i++) {
for (i = 0; i < c; i++) {
bind_buffers[i].buffer = NULL;
params_enc[i] = Qnil;

Expand Down Expand Up @@ -416,10 +421,16 @@ static VALUE rb_mysql_stmt_execute(int argc, VALUE *argv, VALUE self) {
return Qnil;
}

// Important to duplicate the hash, will receive merge! if extra opts
current = rb_hash_dup(rb_iv_get(stmt_wrapper->client, "@query_options"));
(void)RB_GC_GUARD(current);
Check_Type(current, T_HASH);

// Merge in hash opts/keyword arguments
if (!NIL_P(opts)) {
rb_funcall(current, intern_merge_bang, 1, opts);
}

is_streaming = (Qtrue == rb_hash_aref(current, sym_stream));
if (!is_streaming) {
// recieve the whole result set from the server
Expand Down Expand Up @@ -562,4 +573,5 @@ void init_mysql2_statement() {
intern_year = rb_intern("year");

intern_to_s = rb_intern("to_s");
intern_merge_bang = rb_intern("merge!");
}
8 changes: 4 additions & 4 deletions lib/mysql2/statement.rb
Expand Up @@ -5,14 +5,14 @@ class Statement
include Enumerable

if Thread.respond_to?(:handle_interrupt)
def execute(*args)
def execute(*args, **kwargs)
Thread.handle_interrupt(::Mysql2::Util::TIMEOUT_ERROR_CLASS => :never) do
_execute(*args)
_execute(*args, **kwargs)
end
end
else
def execute(*args)
_execute(*args)
def execute(*args, **kwargs)
_execute(*args, **kwargs)
end
end
end
Expand Down
82 changes: 35 additions & 47 deletions spec/mysql2/statement_spec.rb
Expand Up @@ -88,6 +88,20 @@ def stmt_count
expect(result.to_a).to eq(['max1' => int64_max1, 'max2' => int64_max2, 'max3' => int64_max3, 'min1' => int64_min1, 'min2' => int64_min2, 'min3' => int64_min3])
end

it "should accept keyword arguments on statement execute" do
stmt = @client.prepare 'SELECT 1 AS a'

expect(stmt.execute(as: :hash).first).to eq("a" => 1)
expect(stmt.execute(as: :array).first).to eq([1])
end

it "should accept bind arguments and keyword arguments on statement execute" do
stmt = @client.prepare 'SELECT ? AS a'

expect(stmt.execute(1, as: :hash).first).to eq("a" => 1)
expect(stmt.execute(1, as: :array).first).to eq([1])
end

it "should keep its result after other query" do
@client.query 'USE test'
@client.query 'CREATE TABLE IF NOT EXISTS mysql2_stmt_q(a int)'
Expand Down Expand Up @@ -188,10 +202,9 @@ def stmt_count
end

it "should warn but still work if cache_rows is set to false" do
@client.query_options[:cache_rows] = false
statement = @client.prepare 'SELECT 1'
result = nil
expect { result = statement.execute.to_a }.to output(/:cache_rows is forced for prepared statements/).to_stderr
expect { result = statement.execute(cache_rows: false).to_a }.to output(/:cache_rows is forced for prepared statements/).to_stderr
expect(result.length).to eq(1)
end

Expand Down Expand Up @@ -240,10 +253,7 @@ def stmt_count
it "should be able to stream query result" do
n = 1
stmt = @client.prepare("SELECT 1 UNION SELECT 2")

@client.query_options.merge!(stream: true, cache_rows: false, as: :array)

stmt.execute.each do |r|
stmt.execute(stream: true, cache_rows: false, as: :array).each do |r|
case n
when 1
expect(r).to eq([1])
Expand All @@ -269,23 +279,17 @@ def stmt_count
end

it "should yield rows as hash's with symbol keys if :symbolize_keys was set to true" do
@client.query_options[:symbolize_keys] = true
@result = @client.prepare("SELECT 1").execute
@result = @client.prepare("SELECT 1").execute(symbolize_keys: true)
@result.each do |row|
expect(row.keys.first).to be_an_instance_of(Symbol)
end
@client.query_options[:symbolize_keys] = false
end

it "should be able to return results as an array" do
@client.query_options[:as] = :array

@result = @client.prepare("SELECT 1").execute
@result = @client.prepare("SELECT 1").execute(as: :array)
@result.each do |row|
expect(row).to be_an_instance_of(Array)
end

@client.query_options[:as] = :hash
end

it "should cache previously yielded results by default" do
Expand All @@ -294,35 +298,21 @@ def stmt_count
end

it "should yield different value for #first if streaming" do
@client.query_options[:stream] = true
@client.query_options[:cache_rows] = false

result = @client.prepare("SELECT 1 UNION SELECT 2").execute
result = @client.prepare("SELECT 1 UNION SELECT 2").execute(stream: true, cache_rows: true)
expect(result.first).not_to eql(result.first)

@client.query_options[:stream] = false
@client.query_options[:cache_rows] = true
end

it "should yield the same value for #first if streaming is disabled" do
@client.query_options[:stream] = false
result = @client.prepare("SELECT 1 UNION SELECT 2").execute
result = @client.prepare("SELECT 1 UNION SELECT 2").execute(stream: false)
expect(result.first).to eql(result.first)
end

it "should throw an exception if we try to iterate twice when streaming is enabled" do
@client.query_options[:stream] = true
@client.query_options[:cache_rows] = false

result = @client.prepare("SELECT 1 UNION SELECT 2").execute

result = @client.prepare("SELECT 1 UNION SELECT 2").execute(stream: true, cache_rows: false)
expect do
result.each {}
result.each {}
end.to raise_exception(Mysql2::Error)

@client.query_options[:stream] = false
@client.query_options[:cache_rows] = true
end
end

Expand Down Expand Up @@ -371,21 +361,20 @@ def stmt_count

context "cast booleans for TINYINT if :cast_booleans is enabled" do
# rubocop:disable Style/Semicolon
let(:client) { new_client(cast_booleans: true) }
let(:id1) { client.query 'INSERT INTO mysql2_test (bool_cast_test) VALUES ( 1)'; client.last_id }
let(:id2) { client.query 'INSERT INTO mysql2_test (bool_cast_test) VALUES ( 0)'; client.last_id }
let(:id3) { client.query 'INSERT INTO mysql2_test (bool_cast_test) VALUES (-1)'; client.last_id }
let(:id1) { @client.query 'INSERT INTO mysql2_test (bool_cast_test) VALUES ( 1)'; @client.last_id }
let(:id2) { @client.query 'INSERT INTO mysql2_test (bool_cast_test) VALUES ( 0)'; @client.last_id }
let(:id3) { @client.query 'INSERT INTO mysql2_test (bool_cast_test) VALUES (-1)'; @client.last_id }
# rubocop:enable Style/Semicolon

after do
client.query "DELETE from mysql2_test WHERE id IN(#{id1},#{id2},#{id3})"
@client.query "DELETE from mysql2_test WHERE id IN(#{id1},#{id2},#{id3})"
end

it "should return TrueClass or FalseClass for a TINYINT value if :cast_booleans is enabled" do
query = client.prepare 'SELECT bool_cast_test FROM mysql2_test WHERE id = ?'
result1 = query.execute id1
result2 = query.execute id2
result3 = query.execute id3
query = @client.prepare 'SELECT bool_cast_test FROM mysql2_test WHERE id = ?'
result1 = query.execute id1, cast_booleans: true
result2 = query.execute id2, cast_booleans: true
result3 = query.execute id3, cast_booleans: true
expect(result1.first['bool_cast_test']).to be true
expect(result2.first['bool_cast_test']).to be false
expect(result3.first['bool_cast_test']).to be true
Expand All @@ -394,19 +383,18 @@ def stmt_count

context "cast booleans for BIT(1) if :cast_booleans is enabled" do
# rubocop:disable Style/Semicolon
let(:client) { new_client(cast_booleans: true) }
let(:id1) { client.query 'INSERT INTO mysql2_test (single_bit_test) VALUES (1)'; client.last_id }
let(:id2) { client.query 'INSERT INTO mysql2_test (single_bit_test) VALUES (0)'; client.last_id }
let(:id1) { @client.query 'INSERT INTO mysql2_test (single_bit_test) VALUES (1)'; @client.last_id }
let(:id2) { @client.query 'INSERT INTO mysql2_test (single_bit_test) VALUES (0)'; @client.last_id }
# rubocop:enable Style/Semicolon

after do
client.query "DELETE from mysql2_test WHERE id IN(#{id1},#{id2})"
@client.query "DELETE from mysql2_test WHERE id IN(#{id1},#{id2})"
end

it "should return TrueClass or FalseClass for a BIT(1) value if :cast_booleans is enabled" do
query = client.prepare 'SELECT single_bit_test FROM mysql2_test WHERE id = ?'
result1 = query.execute id1
result2 = query.execute id2
query = @client.prepare 'SELECT single_bit_test FROM mysql2_test WHERE id = ?'
result1 = query.execute id1, cast_booleans: true
result2 = query.execute id2, cast_booleans: true
expect(result1.first['single_bit_test']).to be true
expect(result2.first['single_bit_test']).to be false
end
Expand Down

0 comments on commit 6709b2b

Please sign in to comment.