Skip to content

Commit

Permalink
Polkadot: Implement caller_is_root runtime API (#1620)
Browse files Browse the repository at this point in the history
Recent versions of the contracts pallet expose a runtime API for
contracts to check whether their caller is of root origin. The PR
exposes this API as a builtin.
  • Loading branch information
xermicus committed Feb 1, 2024
1 parent b6b2fc9 commit cda387d
Show file tree
Hide file tree
Showing 13 changed files with 156 additions and 15 deletions.
5 changes: 5 additions & 0 deletions docs/language/builtins.rst
Original file line number Diff line number Diff line change
Expand Up @@ -343,6 +343,11 @@ is_contract(address AccountId) returns (bool)

Only available on Polkadot. Checks whether the given address is a contract address.

caller_is_root() returns (bool)
+++++++++++++++++++++++++++++++

Only available on Polkadot. Returns true if the caller of the contract is `root <https://docs.substrate.io/build/origins/>`_.

set_code_hash(uint8[32] hash) returns (uint32)
++++++++++++++++++++++++++++++++++++++++++++++

Expand Down
14 changes: 14 additions & 0 deletions integration/polkadot/caller_is_root.sol
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
import "polkadot";

contract CallerIsRoot {
uint public balance;

function covert() public payable {
if (caller_is_root()) {
balance = 0xdeadbeef;
} else {
print("burn more gas");
balance = 1;
}
}
}
42 changes: 42 additions & 0 deletions integration/polkadot/caller_is_root.spec.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import expect from 'expect';
import { createConnection, deploy, aliceKeypair, query, weight, transaction } from './index';
import { ContractPromise } from '@polkadot/api-contract';
import { ApiPromise } from '@polkadot/api';
import { KeyringPair } from '@polkadot/keyring/types';

describe('Deploy the caller_is_root contract and test it', () => {
let conn: ApiPromise;
let contract: ContractPromise;
let alice: KeyringPair;

before(async function () {
conn = await createConnection();
alice = aliceKeypair();
const instance = await deploy(conn, alice, 'CallerIsRoot.contract', 0n);
contract = new ContractPromise(conn, instance.abi, instance.address);
});

after(async function () {
await conn.disconnect();
});

it('is correct on a non-root caller', async function () {
// Without sudo the caller should not be root
const gasLimit = await weight(conn, contract, "covert");
await transaction(contract.tx.covert({ gasLimit }), alice);

// Calling `covert` as non-root sets the balance to 1
const balance = await query(conn, alice, contract, "balance", []);
expect(BigInt(balance.output?.toString() ?? "")).toStrictEqual(1n);
});

it('is correct on a root caller', async function () {
// Alice has sudo rights on --dev nodes
const gasLimit = await weight(conn, contract, "covert");
await transaction(conn.tx.sudo.sudo(contract.tx.covert({ gasLimit })), alice);

// Calling `covert` as root sets the balance to 0xdeadbeef
const balance = await query(conn, alice, contract, "balance", []);
expect(BigInt(balance.output?.toString() ?? "")).toStrictEqual(0xdeadbeefn);
});
});
12 changes: 8 additions & 4 deletions src/emit/instructions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -519,10 +519,14 @@ pub(super) fn process_instruction<'a, T: TargetRuntime<'a> + ?Sized>(
}
}

