Skip to content

Commit 28b5474

Browse files
committed
[GR-4454] Possible race condition in Truffle code installation.
2 parents b20d243 + 0d0a37d commit 28b5474

File tree

11 files changed

+156
-11
lines changed

11 files changed

+156
-11
lines changed

compiler/src/org.graalvm.compiler.asm.aarch64/src/org/graalvm/compiler/asm/aarch64/AArch64Assembler.java

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,7 @@
9494
import static org.graalvm.compiler.asm.aarch64.AArch64Assembler.Instruction.STXR;
9595
import static org.graalvm.compiler.asm.aarch64.AArch64Assembler.Instruction.SUB;
9696
import static org.graalvm.compiler.asm.aarch64.AArch64Assembler.Instruction.SUBS;
97+
import static org.graalvm.compiler.asm.aarch64.AArch64Assembler.Instruction.TBZ;
9798
import static org.graalvm.compiler.asm.aarch64.AArch64Assembler.Instruction.UBFM;
9899
import static org.graalvm.compiler.asm.aarch64.AArch64Assembler.Instruction.UDIV;
99100
import static org.graalvm.compiler.asm.aarch64.AArch64Assembler.InstructionType.FP32;
@@ -475,6 +476,7 @@ public enum Instruction {
475476
BCOND(0x54000000),
476477
CBNZ(0x01000000),
477478
CBZ(0x00000000),
479+
TBZ(0x36000000),
478480

479481
B(0x00000000),
480482
BL(0x80000000),
@@ -808,6 +810,25 @@ protected void cbz(int size, Register reg, int imm21, int pos) {
808810
conditionalBranchInstruction(reg, imm21, generalFromSize(size), Instruction.CBZ, pos);
809811
}
810812

813+
/**
814+
* Test a single bit and branch if the bit is zero.
815+
*
816+
* @param reg general purpose register. May not be null, zero-register or stackpointer.
817+
* @param size Instruction size in bits. Should be either 32 or 64.
818+
* @param uimm6 Unsigned 6-bit bit index.
819+
* @param pos Position at which instruction is inserted into buffer. -1 means insert at end.
820+
*/
821+
protected void tbz(int size, Register reg, int uimm6, int pos) {
822+
assert reg.getRegisterCategory().equals(CPU);
823+
InstructionType type = generalFromSize(size);
824+
int encoding = type.encoding | TBZ.encoding | (uimm6 << 18) | rd(reg);
825+
if (pos == -1) {
826+
emitInt(encoding);
827+
} else {
828+
emitInt(encoding, pos);
829+
}
830+
}
831+
811832
private void conditionalBranchInstruction(Register reg, int imm21, InstructionType type, Instruction instr, int pos) {
812833
assert reg.getRegisterCategory().equals(CPU);
813834
int instrEncoding = instr.encoding | CompareBranchOp;

compiler/src/org.graalvm.compiler.asm.aarch64/src/org/graalvm/compiler/asm/aarch64/AArch64MacroAssembler.java

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1294,6 +1294,26 @@ public void cbz(int size, Register cmp, Label label) {
12941294
}
12951295
}
12961296

1297+
/**
1298+
* Test a single bit and branch if the bit is zero.
1299+
*
1300+
* @param cmp general purpose register. May not be null, zero-register or stackpointer.
1301+
* @param size Instruction size in bits. Should be either 32 or 64.
1302+
* @param uimm6 Unsigned 6-bit bit index.
1303+
* @param label Can only handle 21-bit word-aligned offsets for now. May be unbound. Non null.
1304+
*/
1305+
public void tbz(int size, Register cmp, int uimm6, Label label) {
1306+
if (label.isBound()) {
1307+
int offset = label.position() - position();
1308+
super.tbz(size, cmp, uimm6, offset);
1309+
} else {
1310+
label.addPatchAt(position());
1311+
int regEncoding = cmp.encoding << (PatchLabelKind.INFORMATION_OFFSET + 1);
1312+
int sizeEncoding = (size == 64 ? 1 : 0) << PatchLabelKind.INFORMATION_OFFSET;
1313+
emitInt(PatchLabelKind.BRANCH_CONDITIONALLY.encoding | regEncoding | sizeEncoding);
1314+
}
1315+
}
1316+
12971317
/**
12981318
* Branches to label if condition is true.
12991319
*

compiler/src/org.graalvm.compiler.asm.amd64/src/org/graalvm/compiler/asm/amd64/AMD64Assembler.java

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3492,6 +3492,14 @@ public final void testq(Register dst, Register src) {
34923492
emitByte(0xC0 | encode);
34933493
}
34943494

3495+
public final void btrq(Register src, int imm8) {
3496+
int encode = prefixqAndEncode(src.encoding);
3497+
emitByte(0x0F);
3498+
emitByte(0xBA);
3499+
emitByte(0xF0 | encode);
3500+
emitByte(imm8);
3501+
}
3502+
34953503
public final void xaddl(AMD64Address dst, Register src) {
34963504
prefix(dst, src);
34973505
emitByte(0x0F);

compiler/src/org.graalvm.compiler.asm.sparc/src/org/graalvm/compiler/asm/sparc/SPARCMacroAssembler.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,10 @@ public void jmp(Label l) {
8181
nop(); // delay slot
8282
}
8383

84+
public void bz(Label l) {
85+
BPCC.emit(this, Xcc, ConditionFlag.Zero, NOT_ANNUL, PREDICT_NOT_TAKEN, l);
86+
}
87+
8488
@Override
8589
protected final void patchJumpTarget(int branch, int branchTarget) {
8690
final int disp = (branchTarget - branch) / 4;

compiler/src/org.graalvm.compiler.truffle.hotspot.aarch64/src/org/graalvm/compiler/truffle/hotspot/aarch64/AArch64OptimizedCallTargetInstumentationFactory.java

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@
2424

2525
import static jdk.vm.ci.hotspot.HotSpotCallingConventionType.JavaCall;
2626
import static jdk.vm.ci.meta.JavaKind.Object;
27+
import static org.graalvm.compiler.asm.aarch64.AArch64Assembler.BarrierKind.LOAD_LOAD;
28+
import static org.graalvm.compiler.asm.aarch64.AArch64Assembler.BarrierKind.LOAD_STORE;
2729

2830
import org.graalvm.compiler.asm.Assembler;
2931
import org.graalvm.compiler.asm.Label;
@@ -62,7 +64,11 @@ protected void injectTailCallCode() {
6264
AArch64Address entryPointAddress = AArch64Address.createPairUnscaledImmediateAddress(thisRegister, getFieldOffset("entryPoint", InstalledCode.class));
6365

6466
masm.ldr(64, spillRegister, entryPointAddress);
67+
masm.dmb(LOAD_LOAD);
68+
masm.dmb(LOAD_STORE);
6569
masm.cbz(64, spillRegister, doProlog);
70+
masm.tbz(64, spillRegister, 0, doProlog);
71+
masm.eor(64, spillRegister, spillRegister, 1);
6672
masm.jmp(spillRegister);
6773
masm.nop();
6874
masm.bind(doProlog);

compiler/src/org.graalvm.compiler.truffle.hotspot.amd64/src/org/graalvm/compiler/truffle/hotspot/amd64/AMD64OptimizedCallTargetInstrumentationFactory.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
import jdk.vm.ci.amd64.AMD64;
2727
import jdk.vm.ci.code.CodeCacheProvider;
2828
import jdk.vm.ci.code.InstalledCode;
29+
import jdk.vm.ci.code.MemoryBarriers;
2930
import jdk.vm.ci.code.Register;
3031
import jdk.vm.ci.meta.JavaKind;
3132

@@ -69,8 +70,11 @@ protected void injectTailCallCode() {
6970
*/
7071
asm.movq(spillRegister, codeBlobAddress, true);
7172
assert asm.position() - pos >= AMD64HotSpotBackend.PATCHED_VERIFIED_ENTRY_POINT_INSTRUCTION_SIZE;
73+
asm.membar(MemoryBarriers.JMM_POST_VOLATILE_READ);
7274
asm.testq(spillRegister, spillRegister);
7375
asm.jcc(ConditionFlag.Equal, doProlog);
76+
asm.btrq(spillRegister, 0);
77+
asm.jcc(ConditionFlag.CarryClear, doProlog);
7478
asm.jmp(spillRegister);
7579

7680
asm.bind(doProlog);

compiler/src/org.graalvm.compiler.truffle.hotspot.sparc/src/org/graalvm/compiler/truffle/hotspot/sparc/SPARCOptimizedCallTargetInstumentationFactory.java

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,9 +31,11 @@
3131
import jdk.vm.ci.code.InstalledCode;
3232
import jdk.vm.ci.code.Register;
3333

34+
import jdk.vm.ci.sparc.SPARC;
3435
import org.graalvm.compiler.asm.Assembler;
3536
import org.graalvm.compiler.asm.Label;
3637
import org.graalvm.compiler.asm.sparc.SPARCAddress;
38+
import org.graalvm.compiler.asm.sparc.SPARCAssembler;
3739
import org.graalvm.compiler.asm.sparc.SPARCMacroAssembler;
3840
import org.graalvm.compiler.asm.sparc.SPARCMacroAssembler.ScratchRegister;
3941
import org.graalvm.compiler.code.CompilationResult;
@@ -65,7 +67,11 @@ protected void injectTailCallCode() {
6567
SPARCAddress entryPointAddress = new SPARCAddress(thisRegister, getFieldOffset("entryPoint", InstalledCode.class));
6668

6769
asm.ldx(entryPointAddress, spillRegister);
70+
asm.membar(SPARCAssembler.MEMBAR_LOAD_LOAD | SPARCAssembler.MEMBAR_LOAD_STORE);
6871
asm.compareBranch(spillRegister, 0, Equal, Xcc, doProlog, PREDICT_NOT_TAKEN, null);
72+
asm.andcc(spillRegister, 1, SPARC.g0);
73+
asm.bz(doProlog);
74+
asm.xor(spillRegister, 1, spillRegister);
6975
asm.jmp(spillRegister);
7076
asm.nop();
7177
asm.bind(doProlog);

compiler/src/org.graalvm.compiler.truffle.test/src/org/graalvm/compiler/truffle/test/InstrumentBranchesPhaseTest.java

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -110,8 +110,8 @@ public void simpleIfTest() {
110110
Assert.assertTrue(target.isValid());
111111
target.call();
112112
String stackOutput = instrumentation.accessTableToList(getOptions()).get(0);
113-
Assert.assertTrue(stackOutput.contains("org.graalvm.compiler.truffle.test.InstrumentBranchesPhaseTest$SimpleIfTestNode.execute(InstrumentBranchesPhaseTest.java"));
114-
Assert.assertTrue(stackOutput.contains("[bci: 4]\n[0] state = ELSE(if=0#, else=1#)"));
113+
Assert.assertTrue(stackOutput, stackOutput.contains("org.graalvm.compiler.truffle.test.InstrumentBranchesPhaseTest$SimpleIfTestNode.execute(InstrumentBranchesPhaseTest.java"));
114+
Assert.assertTrue(stackOutput, stackOutput.contains("[bci: 4]\n[0] state = ELSE(if=0#, else=1#)"));
115115
String histogramOutput = instrumentation.accessTableToHistogram().get(0);
116116
Assert.assertEquals(" 0: ********************************************************************************", histogramOutput);
117117
}
@@ -128,10 +128,10 @@ public void twoIfsTest() {
128128
target.call();
129129
target.call();
130130
String stackOutput1 = instrumentation.accessTableToList(getOptions()).get(0);
131-
Assert.assertTrue(stackOutput1.contains("org.graalvm.compiler.truffle.test.InstrumentBranchesPhaseTest$TwoIfsTestNode.execute(InstrumentBranchesPhaseTest.java"));
132-
Assert.assertTrue(stackOutput1.contains("[bci: 4]\n[1] state = ELSE(if=0#, else=2#)"));
131+
Assert.assertTrue(stackOutput1, stackOutput1.contains("org.graalvm.compiler.truffle.test.InstrumentBranchesPhaseTest$TwoIfsTestNode.execute(InstrumentBranchesPhaseTest.java"));
132+
Assert.assertTrue(stackOutput1, stackOutput1.contains("[bci: 4]\n[1] state = ELSE(if=0#, else=2#)"));
133133
String stackOutput2 = instrumentation.accessTableToList(getOptions()).get(1);
134-
Assert.assertTrue(stackOutput2.contains("org.graalvm.compiler.truffle.test.InstrumentBranchesPhaseTest$TwoIfsTestNode.execute(InstrumentBranchesPhaseTest.java"));
135-
Assert.assertTrue(stackOutput2.contains("[bci: 18]\n[2] state = IF(if=2#, else=0#)"));
134+
Assert.assertTrue(stackOutput2, stackOutput2.contains("org.graalvm.compiler.truffle.test.InstrumentBranchesPhaseTest$TwoIfsTestNode.execute(InstrumentBranchesPhaseTest.java"));
135+
Assert.assertTrue(stackOutput2, stackOutput2.contains("[bci: 18]\n[2] state = IF(if=2#, else=0#)"));
136136
}
137137
}

compiler/src/org.graalvm.compiler.truffle/src/org/graalvm/compiler/truffle/OptimizedCallTarget.java

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,8 +76,17 @@
7676
public class OptimizedCallTarget extends InstalledCode implements RootCallTarget, ReplaceObserver, com.oracle.truffle.api.LoopCountReceiver {
7777

7878
private static final String NODE_REWRITING_ASSUMPTION_NAME = "nodeRewritingAssumption";
79+
private static final long ENTRY_POINT_OFFSET;
7980
static final String CALL_BOUNDARY_METHOD_NAME = "callProxy";
8081

82+
static {
83+
try {
84+
ENTRY_POINT_OFFSET = UnsafeAccess.UNSAFE.objectFieldOffset(InstalledCode.class.getDeclaredField("entryPoint"));
85+
} catch (NoSuchFieldException e) {
86+
throw new RuntimeException(e);
87+
}
88+
}
89+
8190
/** The AST to be executed when this call target is called. */
8291
private final RootNode rootNode;
8392

@@ -565,6 +574,19 @@ CompilerOptions getCompilerOptions() {
565574
return DefaultCompilerOptions.INSTANCE;
566575
}
567576

577+
public void releaseEntryPoint() {
578+
long seenEntryPoint = entryPoint;
579+
if (seenEntryPoint == 0) {
580+
return;
581+
}
582+
// No need to retry, since a failure means that the entry point was reset to zero.
583+
// The reason is that the current thread is the only thread calling this method,
584+
// and the only other thread that is changing the entryPoint is the VM itself.
585+
// Furthermore, no other thread will reinstall the call target until the current thread
586+
// completes.
587+
UnsafeAccess.UNSAFE.compareAndSwapLong(this, ENTRY_POINT_OFFSET, seenEntryPoint, seenEntryPoint | 1);
588+
}
589+
568590
private static final class NonTrivialNodeCountVisitor implements NodeVisitor {
569591
public int nodeCount;
570592

compiler/src/org.graalvm.compiler.truffle/src/org/graalvm/compiler/truffle/TruffleCompiler.java

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,6 @@
6161
import com.oracle.truffle.api.nodes.UnexpectedResultException;
6262

6363
import jdk.vm.ci.code.CompilationRequest;
64-
import jdk.vm.ci.code.InstalledCode;
6564
import jdk.vm.ci.meta.MetaAccessProvider;
6665
import jdk.vm.ci.meta.ResolvedJavaMethod;
6766
import jdk.vm.ci.meta.ResolvedJavaType;
@@ -197,7 +196,7 @@ private static void dequeueInlinedCallSites(TruffleInlining inliningDecision, Op
197196
}
198197

199198
@SuppressWarnings("try")
200-
public CompilationResult compileMethodHelper(StructuredGraph graph, String name, PhaseSuite<HighTierContext> graphBuilderSuite, InstalledCode predefinedInstalledCode,
199+
public CompilationResult compileMethodHelper(StructuredGraph graph, String name, PhaseSuite<HighTierContext> graphBuilderSuite, OptimizedCallTarget predefinedInstalledCode,
201200
CompilationRequest compilationRequest) {
202201
try (Scope s = Debug.scope("TruffleFinal")) {
203202
Debug.dump(Debug.BASIC_LEVEL, graph, "After TruffleTier");
@@ -223,11 +222,11 @@ public CompilationResult compileMethodHelper(StructuredGraph graph, String name,
223222
throw Debug.handle(e);
224223
}
225224

226-
compilationNotify.notifyCompilationGraalTierFinished((OptimizedCallTarget) predefinedInstalledCode, graph);
225+
compilationNotify.notifyCompilationGraalTierFinished(predefinedInstalledCode, graph);
227226

228-
InstalledCode installedCode;
227+
OptimizedCallTarget installedCode;
229228
try (DebugCloseable a = CodeInstallationTime.start(); DebugCloseable c = CodeInstallationMemUse.start()) {
230-
installedCode = backend.createInstalledCode(graph.method(), compilationRequest, result, graph.getSpeculationLog(), predefinedInstalledCode, false);
229+
installedCode = (OptimizedCallTarget) backend.createInstalledCode(graph.method(), compilationRequest, result, graph.getSpeculationLog(), predefinedInstalledCode, false);
231230
} catch (Throwable e) {
232231
throw Debug.handle(e);
233232
}
@@ -236,6 +235,10 @@ public CompilationResult compileMethodHelper(StructuredGraph graph, String name,
236235
a.getAssumption().registerInstalledCode(installedCode);
237236
}
238237

238+
if (!providers.getCodeCache().getTarget().arch.getName().equals("aarch64")) {
239+
installedCode.releaseEntryPoint();
240+
}
241+
239242
return result;
240243
}
241244

0 commit comments

Comments
 (0)