Skip to content

Commit

Permalink
Merge pull request bytecodealliance#216 from near/move_codegen
Browse files Browse the repository at this point in the history
Move zkasm generation into zkasm_codegen.rs
  • Loading branch information
akashin committed Feb 13, 2024
2 parents e360122 + 4dbf674 commit f65f194
Show file tree
Hide file tree
Showing 3 changed files with 237 additions and 231 deletions.
1 change: 1 addition & 0 deletions cranelift/filetests/src/lib.rs
Expand Up @@ -22,6 +22,7 @@ mod match_directive;
mod runner;
mod runone;
mod subtest;
mod zkasm_codegen;
mod zkasm_runner;

mod test_alias_analysis;
Expand Down
234 changes: 3 additions & 231 deletions cranelift/filetests/src/test_zkasm.rs
@@ -1,248 +1,20 @@
#[cfg(test)]
mod tests {
use crate::zkasm_codegen;
use std::collections::HashMap;
use std::path::{Path, PathBuf};

use regex::Regex;
use std::fs::read_to_string;
use std::io::Error;

use cranelift_codegen::entity::EntityRef;
use cranelift_codegen::ir::function::FunctionParameters;
use cranelift_codegen::ir::ExternalName;
use cranelift_codegen::isa::zkasm;
use cranelift_codegen::{settings, FinalizedMachReloc, FinalizedRelocTarget};
use cranelift_wasm::{translate_module, ZkasmEnvironment};

use walkdir::WalkDir;
use wasmtime::*;

fn setup() {
let _ = env_logger::builder().is_test(true).try_init();
}

fn generate_preamble(
start_func_index: usize,
globals: &[(cranelift_wasm::GlobalIndex, cranelift_wasm::GlobalInit)],
data_segments: &[(u64, Vec<u8>)],
) -> Vec<String> {
let mut program: Vec<String> = Vec::new();

// Generate global variable definitions.
for (key, _) in globals {
program.push(format!("VAR GLOBAL global_{}", key.index()));
}

program.push("start:".to_string());
for (key, init) in globals {
match init {
cranelift_wasm::GlobalInit::I32Const(v) => {
// ZKASM stores constants in 2-complement form, so we need a cast to unsigned.
program.push(format!(
" {} :MSTORE(global_{}) ;; Global32({})",
*v as u32,
key.index(),
v
));
}
cranelift_wasm::GlobalInit::I64Const(v) => {
// ZKASM stores constants in 2-complement form, so we need a cast to unsigned.
program.push(format!(
" {} :MSTORE(global_{}) ;; Global64({})",
*v as u64,
key.index(),
v
));
}
_ => unimplemented!("Global type is not supported"),
}
}

// Generate const data segments definitions.
for (offset, data) in data_segments {
program.push(format!(" {} => E", offset / 8));
// Each slot stores 8 consecutive u8 numbers, with earlier addresses stored in lower
// bits.
for (i, chunk) in data.chunks(8).enumerate() {
let mut chunk_data = 0u64;
for c in chunk.iter().rev() {
chunk_data <<= 8;
chunk_data |= *c as u64;
}
program.push(format!(" {chunk_data}n :MSTORE(MEM:E + {i})"));
}
}

// The total amount of stack available on ZKASM processor is 2^16 of 8-byte words.
// Stack memory is a separate region that is independent from the heap.
program.push(" 0xffff => SP".to_string());
program.push(" zkPC + 2 => RR".to_string());
program.push(format!(" :JMP(function_{})", start_func_index));
program.push(" :JMP(finalizeExecution)".to_string());
program
}

fn generate_postamble() -> Vec<String> {
let mut program: Vec<String> = Vec::new();
// In the prover, the program always runs for a fixed number of steps (e.g. 2^23), so we
// need an infinite loop at the end of the program to fill the execution trace to the
// expected number of steps.
// In the future we might need to put zero in all registers here.
program.push("finalizeExecution:".to_string());
program.push(" ${beforeLast()} :JMPN(finalizeExecution)".to_string());
program.push(" :JMP(start)".to_string());
program.push("INCLUDE \"helpers/2-exp.zkasm\"".to_string());
program
}

// TODO: Relocations should be generated by a linker and/or clift itself.
fn fix_relocs(
code_buffer: &mut Vec<u8>,
params: &FunctionParameters,
relocs: &[FinalizedMachReloc],
) {
let mut delta = 0i32;
for reloc in relocs {
let start = (reloc.offset as i32 + delta) as usize;
let mut pos = start;
while code_buffer[pos] != b'\n' {
pos += 1;
delta -= 1;
}

let code = if let FinalizedRelocTarget::ExternalName(ExternalName::User(name)) =
reloc.target
{
let name = &params.user_named_funcs()[name];
if name.index == 0 {
b" B :ASSERT".to_vec()
} else {
format!(" zkPC + 2 => RR\n :JMP(function_{})", name.index)
.as_bytes()
.to_vec()
}
} else {
b" UNKNOWN".to_vec()
};
delta += code.len() as i32;

code_buffer.splice(start..pos, code);
}
}

// TODO: Labels optimization already happens in `MachBuffer`, we need to find a way to leverage
// it.
/// Label name is formatted as follows: <label_name>_<function_id>_<label_id>
/// Function id is unique through whole program while label id is unique only
/// inside given function.
/// Label name must begin from label_.
fn optimize_labels(code: &[&str], func_index: usize) -> Vec<String> {
let mut label_definition: HashMap<String, usize> = HashMap::new();
let mut label_uses: HashMap<String, Vec<usize>> = HashMap::new();
let mut lines = Vec::new();
for (index, line) in code.iter().enumerate() {
let mut line = line.to_string();
if line.starts_with(&"label_") {
// Handles lines with a label marker, e.g.:
// <label_name>_XXX:
let index_begin = line.rfind("_").expect("Failed to parse label index") + 1;
let label_name: String = line[..line.len() - 1].to_string();
line.insert_str(index_begin - 1, &format!("_{}", func_index));
label_definition.insert(label_name, index);
} else if line.contains(&"label_") {
// Handles lines with a jump to label, e.g.:
// A : JMPNZ(<label_name>_XXX)
let pos = line.rfind(&"_").unwrap() + 1;
let label_name = line[line
.find("label_")
.expect(&format!("Error parsing label line '{}'", line))
..line
.rfind(")")
.expect(&format!("Error parsing label line '{}'", line))]
.to_string();
line.insert_str(pos - 1, &format!("_{}", func_index));
label_uses.entry(label_name).or_default().push(index);
}
lines.push(line);
}

let mut lines_to_delete = Vec::new();
for (label, label_line) in label_definition {
match label_uses.entry(label) {
std::collections::hash_map::Entry::Occupied(uses) => {
if uses.get().len() == 1 {
let use_line = uses.get()[0];
if use_line + 1 == label_line {
lines_to_delete.push(use_line);
lines_to_delete.push(label_line);
}
}
}
std::collections::hash_map::Entry::Vacant(_) => {
lines_to_delete.push(label_line);
}
}
}
lines_to_delete.sort();
lines_to_delete.reverse();
for index in lines_to_delete {
lines.remove(index);
}
lines
}

fn generate_zkasm(wasm_module: &[u8]) -> String {
let flag_builder = settings::builder();
let isa_builder = zkasm::isa_builder("zkasm-unknown-unknown".parse().unwrap());
let isa = isa_builder
.finish(settings::Flags::new(flag_builder))
.unwrap();
let mut zkasm_environ = ZkasmEnvironment::new(isa.frontend_config());
translate_module(wasm_module, &mut zkasm_environ).unwrap();

let mut program: Vec<String> = Vec::new();

let start_func = zkasm_environ
.info
.start_func
.expect("Must have a start function");
// TODO: Preamble should be generated by a linker and/or clift itself.
program.append(&mut generate_preamble(
start_func.index(),
&zkasm_environ.info.global_inits,
&zkasm_environ.info.data_inits,
));

let num_func_imports = zkasm_environ.get_num_func_imports();
let mut context = cranelift_codegen::Context::new();
for (def_index, func) in zkasm_environ.info.function_bodies.iter() {
let func_index = num_func_imports + def_index.index();
program.push(format!("function_{}:", func_index));

let mut mem = vec![];
context.func = func.clone();
let compiled_code = context
.compile_and_emit(&*isa, &mut mem, &mut Default::default())
.unwrap();
let mut code_buffer = compiled_code.code_buffer().to_vec();
fix_relocs(
&mut code_buffer,
&func.params,
compiled_code.buffer.relocs(),
);

let code = std::str::from_utf8(&code_buffer).unwrap();
let lines: Vec<&str> = code.lines().collect();
let mut lines = optimize_labels(&lines, func_index);
program.append(&mut lines);

context.clear();
}

program.append(&mut generate_postamble());
program.join("\n")
}

fn run_wat_file(path: &Path) -> Result<(), Box<dyn std::error::Error>> {
let engine = Engine::default();
let binary = wat::parse_file(path)?;
Expand Down Expand Up @@ -287,7 +59,7 @@ mod tests {

fn test_module(name: &str) {
let module_binary = wat::parse_file(format!("../zkasm_data/{name}.wat")).unwrap();
let program = generate_zkasm(&module_binary);
let program = zkasm_codegen::generate_zkasm(&module_binary);
let expected =
expect_test::expect_file![format!("../../zkasm_data/generated/{name}.zkasm")];
expected.assert_eq(&program);
Expand Down Expand Up @@ -379,7 +151,7 @@ mod tests {
.join(path)
.join(format!("generated/{name}.zkasm"))];
let result = std::panic::catch_unwind(|| {
let program = generate_zkasm(&module_binary);
let program = zkasm_codegen::generate_zkasm(&module_binary);
expected.assert_eq(&program);
});
if let Err(err) = result {
Expand Down

0 comments on commit f65f194

Please sign in to comment.