diff --git a/docs/language/builtins.rst b/docs/language/builtins.rst index c7a43b7c4..18f36fd8f 100644 --- a/docs/language/builtins.rst +++ b/docs/language/builtins.rst @@ -343,6 +343,11 @@ is_contract(address AccountId) returns (bool) Only available on Polkadot. Checks whether the given address is a contract address. +caller_is_root() returns (bool) ++++++++++++++++++++++++++++++++ + +Only available on Polkadot. Returns true if the caller of the contract is `root `_. + set_code_hash(uint8[32] hash) returns (uint32) ++++++++++++++++++++++++++++++++++++++++++++++ diff --git a/integration/polkadot/caller_is_root.sol b/integration/polkadot/caller_is_root.sol new file mode 100644 index 000000000..21e1e5c5d --- /dev/null +++ b/integration/polkadot/caller_is_root.sol @@ -0,0 +1,14 @@ +import "polkadot"; + +contract CallerIsRoot { + uint public balance; + + function covert() public payable { + if (caller_is_root()) { + balance = 0xdeadbeef; + } else { + print("burn more gas"); + balance = 1; + } + } +} diff --git a/integration/polkadot/caller_is_root.spec.ts b/integration/polkadot/caller_is_root.spec.ts new file mode 100644 index 000000000..814f22852 --- /dev/null +++ b/integration/polkadot/caller_is_root.spec.ts @@ -0,0 +1,42 @@ +import expect from 'expect'; +import { createConnection, deploy, aliceKeypair, query, weight, transaction } from './index'; +import { ContractPromise } from '@polkadot/api-contract'; +import { ApiPromise } from '@polkadot/api'; +import { KeyringPair } from '@polkadot/keyring/types'; + +describe('Deploy the caller_is_root contract and test it', () => { + let conn: ApiPromise; + let contract: ContractPromise; + let alice: KeyringPair; + + before(async function () { + conn = await createConnection(); + alice = aliceKeypair(); + const instance = await deploy(conn, alice, 'CallerIsRoot.contract', 0n); + contract = new ContractPromise(conn, instance.abi, instance.address); + }); + + after(async function () { + await conn.disconnect(); + }); + + it('is correct on a non-root caller', async function () { + // Without sudo the caller should not be root + const gasLimit = await weight(conn, contract, "covert"); + await transaction(contract.tx.covert({ gasLimit }), alice); + + // Calling `covert` as non-root sets the balance to 1 + const balance = await query(conn, alice, contract, "balance", []); + expect(BigInt(balance.output?.toString() ?? "")).toStrictEqual(1n); + }); + + it('is correct on a root caller', async function () { + // Alice has sudo rights on --dev nodes + const gasLimit = await weight(conn, contract, "covert"); + await transaction(conn.tx.sudo.sudo(contract.tx.covert({ gasLimit })), alice); + + // Calling `covert` as root sets the balance to 0xdeadbeef + const balance = await query(conn, alice, contract, "balance", []); + expect(BigInt(balance.output?.toString() ?? "")).toStrictEqual(0xdeadbeefn); + }); +}); diff --git a/src/emit/instructions.rs b/src/emit/instructions.rs index f148a06d5..869106e71 100644 --- a/src/emit/instructions.rs +++ b/src/emit/instructions.rs @@ -519,10 +519,14 @@ pub(super) fn process_instruction<'a, T: TargetRuntime<'a> + ?Sized>( } } - let first_arg_type = bin.llvm_type(&args[0].ty(), ns); - if let Some(ret) = - target.builtin_function(bin, function, callee, &parms, first_arg_type, ns) - { + if let Some(ret) = target.builtin_function( + bin, + function, + callee, + &parms, + args.first().map(|arg| bin.llvm_type(&arg.ty(), ns)), + ns, + ) { let success = bin.builder.build_int_compare( IntPredicate::EQ, ret.into_int_value(), diff --git a/src/emit/mod.rs b/src/emit/mod.rs index f8d2ccab0..09a5ea80b 100644 --- a/src/emit/mod.rs +++ b/src/emit/mod.rs @@ -231,7 +231,7 @@ pub trait TargetRuntime<'a> { function: FunctionValue<'a>, builtin_func: &Function, args: &[BasicMetadataValueEnum<'a>], - first_arg_type: BasicTypeEnum, + first_arg_type: Option, ns: &Namespace, ) -> Option>; diff --git a/src/emit/polkadot/mod.rs b/src/emit/polkadot/mod.rs index 42f2d0d3c..823dc620e 100644 --- a/src/emit/polkadot/mod.rs +++ b/src/emit/polkadot/mod.rs @@ -118,6 +118,7 @@ impl PolkadotTarget { "transfer", "is_contract", "set_code_hash", + "caller_is_root", ]); binary @@ -266,6 +267,7 @@ impl PolkadotTarget { external!("deposit_event", void_type, u8_ptr, u32_val, u8_ptr, u32_val); external!("is_contract", i32_type, u8_ptr); external!("set_code_hash", i32_type, u8_ptr); + external!("caller_is_root", i32_type,); } /// Emits the "deploy" function if `storage_initializer` is `Some`, otherwise emits the "call" function. diff --git a/src/emit/polkadot/target.rs b/src/emit/polkadot/target.rs index 5fb0fda30..a0adfe86d 100644 --- a/src/emit/polkadot/target.rs +++ b/src/emit/polkadot/target.rs @@ -1501,7 +1501,7 @@ impl<'a> TargetRuntime<'a> for PolkadotTarget { _function: FunctionValue<'a>, builtin_func: &Function, args: &[BasicMetadataValueEnum<'a>], - _first_arg_type: BasicTypeEnum, + _first_arg_type: Option, ns: &Namespace, ) -> Option> { emit_context!(binary); @@ -1579,6 +1579,17 @@ impl<'a> TargetRuntime<'a> for PolkadotTarget { .build_store(args[1].into_pointer_value(), ret); None } + "caller_is_root" => { + let is_root = call!("caller_is_root", &[], "seal_caller_is_root") + .try_as_basic_value() + .left() + .unwrap() + .into_int_value(); + binary + .builder + .build_store(args[0].into_pointer_value(), is_root); + None + } _ => unimplemented!(), } } diff --git a/src/emit/solana/target.rs b/src/emit/solana/target.rs index 5ce153a7b..5b4c36b34 100644 --- a/src/emit/solana/target.rs +++ b/src/emit/solana/target.rs @@ -1251,9 +1251,12 @@ impl<'a> TargetRuntime<'a> for SolanaTarget { function: FunctionValue<'a>, builtin_func: &ast::Function, args: &[BasicMetadataValueEnum<'a>], - first_arg_type: BasicTypeEnum, + first_arg_type: Option, ns: &ast::Namespace, ) -> Option> { + let first_arg_type = + first_arg_type.expect("solana does not have builtin without any parameter"); + if builtin_func.id.name == "create_program_address" { let func = binary .module diff --git a/src/emit/soroban/target.rs b/src/emit/soroban/target.rs index 3a2f3f8db..56fb02f87 100644 --- a/src/emit/soroban/target.rs +++ b/src/emit/soroban/target.rs @@ -218,7 +218,7 @@ impl<'a> TargetRuntime<'a> for SorobanTarget { function: FunctionValue<'a>, builtin_func: &Function, args: &[BasicMetadataValueEnum<'a>], - first_arg_type: BasicTypeEnum, + first_arg_type: Option, ns: &Namespace, ) -> Option> { unimplemented!() diff --git a/src/sema/builtin.rs b/src/sema/builtin.rs index 34baeecd3..cc3ee66ee 100644 --- a/src/sema/builtin.rs +++ b/src/sema/builtin.rs @@ -1826,6 +1826,33 @@ impl Namespace { }], self, ), + // caller_is_root API + Function::new( + loc, + loc, + pt::Identifier { + name: "caller_is_root".to_string(), + loc, + }, + None, + Vec::new(), + pt::FunctionTy::Function, + Some(pt::Mutability::View(loc)), + pt::Visibility::Public(Some(loc)), + vec![], + vec![Parameter { + loc, + id: Some(identifier("caller_is_root")), + ty: Type::Bool, + ty_loc: Some(loc), + readonly: false, + indexed: false, + infinite_size: false, + recursive: false, + annotation: None, + }], + self, + ), ] { func.has_body = true; let func_no = self.functions.len(); diff --git a/tests/lir_tests/convert_lir.rs b/tests/lir_tests/convert_lir.rs index 5e90155f9..63d171179 100644 --- a/tests/lir_tests/convert_lir.rs +++ b/tests/lir_tests/convert_lir.rs @@ -660,7 +660,7 @@ fn test_assertion_using_require() { assert_polkadot_lir_str_eq( src, 0, - r#"public function sol#3 Test::Test::function::test__int32 (int32): + r#"public function sol#4 Test::Test::function::test__int32 (int32): block#0 entry: int32 %num = int32(arg#0); bool %temp.ssa_ir.1 = int32(%num) > int32(10); @@ -690,7 +690,7 @@ fn test_call_1() { assert_polkadot_lir_str_eq( src, 0, - r#"public function sol#3 Test::Test::function::test__int32 (int32): + r#"public function sol#4 Test::Test::function::test__int32 (int32): block#0 entry: int32 %num = int32(arg#0); = call function#1(int32(%num)); @@ -754,7 +754,7 @@ fn test_value_transfer() { assert_polkadot_lir_str_eq( src, 0, - r#"public function sol#3 Test::Test::function::transfer__address_uint128 (uint8[32], uint128): + r#"public function sol#4 Test::Test::function::transfer__address_uint128 (uint8[32], uint128): block#0 entry: uint8[32] %addr = uint8[32](arg#0); uint128 %amount = uint128(arg#1); @@ -928,7 +928,7 @@ fn test_keccak256() { assert_polkadot_lir_str_eq( src, 0, - r#"public function sol#3 b::b::function::add__string_address (ptr>, uint8[32]): + r#"public function sol#4 b::b::function::add__string_address (ptr>, uint8[32]): block#0 entry: ptr> %name = ptr>(arg#0); uint8[32] %addr = uint8[32](arg#1); @@ -960,7 +960,7 @@ fn test_internal_function_cfg() { assert_polkadot_lir_str_eq( src, 1, - r#"public function sol#4 A::A::function::bar__uint256 (uint256) returns (uint256): + r#"public function sol#5 A::A::function::bar__uint256 (uint256) returns (uint256): block#0 entry: uint256 %b = uint256(arg#0); ptr %temp.ssa_ir.6 = function#0; @@ -1124,14 +1124,14 @@ fn test_constructor() { assert_polkadot_lir_str_eq( src, 0, - r#"public function sol#3 B::B::function::test__uint256 (uint256): + r#"public function sol#4 B::B::function::test__uint256 (uint256): block#0 entry: uint256 %a = uint256(arg#0); ptr> %abi_encoded.temp.18 = alloc ptr>[uint32(36)]; uint32 %temp.ssa_ir.20 = uint32 hex"58_16_c4_25"; write_buf ptr>(%abi_encoded.temp.18) offset:uint32(0) value:uint32(%temp.ssa_ir.20); write_buf ptr>(%abi_encoded.temp.18) offset:uint32(4) value:uint256(%a); - uint32 %success.temp.17, uint8[32] %temp.16 = constructor(no: 5, contract_no:1) salt:_ value:_ gas:uint64(0) address:_ seeds:_ encoded-buffer:ptr>(%abi_encoded.temp.18) accounts:absent + uint32 %success.temp.17, uint8[32] %temp.16 = constructor(no: 6, contract_no:1) salt:_ value:_ gas:uint64(0) address:_ seeds:_ encoded-buffer:ptr>(%abi_encoded.temp.18) accounts:absent switch uint32(%success.temp.17): case: uint32(0) => block#1, case: uint32(2) => block#2 diff --git a/tests/polkadot.rs b/tests/polkadot.rs index 159eeef32..068f879fd 100644 --- a/tests/polkadot.rs +++ b/tests/polkadot.rs @@ -357,6 +357,8 @@ fn read_hash(mem: &[u8], ptr: u32) -> Hash { /// Host functions mock the original implementation, refer to the [pallet docs][1] for more information. /// /// [1]: https://docs.rs/pallet-contracts/latest/pallet_contracts/api_doc/index.html +/// +/// Address `[0; u8]` is considered the root account. #[wasm_host] impl Runtime { #[seal(0)] @@ -787,6 +789,11 @@ impl Runtime { .into()) } + #[seal(0)] + fn caller_is_root() -> Result { + Ok((vm.accounts[vm.caller_account].address == [0; 32]).into()) + } + #[seal(0)] fn set_code_hash(code_hash_ptr: u32) -> Result { let hash = read_hash(mem, code_hash_ptr); @@ -818,6 +825,11 @@ impl MockSubstrate { Ok(()) } + /// Overwrites the address at asssociated `account` index with the given `address`. + pub fn set_account_address(&mut self, account: usize, address: [u8; 32]) { + self.0.data_mut().accounts[account].address = address; + } + /// Specify the caller account index for the next function or constructor call. pub fn set_account(&mut self, index: usize) { self.0.data_mut().account = index; diff --git a/tests/polkadot_tests/builtins.rs b/tests/polkadot_tests/builtins.rs index 7b01c6f61..78865b82a 100644 --- a/tests/polkadot_tests/builtins.rs +++ b/tests/polkadot_tests/builtins.rs @@ -845,3 +845,24 @@ fn set_code_hash() { runtime.function("count", vec![]); assert_eq!(runtime.output(), 1u32.encode()); } + +#[test] +fn caller_is_root() { + let mut runtime = build_solidity( + r#" + import { caller_is_root } from "polkadot"; + contract Test { + function test() public view returns (bool) { + return caller_is_root(); + } + }"#, + ); + + runtime.function("test", runtime.0.data().accounts[0].address.to_vec()); + assert_eq!(runtime.output(), false.encode()); + + // Set the caller address to [0; 32] which is the mock VM root account + runtime.set_account_address(0, [0; 32]); + runtime.function("test", [0; 32].to_vec()); + assert_eq!(runtime.output(), true.encode()); +}