Skip to content

Commit

Permalink
perf: reference management optimizations (#1214)
Browse files Browse the repository at this point in the history
* wip

* test passing

* more accurate bench

* fix

* perf: optimizations for reference management

- Store directly the reference list rather than the whole
  `ReferenceManager` structures
- Only process them once, when creating the `Program`
- Move them to the `SharedProgramData` member
- Convert to `Vec<HintReference>` and adapt the methods using it

---------

Co-authored-by: Pedro Fontana <fontana.pedro93@gmail.com>
  • Loading branch information
Oppen and pefontana committed Jun 15, 2023
1 parent 4221723 commit dbcf4c4
Show file tree
Hide file tree
Showing 11 changed files with 78 additions and 81 deletions.
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,12 @@
* perf: accumulate `min` and `max` instruction offsets during run to speed up range check [#1080](https://github.com/lambdaclass/cairo-rs/pull/)
BREAKING: `Cairo_runner::get_perm_range_check_limits` no longer returns an error when called without trace enabled, as it no longer depends on it

* perf: process reference list on `Program` creation only [#1214](https://github.com/lambdaclass/cairo-rs/pull/1214)
Also keep them in a `Vec<_>` instead of a `HashMap<_, _>` since it will be continuous anyway.
BREAKING:
* `HintProcessor::compile_hint` now receies a `&[HintReference]` rather than `&HashMap<usize, HintReference>`
* Public `CairoRunner::get_reference_list` has been removed

#### [0.5.2] - 2023-6-12

* BREAKING: Compute `ExecutionResources.n_steps` without requiring trace [#1222](https://github.com/lambdaclass/cairo-rs/pull/1222)
Expand Down
4 changes: 2 additions & 2 deletions bench/criterion_benchmark.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ fn parse_program(c: &mut Criterion) {
//Picked the biggest one at the time of writing
let program = include_bytes!("../cairo_programs/benchmarks/keccak_integration_benchmark.json");
c.bench_function("parse program", |b| {
b.iter(|| {
b.iter_with_large_drop(|| {
_ = Program::from_bytes(black_box(program.as_slice()), black_box(Some("main")))
.unwrap();
})
Expand All @@ -29,7 +29,7 @@ fn build_many_runners(c: &mut Criterion) {
let program = include_bytes!("../cairo_programs/benchmarks/keccak_integration_benchmark.json");
let program = Program::from_bytes(program.as_slice(), Some("main")).unwrap();
c.bench_function("build runner", |b| {
b.iter(|| {
b.iter_with_large_drop(|| {
_ = black_box(
CairoRunner::new(
black_box(&program),
Expand Down
9 changes: 5 additions & 4 deletions bench/iai_benchmark.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use iai_callgrind::{black_box, main};
use core::hint::black_box;
use iai_callgrind::main;

use cairo_vm::{
types::program::Program,
Expand All @@ -18,7 +19,7 @@ fn parse_program() {
let program = include_bytes!("../cairo_programs/benchmarks/keccak_integration_benchmark.json");
let program =
Program::from_bytes(black_box(program.as_slice()), black_box(Some("main"))).unwrap();
let _ = black_box(program);
core::mem::drop(black_box(program));
}

#[export_name = "helper::parse_program"]
Expand All @@ -33,7 +34,7 @@ fn parse_program_helper() -> Program {
fn build_runner() {
let program = parse_program_helper();
let runner = CairoRunner::new(black_box(&program), "starknet_with_keccak", false).unwrap();
let _ = black_box(runner);
core::mem::drop(black_box(runner));
}

#[export_name = "helper::build_runner"]
Expand All @@ -54,6 +55,6 @@ fn load_program_data() {
}

main!(
callgrind_args = "toggle-collect=helper::*";
callgrind_args = "toggle-collect=helper::*,core::mem::drop";
functions = parse_program, build_runner, load_program_data
);
2 changes: 1 addition & 1 deletion hint_accountant/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ fn run() {
let (ap_tracking_data, reference_ids, references, mut exec_scopes, constants) = (
ApTracking::default(),
HashMap::new(),
HashMap::new(),
Vec::new(),
ExecutionScopes::new(),
HashMap::new(),
);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1106,7 +1106,7 @@ impl HintProcessor for Cairo1HintProcessor {
//(may contain other variables aside from those used by the hint)
_reference_ids: &HashMap<String, usize>,
//List of all references (key corresponds to element of the previous dictionary)
_references: &HashMap<usize, HintReference>,
_references: &[HintReference],
) -> Result<Box<dyn Any>, VirtualMachineError> {
let data = hint_code.parse().ok().and_then(|x: usize| self.hints.get(&x).cloned())
.ok_or(VirtualMachineError::CompileHintFail(
Expand Down
6 changes: 3 additions & 3 deletions vm/src/hint_processor/hint_processor_definition.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ pub trait HintProcessor {
//(may contain other variables aside from those used by the hint)
reference_ids: &HashMap<String, usize>,
//List of all references (key corresponds to element of the previous dictionary)
references: &HashMap<usize, HintReference>,
references: &[HintReference],
) -> Result<Box<dyn Any>, VirtualMachineError> {
Ok(any_box!(HintProcessorData {
code: hint_code.to_string(),
Expand All @@ -52,7 +52,7 @@ pub trait HintProcessor {

fn get_ids_data(
reference_ids: &HashMap<String, usize>,
references: &HashMap<usize, HintReference>,
references: &[HintReference],
) -> Result<HashMap<String, HintReference>, VirtualMachineError> {
let mut ids_data = HashMap::<String, HintReference>::new();
for (path, ref_id) in reference_ids {
Expand All @@ -63,7 +63,7 @@ fn get_ids_data(
ids_data.insert(
name.to_string(),
references
.get(ref_id)
.get(*ref_id)
.ok_or(VirtualMachineError::Unexpected)?
.clone(),
);
Expand Down
4 changes: 2 additions & 2 deletions vm/src/serde/deserialize_program.rs
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@ fn deserialize_scientific_notation(n: Number) -> Option<Felt252> {
}
}

#[derive(Serialize, Deserialize, Debug, PartialEq, Eq, Clone)]
#[derive(Serialize, Deserialize, Debug, PartialEq, Eq, Clone, Default)]
pub struct ReferenceManager {
pub references: Vec<Reference>,
}
Expand Down Expand Up @@ -432,11 +432,11 @@ pub fn parse_program_json(
.debug_info
.map(|debug_info| debug_info.instruction_locations),
identifiers: program_json.identifiers,
reference_manager: Program::get_reference_list(&program_json.reference_manager),
};
Ok(Program {
shared_program_data: Arc::new(shared_program_data),
constants,
reference_manager: program_json.reference_manager,
builtins: program_json.builtins,
})
}
Expand Down
43 changes: 33 additions & 10 deletions vm/src/types/program.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,14 @@ use crate::stdlib::{collections::HashMap, prelude::*, sync::Arc};
#[cfg(feature = "cairo-1-hints")]
use crate::serde::deserialize_program::{ApTracking, FlowTrackingData};
use crate::{
hint_processor::hint_processor_definition::HintReference,
serde::deserialize_program::{
deserialize_and_parse_program, Attribute, BuiltinName, HintParams, Identifier,
InstructionLocation, ReferenceManager,
InstructionLocation, OffsetValue, ReferenceManager,
},
types::{
errors::program_errors::ProgramError, instruction::Register, relocatable::MaybeRelocatable,
},
types::{errors::program_errors::ProgramError, relocatable::MaybeRelocatable},
};
#[cfg(feature = "cairo-1-hints")]
use cairo_lang_starknet::casm_contract_class::CasmContractClass;
Expand Down Expand Up @@ -48,14 +51,14 @@ pub(crate) struct SharedProgramData {
pub(crate) error_message_attributes: Vec<Attribute>,
pub(crate) instruction_locations: Option<HashMap<usize, InstructionLocation>>,
pub(crate) identifiers: HashMap<String, Identifier>,
pub(crate) reference_manager: Vec<HintReference>,
}

#[derive(Clone, Debug, PartialEq, Eq)]
pub struct Program {
pub(crate) shared_program_data: Arc<SharedProgramData>,
pub(crate) constants: HashMap<String, Felt252>,
pub(crate) builtins: Vec<BuiltinName>,
pub(crate) reference_manager: ReferenceManager,
}

impl Program {
Expand Down Expand Up @@ -89,11 +92,11 @@ impl Program {
error_message_attributes,
instruction_locations,
identifiers,
reference_manager: Self::get_reference_list(&reference_manager),
};
Ok(Self {
shared_program_data: Arc::new(shared_program_data),
constants,
reference_manager,
builtins,
})
}
Expand Down Expand Up @@ -139,16 +142,36 @@ impl Program {
.iter()
.map(|(cairo_type, identifier)| (cairo_type.as_str(), identifier))
}

pub(crate) fn get_reference_list(reference_manager: &ReferenceManager) -> Vec<HintReference> {
reference_manager
.references
.iter()
.map(|r| {
HintReference {
offset1: r.value_address.offset1.clone(),
offset2: r.value_address.offset2.clone(),
dereference: r.value_address.dereference,
// only store `ap` tracking data if the reference is referred to it
ap_tracking_data: match (&r.value_address.offset1, &r.value_address.offset2) {
(OffsetValue::Reference(Register::AP, _, _), _)
| (_, OffsetValue::Reference(Register::AP, _, _)) => {
Some(r.ap_tracking_data.clone())
}
_ => None,
},
cairo_type: Some(r.value_address.value_type.clone()),
}
})
.collect()
}
}

impl Default for Program {
fn default() -> Self {
Self {
shared_program_data: Arc::new(SharedProgramData::default()),
constants: HashMap::new(),
reference_manager: ReferenceManager {
references: Vec::new(),
},
builtins: Vec::new(),
}
}
Expand Down Expand Up @@ -854,13 +877,13 @@ mod tests {
error_message_attributes: Vec::new(),
instruction_locations: None,
identifiers: HashMap::new(),
reference_manager: Program::get_reference_list(&ReferenceManager {
references: Vec::new(),
}),
};
let program = Program {
shared_program_data: Arc::new(shared_program_data),
constants: HashMap::new(),
reference_manager: ReferenceManager {
references: Vec::new(),
},
builtins: Vec::new(),
};

Expand Down
26 changes: 13 additions & 13 deletions vm/src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -265,14 +265,14 @@ pub mod test_utils {
error_message_attributes: crate::stdlib::vec::Vec::new(),
instruction_locations: None,
identifiers: crate::stdlib::collections::HashMap::new(),
reference_manager: Program::get_reference_list(&ReferenceManager {
references: crate::stdlib::vec::Vec::new(),
}),
};
Program {
shared_program_data: Arc::new(shared_program_data),
constants: crate::stdlib::collections::HashMap::new(),
builtins: vec![$( $builtin_name ),*],
reference_manager: ReferenceManager {
references: crate::stdlib::vec::Vec::new(),
},
}
}};
($($field:ident = $value:expr),* $(,)?) => {{
Expand Down Expand Up @@ -352,10 +352,10 @@ pub mod test_utils {
error_message_attributes: val.error_message_attributes,
instruction_locations: val.instruction_locations,
identifiers: val.identifiers,
reference_manager: Program::get_reference_list(&val.reference_manager),
}),
constants: val.constants,
builtins: val.builtins,
reference_manager: val.reference_manager,
}
}
}
Expand Down Expand Up @@ -926,13 +926,13 @@ mod test {
error_message_attributes: Vec::new(),
instruction_locations: None,
identifiers: HashMap::new(),
reference_manager: Program::get_reference_list(&ReferenceManager {
references: Vec::new(),
}),
};
let program = Program {
shared_program_data: Arc::new(shared_data),
constants: HashMap::new(),
reference_manager: ReferenceManager {
references: Vec::new(),
},
builtins: Vec::new(),
};
assert_eq!(program, program!())
Expand All @@ -950,13 +950,13 @@ mod test {
error_message_attributes: Vec::new(),
instruction_locations: None,
identifiers: HashMap::new(),
reference_manager: Program::get_reference_list(&ReferenceManager {
references: Vec::new(),
}),
};
let program = Program {
shared_program_data: Arc::new(shared_data),
constants: HashMap::new(),
reference_manager: ReferenceManager {
references: Vec::new(),
},
builtins: vec![BuiltinName::range_check],
};

Expand All @@ -975,13 +975,13 @@ mod test {
error_message_attributes: Vec::new(),
instruction_locations: None,
identifiers: HashMap::new(),
reference_manager: Program::get_reference_list(&ReferenceManager {
references: Vec::new(),
}),
};
let program = Program {
shared_program_data: Arc::new(shared_data),
constants: HashMap::new(),
reference_manager: ReferenceManager {
references: Vec::new(),
},
builtins: vec![BuiltinName::range_check],
};

Expand Down
15 changes: 5 additions & 10 deletions vm/src/vm/errors/vm_exception.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,7 @@ use thiserror::Error;
use thiserror_no_std::Error;

use crate::{
hint_processor::{
hint_processor_definition::HintReference,
hint_processor_utils::get_maybe_relocatable_from_reference,
},
hint_processor::hint_processor_utils::get_maybe_relocatable_from_reference,
serde::deserialize_program::{ApTracking, Attribute, Location, OffsetValue},
types::{instruction::Register, relocatable::MaybeRelocatable},
vm::{runners::cairo_runner::CairoRunner, vm_core::VirtualMachine},
Expand Down Expand Up @@ -177,21 +174,19 @@ fn get_value_from_simple_reference(
runner: &CairoRunner,
vm: &VirtualMachine,
) -> Option<MaybeRelocatable> {
let reference: HintReference = runner
let reference = runner
.program
.shared_program_data
.reference_manager
.references
.get(ref_id)?
.clone()
.into();
.get(ref_id)?;
// Filter ap-based references
match reference.offset1 {
OffsetValue::Reference(Register::AP, _, _) => None,
_ => {
// Filer complex types (only felt/felt pointers)
match reference.cairo_type {
Some(ref cairo_type) if cairo_type.contains("felt") => Some(
get_maybe_relocatable_from_reference(vm, &reference, ap_tracking)?,
get_maybe_relocatable_from_reference(vm, reference, ap_tracking)?,
),
_ => None,
}
Expand Down
Loading

0 comments on commit dbcf4c4

Please sign in to comment.