From 1e241c0044042409b8567c5ab9a72b7525ad128a Mon Sep 17 00:00:00 2001 From: Katelyn Gadd Date: Wed, 22 Mar 2023 20:52:33 -0700 Subject: [PATCH] [wasm] More accurate jiterpreter cfg size estimation; generate smaller dispatch tables (#83759) * More accurate cfg size estimation * Generate smaller dispatch tables for traces with backward branches * Make sure we never actually can dispatch to the unreachable entries in the back branch table * If we somehow generate a module bigger than 4KB, don't try to compile it. Just log a warning * Better cfg logging for failed branches * Add a separate runtime option that controls whether trace monitoring will print to the log --- src/mono/mono/mini/interp/jiterpreter.c | 8 +- src/mono/mono/utils/options-def.h | 2 + src/mono/wasm/runtime/jiterpreter-support.ts | 77 +++++++++++++------ .../runtime/jiterpreter-trace-generator.ts | 4 +- src/mono/wasm/runtime/jiterpreter.ts | 10 ++- 5 files changed, 68 insertions(+), 33 deletions(-) diff --git a/src/mono/mono/mini/interp/jiterpreter.c b/src/mono/mono/mini/interp/jiterpreter.c index 31ba66e601aae..c6374da03ead8 100644 --- a/src/mono/mono/mini/interp/jiterpreter.c +++ b/src/mono/mono/mini/interp/jiterpreter.c @@ -1342,7 +1342,6 @@ mono_jiterp_write_number_unaligned (void *dest, double value, int mode) { } #define TRACE_PENALTY_LIMIT 200 -#define TRACE_MONITORING_DETAILED FALSE ptrdiff_t mono_jiterp_monitor_trace (const guint16 *ip, void *_frame, void *locals) @@ -1377,7 +1376,8 @@ mono_jiterp_monitor_trace (const guint16 *ip, void *_frame, void *locals) int penalty = MIN ((int)((1.0f - scaled) * TRACE_PENALTY_LIMIT), TRACE_PENALTY_LIMIT); info->penalty_total += penalty; - // g_print ("trace #%d @%d '%s' bailout recorded at opcode #%d, penalty=%d\n", index, ip, frame->imethod->method->name, cinfo.bailout_opcode_count, penalty); + if (mono_opt_jiterpreter_trace_monitoring_log > 2) + g_print ("trace #%d @%d '%s' bailout recorded at opcode #%d, penalty=%d\n", index, ip, frame->imethod->method->name, cinfo.bailout_opcode_count, penalty); } gint64 hit_count = info->hit_count++ - mono_opt_jiterpreter_minimum_trace_hit_count; @@ -1394,11 +1394,11 @@ mono_jiterp_monitor_trace (const guint16 *ip, void *_frame, void *locals) *(volatile JiterpreterThunk*)(ip + 1) = thunk; mono_memory_barrier (); *mutable_ip = MINT_TIER_ENTER_JITERPRETER; - if (mono_opt_jiterpreter_stats_enabled && TRACE_MONITORING_DETAILED) + if (mono_opt_jiterpreter_trace_monitoring_log > 1) g_print ("trace #%d @%d '%s' accepted; average_penalty %f <= %f\n", index, ip, frame->imethod->method->name, average_penalty, threshold); } else { traces_rejected++; - if (mono_opt_jiterpreter_stats_enabled) { + if (mono_opt_jiterpreter_trace_monitoring_log > 0) { char * full_name = mono_method_get_full_name (frame->imethod->method); g_print ("trace #%d @%d '%s' rejected; average_penalty %f > %f\n", index, ip, full_name, average_penalty, threshold); g_free (full_name); diff --git a/src/mono/mono/utils/options-def.h b/src/mono/mono/utils/options-def.h index 9637e0541db11..b9454c0073305 100644 --- a/src/mono/mono/utils/options-def.h +++ b/src/mono/mono/utils/options-def.h @@ -127,6 +127,8 @@ DEFINE_INT(jiterpreter_trace_monitoring_short_distance, "jiterpreter-trace-monit DEFINE_INT(jiterpreter_trace_monitoring_long_distance, "jiterpreter-trace-monitoring-long-distance", 10, "Traces that exit after processing this many opcodes have no exit penalty") // the average penalty value for a trace is compared against this threshold / 100 to decide whether to discard it DEFINE_INT(jiterpreter_trace_monitoring_max_average_penalty, "jiterpreter-trace-monitoring-max-average-penalty", 75, "If the average penalty value for a trace is above this value it will be rejected") +// 0 = no monitoring, 1 = log when rejecting a trace, 2 = log when accepting or rejecting a trace, 3 = log every recorded bailout +DEFINE_INT(jiterpreter_trace_monitoring_log, "jiterpreter-trace-monitoring-log", 0, "Logging detail level for trace monitoring") // After a do_jit_call call site is hit this many times, we will queue it to be jitted DEFINE_INT(jiterpreter_jit_call_trampoline_hit_count, "jiterpreter-jit-call-hit-count", 1000, "Queue specialized do_jit_call trampoline for JIT after this many hits") // After a do_jit_call call site is hit this many times without being jitted, we will flush the JIT queue diff --git a/src/mono/wasm/runtime/jiterpreter-support.ts b/src/mono/wasm/runtime/jiterpreter-support.ts index d9ee083ced0c6..148680428641b 100644 --- a/src/mono/wasm/runtime/jiterpreter-support.ts +++ b/src/mono/wasm/runtime/jiterpreter-support.ts @@ -1007,13 +1007,13 @@ class Cfg { entryBlob!: CfgBlob; blockStack: Array = []; dispatchTable = new Map(); - trace = false; + trace = 0; constructor (builder: WasmBuilder) { this.builder = builder; } - initialize (startOfBody: MintOpcodePtr, backBranchTargets: Uint16Array | null, trace: boolean) { + initialize (startOfBody: MintOpcodePtr, backBranchTargets: Uint16Array | null, trace: number) { this.segments.length = 0; this.blockStack.length = 0; this.startOfBody = startOfBody; @@ -1034,9 +1034,11 @@ class Cfg { mono_assert(this.segments[0].type === "blob", "expected blob"); this.entryBlob = this.segments[0]; this.segments.length = 0; - this.overheadBytes += 9; // entry eip init + block + optional loop - if (this.backBranchTargets) - this.overheadBytes += 24; // some extra padding for the dispatch br_table + this.overheadBytes += 9; // entry disp init + block + optional loop + if (this.backBranchTargets) { + this.overheadBytes += 20; // some extra padding for the dispatch br_table + this.overheadBytes += this.backBranchTargets.length; // one byte for each target in the table + } } appendBlob () { @@ -1051,6 +1053,8 @@ class Cfg { }); this.lastSegmentStartIp = this.ip; this.lastSegmentEnd = this.builder.current.size; + // each segment generates a block + this.overheadBytes += 2; } startBranchBlock (ip: MintOpcodePtr, isBackBranchTarget: boolean) { @@ -1060,9 +1064,7 @@ class Cfg { ip, isBackBranchTarget, }); - this.overheadBytes += 3; // each branch block just costs us a block (2 bytes) and an end - if (this.backBranchTargets) - this.overheadBytes += 3; // size of the br_table entry for this branch target + this.overheadBytes += 1; // each branch block just costs us an end } branch (target: MintOpcodePtr, isBackward: boolean, isConditional: boolean) { @@ -1074,9 +1076,17 @@ class Cfg { isBackward, isConditional, }); - this.overheadBytes += 3; // forward branches are a constant br + depth (optimally 2 bytes) - if (isBackward) - this.overheadBytes += 4; // back branches are more complex + // some branches will generate bailouts instead so we allocate 4 bytes per branch + // to try and balance this out and avoid underestimating too much + this.overheadBytes += 4; // forward branches are a constant br + depth (optimally 2 bytes) + if (isBackward) { + // get_local + // i32_const 1 + // i32_store 0 0 + // i32.const + // set_local + this.overheadBytes += 11; + } } emitBlob (segment: CfgBlob, source: Uint8Array) { @@ -1135,11 +1145,21 @@ class Cfg { // br_table // we have to assign disp==0 to fallthrough so that we start at the top of the fn body, then // assign disp values starting from 1 to branch targets - this.builder.appendULeb(this.blockStack.length + 1); + // FIXME: Only include back branch targets that are *also* in the block stack. This is necessary + // when starting a trace in the middle of a method to make the table smaller + this.builder.appendULeb(this.backBranchTargets.length + 1); this.builder.appendULeb(1); // br depth of 1 = skip the unreachable and fall through to the start - for (let i = 0; i < this.blockStack.length; i++) { - this.dispatchTable.set(this.blockStack[i], i + 1); - this.builder.appendULeb(i + 2); // add 2 to the depth because of the double block around it + for (let i = 0; i < this.backBranchTargets.length; i++) { + const offset = (this.backBranchTargets[i] * 2) + this.startOfBody; + const breakDepth = this.blockStack.indexOf(offset); + if (breakDepth >= 0) { + this.dispatchTable.set(offset, i + 1); + this.builder.appendULeb(breakDepth + 2); // add 2 to the depth because of the double block around it + } else { + // This means the back branch target is outside of the trace. It shouldn't be possible to reach this + // and we didn't add it to the dispatch table anyway + this.builder.appendULeb(0); + } } this.builder.appendULeb(0); // for unrecognized value we br 0, which causes us to trap this.builder.endBlock(); @@ -1150,7 +1170,7 @@ class Cfg { this.blockStack.push(dispatchIp); } - if (this.trace) + if (this.trace > 1) console.log(`blockStack=${this.blockStack}`); for (let i = 0; i < this.segments.length; i++) { @@ -1173,14 +1193,15 @@ class Cfg { } case "branch": { const lookupTarget = segment.isBackward ? dispatchIp : segment.target; - let indexInStack = this.blockStack.indexOf(lookupTarget); + let indexInStack = this.blockStack.indexOf(lookupTarget), + successfulBackBranch = false; // Back branches will target the dispatcher loop so we need to update the dispatch index // which will be used by the loop dispatch br_table to jump to the correct location - if (segment.isBackward && (indexInStack >= 0)) { + if (segment.isBackward) { if (this.dispatchTable.has(segment.target)) { const disp = this.dispatchTable.get(segment.target)!; - if (this.trace) + if (this.trace > 1) console.log(`backward br from ${(segment.from).toString(16)} to ${(segment.target).toString(16)}: disp=${disp}`); // set the backward branch taken flag in the cinfo so that the monitoring phase @@ -1195,23 +1216,29 @@ class Cfg { // set the dispatch index for the br_table this.builder.i32_const(disp); this.builder.local("disp", WasmOpcode.set_local); + successfulBackBranch = true; } else { - if (this.trace) + if (this.trace > 0) console.log(`br from ${(segment.from).toString(16)} to ${(segment.target).toString(16)} failed: back branch target not in dispatch table`); indexInStack = -1; } } - if (indexInStack >= 0) { + if ((indexInStack >= 0) || successfulBackBranch) { // Conditional branches are nested in an extra block, so the depth is +1 const offset = segment.isConditional ? 1 : 0; this.builder.appendU8(WasmOpcode.br); this.builder.appendULeb(offset + indexInStack); - if (this.trace) + if (this.trace > 1) console.log(`br from ${(segment.from).toString(16)} to ${(segment.target).toString(16)} breaking out ${offset + indexInStack + 1} level(s)`); } else { - if (this.trace) - console.log(`br from ${(segment.from).toString(16)} to ${(segment.target).toString(16)} failed`); + if (this.trace > 0) { + const base = this.base; + if ((segment.target >= base) && (segment.target < this.exitIp)) + console.log(`br from ${(segment.from).toString(16)} to ${(segment.target).toString(16)} failed (inside of trace!)`); + else if (this.trace > 1) + console.log(`br from ${(segment.from).toString(16)} to ${(segment.target).toString(16)} failed (outside of trace 0x${base.toString(16)} - 0x${(this.exitIp).toString(16)})`); + } append_bailout(this.builder, segment.target, BailoutReason.Branch); } break; @@ -1289,7 +1316,7 @@ export function append_bailout (builder: WasmBuilder, ip: MintOpcodePtr, reason: // generate a bailout that is recorded for the monitoring phase as a possible early exit. export function append_exit (builder: WasmBuilder, ip: MintOpcodePtr, opcodeCounter: number, reason: BailoutReason) { - if (opcodeCounter <= (builder.options.monitoringLongDistance + 1)) { + if (opcodeCounter <= (builder.options.monitoringLongDistance + 2)) { builder.local("cinfo"); builder.i32_const(opcodeCounter); builder.appendU8(WasmOpcode.i32_store); diff --git a/src/mono/wasm/runtime/jiterpreter-trace-generator.ts b/src/mono/wasm/runtime/jiterpreter-trace-generator.ts index e58843349d17f..fe14f1bd79ddf 100644 --- a/src/mono/wasm/runtime/jiterpreter-trace-generator.ts +++ b/src/mono/wasm/runtime/jiterpreter-trace-generator.ts @@ -172,8 +172,8 @@ export function generate_wasm_body ( // HACK: Browsers set a limit of 4KB, we lower it slightly since a single opcode // might generate a ton of code and we generate a bit of an epilogue after // we finish - const maxModuleSize = 3850, - spaceLeft = maxModuleSize - builder.bytesGeneratedSoFar - builder.cfg.overheadBytes; + const maxBytesGenerated = 3840, + spaceLeft = maxBytesGenerated - builder.bytesGeneratedSoFar - builder.cfg.overheadBytes; if (builder.size >= spaceLeft) { // console.log(`trace too big, estimated size is ${builder.size + builder.bytesGeneratedSoFar}`); record_abort(traceIp, ip, traceName, "trace-too-big"); diff --git a/src/mono/wasm/runtime/jiterpreter.ts b/src/mono/wasm/runtime/jiterpreter.ts index 852199a2404fb..d9384788a397c 100644 --- a/src/mono/wasm/runtime/jiterpreter.ts +++ b/src/mono/wasm/runtime/jiterpreter.ts @@ -69,7 +69,9 @@ export const // Always grab method full names useFullNames = false, // Use the mono_debug_count() API (set the COUNT=n env var) to limit the number of traces to compile - useDebugCount = false; + useDebugCount = false, + // Web browsers limit synchronous module compiles to 4KB + maxModuleSize = 4080; export const callTargetCounts : { [method: number] : number } = {}; @@ -718,7 +720,7 @@ function generate_wasm ( if (getU16(ip) !== MintOpcode.MINT_TIER_PREPARE_JITERPRETER) throw new Error(`Expected *ip to be MINT_TIER_PREPARE_JITERPRETER but was ${getU16(ip)}`); - builder.cfg.initialize(startOfBody, backwardBranchTable, !!instrument); + builder.cfg.initialize(startOfBody, backwardBranchTable, instrument ? 1 : 0); // TODO: Call generate_wasm_body before generating any of the sections and headers. // This will allow us to do things like dynamically vary the number of locals, in addition @@ -754,6 +756,10 @@ function generate_wasm ( if (trace > 0) console.log(`${((builder.base)).toString(16)} ${methodFullName || traceName} generated ${buffer.length} byte(s) of wasm`); counters.bytesGenerated += buffer.length; + if (buffer.length >= maxModuleSize) { + console.warn(`MONO_WASM: Jiterpreter generated too much code (${buffer.length} bytes) for trace ${traceName}. Please report this issue.`); + return 0; + } const traceModule = new WebAssembly.Module(buffer); const traceInstance = new WebAssembly.Instance(traceModule, {