Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Validate debitGasFees execution in tx pool #2279

Merged
merged 6 commits into from
Mar 12, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 29 additions & 10 deletions contracts/erc20gas/erc20gas.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"github.com/celo-org/celo-blockchain/common"
"github.com/celo-org/celo-blockchain/common/hexutil"
"github.com/celo-org/celo-blockchain/contracts/internal/n"
"github.com/celo-org/celo-blockchain/core/types"
"github.com/celo-org/celo-blockchain/core/vm"
"github.com/celo-org/celo-blockchain/log"
)
Expand All @@ -17,14 +18,34 @@ const (
maxGasForCreditGasFeesTransactions uint64 = 1 * n.Million
)

var (
// Recalculate with `cast sig 'debitGasFees(address from, uint256 value)'`
debitGasFeesSelector = hexutil.MustDecode("0x58cf9672")
karlb marked this conversation as resolved.
Show resolved Hide resolved
// Recalculate with `cast sig 'creditGasFees(address,address,address,address,uint256,uint256,uint256,uint256)'`
creditGasFeesSelector = hexutil.MustDecode("0x6a30b253")
)

// Returns nil if debit is possible, used in tx pool validation
func TryDebitFees(tx *types.Transaction, from common.Address, currentVMRunner vm.EVMRunner) error {
cost := new(big.Int).SetUint64(tx.Gas())
cost.Mul(cost, tx.GasFeeCap())

// The following code is similar to DebitFees, but that function does not work on a vm.EVMRunner,
// so we have to adapt it instead of reusing.
transactionData := common.GetEncodedAbi(debitGasFeesSelector, [][]byte{common.AddressToAbi(from), common.AmountToAbi(cost)})

ret, err := currentVMRunner.ExecuteAndDiscardChanges(*tx.FeeCurrency(), transactionData, maxGasForDebitGasFeesTransactions, common.Big0)
if err != nil {
revertReason, err2 := abi.UnpackRevert(ret)
if err2 == nil {
return fmt.Errorf("TryDebitFees reverted: %s", revertReason)
}
}
return err
}

