Skip to content

Commit

Permalink
First pass on in-Wasm more-granular tracking of memory writes
Browse files Browse the repository at this point in the history
  • Loading branch information
kettle11 committed Feb 20, 2023
1 parent 137bb4a commit d3cbaa0
Showing 1 changed file with 110 additions and 82 deletions.
192 changes: 110 additions & 82 deletions wasm_guardian/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use walrus::ir::RefNull;

/// Transforms a WebAssembly binary to report to the host environment whenever it makes persistent state changes.
///
/// If memory is modified the imported function `on_store` will be called with an i32 of the
Expand All @@ -16,7 +18,11 @@ pub fn transform_wasm_to_track_changes(
let mut module = walrus::Module::from_buffer(&bytes).unwrap();

let walrus::Module {
exports, globals, ..
exports,
globals,
tables,
memories,
..
} = &mut module;

if export_globals {
Expand All @@ -33,6 +39,18 @@ pub fn transform_wasm_to_track_changes(
}

if track_changes {
const WASM_PAGE_SIZE: u32 = 2 ^ 16;
const PAGE_SIZE_POWER_OF_2: u32 = 16;
const PAGE_SIZE_BYTES: u32 = 2 ^ PAGE_SIZE_POWER_OF_2;

let initial_memory = memories.iter().next().unwrap();

let dirty_flags_table = tables.add_local(
(initial_memory.initial * WASM_PAGE_SIZE) / PAGE_SIZE_BYTES,
None,
walrus::ValType::Funcref,
);

// Create a unique local identifier, one for each type we'll need to temporarily store.
let local0 = module.locals.add(walrus::ValType::I32);
let local1_i32 = module.locals.add(walrus::ValType::I32);
Expand All @@ -42,15 +60,8 @@ pub fn transform_wasm_to_track_changes(
let local1_f64 = module.locals.add(walrus::ValType::F64);

// Used for 3 arg operations that are part of the bulk-memory extension.
let local2 = module.locals.add(walrus::ValType::I32);
let local3 = module.locals.add(walrus::ValType::I32);

let function_type = module
.types
.add(&[walrus::ValType::I32, walrus::ValType::I32], &[]);
let mem_log_function = module
.add_import_func("wasm_guardian", "on_store", function_type)
.0;
// let local2 = module.locals.add(walrus::ValType::I32);
// let local3 = module.locals.add(walrus::ValType::I32);

let function_type = module.types.add(&[walrus::ValType::I32], &[]);
let grow_function = module
Expand Down Expand Up @@ -92,68 +103,12 @@ pub fn transform_wasm_to_track_changes(
walrus::ir::Instr::MemoryCopy(_)
| walrus::ir::Instr::MemoryInit(_)
| walrus::ir::Instr::MemoryFill(_) => {
new_instructions.extend_from_slice(&[
// Push both args to the store to temporary locals.
// This isn't the most efficient approach but it is simple
// and works for now without more complex analysis.
(
walrus::ir::Instr::LocalSet(walrus::ir::LocalSet {
local: local3,
}),
walrus::InstrLocId::default(),
),
(
walrus::ir::Instr::LocalSet(walrus::ir::LocalSet {
local: local2,
}),
walrus::InstrLocId::default(),
),
(
walrus::ir::Instr::LocalSet(walrus::ir::LocalSet {
local: local0,
}),
walrus::InstrLocId::default(),
),
(
walrus::ir::Instr::LocalGet(walrus::ir::LocalGet {
local: local3,
}),
walrus::InstrLocId::default(),
),
(
walrus::ir::Instr::LocalGet(walrus::ir::LocalGet {
local: local0,
}),
walrus::InstrLocId::default(),
),
(
walrus::ir::Instr::Call(walrus::ir::Call {
func: mem_log_function,
}),
walrus::InstrLocId::default(),
),
(
walrus::ir::Instr::LocalGet(walrus::ir::LocalGet {
local: local0,
}),
walrus::InstrLocId::default(),
),
(
walrus::ir::Instr::LocalGet(walrus::ir::LocalGet {
local: local2,
}),
walrus::InstrLocId::default(),
),
(
walrus::ir::Instr::LocalGet(walrus::ir::LocalGet {
local: local3,
}),
walrus::InstrLocId::default(),
),
instruction.clone(),
]);
// These operations probably require looping to set the pages that were written to.
todo!()
}
walrus::ir::Instr::Store(s) => {
// 15-19 extra instructions per call to store. Certainly not ideal!

let (local1, size) = match s.kind {
walrus::ir::StoreKind::I32 { .. } => {
(local1_i32, std::mem::size_of::<i32>() as _)
Expand Down Expand Up @@ -187,10 +142,10 @@ pub fn transform_wasm_to_track_changes(
}
};

// Push both store args to temporary locals.
// This isn't the most efficient approach but it is simple
// and works for now without more complex analysis.
new_instructions.extend_from_slice(&[
// Push both args to the store to temporary locals.
// This isn't the most efficient approach but it is simple
// and works for now without more complex analysis.
(
walrus::ir::Instr::LocalSet(walrus::ir::LocalSet {
local: local1,
Expand All @@ -205,12 +160,55 @@ pub fn transform_wasm_to_track_changes(
),
]);

// If there is an offset then add that to the returned address.
// If there is an offset then add it to the address.
if s.arg.offset != 0 {
new_instructions.extend_from_slice(&[
(
walrus::ir::Instr::Const(walrus::ir::Const {
value: walrus::ir::Value::I32(s.arg.offset as _),
}),
walrus::InstrLocId::default(),
),
(
// This is operating on memory addresses, is this the correct type of add?
walrus::ir::Instr::Binop(walrus::ir::Binop {
op: walrus::ir::BinaryOp::I32Add,
}),
walrus::InstrLocId::default(),
),
]);
}
new_instructions.extend_from_slice(&[
// Mark dirty_flags for the start of the value being stored.
(
walrus::ir::Instr::Const(walrus::ir::Const {
value: walrus::ir::Value::I32(PAGE_SIZE_POWER_OF_2 as _),
}),
walrus::InstrLocId::default(),
),
(
walrus::ir::Instr::Binop(walrus::ir::Binop {
op: walrus::ir::BinaryOp::I32ShrU,
}),
walrus::InstrLocId::default(),
),
(
walrus::ir::Instr::RefNull(RefNull {
ty: walrus::ValType::Funcref,
}),
walrus::InstrLocId::default(),
),
(
walrus::ir::Instr::TableSet(walrus::ir::TableSet {
table: dirty_flags_table,
}),
walrus::InstrLocId::default(),
),
]);

// If there is an offset then add it to the address.
if s.arg.offset != 0 {
new_instructions.extend_from_slice(&[
// Push both args to the store to temporary locals.
// This isn't the most efficient approach but it is simple
// and works for now without more complex analysis.
(
walrus::ir::Instr::Const(walrus::ir::Const {
value: walrus::ir::Value::I32(s.arg.offset as _),
Expand All @@ -228,21 +226,51 @@ pub fn transform_wasm_to_track_changes(
}

new_instructions.extend_from_slice(&[
// Output the size of the memory being written.
// An alternative approach would be to implement a function export for each type,
// but this is simpler for now.
// Mark dirty flags for the end of the value being stored.
// It's unfortunate this is needed because it's non-trivial overhead.
(
walrus::ir::Instr::LocalGet(walrus::ir::LocalGet {
local: local0,
}),
walrus::InstrLocId::default(),
),
(
walrus::ir::Instr::Const(walrus::ir::Const {
value: walrus::ir::Value::I32(size),
}),
walrus::InstrLocId::default(),
),
(
walrus::ir::Instr::Call(walrus::ir::Call {
func: mem_log_function,
walrus::ir::Instr::Binop(walrus::ir::Binop {
op: walrus::ir::BinaryOp::I32Add,
}),
walrus::InstrLocId::default(),
),
(
walrus::ir::Instr::Const(walrus::ir::Const {
value: walrus::ir::Value::I32(PAGE_SIZE_POWER_OF_2 as _),
}),
walrus::InstrLocId::default(),
),
(
walrus::ir::Instr::Binop(walrus::ir::Binop {
op: walrus::ir::BinaryOp::I32ShrU,
}),
walrus::InstrLocId::default(),
),
(
walrus::ir::Instr::RefNull(RefNull {
ty: walrus::ValType::Funcref,
}),
walrus::InstrLocId::default(),
),
(
walrus::ir::Instr::TableSet(walrus::ir::TableSet {
table: dirty_flags_table,
}),
walrus::InstrLocId::default(),
),
// Restore the locals so the store OP can go ahead.
(
walrus::ir::Instr::LocalGet(walrus::ir::LocalGet {
local: local0,
Expand Down

0 comments on commit d3cbaa0

Please sign in to comment.