let first_arg_type = bin.llvm_type(&args[0].ty(), ns);
if let Some(ret) =
target.builtin_function(bin, function, callee, &parms, first_arg_type, ns)
{
if let Some(ret) = target.builtin_function(
bin,
function,
callee,
&parms,
args.first().map(|arg| bin.llvm_type(&arg.ty(), ns)),
ns,
) {
let success = bin.builder.build_int_compare(
IntPredicate::EQ,
ret.into_int_value(),
Expand Down
2 changes: 1 addition & 1 deletion src/emit/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,7 @@ pub trait TargetRuntime<'a> {
function: FunctionValue<'a>,
builtin_func: &Function,
args: &[BasicMetadataValueEnum<'a>],
first_arg_type: BasicTypeEnum,
first_arg_type: Option<BasicTypeEnum>,
ns: &Namespace,
) -> Option<BasicValueEnum<'a>>;

Expand Down
2 changes: 2 additions & 0 deletions src/emit/polkadot/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@ impl PolkadotTarget {
"transfer",
"is_contract",
"set_code_hash",
"caller_is_root",
]);

binary
Expand Down Expand Up @@ -266,6 +267,7 @@ impl PolkadotTarget {
external!("deposit_event", void_type, u8_ptr, u32_val, u8_ptr, u32_val);
external!("is_contract", i32_type, u8_ptr);
external!("set_code_hash", i32_type, u8_ptr);
external!("caller_is_root", i32_type,);
}

/// Emits the "deploy" function if `storage_initializer` is `Some`, otherwise emits the "call" function.
Expand Down
13 changes: 12 additions & 1 deletion src/emit/polkadot/target.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1501,7 +1501,7 @@ impl<'a> TargetRuntime<'a> for PolkadotTarget {
_function: FunctionValue<'a>,
builtin_func: &Function,
args: &[BasicMetadataValueEnum<'a>],
_first_arg_type: BasicTypeEnum,
_first_arg_type: Option<BasicTypeEnum>,
ns: &Namespace,
) -> Option<BasicValueEnum<'a>> {
emit_context!(binary);
Expand Down Expand Up @@ -1579,6 +1579,17 @@ impl<'a> TargetRuntime<'a> for PolkadotTarget {
.build_store(args[1].into_pointer_value(), ret);
None
}
"caller_is_root" => {
let is_root = call!("caller_is_root", &[], "seal_caller_is_root")
.try_as_basic_value()
.left()
.unwrap()
.into_int_value();
binary
.builder
.build_store(args[0].into_pointer_value(), is_root);
None
}
_ => unimplemented!(),
}
}
Expand Down
5 changes: 4 additions & 1 deletion src/emit/solana/target.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1251,9 +1251,12 @@ impl<'a> TargetRuntime<'a> for SolanaTarget {
function: FunctionValue<'a>,
builtin_func: &ast::Function,
args: &[BasicMetadataValueEnum<'a>],
first_arg_type: BasicTypeEnum,
first_arg_type: Option<BasicTypeEnum>,
ns: &ast::Namespace,
) -> Option<BasicValueEnum<'a>> {
let first_arg_type =
first_arg_type.expect("solana does not have builtin without any parameter");

if builtin_func.id.name == "create_program_address" {
let func = binary
.module
Expand Down
2 changes: 1 addition & 1 deletion src/emit/soroban/target.rs
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ impl<'a> TargetRuntime<'a> for SorobanTarget {
function: FunctionValue<'a>,
builtin_func: &Function,
args: &[BasicMetadataValueEnum<'a>],
first_arg_type: BasicTypeEnum,
first_arg_type: Option<BasicTypeEnum>,
ns: &Namespace,
) -> Option<BasicValueEnum<'a>> {
unimplemented!()
Expand Down
27 changes: 27 additions & 0 deletions src/sema/builtin.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1826,6 +1826,33 @@ impl Namespace {
}],
self,
),
// caller_is_root API
Function::new(
loc,
loc,
pt::Identifier {
name: "caller_is_root".to_string(),
loc,
},
None,
Vec::new(),
pt::FunctionTy::Function,
Some(pt::Mutability::View(loc)),
pt::Visibility::Public(Some(loc)),
vec![],
vec![Parameter {
loc,
id: Some(identifier("caller_is_root")),
ty: Type::Bool,
ty_loc: Some(loc),
readonly: false,
indexed: false,
infinite_size: false,
recursive: false,
annotation: None,
}],
self,
),
] {
func.has_body = true;
let func_no = self.functions.len();
Expand Down
14 changes: 7 additions & 7 deletions tests/lir_tests/convert_lir.rs
Original file line number Diff line number Diff line change
Expand Up @@ -660,7 +660,7 @@ fn test_assertion_using_require() {
assert_polkadot_lir_str_eq(
src,
0,
r#"public function sol#3 Test::Test::function::test__int32 (int32):
r#"public function sol#4 Test::Test::function::test__int32 (int32):
block#0 entry:
int32 %num = int32(arg#0);
bool %temp.ssa_ir.1 = int32(%num) > int32(10);
Expand Down Expand Up @@ -690,7 +690,7 @@ fn test_call_1() {
assert_polkadot_lir_str_eq(
src,
0,
r#"public function sol#3 Test::Test::function::test__int32 (int32):
r#"public function sol#4 Test::Test::function::test__int32 (int32):
block#0 entry:
int32 %num = int32(arg#0);
= call function#1(int32(%num));
Expand Down Expand Up @@ -754,7 +754,7 @@ fn test_value_transfer() {
assert_polkadot_lir_str_eq(
src,
0,
r#"public function sol#3 Test::Test::function::transfer__address_uint128 (uint8[32], uint128):
r#"public function sol#4 Test::Test::function::transfer__address_uint128 (uint8[32], uint128):
block#0 entry:
uint8[32] %addr = uint8[32](arg#0);
uint128 %amount = uint128(arg#1);
Expand Down Expand Up @@ -928,7 +928,7 @@ fn test_keccak256() {
assert_polkadot_lir_str_eq(
src,
0,
r#"public function sol#3 b::b::function::add__string_address (ptr<struct.vector<uint8>>, uint8[32]):
r#"public function sol#4 b::b::function::add__string_address (ptr<struct.vector<uint8>>, uint8[32]):
block#0 entry:
ptr<struct.vector<uint8>> %name = ptr<struct.vector<uint8>>(arg#0);
uint8[32] %addr = uint8[32](arg#1);
Expand Down Expand Up @@ -960,7 +960,7 @@ fn test_internal_function_cfg() {
assert_polkadot_lir_str_eq(
src,
1,
r#"public function sol#4 A::A::function::bar__uint256 (uint256) returns (uint256):
r#"public function sol#5 A::A::function::bar__uint256 (uint256) returns (uint256):
block#0 entry:
uint256 %b = uint256(arg#0);
ptr<function (uint256) returns (uint256)> %temp.ssa_ir.6 = function#0;
Expand Down Expand Up @@ -1124,14 +1124,14 @@ fn test_constructor() {
assert_polkadot_lir_str_eq(
src,
0,
r#"public function sol#3 B::B::function::test__uint256 (uint256):
r#"public function sol#4 B::B::function::test__uint256 (uint256):
block#0 entry:
uint256 %a = uint256(arg#0);
ptr<struct.vector<uint8>> %abi_encoded.temp.18 = alloc ptr<struct.vector<uint8>>[uint32(36)];
uint32 %temp.ssa_ir.20 = uint32 hex"58_16_c4_25";
write_buf ptr<struct.vector<uint8>>(%abi_encoded.temp.18) offset:uint32(0) value:uint32(%temp.ssa_ir.20);
write_buf ptr<struct.vector<uint8>>(%abi_encoded.temp.18) offset:uint32(4) value:uint256(%a);
uint32 %success.temp.17, uint8[32] %temp.16 = constructor(no: 5, contract_no:1) salt:_ value:_ gas:uint64(0) address:_ seeds:_ encoded-buffer:ptr<struct.vector<uint8>>(%abi_encoded.temp.18) accounts:absent
uint32 %success.temp.17, uint8[32] %temp.16 = constructor(no: 6, contract_no:1) salt:_ value:_ gas:uint64(0) address:_ seeds:_ encoded-buffer:ptr<struct.vector<uint8>>(%abi_encoded.temp.18) accounts:absent
switch uint32(%success.temp.17):
case: uint32(0) => block#1,
case: uint32(2) => block#2
Expand Down
12 changes: 12 additions & 0 deletions tests/polkadot.rs
Original file line number Diff line number Diff line change
Expand Up @@ -357,6 +357,8 @@ fn read_hash(mem: &[u8], ptr: u32) -> Hash {
/// Host functions mock the original implementation, refer to the [pallet docs][1] for more information.
///
/// [1]: https://docs.rs/pallet-contracts/latest/pallet_contracts/api_doc/index.html
///
/// Address `[0; u8]` is considered the root account.
#[wasm_host]
impl Runtime {
#[seal(0)]
Expand Down Expand Up @@ -787,6 +789,11 @@ impl Runtime {
.into())
}

#[seal(0)]
fn caller_is_root() -> Result<u32, Trap> {
Ok((vm.accounts[vm.caller_account].address == [0; 32]).into())
}

#[seal(0)]
fn set_code_hash(code_hash_ptr: u32) -> Result<u32, Trap> {
let hash = read_hash(mem, code_hash_ptr);
Expand Down Expand Up @@ -818,6 +825,11 @@ impl MockSubstrate {
Ok(())
}

/// Overwrites the address at asssociated `account` index with the given `address`.
pub fn set_account_address(&mut self, account: usize, address: [u8; 32]) {
self.0.data_mut().accounts[account].address = address;
}

/// Specify the caller account index for the next function or constructor call.
pub fn set_account(&mut self, index: usize) {
self.0.data_mut().account = index;
Expand Down
21 changes: 21 additions & 0 deletions tests/polkadot_tests/builtins.rs
Original file line number Diff line number Diff line change
Expand Up @@ -845,3 +845,24 @@ fn set_code_hash() {
runtime.function("count", vec![]);
assert_eq!(runtime.output(), 1u32.encode());
}

#[test]
fn caller_is_root() {
let mut runtime = build_solidity(
r#"
import { caller_is_root } from "polkadot";
contract Test {
function test() public view returns (bool) {
return caller_is_root();
}
}"#,
);

runtime.function("test", runtime.0.data().accounts[0].address.to_vec());
assert_eq!(runtime.output(), false.encode());

// Set the caller address to [0; 32] which is the mock VM root account
runtime.set_account_address(0, [0; 32]);
runtime.function("test", [0; 32].to_vec());
assert_eq!(runtime.output(), true.encode());
}

0 comments on commit cda387d

Please sign in to comment.