Skip to content

Commit

Permalink
[wasm] More accurate jiterpreter cfg size estimation; generate smalle…
Browse files Browse the repository at this point in the history
…r dispatch tables (#83759)

* More accurate cfg size estimation
* Generate smaller dispatch tables for traces with backward branches
* Make sure we never actually can dispatch to the unreachable entries in the back branch table
* If we somehow generate a module bigger than 4KB, don't try to compile it. Just log a warning
* Better cfg logging for failed branches
* Add a separate runtime option that controls whether trace monitoring will print to the log
  • Loading branch information
kg committed Mar 23, 2023
1 parent 10abc29 commit 1e241c0
Show file tree
Hide file tree
Showing 5 changed files with 68 additions and 33 deletions.
8 changes: 4 additions & 4 deletions src/mono/mono/mini/interp/jiterpreter.c
Original file line number Diff line number Diff line change
Expand Up @@ -1342,7 +1342,6 @@ mono_jiterp_write_number_unaligned (void *dest, double value, int mode) {
}

#define TRACE_PENALTY_LIMIT 200
#define TRACE_MONITORING_DETAILED FALSE

ptrdiff_t
mono_jiterp_monitor_trace (const guint16 *ip, void *_frame, void *locals)
Expand Down Expand Up @@ -1377,7 +1376,8 @@ mono_jiterp_monitor_trace (const guint16 *ip, void *_frame, void *locals)
int penalty = MIN ((int)((1.0f - scaled) * TRACE_PENALTY_LIMIT), TRACE_PENALTY_LIMIT);
info->penalty_total += penalty;

// g_print ("trace #%d @%d '%s' bailout recorded at opcode #%d, penalty=%d\n", index, ip, frame->imethod->method->name, cinfo.bailout_opcode_count, penalty);
if (mono_opt_jiterpreter_trace_monitoring_log > 2)
g_print ("trace #%d @%d '%s' bailout recorded at opcode #%d, penalty=%d\n", index, ip, frame->imethod->method->name, cinfo.bailout_opcode_count, penalty);
}

