Skip to content

Commit

Permalink
Merge pull request #200 from jhawthorn/opt_eq
Browse files Browse the repository at this point in the history
Specialize equality and support arbitrary types
  • Loading branch information
maximecb committed Sep 9, 2021
2 parents 7b8fc81 + db585c4 commit 3e08b8a
Show file tree
Hide file tree
Showing 4 changed files with 182 additions and 26 deletions.
26 changes: 26 additions & 0 deletions bootstraptest/test_yjit.rb
Original file line number Diff line number Diff line change
Expand Up @@ -1928,3 +1928,29 @@ def compiled(arg)
ractor.take
}

# Test equality with changing types
assert_equal '[true, false, false, false]', %q{
def eq(a, b)
a == b
end
[
eq("foo", "foo"),
eq("foo", "bar"),
eq(:foo, "bar"),
eq("foo", :bar)
]
}

# Redefined eq
assert_equal 'true', %q{
class String
def ==(other)
true
end
end
"foo" == "bar"
"foo" == "bar"
}
42 changes: 42 additions & 0 deletions test/ruby/test_yjit.rb
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,48 @@ def test_compile_eq_string
assert_compiles('-"foo" == -"bar"', insns: %i[opt_eq], result: false)
end

def test_compile_eq_symbol
assert_compiles(':foo == :foo', insns: %i[opt_eq], result: true)
assert_compiles(':foo == :bar', insns: %i[opt_eq], result: false)
assert_compiles(':foo == "foo".to_sym', insns: %i[opt_eq], result: true)
end

def test_compile_eq_object
assert_compiles(<<~RUBY, insns: %i[opt_eq], result: false)
def eq(a, b)
a == b
end
eq(Object.new, Object.new)
RUBY

assert_compiles(<<~RUBY, insns: %i[opt_eq], result: true)
def eq(a, b)
a == b
end
obj = Object.new
eq(obj, obj)
RUBY
end

def test_compile_eq_arbitrary_class
assert_compiles(<<~RUBY, insns: %i[opt_eq], result: "yes")
def eq(a, b)
a == b
end
class Foo
def ==(other)
"yes"
end
end
eq(Foo.new, Foo.new)
eq(Foo.new, Foo.new)
RUBY
end

def test_compile_set_and_get_global
assert_compiles('$foo = 123; $foo', insns: %i[setglobal], result: 123)
end
Expand Down
6 changes: 0 additions & 6 deletions vm_insnhelper.c
Original file line number Diff line number Diff line change
Expand Up @@ -2037,12 +2037,6 @@ opt_equality_specialized(VALUE recv, VALUE obj)
return RBOOL(recv == obj);
}

VALUE
rb_opt_equality_specialized(VALUE recv, VALUE obj)
{
return opt_equality_specialized(recv, obj);
}

static VALUE
opt_equality(const rb_iseq_t *cd_owner, VALUE recv, VALUE obj, CALL_DATA cd)
{
Expand Down
134 changes: 114 additions & 20 deletions yjit_codegen.c
Original file line number Diff line number Diff line change
Expand Up @@ -1989,36 +1989,103 @@ gen_opt_gt(jitstate_t* jit, ctx_t* ctx)
return gen_fixnum_cmp(jit, ctx, cmovg);
}

VALUE rb_opt_equality_specialized(VALUE recv, VALUE obj);

