Skip to content

Commit

Permalink
Merge pull request #349 from Shopify/better-buffer-stack-fix
Browse files Browse the repository at this point in the history
Better fix for marking buffers held on the stack
  • Loading branch information
byroot committed Jul 17, 2023
2 parents c3d6367 + e4ca627 commit 6dcdb39
Show file tree
Hide file tree
Showing 3 changed files with 76 additions and 37 deletions.
73 changes: 72 additions & 1 deletion ext/msgpack/buffer_class.c
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@
#include "buffer.h"
#include "buffer_class.h"

VALUE cMessagePack_Buffer;
VALUE cMessagePack_Buffer = Qnil;
VALUE cMessagePack_HeldBuffer = Qnil;

static ID s_read;
static ID s_readpartial;
Expand All @@ -34,6 +35,73 @@ static VALUE sym_read_reference_threshold;
static VALUE sym_write_reference_threshold;
static VALUE sym_io_buffer_size;

typedef struct msgpack_held_buffer_t msgpack_held_buffer_t;
struct msgpack_held_buffer_t {
size_t size;
VALUE mapped_strings[];
};

static void HeldBuffer_mark(void *data)
{
msgpack_held_buffer_t* held_buffer = (msgpack_held_buffer_t*)data;
for (size_t index = 0; index < held_buffer->size; index++) {
rb_gc_mark(held_buffer->mapped_strings[index]);
}
}

static size_t HeldBuffer_memsize(const void *data)
{
const msgpack_held_buffer_t* held_buffer = (msgpack_held_buffer_t*)data;
return sizeof(size_t) + sizeof(VALUE) * held_buffer->size;
}

static const rb_data_type_t held_buffer_data_type = {
.wrap_struct_name = "msgpack:held_buffer",
.function = {
.dmark = HeldBuffer_mark,
.dfree = RUBY_TYPED_DEFAULT_FREE,
.dsize = HeldBuffer_memsize,
},
.flags = RUBY_TYPED_FREE_IMMEDIATELY
};

VALUE MessagePack_Buffer_hold(msgpack_buffer_t* buffer)
{
size_t mapped_strings_count = 0;
msgpack_buffer_chunk_t* c = buffer->head;
while (c != &buffer->tail) {
if (c->mapped_string != NO_MAPPED_STRING) {
mapped_strings_count++;
}
c = c->next;
}
if (c->mapped_string != NO_MAPPED_STRING) {
mapped_strings_count++;
}

if (mapped_strings_count == 0) {
return Qnil;
}

msgpack_held_buffer_t* held_buffer = xmalloc(sizeof(msgpack_held_buffer_t) + mapped_strings_count * sizeof(VALUE));

c = buffer->head;
mapped_strings_count = 0;
while (c != &buffer->tail) {
if (c->mapped_string != NO_MAPPED_STRING) {
held_buffer->mapped_strings[mapped_strings_count] = c->mapped_string;
mapped_strings_count++;
}
c = c->next;
}
if (c->mapped_string != NO_MAPPED_STRING) {
held_buffer->mapped_strings[mapped_strings_count] = c->mapped_string;
mapped_strings_count++;
}
held_buffer->size = mapped_strings_count;
return TypedData_Wrap_Struct(cMessagePack_HeldBuffer, &held_buffer_data_type, held_buffer);
}


#define CHECK_STRING_TYPE(value) \
value = rb_check_string_type(value); \
Expand Down Expand Up @@ -520,6 +588,9 @@ void MessagePack_Buffer_module_init(VALUE mMessagePack)

msgpack_buffer_static_init();

cMessagePack_HeldBuffer = rb_define_class_under(mMessagePack, "HeldBuffer", rb_cBasicObject);
rb_undef_alloc_func(cMessagePack_HeldBuffer);

cMessagePack_Buffer = rb_define_class_under(mMessagePack, "Buffer", rb_cObject);

rb_define_alloc_func(cMessagePack_Buffer, Buffer_alloc);
Expand Down
1 change: 1 addition & 0 deletions ext/msgpack/buffer_class.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ extern VALUE cMessagePack_Buffer;
void MessagePack_Buffer_module_init(VALUE mMessagePack);

VALUE MessagePack_Buffer_wrap(msgpack_buffer_t* b, VALUE owner);
VALUE MessagePack_Buffer_hold(msgpack_buffer_t* b);

void MessagePack_Buffer_set_options(msgpack_buffer_t* b, VALUE io, VALUE options);

Expand Down
39 changes: 3 additions & 36 deletions ext/msgpack/packer.c
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
*/

#include "packer.h"
#include "buffer_class.h"

#if !defined(HAVE_RB_PROC_CALL_WITH_BLOCK)
#define rb_proc_call_with_block(recv, argc, argv, block) rb_funcallv(recv, rb_intern("call"), argc, argv)
Expand Down Expand Up @@ -114,39 +115,7 @@ bool msgpack_packer_try_write_with_ext_type_lookup(msgpack_packer_t* pk, VALUE v
}

if(ext_flags & MSGPACK_EXT_RECURSIVE) {
// HACK: While we call the proc, the current pk->buffer won't be reachable
// as it will be stored on the stack.
// To ensure all the `mapped_string` reference in that buffer are properly
// marked and pined, we copy them all on the stack.
VALUE* mapped_strings = NULL;
size_t mapped_strings_count = 0;
msgpack_buffer_chunk_t* c = pk->buffer.head;
while (c != &pk->buffer.tail) {
if (c->mapped_string != NO_MAPPED_STRING) {
mapped_strings_count++;
}
c = c->next;
}
if (c->mapped_string != NO_MAPPED_STRING) {
mapped_strings_count++;
}

if (mapped_strings_count > 0) {
mapped_strings = ALLOCA_N(VALUE, mapped_strings_count);
mapped_strings_count = 0;
c = pk->buffer.head;
while (c != &pk->buffer.tail) {
if (c->mapped_string != NO_MAPPED_STRING) {
mapped_strings[mapped_strings_count] = c->mapped_string;
mapped_strings_count++;
}
c = c->next;
}
if (c->mapped_string != NO_MAPPED_STRING) {
mapped_strings[mapped_strings_count] = c->mapped_string;
mapped_strings_count++;
}
}
VALUE held_buffer = MessagePack_Buffer_hold(&pk->buffer);

msgpack_buffer_t parent_buffer = pk->buffer;
msgpack_buffer_init(PACKER_BUFFER_(pk));
Expand All @@ -167,9 +136,7 @@ bool msgpack_packer_try_write_with_ext_type_lookup(msgpack_packer_t* pk, VALUE v
msgpack_packer_write_ext(pk, ext_type, payload);
}

if (mapped_strings_count > 0) {
RB_GC_GUARD(mapped_strings[0]);
}
RB_GC_GUARD(held_buffer);
} else {
VALUE payload = rb_proc_call_with_block(proc, 1, &v, Qnil);
StringValue(payload);
Expand Down

0 comments on commit 6dcdb39

Please sign in to comment.