Skip to content

Commit

Permalink
Add tests for the failing cases of PyDictManager (#151)
Browse files Browse the repository at this point in the history
* Add test for invalid key on dict read and write

* Add test for get_tracker with invalid dict pointer
  • Loading branch information
MegaRedHand authored Nov 23, 2022
1 parent f10e775 commit 6128817
Showing 1 changed file with 211 additions and 0 deletions.
211 changes: 211 additions & 0 deletions src/dict_manager.rs
Original file line number Diff line number Diff line change
Expand Up @@ -164,10 +164,12 @@ impl PyDictTracker {
mod tests {
use crate::{ids::PyIds, memory::PyMemory, utils::to_vm_error, vm_core::PyVM};
use cairo_rs::{
bigint,
hint_processor::hint_processor_definition::HintReference,
serde::deserialize_program::{ApTracking, Member},
types::relocatable::Relocatable,
types::{instruction::Register, relocatable::MaybeRelocatable},
vm::errors::vm_errors::VirtualMachineError,
};
use num_bigint::{BigInt, Sign};
use pyo3::{types::PyDict, PyCell};
Expand Down Expand Up @@ -439,6 +441,114 @@ assert dict_tracker.data[1] == 22
});
}

#[test]
fn tracker_read_and_write_invalid_key() {
Python::with_gil(|py| {
let vm = PyVM::new(
BigInt::new(Sign::Plus, vec![1, 0, 0, 0, 0, 0, 17, 134217728]),
false,
);
for _ in 0..2 {
vm.vm.borrow_mut().add_memory_segment();
}

let dict_manager = PyDictManager::default();

let segment_manager = PySegmentManager::new(&vm, PyMemory::new(&vm));

//Create references
let mut references = HashMap::new();
references.insert(
String::from("dict"),
HintReference {
register: Some(Register::FP),
offset1: 0,
offset2: 0,
inner_dereference: false,
ap_tracking_data: None,
immediate: None,
dereference: true,
cairo_type: Some(String::from("DictAccess*")),
},
);
// Create ids.a
references.insert(String::from("a"), HintReference::new_simple(1));

//Insert ids.a into memory
vm.vm
.borrow_mut()
.insert_value(
&Relocatable::from((1, 1)),
&MaybeRelocatable::from((128, 64)),
)
.unwrap();

let mut struct_types: HashMap<String, HashMap<String, Member>> = HashMap::new();
struct_types.insert(String::from("DictAccess"), HashMap::new());

let ids = PyIds::new(
&vm,
&references,
&ApTracking::default(),
&HashMap::new(),
Rc::new(struct_types),
);

let globals = PyDict::new(py);
globals
.set_item("dict_manager", PyCell::new(py, dict_manager).unwrap())
.unwrap();
globals
.set_item("ids", PyCell::new(py, ids).unwrap())
.unwrap();
globals
.set_item("segments", PyCell::new(py, segment_manager).unwrap())
.unwrap();

let code = r#"
initial_dict = { 1: 2, 4: 8, 16: 32 }
ids.dict = dict_manager.new_dict(segments, initial_dict)
dict_tracker = dict_manager.get_tracker(ids.dict)
dict_tracker.data[3]
"#;

let py_result = py.run(code, Some(globals), None);

assert_eq!(
py_result.map_err(to_vm_error),
Err(to_vm_error(to_py_error(
VirtualMachineError::NoValueForKey(bigint!(3))
))),
);

let code = r#"
dict_tracker = dict_manager.get_tracker(ids.dict)
dict_tracker.data[ids.a]
"#;

let py_result = py.run(code, Some(globals), None);
let key = PyMaybeRelocatable::from(PyRelocatable::from((128, 64)));

assert_eq!(
py_result.map_err(to_vm_error),
Err(PyKeyError::new_err(key.to_object(py))).map_err(to_vm_error),
);

let code = r#"
dict_tracker = dict_manager.get_tracker(ids.dict)
dict_tracker.data[ids.a] = 5
"#;

let py_result = py.run(code, Some(globals), None);
let key = PyMaybeRelocatable::from(PyRelocatable::from((128, 64)));

assert_eq!(
py_result.map_err(to_vm_error),
Err(PyKeyError::new_err(key.to_object(py))).map_err(to_vm_error),
);
});
}

#[test]
fn tracker_get_and_set_current_ptr() {
Python::with_gil(|py| {
Expand Down Expand Up @@ -524,4 +634,105 @@ assert dict_tracker.current_ptr == ids.end_ptr
assert_eq!(py_result.map_err(to_vm_error), Ok(()));
});
}

#[test]
fn manager_get_tracker_invalid_dict_ptr() {
Python::with_gil(|py| {
let vm = PyVM::new(
BigInt::new(Sign::Plus, vec![1, 0, 0, 0, 0, 0, 17, 134217728]),
false,
);
for _ in 0..2 {
vm.vm.borrow_mut().add_memory_segment();
}

let dict_manager = PyDictManager::default();

let segment_manager = PySegmentManager::new(&vm, PyMemory::new(&vm));

//Create references
let mut references = HashMap::new();
references.insert(
String::from("dict"),
HintReference {
register: Some(Register::FP),
offset1: 0,
offset2: 0,
inner_dereference: false,
ap_tracking_data: None,
immediate: None,
dereference: true,
cairo_type: Some(String::from("DictAccess*")),
},
);
references.insert(
String::from("no_dict"),
HintReference {
register: Some(Register::FP),
offset1: 0,
offset2: 0,
inner_dereference: false,
ap_tracking_data: None,
immediate: None,
dereference: true,
cairo_type: Some(String::from("DictAccess")),
},
);

let mut struct_types: HashMap<String, HashMap<String, Member>> = HashMap::new();
struct_types.insert(String::from("DictAccess"), HashMap::new());

let ids = PyIds::new(
&vm,
&references,
&ApTracking::default(),
&HashMap::new(),
Rc::new(struct_types),
);

let globals = PyDict::new(py);
globals
.set_item("dict_manager", PyCell::new(py, dict_manager).unwrap())
.unwrap();
globals
.set_item("ids", PyCell::new(py, ids).unwrap())
.unwrap();
globals
.set_item("segments", PyCell::new(py, segment_manager).unwrap())
.unwrap();

let code = r#"
ids.dict = dict_manager.new_dict(segments, {})
dict_tracker = dict_manager.get_tracker(ids.no_dict)
"#;

let py_result = py.run(code, Some(globals), None);

assert_eq!(
py_result.map_err(to_vm_error),
Err(to_vm_error(to_py_error(
VirtualMachineError::NoDictTracker(vm.vm.borrow().get_fp().segment_index),
))),
);

let code = r#"
dict_tracker = dict_manager.get_tracker(ids.dict)
dict_tracker.current_ptr = dict_tracker.current_ptr + 3
dict_tracker = dict_manager.get_tracker(ids.dict)
"#;

let py_result = py.run(code, Some(globals), None);

assert_eq!(
py_result.map_err(to_vm_error),
Err(to_vm_error(to_py_error(
VirtualMachineError::MismatchedDictPtr(
Relocatable::from((2, 3)),
Relocatable::from((2, 0)),
),
))),
);
});
}
}

0 comments on commit 6128817

Please sign in to comment.