func DebitFees(evm *vm.EVM, address common.Address, amount *big.Int, feeCurrency *common.Address) error {
// Function is "debitGasFees(address from, uint256 value)"
// selector is first 4 bytes of keccak256 of "debitGasFees(address,uint256)"
// Source:
// pip3 install pyethereum
// python3 -c 'from ethereum.utils import sha3; print(sha3("debitGasFees(address,uint256)")[0:4].hex())'
functionSelector := hexutil.MustDecode("0x58cf9672")
transactionData := common.GetEncodedAbi(functionSelector, [][]byte{common.AddressToAbi(address), common.AmountToAbi(amount)})
transactionData := common.GetEncodedAbi(debitGasFeesSelector, [][]byte{common.AddressToAbi(address), common.AmountToAbi(amount)})

// Run only primary evm.Call() with tracer
if evm.GetDebug() {
Expand Down Expand Up @@ -57,9 +78,7 @@ func CreditFees(
gatewayFee *big.Int,
baseTxFee *big.Int,
feeCurrency *common.Address) error {
// Function is "creditGasFees(address,address,address,address,uint256,uint256,uint256,uint256)"
functionSelector := hexutil.MustDecode("0x6a30b253")
transactionData := common.GetEncodedAbi(functionSelector, [][]byte{common.AddressToAbi(from), common.AddressToAbi(feeRecipient), common.AddressToAbi(*gatewayFeeRecipient), common.AddressToAbi(feeHandler), common.AmountToAbi(refund), common.AmountToAbi(tipTxFee), common.AmountToAbi(gatewayFee), common.AmountToAbi(baseTxFee)})
transactionData := common.GetEncodedAbi(creditGasFeesSelector, [][]byte{common.AddressToAbi(from), common.AddressToAbi(feeRecipient), common.AddressToAbi(*gatewayFeeRecipient), common.AddressToAbi(feeHandler), common.AmountToAbi(refund), common.AmountToAbi(tipTxFee), common.AmountToAbi(gatewayFee), common.AmountToAbi(baseTxFee)})

// Run only primary evm.Call() with tracer
if evm.GetDebug() {
Expand Down
4 changes: 4 additions & 0 deletions contracts/testutil/fail_runner.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,10 @@ func (fvm FailingVmRunner) ExecuteFrom(sender, recipient common.Address, input [
return nil, ErrFailingRunner
}

func (fvm FailingVmRunner) ExecuteAndDiscardChanges(recipient common.Address, input []byte, gas uint64, value *big.Int) (ret []byte, err error) {
return nil, ErrFailingRunner
}

func (fvm FailingVmRunner) Query(recipient common.Address, input []byte, gas uint64) (ret []byte, err error) {
return nil, ErrFailingRunner
}
Expand Down
4 changes: 4 additions & 0 deletions contracts/testutil/runner.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,10 @@ func (ev *MockEVMRunner) ExecuteFrom(sender, recipient common.Address, input []b
return ev.Execute(recipient, input, gas, value)
}

func (ev *MockEVMRunner) ExecuteAndDiscardChanges(recipient common.Address, input []byte, gas uint64, value *big.Int) (ret []byte, err error) {
return ev.Execute(recipient, input, gas, value)
}

func (ev *MockEVMRunner) Query(recipient common.Address, input []byte, gas uint64) (ret []byte, err error) {
mock, ok := ev.contracts[recipient]
if !ok {
Expand Down
73 changes: 19 additions & 54 deletions core/tx_pool.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ import (
"github.com/celo-org/celo-blockchain/consensus"
"github.com/celo-org/celo-blockchain/contracts/blockchain_parameters"
"github.com/celo-org/celo-blockchain/contracts/currency"
"github.com/celo-org/celo-blockchain/contracts/erc20gas"
gpm "github.com/celo-org/celo-blockchain/contracts/gasprice_minimum"
"github.com/celo-org/celo-blockchain/core/state"
"github.com/celo-org/celo-blockchain/core/types"
Expand Down Expand Up @@ -1717,63 +1718,27 @@ func (pool *TxPool) demoteUnexecutables() {
// ValidateTransactorBalanceCoversTx validates transactor has enough funds to cover transaction cost, the rules are consistent with state_transition.
//
// For native token(CELO) as feeCurrency:
// - Pre-Espresso: it ensures balance >= GasPrice * gas + value + gatewayFee (1)
// - Post-Espresso: it ensures balance >= GasFeeCap * gas + value + gatewayFee (2)
// - it ensures balance >= GasFeeCap * gas + value
//
// For non-native tokens(cUSD, cEUR, ...) as feeCurrency:
// - Pre-Espresso: it ensures balance > GasPrice * gas + gatewayFee (3)
// - Post-Espresso: it ensures balance >= GasFeeCap * gas + gatewayFee (4)
// - It executes a static call on debitGasFees, implicitly ensuring balance >= GasFeeCap * gas and that `from` is not on the token's block list
func ValidateTransactorBalanceCoversTx(tx *types.Transaction, from common.Address, currentState *state.StateDB, currentVMRunner vm.EVMRunner, espresso bool) error {
if tx.FeeCurrency() == nil {
balance := currentState.GetBalance(from)
var cost *big.Int

if espresso {
// cost = GasFeeCap * gas + value + gatewayFee, as in (2)
cost = new(big.Int).SetUint64(tx.Gas())
cost.Mul(cost, tx.GasFeeCap())
cost.Add(cost, tx.Value())
if tx.GatewayFeeRecipient() != nil {
cost.Add(cost, tx.GatewayFee())
}
} else {
// cost = GasPrice * gas + value + gatewayFee, as in (1)
cost = tx.Cost()
}

if balance.Cmp(cost) < 0 {
log.Debug("ValidateTransactorBalanceCoversTx: insufficient CELO funds",
"from", from, "Transaction cost", cost, "to", tx.To(),
"gas", tx.Gas(), "gas price", tx.GasPrice(), "nonce", tx.Nonce(),
"value", tx.Value(), "fee currency", tx.FeeCurrency(), "balance", balance)
return ErrInsufficientFunds
}
} else {
balance, err := currency.GetBalanceOf(currentVMRunner, from, *tx.FeeCurrency())
if err != nil {
log.Debug("ValidateTransactorBalanceCoversTx: error in getting fee currency balance", "feeCurrency", tx.FeeCurrency())
return err
}

if espresso {
// cost = GasFeeCap * gas + gatewayFee, as in (4)
cost := new(big.Int).SetUint64(tx.Gas())
cost.Mul(cost, tx.GasFeeCap())
if tx.GatewayFeeRecipient() != nil {
cost.Add(cost, tx.GatewayFee())
}
if balance.Cmp(cost) < 0 {
log.Debug("ValidateTransactorBalanceCoversTx: insufficient funds", "feeCurrency", tx.FeeCurrency(), "balance", balance)
return ErrInsufficientFunds
}
} else {
// cost = GasPrice * gas + gatewayFee, as in (3)
cost := tx.Fee()
if balance.Cmp(cost) <= 0 {
log.Debug("ValidateTransactorBalanceCoversTx: insufficient funds", "feeCurrency", tx.FeeCurrency(), "balance", balance)
return ErrInsufficientFunds
}
}
if tx.FeeCurrency() != nil {
return erc20gas.TryDebitFees(tx, from, currentVMRunner)
Copy link
Contributor

@piersy piersy Mar 5, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rather than having an else statement we could simplify this function like this:

if tx.FeeCurrency() != nil {
  return erc20gas.TryDebitFees(tx, from, currentVMRunner)
}
...

}
balance := currentState.GetBalance(from)

// cost = GasFeeCap * gas + value
cost := new(big.Int).SetUint64(tx.Gas())
cost.Mul(cost, tx.GasFeeCap())
cost.Add(cost, tx.Value())

if balance.Cmp(cost) < 0 {
log.Debug("ValidateTransactorBalanceCoversTx: insufficient CELO funds",
"from", from, "Transaction cost", cost, "to", tx.To(),
"gas", tx.Gas(), "gas price", tx.GasPrice(), "nonce", tx.Nonce(),
"value", tx.Value(), "fee currency", tx.FeeCurrency(), "balance", balance)
return ErrInsufficientFunds
}

return nil
Expand Down
3 changes: 3 additions & 0 deletions core/vm/interface.go
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,9 @@ type EVMRunner interface {
// originally used the transaction's sender instead of the zero address.
ExecuteFrom(sender, recipient common.Address, input []byte, gas uint64, value *big.Int) (ret []byte, err error)

// Execute input after creating a snapshot and revert to that snapshot afterwards.
ExecuteAndDiscardChanges(recipient common.Address, input []byte, gas uint64, value *big.Int) (ret []byte, err error)

// Query performs a read operation over the runner's state
// It can be seen as a message (input,value) from sender to recipient that returns `ret`
Query(recipient common.Address, input []byte, gas uint64) (ret []byte, err error)
Expand Down
18 changes: 18 additions & 0 deletions core/vm/vmcontext/runner.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,17 @@ func (ev *evmRunner) ExecuteFrom(sender, recipient common.Address, input []byte,
return ret, err
}

func (ev *evmRunner) ExecuteAndDiscardChanges(recipient common.Address, input []byte, gas uint64, value *big.Int) (ret []byte, err error) {
evm := ev.newEVM(VMAddress)
var snapshot = evm.StateDB.Snapshot()
if ev.dontMeterGas {
evm.StopGasMetering()
}
ret, _, err = evm.Call(vm.AccountRef(evm.Origin), recipient, input, gas, value)
evm.StateDB.RevertToSnapshot(snapshot)
return ret, err
}

func (ev *evmRunner) Query(recipient common.Address, input []byte, gas uint64) (ret []byte, err error) {
evm := ev.newEVM(VMAddress)
if ev.dontMeterGas {
Expand Down Expand Up @@ -98,6 +109,13 @@ func (sev *SharedEVMRunner) ExecuteFrom(sender, recipient common.Address, input
return ret, err
}

func (sev *SharedEVMRunner) ExecuteAndDiscardChanges(recipient common.Address, input []byte, gas uint64, value *big.Int) (ret []byte, err error) {
var snapshot = sev.StateDB.Snapshot()
ret, _, err = sev.Call(vm.AccountRef(VMAddress), recipient, input, gas, value)
sev.StateDB.RevertToSnapshot(snapshot)
return ret, err
}

func (sev *SharedEVMRunner) Query(recipient common.Address, input []byte, gas uint64) (ret []byte, err error) {
ret, _, err = sev.StaticCall(vm.AccountRef(VMAddress), recipient, input, gas)
return ret, err
Expand Down
Loading