Skip to content

Commit

Permalink
Prevent from sending tokens when calling a non-payable function on Od… (
Browse files Browse the repository at this point in the history
#459)

* Prevent from sending tokens when calling a non-payable function on OdraVm
* Non-payable functions revert if tokens attached (#460)
  • Loading branch information
kpob committed May 29, 2024
1 parent fa7fdd4 commit 9708458
Show file tree
Hide file tree
Showing 13 changed files with 92 additions and 40 deletions.
9 changes: 7 additions & 2 deletions core/src/contract_container.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
use crate::entry_point_callback::EntryPointsCaller;
use crate::{prelude::*, OdraResult};
use crate::{prelude::*, ExecutionError, OdraResult};
use crate::{CallDef, OdraError, VmError};
use casper_types::bytesrepr::Bytes;
use casper_types::U512;

/// A wrapper struct for a EntryPointsCaller that is a layer of abstraction between the host and the entry points caller.
///
Expand All @@ -22,13 +23,17 @@ impl ContractContainer {
/// Calls the entry point with the given call definition.
pub fn call(&self, call_def: CallDef) -> OdraResult<Bytes> {
// find the entry point
self.entry_points_caller
let ep = self
.entry_points_caller
.entry_points()
.iter()
.find(|ep| ep.name == call_def.entry_point())
.ok_or_else(|| {
OdraError::VmError(VmError::NoSuchMethod(call_def.entry_point().to_owned()))
})?;
if !ep.is_payable && call_def.amount() > U512::zero() {
return Err(OdraError::ExecutionError(ExecutionError::NonPayable));
}
self.entry_points_caller.call(call_def)
}
}
Expand Down
19 changes: 17 additions & 2 deletions core/src/entry_point_callback.rs
Original file line number Diff line number Diff line change
Expand Up @@ -65,13 +65,28 @@ pub struct EntryPoint {
/// The name of the entry point.
pub name: String,
/// The collection of arguments to the entry point.
pub args: Vec<Argument>
pub args: Vec<Argument>,
/// A flag indicating whether the entry point is payable.
pub is_payable: bool
}

impl EntryPoint {
/// Creates a new instance of `EntryPoint`.
pub fn new(name: String, args: Vec<Argument>) -> Self {
Self { name, args }
Self {
name,
args,
is_payable: false
}
}

/// Creates a new instance of payable `EntryPoint`.
pub fn new_payable(name: String, args: Vec<Argument>) -> Self {
Self {
name,
args,
is_payable: true
}
}
}

Expand Down
20 changes: 20 additions & 0 deletions examples/src/features/native_token.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,4 +47,24 @@ mod tests {
let contract_balance = my_contract.balance();
assert_eq!(contract_balance, original_contract_balance + U512::from(75));
}

#[test]
fn test_call_non_payable_function_with_tokens() {
let test_env = odra_test::env();
let contract = PublicWalletHostRef::deploy(&test_env, NoArgs);
let caller_address = test_env.get_account(0);
let original_caller_balance = test_env.balance_of(&caller_address);

contract.with_tokens(U512::from(100)).deposit();
// call a non-payable function with tokens should fail and tokens should be refunded
assert!(contract
.with_tokens(U512::from(10))
.try_withdraw(&U512::from(25))
.is_err());
// only the `deposit` function should have an effect
assert_eq!(
test_env.balance_of(&caller_address),
original_caller_balance - U512::from(100)
);
}
}
8 changes: 5 additions & 3 deletions odra-casper/proxy-caller/bin/proxy_caller.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,18 @@ compile_error!("This binary only supports wasm32 target architecture!");

extern crate alloc;

use odra_casper_proxy_caller::ProxyCall;
use odra_casper_wasm_env::casper_contract::contract_api::runtime::call_versioned_contract;
use odra_casper_proxy_caller::{ensure_cargo_purse_is_empty, ProxyCall};

use odra_casper_wasm_env::casper_contract::contract_api::runtime;

#[no_mangle]
fn call() {
let proxy_call = ProxyCall::load_from_args();
let _: () = call_versioned_contract(
let _: () = runtime::call_versioned_contract(
proxy_call.contract_package_hash,
None,
proxy_call.entry_point_name.as_str(),
proxy_call.runtime_args
);
ensure_cargo_purse_is_empty(proxy_call.attached_value);
}
5 changes: 4 additions & 1 deletion odra-casper/proxy-caller/bin/proxy_caller_with_return.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@ compile_error!("This binary only supports wasm32 target architecture!");

extern crate alloc;

use odra_casper_proxy_caller::{call_versioned_contract_ret_bytes, set_key, ProxyCall};
use odra_casper_proxy_caller::{
call_versioned_contract_ret_bytes, ensure_cargo_purse_is_empty, set_key, ProxyCall
};
use odra_core::casper_types::bytesrepr::Bytes;
use odra_core::consts::RESULT_KEY;
use odra_core::prelude::*;
Expand All @@ -20,5 +22,6 @@ fn call() {
proxy_call.entry_point_name.as_str(),
proxy_call.runtime_args
);
ensure_cargo_purse_is_empty(proxy_call.attached_value);
set_key(RESULT_KEY, Bytes::from(result));
}
11 changes: 11 additions & 0 deletions odra-casper/proxy-caller/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,17 @@ pub fn call_versioned_contract_ret_bytes(
deserialize_contract_result(bytes_written)
}

/// Ensures that the cargo purse is empty. Reverts if it's not.
pub fn ensure_cargo_purse_is_empty(value: U512) {
if !value.is_zero() {
let cargo_purse = get_cargo_purse();
let balance = system::get_purse_balance(cargo_purse).unwrap_or_revert();
if !balance.is_zero() {
revert(ApiError::InvalidPurse);
}
}
}

/// Load or create cargo purse.
fn get_cargo_purse() -> URef {
match runtime::get_key(CARGO_PURSE_KEY) {
Expand Down
Binary file modified odra-casper/test-vm/resources/proxy_caller.wasm
Binary file not shown.
Binary file modified odra-casper/test-vm/resources/proxy_caller_with_return.wasm
Binary file not shown.
4 changes: 2 additions & 2 deletions odra-macros/src/ast/deployer_item.rs
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,7 @@ mod deployer_impl {
]
),
odra::entry_point_callback::EntryPoint::new(odra::prelude::string::String::from("total_supply"), odra::prelude::vec![]),
odra::entry_point_callback::EntryPoint::new(odra::prelude::string::String::from("pay_to_mint"), odra::prelude::vec![]),
odra::entry_point_callback::EntryPoint::new_payable(odra::prelude::string::String::from("pay_to_mint"), odra::prelude::vec![]),
odra::entry_point_callback::EntryPoint::new(
odra::prelude::string::String::from("approve"),
odra::prelude::vec![
Expand Down Expand Up @@ -274,7 +274,7 @@ mod deployer_impl {
fn entry_points_caller(env: &odra::host::HostEnv) -> odra::entry_point_callback::EntryPointsCaller {
let entry_points = odra::prelude::vec![
odra::entry_point_callback::EntryPoint::new(odra::prelude::string::String::from("total_supply"), odra::prelude::vec![]),
odra::entry_point_callback::EntryPoint::new(odra::prelude::string::String::from("pay_to_mint"), odra::prelude::vec![])
odra::entry_point_callback::EntryPoint::new_payable(odra::prelude::string::String::from("pay_to_mint"), odra::prelude::vec![])
];
odra::entry_point_callback::EntryPointsCaller::new(env.clone(), entry_points, |contract_env, call_def| {
match call_def.entry_point() {
Expand Down
3 changes: 2 additions & 1 deletion odra-macros/src/ast/deployer_utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,10 @@ impl TryFrom<&'_ ModuleImplIR> for EntrypointsInitExpr {

fn try_from(module: &'_ ModuleImplIR) -> Result<Self, Self::Error> {
let functions = module.functions()?;

let entry_points = functions
.iter()
.map(|f| utils::expr::new_entry_point(f.name_str(), f.raw_typed_args()))
.map(|f| utils::expr::new_entry_point(f.name_str(), f.raw_typed_args(), f.is_payable()))
.collect::<Punctuated<_, syn::Token![,]>>();
let value_expr = utils::expr::vec(entry_points);

Expand Down
4 changes: 2 additions & 2 deletions odra-macros/src/ast/test_parts.rs
Original file line number Diff line number Diff line change
Expand Up @@ -292,7 +292,7 @@ mod test {
odra::entry_point_callback::Argument::new::<Option<U256> >(odra::prelude::string::String::from("total_supply"))
]),
odra::entry_point_callback::EntryPoint::new(odra::prelude::string::String::from("total_supply"), odra::prelude::vec![]),
odra::entry_point_callback::EntryPoint::new(odra::prelude::string::String::from("pay_to_mint"), odra::prelude::vec![]),
odra::entry_point_callback::EntryPoint::new_payable(odra::prelude::string::String::from("pay_to_mint"), odra::prelude::vec![]),
odra::entry_point_callback::EntryPoint::new(odra::prelude::string::String::from("approve"), odra::prelude::vec![
odra::entry_point_callback::Argument::new::<Address>(odra::prelude::string::String::from("to")),
odra::entry_point_callback::Argument::new::<U256>(odra::prelude::string::String::from("amount")),
Expand Down Expand Up @@ -454,7 +454,7 @@ mod test {
fn entry_points_caller(env: &odra::host::HostEnv) -> odra::entry_point_callback::EntryPointsCaller {
let entry_points = odra::prelude::vec![
odra::entry_point_callback::EntryPoint::new(odra::prelude::string::String::from("total_supply"), odra::prelude::vec![]),
odra::entry_point_callback::EntryPoint::new(odra::prelude::string::String::from("pay_to_mint"), odra::prelude::vec![])
odra::entry_point_callback::EntryPoint::new_payable(odra::prelude::string::String::from("pay_to_mint"), odra::prelude::vec![])
];
odra::entry_point_callback::EntryPointsCaller::new(env.clone(), entry_points, |contract_env, call_def| {
match call_def.entry_point() {
Expand Down
8 changes: 6 additions & 2 deletions odra-macros/src/utils/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -164,15 +164,19 @@ pub fn default() -> syn::Expr {
parse_quote!(Default::default())
}

pub fn new_entry_point(name: String, args: Vec<syn::PatType>) -> syn::Expr {
pub fn new_entry_point(name: String, args: Vec<syn::PatType>, is_payable: bool) -> syn::Expr {
let ty = super::ty::odra_entry_point();
let name = string_from(name);
let args_stream = args
.iter()
.map(new_entry_point_arg)
.collect::<Punctuated<_, syn::Token![,]>>();
let args_vec = vec(args_stream);
parse_quote!(#ty::new(#name, #args_vec))
if is_payable {
parse_quote!(#ty::new_payable(#name, #args_vec))
} else {
parse_quote!(#ty::new(#name, #args_vec))
}
}

fn new_entry_point_arg(arg: &syn::PatType) -> syn::Expr {
Expand Down
41 changes: 16 additions & 25 deletions odra-vm/src/vm/odra_vm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,10 @@ impl OdraVm {
.unwrap()
.call(&address, call_def);

self.handle_call_result(result)
match result {
Err(err) => self.revert(err),
Ok(bytes) => self.handle_call_result(bytes)
}
}

/// Stops the execution of the virtual machine and reverts all the changes.
Expand Down Expand Up @@ -355,32 +358,17 @@ impl OdraVm {
state.push_callstack_element(element);
}

fn handle_call_result(&self, result: OdraResult<Bytes>) -> Bytes {
fn handle_call_result(&self, result: Bytes) -> Bytes {
let mut state = self.state.write().unwrap();
let result = match result {
Ok(data) => data,
Err(err) => {
state.set_error(err);
Bytes::new()
}
};

// Drop the address from stack.
state.pop_callstack_element();

if state.error.is_none() {
// If only one address on the call_stack, drop the snapshot
if state.is_in_caller_context() {
state.drop_snapshot();
}
result
} else {
// If only one address on the call_stack an an error occurred, restore the snapshot
if state.is_in_caller_context() {
state.restore_snapshot();
};
Bytes::new()
// If only one address on the call_stack, drop the snapshot
if state.is_in_caller_context() {
state.drop_snapshot();
}
result
}

fn key_of_named_key(name: &str) -> String {
Expand Down Expand Up @@ -497,7 +485,9 @@ mod tests {

// when call a contract
let call_def = CallDef::new(TEST_ENTRY_POINT, false, RuntimeArgs::new());
instance.call_contract(address, call_def);
let _ = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
instance.call_contract(address, call_def)
}));

// then the vm is in error state
assert_eq!(
Expand All @@ -515,8 +505,9 @@ mod tests {

// when call non-existing entrypoint
let call_def = CallDef::new(invalid_entry_point_name, false, RuntimeArgs::new());

instance.call_contract(contract_address, call_def);
let _ = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
instance.call_contract(contract_address, call_def)
}));

// then the vm is in error state
assert_eq!(
Expand Down Expand Up @@ -715,7 +706,7 @@ mod tests {
let vm = OdraVm::new();
let host_env = OdraVmHost::new(vm);
let env = HostEnv::new(host_env);
let entry_point = EntryPoint::new(String::from(entry_point_name), vec![]);
let entry_point = EntryPoint::new_payable(String::from(entry_point_name), vec![]);
EntryPointsCaller::new(env, vec![entry_point], |_, _| Ok(test_call_result()))
}
}

0 comments on commit 9708458

Please sign in to comment.