17
17
import com .oracle .truffle .api .utilities .BranchProfile ;
18
18
import com .oracle .truffle .api .utilities .ConditionProfile ;
19
19
20
+ import org .jruby .truffle .nodes .cast .BooleanCastNode ;
21
+ import org .jruby .truffle .nodes .cast .BooleanCastNodeFactory ;
20
22
import org .jruby .truffle .nodes .dispatch .CallDispatchHeadNode ;
23
+ import org .jruby .truffle .nodes .dispatch .DispatchHeadNode ;
21
24
import org .jruby .truffle .nodes .dispatch .DispatchHeadNodeFactory ;
22
25
import org .jruby .truffle .runtime .RubyContext ;
23
26
import org .jruby .truffle .runtime .UndefinedPlaceholder ;
@@ -334,12 +337,16 @@ public boolean lessEqual(RubyBignum a, RubyBignum b) {
334
337
@ CoreMethod (names = {"==" , "eql?" }, required = 1 )
335
338
public abstract static class EqualNode extends CoreMethodNode {
336
339
340
+ @ Child private BooleanCastNode booleanCastNode ;
341
+ @ Child private CallDispatchHeadNode reverseCallNode ;
342
+
337
343
public EqualNode (RubyContext context , SourceSection sourceSection ) {
338
344
super (context , sourceSection );
339
345
}
340
346
341
347
public EqualNode (EqualNode prev ) {
342
348
super (prev );
349
+ reverseCallNode = prev .reverseCallNode ;
343
350
}
344
351
345
352
@ Specialization
@@ -361,6 +368,23 @@ public boolean equal(RubyBignum a, double b) {
361
368
public boolean equal (RubyBignum a , RubyBignum b ) {
362
369
return a .bigIntegerValue ().equals (b .bigIntegerValue ());
363
370
}
371
+
372
+ @ Specialization (guards = "!isRubyBignum(arguments[1])" )
373
+ public Object equal (VirtualFrame frame , RubyBignum a , RubyBasicObject b ) {
374
+ if (booleanCastNode == null ) {
375
+ CompilerDirectives .transferToInterpreter ();
376
+ booleanCastNode = insert (BooleanCastNodeFactory .create (getContext (), getSourceSection (), null ));
377
+ }
378
+
379
+ if (reverseCallNode == null ) {
380
+ CompilerDirectives .transferToInterpreter ();
381
+ reverseCallNode = insert (DispatchHeadNodeFactory .createMethodCall (getContext ()));
382
+ }
383
+
384
+ final Object reversedResult = reverseCallNode .call (frame , b , "==" , null , a );
385
+
386
+ return booleanCastNode .executeBoolean (frame , reversedResult );
387
+ }
364
388
}
365
389
366
390
@ CoreMethod (names = "<=>" , required = 1 )
0 commit comments