static codegen_status_t
gen_opt_eq(jitstate_t* jit, ctx_t* ctx)
// Implements specialized equality for either two fixnum or two strings
// Returns true if code was generated, otherwise false
bool
gen_equality_specialized(jitstate_t* jit, ctx_t* ctx, uint8_t *side_exit)
{
uint8_t* side_exit = yjit_side_exit(jit, ctx);
VALUE comptime_a = jit_peek_at_stack(jit, ctx, 1);
VALUE comptime_b = jit_peek_at_stack(jit, ctx, 0);

// Get the operands from the stack
x86opnd_t arg1 = ctx_stack_pop(ctx, 1);
x86opnd_t arg0 = ctx_stack_pop(ctx, 1);
x86opnd_t a_opnd = ctx_stack_opnd(ctx, 1);
x86opnd_t b_opnd = ctx_stack_opnd(ctx, 0);

// Call rb_opt_equality_specialized(VALUE recv, VALUE obj)
// We know this method won't allocate or perform calls
mov(cb, C_ARG_REGS[0], arg0);
mov(cb, C_ARG_REGS[1], arg1);
call_ptr(cb, REG0, (void *)rb_opt_equality_specialized);
if (FIXNUM_P(comptime_a) && FIXNUM_P(comptime_b)) {
if (!assume_bop_not_redefined(jit->block, INTEGER_REDEFINED_OP_FLAG, BOP_EQ)) {
return YJIT_CANT_COMPILE;
}

// If val == Qundef, bail to do a method call
cmp(cb, RAX, imm_opnd(Qundef));
je_ptr(cb, side_exit);
guard_two_fixnums(ctx, side_exit);

// Push the return value onto the stack
x86opnd_t stack_ret = ctx_stack_push(ctx, TYPE_IMM);
mov(cb, stack_ret, RAX);
mov(cb, REG0, a_opnd);
cmp(cb, REG0, b_opnd);

return YJIT_KEEP_COMPILING;
mov(cb, REG0, imm_opnd(Qfalse));
mov(cb, REG1, imm_opnd(Qtrue));
cmove(cb, REG0, REG1);

// Push the output on the stack
ctx_stack_pop(ctx, 2);
x86opnd_t dst = ctx_stack_push(ctx, TYPE_IMM);
mov(cb, dst, REG0);

return true;
} else if (CLASS_OF(comptime_a) == rb_cString &&
CLASS_OF(comptime_b) == rb_cString) {
if (!assume_bop_not_redefined(jit->block, STRING_REDEFINED_OP_FLAG, BOP_EQ)) {
return YJIT_CANT_COMPILE;
}

// Load a and b in preparation for call later
mov(cb, C_ARG_REGS[0], a_opnd);
mov(cb, C_ARG_REGS[1], b_opnd);

// Guard that a is a String
mov(cb, REG0, C_ARG_REGS[0]);
jit_guard_known_klass(jit, ctx, rb_cString, OPND_STACK(1), comptime_a, SEND_MAX_DEPTH, side_exit);

uint32_t ret = cb_new_label(cb, "ret");

// If they are equal by identity, return true
cmp(cb, C_ARG_REGS[0], C_ARG_REGS[1]);
mov(cb, RAX, imm_opnd(Qtrue));
je_label(cb, ret);

// Otherwise guard that b is a T_STRING (from type info) or String (from runtime guard)
if (ctx_get_opnd_type(ctx, OPND_STACK(0)).type != ETYPE_STRING) {
mov(cb, REG0, C_ARG_REGS[1]);
// Note: any T_STRING is valid here, but we check for a ::String for simplicity
jit_guard_known_klass(jit, ctx, rb_cString, OPND_STACK(0), comptime_b, SEND_MAX_DEPTH, side_exit);
}

// Call rb_str_eql_internal(a, b)
call_ptr(cb, REG0, (void *)rb_str_eql_internal);

// Push the output on the stack
cb_write_label(cb, ret);
ctx_stack_pop(ctx, 2);
x86opnd_t dst = ctx_stack_push(ctx, TYPE_IMM);
mov(cb, dst, RAX);
cb_link_labels(cb);

return true;
} else {
return false;
}
}

static codegen_status_t gen_opt_send_without_block(jitstate_t *jit, ctx_t *ctx);

static codegen_status_t
gen_opt_eq(jitstate_t* jit, ctx_t* ctx)
{
// Defer compilation so we can specialize base on a runtime receiver
if (!jit_at_current_insn(jit)) {
defer_compilation(jit->block, jit->insn_idx, ctx);
return YJIT_END_BLOCK;
}

// Create a size-exit to fall back to the interpreter
uint8_t *side_exit = yjit_side_exit(jit, ctx);

if (gen_equality_specialized(jit, ctx, side_exit)) {
jit_jump_to_next_insn(jit, ctx);
return YJIT_END_BLOCK;
} else {
return gen_opt_send_without_block(jit, ctx);
}
}

static codegen_status_t gen_send_general(jitstate_t *jit, ctx_t *ctx, struct rb_call_data *cd, rb_iseq_t *block);

static codegen_status_t
Expand Down Expand Up @@ -2834,6 +2901,26 @@ jit_rb_false(jitstate_t *jit, ctx_t *ctx, const struct rb_callinfo *ci, const rb
return true;
}

// Codegen for rb_obj_equal()
// object identity comparison
static bool
jit_rb_obj_equal(jitstate_t *jit, ctx_t *ctx, const struct rb_callinfo *ci, const rb_callable_method_entry_t *cme, rb_iseq_t *block, const int32_t argc)
{
ADD_COMMENT(cb, "equal?");
x86opnd_t obj1 = ctx_stack_pop(ctx, 1);
x86opnd_t obj2 = ctx_stack_pop(ctx, 1);

mov(cb, REG0, obj1);
cmp(cb, REG0, obj2);
mov(cb, REG0, imm_opnd(Qtrue));
mov(cb, REG1, imm_opnd(Qfalse));
cmovne(cb, REG0, REG1);

x86opnd_t stack_ret = ctx_stack_push(ctx, TYPE_IMM);
mov(cb, stack_ret, REG0);
return true;
}

// Check if we know how to codegen for a particular cfunc method
static method_codegen_t
lookup_cfunc_codegen(const rb_method_definition_t *def)
Expand Down Expand Up @@ -4127,4 +4214,11 @@ yjit_init_codegen(void)

yjit_reg_method(rb_cNilClass, "nil?", jit_rb_true);
yjit_reg_method(rb_mKernel, "nil?", jit_rb_false);

yjit_reg_method(rb_cBasicObject, "==", jit_rb_obj_equal);
yjit_reg_method(rb_cBasicObject, "equal?", jit_rb_obj_equal);
yjit_reg_method(rb_mKernel, "eql?", jit_rb_obj_equal);
yjit_reg_method(rb_cModule, "==", jit_rb_obj_equal);
yjit_reg_method(rb_cSymbol, "==", jit_rb_obj_equal);
yjit_reg_method(rb_cSymbol, "===", jit_rb_obj_equal);
}

0 comments on commit 3e08b8a

Please sign in to comment.