gint64 hit_count = info->hit_count++ - mono_opt_jiterpreter_minimum_trace_hit_count;
Expand All @@ -1394,11 +1394,11 @@ mono_jiterp_monitor_trace (const guint16 *ip, void *_frame, void *locals)
*(volatile JiterpreterThunk*)(ip + 1) = thunk;
mono_memory_barrier ();
*mutable_ip = MINT_TIER_ENTER_JITERPRETER;
if (mono_opt_jiterpreter_stats_enabled && TRACE_MONITORING_DETAILED)
if (mono_opt_jiterpreter_trace_monitoring_log > 1)
g_print ("trace #%d @%d '%s' accepted; average_penalty %f <= %f\n", index, ip, frame->imethod->method->name, average_penalty, threshold);
} else {
traces_rejected++;
if (mono_opt_jiterpreter_stats_enabled) {
if (mono_opt_jiterpreter_trace_monitoring_log > 0) {
char * full_name = mono_method_get_full_name (frame->imethod->method);
g_print ("trace #%d @%d '%s' rejected; average_penalty %f > %f\n", index, ip, full_name, average_penalty, threshold);
g_free (full_name);
Expand Down
2 changes: 2 additions & 0 deletions src/mono/mono/utils/options-def.h
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,8 @@ DEFINE_INT(jiterpreter_trace_monitoring_short_distance, "jiterpreter-trace-monit
DEFINE_INT(jiterpreter_trace_monitoring_long_distance, "jiterpreter-trace-monitoring-long-distance", 10, "Traces that exit after processing this many opcodes have no exit penalty")
// the average penalty value for a trace is compared against this threshold / 100 to decide whether to discard it
DEFINE_INT(jiterpreter_trace_monitoring_max_average_penalty, "jiterpreter-trace-monitoring-max-average-penalty", 75, "If the average penalty value for a trace is above this value it will be rejected")
// 0 = no monitoring, 1 = log when rejecting a trace, 2 = log when accepting or rejecting a trace, 3 = log every recorded bailout
DEFINE_INT(jiterpreter_trace_monitoring_log, "jiterpreter-trace-monitoring-log", 0, "Logging detail level for trace monitoring")
// After a do_jit_call call site is hit this many times, we will queue it to be jitted
DEFINE_INT(jiterpreter_jit_call_trampoline_hit_count, "jiterpreter-jit-call-hit-count", 1000, "Queue specialized do_jit_call trampoline for JIT after this many hits")
// After a do_jit_call call site is hit this many times without being jitted, we will flush the JIT queue
Expand Down
77 changes: 52 additions & 25 deletions src/mono/wasm/runtime/jiterpreter-support.ts
Original file line number Diff line number Diff line change
Expand Up @@ -1007,13 +1007,13 @@ class Cfg {
entryBlob!: CfgBlob;
blockStack: Array<MintOpcodePtr> = [];
dispatchTable = new Map<MintOpcodePtr, number>();
trace = false;
trace = 0;

constructor (builder: WasmBuilder) {
this.builder = builder;
}

initialize (startOfBody: MintOpcodePtr, backBranchTargets: Uint16Array | null, trace: boolean) {
initialize (startOfBody: MintOpcodePtr, backBranchTargets: Uint16Array | null, trace: number) {
this.segments.length = 0;
this.blockStack.length = 0;
this.startOfBody = startOfBody;
Expand All @@ -1034,9 +1034,11 @@ class Cfg {
mono_assert(this.segments[0].type === "blob", "expected blob");
this.entryBlob = <CfgBlob>this.segments[0];
this.segments.length = 0;
this.overheadBytes += 9; // entry eip init + block + optional loop
if (this.backBranchTargets)
this.overheadBytes += 24; // some extra padding for the dispatch br_table
this.overheadBytes += 9; // entry disp init + block + optional loop
if (this.backBranchTargets) {
this.overheadBytes += 20; // some extra padding for the dispatch br_table
this.overheadBytes += this.backBranchTargets.length; // one byte for each target in the table
}
}

appendBlob () {
Expand All @@ -1051,6 +1053,8 @@ class Cfg {
});
this.lastSegmentStartIp = this.ip;
this.lastSegmentEnd = this.builder.current.size;
// each segment generates a block
this.overheadBytes += 2;
}

startBranchBlock (ip: MintOpcodePtr, isBackBranchTarget: boolean) {
Expand All @@ -1060,9 +1064,7 @@ class Cfg {
ip,
isBackBranchTarget,
});
this.overheadBytes += 3; // each branch block just costs us a block (2 bytes) and an end
if (this.backBranchTargets)
this.overheadBytes += 3; // size of the br_table entry for this branch target
this.overheadBytes += 1; // each branch block just costs us an end
}

branch (target: MintOpcodePtr, isBackward: boolean, isConditional: boolean) {
Expand All @@ -1074,9 +1076,17 @@ class Cfg {
isBackward,
isConditional,
});
this.overheadBytes += 3; // forward branches are a constant br + depth (optimally 2 bytes)
if (isBackward)
this.overheadBytes += 4; // back branches are more complex
// some branches will generate bailouts instead so we allocate 4 bytes per branch
// to try and balance this out and avoid underestimating too much
this.overheadBytes += 4; // forward branches are a constant br + depth (optimally 2 bytes)
if (isBackward) {
// get_local <cinfo>
// i32_const 1
// i32_store 0 0
// i32.const <n>
// set_local <disp>
this.overheadBytes += 11;
}
}

emitBlob (segment: CfgBlob, source: Uint8Array) {
Expand Down Expand Up @@ -1135,11 +1145,21 @@ class Cfg {
// br_table <number of values starting from 0> <labels for values starting from 0> <default>
// we have to assign disp==0 to fallthrough so that we start at the top of the fn body, then
// assign disp values starting from 1 to branch targets
this.builder.appendULeb(this.blockStack.length + 1);
// FIXME: Only include back branch targets that are *also* in the block stack. This is necessary
// when starting a trace in the middle of a method to make the table smaller
this.builder.appendULeb(this.backBranchTargets.length + 1);
this.builder.appendULeb(1); // br depth of 1 = skip the unreachable and fall through to the start
for (let i = 0; i < this.blockStack.length; i++) {
this.dispatchTable.set(this.blockStack[i], i + 1);
this.builder.appendULeb(i + 2); // add 2 to the depth because of the double block around it
for (let i = 0; i < this.backBranchTargets.length; i++) {
const offset = (this.backBranchTargets[i] * 2) + <any>this.startOfBody;
const breakDepth = this.blockStack.indexOf(offset);
if (breakDepth >= 0) {
this.dispatchTable.set(offset, i + 1);
this.builder.appendULeb(breakDepth + 2); // add 2 to the depth because of the double block around it
} else {
// This means the back branch target is outside of the trace. It shouldn't be possible to reach this
// and we didn't add it to the dispatch table anyway
this.builder.appendULeb(0);
}
}
this.builder.appendULeb(0); // for unrecognized value we br 0, which causes us to trap
this.builder.endBlock();
Expand All @@ -1150,7 +1170,7 @@ class Cfg {
this.blockStack.push(dispatchIp);
}

if (this.trace)
if (this.trace > 1)
console.log(`blockStack=${this.blockStack}`);

for (let i = 0; i < this.segments.length; i++) {
Expand All @@ -1173,14 +1193,15 @@ class Cfg {
}
case "branch": {
const lookupTarget = segment.isBackward ? dispatchIp : segment.target;
let indexInStack = this.blockStack.indexOf(lookupTarget);
let indexInStack = this.blockStack.indexOf(lookupTarget),
successfulBackBranch = false;

// Back branches will target the dispatcher loop so we need to update the dispatch index
// which will be used by the loop dispatch br_table to jump to the correct location
if (segment.isBackward && (indexInStack >= 0)) {
if (segment.isBackward) {
if (this.dispatchTable.has(segment.target)) {
const disp = this.dispatchTable.get(segment.target)!;
if (this.trace)
if (this.trace > 1)
console.log(`backward br from ${(<any>segment.from).toString(16)} to ${(<any>segment.target).toString(16)}: disp=${disp}`);

// set the backward branch taken flag in the cinfo so that the monitoring phase
Expand All @@ -1195,23 +1216,29 @@ class Cfg {
// set the dispatch index for the br_table
this.builder.i32_const(disp);
this.builder.local("disp", WasmOpcode.set_local);
successfulBackBranch = true;
} else {
if (this.trace)
if (this.trace > 0)
console.log(`br from ${(<any>segment.from).toString(16)} to ${(<any>segment.target).toString(16)} failed: back branch target not in dispatch table`);
indexInStack = -1;
}
}

if (indexInStack >= 0) {
if ((indexInStack >= 0) || successfulBackBranch) {
// Conditional branches are nested in an extra block, so the depth is +1
const offset = segment.isConditional ? 1 : 0;
this.builder.appendU8(WasmOpcode.br);
this.builder.appendULeb(offset + indexInStack);
if (this.trace)
if (this.trace > 1)
console.log(`br from ${(<any>segment.from).toString(16)} to ${(<any>segment.target).toString(16)} breaking out ${offset + indexInStack + 1} level(s)`);
} else {
if (this.trace)
console.log(`br from ${(<any>segment.from).toString(16)} to ${(<any>segment.target).toString(16)} failed`);
if (this.trace > 0) {
const base = <any>this.base;
if ((segment.target >= base) && (segment.target < this.exitIp))
console.log(`br from ${(<any>segment.from).toString(16)} to ${(<any>segment.target).toString(16)} failed (inside of trace!)`);
else if (this.trace > 1)
console.log(`br from ${(<any>segment.from).toString(16)} to ${(<any>segment.target).toString(16)} failed (outside of trace 0x${base.toString(16)} - 0x${(<any>this.exitIp).toString(16)})`);
}
append_bailout(this.builder, segment.target, BailoutReason.Branch);
}
break;
Expand Down Expand Up @@ -1289,7 +1316,7 @@ export function append_bailout (builder: WasmBuilder, ip: MintOpcodePtr, reason:

// generate a bailout that is recorded for the monitoring phase as a possible early exit.
export function append_exit (builder: WasmBuilder, ip: MintOpcodePtr, opcodeCounter: number, reason: BailoutReason) {
if (opcodeCounter <= (builder.options.monitoringLongDistance + 1)) {
if (opcodeCounter <= (builder.options.monitoringLongDistance + 2)) {
builder.local("cinfo");
builder.i32_const(opcodeCounter);
builder.appendU8(WasmOpcode.i32_store);
Expand Down
4 changes: 2 additions & 2 deletions src/mono/wasm/runtime/jiterpreter-trace-generator.ts
Original file line number Diff line number Diff line change
Expand Up @@ -172,8 +172,8 @@ export function generate_wasm_body (
// HACK: Browsers set a limit of 4KB, we lower it slightly since a single opcode
// might generate a ton of code and we generate a bit of an epilogue after
// we finish
const maxModuleSize = 3850,
spaceLeft = maxModuleSize - builder.bytesGeneratedSoFar - builder.cfg.overheadBytes;
const maxBytesGenerated = 3840,
spaceLeft = maxBytesGenerated - builder.bytesGeneratedSoFar - builder.cfg.overheadBytes;
if (builder.size >= spaceLeft) {
// console.log(`trace too big, estimated size is ${builder.size + builder.bytesGeneratedSoFar}`);
record_abort(traceIp, ip, traceName, "trace-too-big");
Expand Down
10 changes: 8 additions & 2 deletions src/mono/wasm/runtime/jiterpreter.ts
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,9 @@ export const
// Always grab method full names
useFullNames = false,
// Use the mono_debug_count() API (set the COUNT=n env var) to limit the number of traces to compile
useDebugCount = false;
useDebugCount = false,
// Web browsers limit synchronous module compiles to 4KB
maxModuleSize = 4080;

export const callTargetCounts : { [method: number] : number } = {};

Expand Down Expand Up @@ -718,7 +720,7 @@ function generate_wasm (
if (getU16(ip) !== MintOpcode.MINT_TIER_PREPARE_JITERPRETER)
throw new Error(`Expected *ip to be MINT_TIER_PREPARE_JITERPRETER but was ${getU16(ip)}`);

builder.cfg.initialize(startOfBody, backwardBranchTable, !!instrument);
builder.cfg.initialize(startOfBody, backwardBranchTable, instrument ? 1 : 0);

// TODO: Call generate_wasm_body before generating any of the sections and headers.
// This will allow us to do things like dynamically vary the number of locals, in addition
Expand Down Expand Up @@ -754,6 +756,10 @@ function generate_wasm (
if (trace > 0)
console.log(`${(<any>(builder.base)).toString(16)} ${methodFullName || traceName} generated ${buffer.length} byte(s) of wasm`);
counters.bytesGenerated += buffer.length;
if (buffer.length >= maxModuleSize) {
console.warn(`MONO_WASM: Jiterpreter generated too much code (${buffer.length} bytes) for trace ${traceName}. Please report this issue.`);
return 0;
}
const traceModule = new WebAssembly.Module(buffer);

const traceInstance = new WebAssembly.Instance(traceModule, {
Expand Down

0 comments on commit 1e241c0

Please sign in to comment.