diff --git a/ext/mysql2/result.c b/ext/mysql2/result.c index 0ab38aabe..aa99cad9f 100644 --- a/ext/mysql2/result.c +++ b/ext/mysql2/result.c @@ -22,7 +22,7 @@ static rb_encoding *binaryEncoding; typedef struct { int symbolizeKeys; - int asArray; + int rowsAs; int castBool; int cacheRows; int cast; @@ -41,6 +41,13 @@ static VALUE sym_symbolize_keys, sym_as, sym_array, sym_database_timezone, sym_application_timezone, sym_local, sym_utc, sym_cast_booleans, sym_cache_rows, sym_cast, sym_stream, sym_name; +static VALUE sym_struct; + +/* internal rowsAs constants */ +#define AS_HASH 0 +#define AS_ARRAY 1 +#define AS_STRUCT 2 + /* Mark any VALUEs that are only referenced in C, so the GC won't get them. */ static void rb_mysql_result_mark(void * wrapper) { mysql2_result_wrapper * w = wrapper; @@ -48,6 +55,7 @@ static void rb_mysql_result_mark(void * wrapper) { rb_gc_mark(w->fields); rb_gc_mark(w->rows); rb_gc_mark(w->encoding); + rb_gc_mark(w->rowStruct); rb_gc_mark(w->client); rb_gc_mark(w->statement); } @@ -287,6 +295,21 @@ static void rb_mysql_result_alloc_result_buffers(VALUE self, MYSQL_FIELD *fields } } +static VALUE cast_row_as_struct(VALUE self, VALUE rowVal, mysql2_result_wrapper *wrapper) +{ + /* create struct from intermediate array */ + if (wrapper->rowStruct == Qnil) { + unsigned int i; + VALUE *argv_fields = ALLOCA_N(VALUE, wrapper->numberOfFields); + for (i = 0; i < wrapper->numberOfFields; i++) { + argv_fields[i] = rb_mysql_result_fetch_field(self, i, 1); + } + wrapper->rowStruct = rb_funcall2(rb_cStruct, intern_new, (int) wrapper->numberOfFields, argv_fields); + } + + return rb_struct_alloc(wrapper->rowStruct, rowVal); +} + static VALUE rb_mysql_result_fetch_row_stmt(VALUE self, MYSQL_FIELD * fields, const result_each_args *args) { VALUE rowVal; @@ -303,11 +326,6 @@ static VALUE rb_mysql_result_fetch_row_stmt(VALUE self, MYSQL_FIELD * fields, co wrapper->numberOfFields = mysql_num_fields(wrapper->result); wrapper->fields = rb_ary_new2(wrapper->numberOfFields); } - if (args->asArray) { - rowVal = rb_ary_new2(wrapper->numberOfFields); - } else { - rowVal = rb_hash_new(); - } if (wrapper->result_buffers == NULL) { rb_mysql_result_alloc_result_buffers(self, fields); @@ -336,6 +354,12 @@ static VALUE rb_mysql_result_fetch_row_stmt(VALUE self, MYSQL_FIELD * fields, co } } + if (args->rowsAs == AS_HASH) { + rowVal = rb_hash_new(); + } else /* array or struct */ { + rowVal = rb_ary_new2(wrapper->numberOfFields); + } + for (i = 0; i < wrapper->numberOfFields; i++) { VALUE field = rb_mysql_result_fetch_field(self, i, args->symbolizeKeys); VALUE val = Qnil; @@ -464,13 +488,17 @@ static VALUE rb_mysql_result_fetch_row_stmt(VALUE self, MYSQL_FIELD * fields, co } } - if (args->asArray) { - rb_ary_push(rowVal, val); - } else { + if (args->rowsAs == AS_HASH) { rb_hash_aset(rowVal, field, val); + } else /* array or struct */ { + rb_ary_push(rowVal, val); } } + if (args->rowsAs == AS_STRUCT) { + rowVal = cast_row_as_struct(self, rowVal, wrapper); + } + return rowVal; } @@ -498,10 +526,12 @@ static VALUE rb_mysql_result_fetch_row(VALUE self, MYSQL_FIELD * fields, const r wrapper->numberOfFields = mysql_num_fields(wrapper->result); wrapper->fields = rb_ary_new2(wrapper->numberOfFields); } - if (args->asArray) { - rowVal = rb_ary_new2(wrapper->numberOfFields); - } else { + + if (args->rowsAs == AS_HASH) { rowVal = rb_hash_new(); + } else /* array or struct */ { + /* struct uses array as an intermediary */ + rowVal = rb_ary_new2(wrapper->numberOfFields); } fieldLengths = mysql_fetch_lengths(wrapper->result); @@ -671,19 +701,24 @@ static VALUE rb_mysql_result_fetch_row(VALUE self, MYSQL_FIELD * fields, const r break; } } - if (args->asArray) { - rb_ary_push(rowVal, val); - } else { + if (args->rowsAs == AS_HASH) { rb_hash_aset(rowVal, field, val); + } else /* array or struct */ { + rb_ary_push(rowVal, val); } } else { - if (args->asArray) { - rb_ary_push(rowVal, Qnil); - } else { + if (args->rowsAs == AS_HASH) { rb_hash_aset(rowVal, field, Qnil); + } else /* array or struct */ { + rb_ary_push(rowVal, Qnil); } } } + + if (args->rowsAs == AS_STRUCT) { + rowVal = cast_row_as_struct(self, rowVal, wrapper); + } + return rowVal; } @@ -696,7 +731,7 @@ static VALUE rb_mysql_result_fetch_fields(VALUE self) { defaults = rb_iv_get(self, "@query_options"); Check_Type(defaults, T_HASH); - if (rb_hash_aref(defaults, sym_symbolize_keys) == Qtrue) { + if (rb_hash_aref(defaults, sym_symbolize_keys) == Qtrue || rb_hash_aref(defaults, sym_as) == sym_struct) { symbolizeKeys = 1; } @@ -807,9 +842,9 @@ static VALUE rb_mysql_result_each_(VALUE self, static VALUE rb_mysql_result_each(int argc, VALUE * argv, VALUE self) { result_each_args args; - VALUE defaults, opts, block, (*fetch_row_func)(VALUE, MYSQL_FIELD *fields, const result_each_args *args); + VALUE defaults, opts, as_opt, block, (*fetch_row_func)(VALUE, MYSQL_FIELD *fields, const result_each_args *args); ID db_timezone, app_timezone, dbTz, appTz; - int symbolizeKeys, asArray, castBool, cacheRows, cast; + int symbolizeKeys, rowsAs, castBool, cacheRows, cast; GET_RESULT(self); @@ -826,11 +861,20 @@ static VALUE rb_mysql_result_each(int argc, VALUE * argv, VALUE self) { } symbolizeKeys = RTEST(rb_hash_aref(opts, sym_symbolize_keys)); - asArray = rb_hash_aref(opts, sym_as) == sym_array; castBool = RTEST(rb_hash_aref(opts, sym_cast_booleans)); cacheRows = RTEST(rb_hash_aref(opts, sym_cache_rows)); cast = RTEST(rb_hash_aref(opts, sym_cast)); + as_opt = rb_hash_aref(opts, sym_as); + if (as_opt == sym_array) { + rowsAs = AS_ARRAY; + } else if (as_opt == sym_struct) { + rowsAs = AS_STRUCT; + symbolizeKeys = 1; /* force */ + } else { + rowsAs = AS_HASH; + } + if (wrapper->is_streaming && cacheRows) { rb_warn(":cache_rows is ignored if :stream is true"); } @@ -879,7 +923,7 @@ static VALUE rb_mysql_result_each(int argc, VALUE * argv, VALUE self) { // Backward compat args.symbolizeKeys = symbolizeKeys; - args.asArray = asArray; + args.rowsAs = rowsAs; args.castBool = castBool; args.cacheRows = cacheRows; args.cast = cast; @@ -930,6 +974,7 @@ VALUE rb_mysql_result_to_obj(VALUE client, VALUE encoding, VALUE options, MYSQL_ wrapper->result = r; wrapper->fields = Qnil; wrapper->rows = Qnil; + wrapper->rowStruct = Qnil; wrapper->encoding = encoding; wrapper->streamingComplete = 0; wrapper->client = client; @@ -983,6 +1028,7 @@ void init_mysql2_result() { sym_symbolize_keys = ID2SYM(rb_intern("symbolize_keys")); sym_as = ID2SYM(rb_intern("as")); sym_array = ID2SYM(rb_intern("array")); + sym_struct = ID2SYM(rb_intern("struct")); sym_local = ID2SYM(rb_intern("local")); sym_utc = ID2SYM(rb_intern("utc")); sym_cast_booleans = ID2SYM(rb_intern("cast_booleans")); diff --git a/ext/mysql2/result.h b/ext/mysql2/result.h index 0c25b24b6..afbed4a56 100644 --- a/ext/mysql2/result.h +++ b/ext/mysql2/result.h @@ -10,6 +10,7 @@ typedef struct { VALUE client; VALUE encoding; VALUE statement; + VALUE rowStruct; my_ulonglong numberOfFields; my_ulonglong numberOfRows; unsigned long lastRowProcessed; diff --git a/spec/mysql2/result_spec.rb b/spec/mysql2/result_spec.rb index a70b38ef0..669037d3c 100644 --- a/spec/mysql2/result_spec.rb +++ b/spec/mysql2/result_spec.rb @@ -72,6 +72,12 @@ end end + it "should be able to return results as a struct" do + @result.each(as: :struct) do |row| + expect(row).to be_kind_of(Struct) + end + end + it "should cache previously yielded results by default" do expect(@result.first.object_id).to eql(@result.first.object_id) end @@ -117,6 +123,11 @@ result = @client.query "SELECT 'a', 'b', 'c'" expect(result.fields).to eql(%w[a b c]) end + + it "should return field names as symbols if rows are structs" do + result = @client.query "SELECT 'a', 'b', 'c'", as: :struct + expect(result.fields.first).to be_an_instance_of(Symbol) + end end context "streaming" do diff --git a/spec/mysql2/statement_spec.rb b/spec/mysql2/statement_spec.rb index dbc185e6b..1a3216659 100644 --- a/spec/mysql2/statement_spec.rb +++ b/spec/mysql2/statement_spec.rb @@ -294,6 +294,13 @@ def stmt_count end end + it "should be able to return results as a struct" do + @result = @client.prepare("SELECT 1").execute(as: :struct) + @result.each do |row| + expect(row).to be_kind_of(Struct) + end + end + it "should cache previously yielded results by default" do @result = @client.prepare("SELECT 1").execute expect(@result.first.object_id).to eql(@result.first.object_id) @@ -329,6 +336,11 @@ def stmt_count expect(stmt.fields).to eql(%w[a b c]) end + it "should return field names as symbols if rows are structs" do + result = @client.prepare("SELECT 'a', 'b', 'c'").execute(as: :struct) + expect(result.fields.first).to be_an_instance_of(Symbol) + end + it "should return nil for statement with no result fields" do stmt = @client.prepare("INSERT INTO mysql2_test () VALUES ()") expect(stmt.fields).to eql(nil)