diff --git a/process/smartContract/process.go b/process/smartContract/process.go index 72694c934e3..4c5856b95db 100644 --- a/process/smartContract/process.go +++ b/process/smartContract/process.go @@ -579,11 +579,15 @@ func (sc *scProcessor) ExecuteBuiltInFunction( return vmcommon.UserError, sc.ProcessIfError(acntSnd, txHash, tx, vmOutput.ReturnCode.String(), []byte(vmOutput.ReturnMessage), snapshot) } + createdAsyncCallback := false builtInFuncGasUsed := vmInput.GasProvided - vmOutput.GasRemaining scrResults := make([]data.TransactionHandler, 0, len(vmOutput.OutputAccounts)+1) outputAccounts := process.SortVMOutputInsideData(vmOutput) for _, outAcc := range outputAccounts { - scTxs := sc.createSmartContractResults(outAcc, tx, txHash) + tmpCreatedAsyncCallback, scTxs := sc.createSmartContractResults(vmOutput, vmInput.CallType, outAcc, tx, txHash) + if !createdAsyncCallback { + createdAsyncCallback = tmpCreatedAsyncCallback + } scrResults = append(scrResults, scTxs...) } @@ -599,10 +603,15 @@ func (sc *scProcessor) ExecuteBuiltInFunction( if isSCCall { outPutAccounts := process.SortVMOutputInsideData(newVMOutput) var newSCRTxs []data.TransactionHandler - newSCRTxs, err = sc.processSCOutputAccounts(outPutAccounts, tx, txHash) + tmpCreatedAsyncCallback := false + tmpCreatedAsyncCallback, newSCRTxs, err = sc.processSCOutputAccounts(newVMOutput, vmInput.CallType, outPutAccounts, tx, txHash) if err != nil { return 0, err } + if !createdAsyncCallback { + createdAsyncCallback = tmpCreatedAsyncCallback + } + scrResults = append(scrResults, newSCRTxs...) } @@ -611,6 +620,11 @@ func (sc *scProcessor) ExecuteBuiltInFunction( return 0, err } + if !createdAsyncCallback && vmInput.CallType == vmcommon.AsynchronousCall { + asyncCallBackSCR := createAsyncCallBackSCRFromVMOutput(newVMOutput, tx, txHash) + scrResults = append(scrResults, asyncCallBackSCR) + } + scrResults = append(scrResults, scrForSender) if !check.IfNil(scrForRelayer) { scrResults = append(scrResults, scrForRelayer) @@ -1102,7 +1116,7 @@ func (sc *scProcessor) processVMOutput( ) outPutAccounts := process.SortVMOutputInsideData(vmOutput) - scrTxs, err := sc.processSCOutputAccounts(outPutAccounts, tx, txHash) + createdAsyncCallback, scrTxs, err := sc.processSCOutputAccounts(vmOutput, callType, outPutAccounts, tx, txHash) if err != nil { return nil, err } @@ -1116,6 +1130,11 @@ func (sc *scProcessor) processVMOutput( } } + if !createdAsyncCallback && callType == vmcommon.AsynchronousCall { + asyncCallBackSCR := createAsyncCallBackSCRFromVMOutput(vmOutput, tx, txHash) + scrTxs = append(scrTxs, asyncCallBackSCR) + } + err = sc.addToBalanceIfInShard(scrForSender.RcvAddr, scrForSender.Value) if err != nil { return nil, err @@ -1315,12 +1334,55 @@ func createBaseSCR( return result } +func addVMOutputResultsToSCR(vmOutput *vmcommon.VMOutput, result *smartContractResult.SmartContractResult) { + result.CallType = vmcommon.AsynchronousCallBack + result.GasLimit = vmOutput.GasRemaining + result.Data = []byte("@" + core.ConvertToEvenHex(int(vmOutput.ReturnCode))) + addReturnDataToSCR(vmOutput, result) +} + +func createAsyncCallBackSCRFromVMOutput( + vmOutput *vmcommon.VMOutput, + tx data.TransactionHandler, + txHash []byte, +) *smartContractResult.SmartContractResult { + scr := &smartContractResult.SmartContractResult{ + Value: big.NewInt(0), + RcvAddr: tx.GetSndAddr(), + SndAddr: tx.GetRcvAddr(), + PrevTxHash: txHash, + GasPrice: tx.GetGasPrice(), + ReturnMessage: []byte(vmOutput.ReturnMessage), + OriginalSender: tx.GetSndAddr(), + } + setOriginalTxHash(scr, txHash, tx) + relayedTx, isRelayed := isRelayedTx(tx) + if isRelayed { + scr.RelayedValue = big.NewInt(0) + scr.RelayerAddr = relayedTx.RelayerAddr + } + + addVMOutputResultsToSCR(vmOutput, scr) + + return scr +} + func (sc *scProcessor) createSmartContractResults( + vmOutput *vmcommon.VMOutput, + callType vmcommon.CallType, outAcc *vmcommon.OutputAccount, tx data.TransactionHandler, txHash []byte, -) []data.TransactionHandler { - if len(outAcc.OutputTransfers) == 0 { +) (bool, []data.TransactionHandler) { + + lenOutTransfers := len(outAcc.OutputTransfers) + if lenOutTransfers == 0 { + if callType == vmcommon.AsynchronousCall && bytes.Equal(outAcc.Address, tx.GetSndAddr()) { + result := createBaseSCR(outAcc, tx, txHash) + addVMOutputResultsToSCR(vmOutput, result) + return true, []data.TransactionHandler{result} + } + if !sc.flagDeploy.IsSet() { result := createBaseSCR(outAcc, tx, txHash) result.Code = outAcc.Code @@ -1329,10 +1391,10 @@ func (sc *scProcessor) createSmartContractResults( result.OriginalSender = tx.GetSndAddr() } - return []data.TransactionHandler{result} + return false, []data.TransactionHandler{result} } - return nil + return false, nil } if bytes.Equal(outAcc.Address, vm.StakingSCAddress) { @@ -1340,28 +1402,55 @@ func (sc *scProcessor) createSmartContractResults( result := createBaseSCR(outAcc, tx, txHash) result.Data = append(result.Data, sc.argsParser.CreateDataFromStorageUpdate(storageUpdates)...) - return []data.TransactionHandler{result} + return false, []data.TransactionHandler{result} } + createdAsyncCallBack := false + var result *smartContractResult.SmartContractResult scResults := make([]data.TransactionHandler, 0, len(outAcc.OutputTransfers)) - for _, outPutTransfer := range outAcc.OutputTransfers { - result := createBaseSCR(outAcc, tx, txHash) + for i, outputTransfer := range outAcc.OutputTransfers { + result = createBaseSCR(outAcc, tx, txHash) - if outPutTransfer.Value != nil { - result.Value.Set(outPutTransfer.Value) + if outputTransfer.Value != nil { + result.Value.Set(outputTransfer.Value) } - result.Data = outPutTransfer.Data - result.GasLimit = outPutTransfer.GasLimit - result.CallType = outPutTransfer.CallType + result.Data = outputTransfer.Data + result.GasLimit = outputTransfer.GasLimit + result.CallType = outputTransfer.CallType setOriginalTxHash(result, txHash, tx) if result.Value.Cmp(zero) > 0 { result.OriginalSender = tx.GetSndAddr() } + isAsyncTransferBackToSender := callType == vmcommon.AsynchronousCall && + bytes.Equal(outAcc.Address, tx.GetSndAddr()) + isLastOutTransfer := i == lenOutTransfers-1 + if isLastOutTransfer && isAsyncTransferBackToSender && sc.isTransferWithNoDataOrBuiltInCall(outputTransfer.Data) { + addVMOutputResultsToSCR(vmOutput, result) + createdAsyncCallBack = true + } + scResults = append(scResults, result) } - return scResults + return createdAsyncCallBack, scResults +} + +func (sc *scProcessor) isTransferWithNoDataOrBuiltInCall(data []byte) bool { + if len(data) == 0 { + return true + } + function, _, err := sc.argsParser.ParseCallData(string(data)) + if err != nil { + return false + } + + _, err = sc.builtInFunctions.Get(function) + if err != nil { + return false + } + + return true } // createSCRForSender(vmOutput, tx, txHash, acntSnd) @@ -1376,20 +1465,22 @@ func (sc *scProcessor) createSCRForSenderAndRelayer( vmOutput.GasRefund = big.NewInt(0) } + gasRefund := core.SafeMul(vmOutput.GasRemaining, tx.GetGasPrice()) + gasRemaining := uint64(0) storageFreeRefund := big.NewInt(0) // backward compatibility - there should be no refund as the storage pay was already distributed among validators // this would only create additional inflation + // backward compatibility - direct smart contract results were created with gasLimit - there is no need for them if !sc.flagDeploy.IsSet() { storageFreeRefund = big.NewInt(0).Mul(vmOutput.GasRefund, big.NewInt(0).SetUint64(sc.economicsFee.MinGasPrice())) + gasRemaining = vmOutput.GasRemaining } - gasRefund := core.SafeMul(vmOutput.GasRemaining, tx.GetGasPrice()) rcvAddress := tx.GetSndAddr() if callType == vmcommon.AsynchronousCallBack { rcvAddress = tx.GetRcvAddr() } - gasRemaining := vmOutput.GasRemaining var refundGasToRelayerSCR *smartContractResult.SmartContractResult relayedSCR, isRelayed := isRelayedTx(tx) if isRelayed && callType != vmcommon.AsynchronousCall && gasRefund.Cmp(zero) > 0 { @@ -1421,41 +1512,50 @@ func (sc *scProcessor) createSCRForSenderAndRelayer( scTx.GasLimit = gasRemaining scTx.GasPrice = tx.GetGasPrice() scTx.ReturnMessage = []byte(vmOutput.ReturnMessage) + scTx.CallType = vmcommon.DirectCall setOriginalTxHash(scTx, txHash, tx) + scTx.Data = []byte("@" + hex.EncodeToString([]byte(vmOutput.ReturnCode.String()))) - if callType == vmcommon.AsynchronousCall { - scTx.CallType = vmcommon.AsynchronousCallBack - scTx.Data = []byte("@" + core.ConvertToEvenHex(int(vmOutput.ReturnCode))) - } else { - scTx.Data = []byte("@" + hex.EncodeToString([]byte(vmOutput.ReturnCode.String()))) + // when asynchronous call - the callback is created by combining the last output transfer with the returnData + if callType != vmcommon.AsynchronousCall { + addReturnDataToSCR(vmOutput, scTx) } + log.Trace("createSCRForSenderAndRelayer ", "data", string(scTx.Data), "snd", scTx.SndAddr, "rcv", scTx.RcvAddr) + return scTx, refundGasToRelayerSCR +} + +func addReturnDataToSCR(vmOutput *vmcommon.VMOutput, scTx *smartContractResult.SmartContractResult) { for _, retData := range vmOutput.ReturnData { scTx.Data = append(scTx.Data, []byte("@"+hex.EncodeToString(retData))...) } - - log.Trace("createSCRForSender ", "data", string(scTx.Data), "snd", scTx.SndAddr, "rcv", scTx.RcvAddr) - return scTx, refundGasToRelayerSCR } // save account changes in state from vmOutput - protected by VM - every output can be treated as is. func (sc *scProcessor) processSCOutputAccounts( + vmOutput *vmcommon.VMOutput, + callType vmcommon.CallType, outputAccounts []*vmcommon.OutputAccount, tx data.TransactionHandler, txHash []byte, -) ([]data.TransactionHandler, error) { +) (bool, []data.TransactionHandler, error) { scResults := make([]data.TransactionHandler, 0, len(outputAccounts)) sumOfAllDiff := big.NewInt(0) sumOfAllDiff.Sub(sumOfAllDiff, tx.GetValue()) + createdAsyncCallback := false for _, outAcc := range outputAccounts { acc, err := sc.getAccountFromAddress(outAcc.Address) if err != nil { - return nil, err + return false, nil, err + } + + tmpCreatedAsyncCallback, newScrs := sc.createSmartContractResults(vmOutput, callType, outAcc, tx, txHash) + if !createdAsyncCallback { + createdAsyncCallback = tmpCreatedAsyncCallback } - newScrs := sc.createSmartContractResults(outAcc, tx, txHash) scResults = append(scResults, newScrs...) if check.IfNil(acc) { if outAcc.BalanceDelta != nil { @@ -1475,7 +1575,7 @@ func (sc *scProcessor) processSCOutputAccounts( err = acc.DataTrieTracker().SaveKeyValue(storeUpdate.Offset, storeUpdate.Data) if err != nil { log.Warn("saveKeyValue", "error", err) - return nil, err + return false, nil, err } log.Trace("storeUpdate", "acc", outAcc.Address, "key", storeUpdate.Offset, "data", storeUpdate.Data) } @@ -1484,7 +1584,7 @@ func (sc *scProcessor) processSCOutputAccounts( // change nonce only if there is a change if outAcc.Nonce != acc.GetNonce() && outAcc.Nonce != 0 { if outAcc.Nonce < acc.GetNonce() { - return nil, process.ErrWrongNonceInVMOutput + return false, nil, process.ErrWrongNonceInVMOutput } nonceDifference := outAcc.Nonce - acc.GetNonce() @@ -1495,7 +1595,7 @@ func (sc *scProcessor) processSCOutputAccounts( if outAcc.BalanceDelta == nil || outAcc.BalanceDelta.Cmp(zero) == 0 { err = sc.accounts.SaveAccount(acc) if err != nil { - return nil, err + return false, nil, err } continue @@ -1505,20 +1605,20 @@ func (sc *scProcessor) processSCOutputAccounts( err = acc.AddToBalance(outAcc.BalanceDelta) if err != nil { - return nil, err + return false, nil, err } err = sc.accounts.SaveAccount(acc) if err != nil { - return nil, err + return false, nil, err } } if sumOfAllDiff.Cmp(zero) != 0 { - return nil, process.ErrOverallBalanceChangeFromSC + return false, nil, process.ErrOverallBalanceChangeFromSC } - return scResults, nil + return createdAsyncCallback, scResults, nil } // updateSmartContractCode upgrades code for "direct" deployments & upgrades and for "indirect" deployments & upgrades diff --git a/process/smartContract/process_test.go b/process/smartContract/process_test.go index c2320918af5..bc08428a801 100644 --- a/process/smartContract/process_test.go +++ b/process/smartContract/process_test.go @@ -1494,7 +1494,7 @@ func TestScProcessor_processSCOutputAccounts(t *testing.T) { tx := &transaction.Transaction{Value: big.NewInt(0)} outputAccounts := make([]*vmcommon.OutputAccount, 0) - _, err = sc.processSCOutputAccounts(outputAccounts, tx, []byte("hash")) + _, _, err = sc.processSCOutputAccounts(&vmcommon.VMOutput{}, vmcommon.DirectCall, outputAccounts, tx, []byte("hash")) require.Nil(t, err) outaddress := []byte("newsmartcontract") @@ -1520,13 +1520,13 @@ func TestScProcessor_processSCOutputAccounts(t *testing.T) { } tx.Value = big.NewInt(int64(5)) - _, err = sc.processSCOutputAccounts(outputAccounts, tx, []byte("hash")) + _, _, err = sc.processSCOutputAccounts(&vmcommon.VMOutput{}, vmcommon.DirectCall, outputAccounts, tx, []byte("hash")) require.Nil(t, err) outacc1.BalanceDelta = nil outacc1.Nonce++ tx.Value = big.NewInt(0) - _, err = sc.processSCOutputAccounts(outputAccounts, tx, []byte("hash")) + _, _, err = sc.processSCOutputAccounts(&vmcommon.VMOutput{}, vmcommon.DirectCall, outputAccounts, tx, []byte("hash")) require.Nil(t, err) outacc1.Nonce++ @@ -1535,7 +1535,7 @@ func TestScProcessor_processSCOutputAccounts(t *testing.T) { currentBalance := testAcc.Balance.Uint64() vmOutBalance := outacc1.BalanceDelta.Uint64() - _, err = sc.processSCOutputAccounts(outputAccounts, tx, []byte("hash")) + _, _, err = sc.processSCOutputAccounts(&vmcommon.VMOutput{}, vmcommon.DirectCall, outputAccounts, tx, []byte("hash")) require.Nil(t, err) require.Equal(t, currentBalance+vmOutBalance, testAcc.Balance.Uint64()) } @@ -1554,7 +1554,7 @@ func TestScProcessor_processSCOutputAccountsNotInShard(t *testing.T) { tx := &transaction.Transaction{Value: big.NewInt(0)} outputAccounts := make([]*vmcommon.OutputAccount, 0) - _, err = sc.processSCOutputAccounts(outputAccounts, tx, []byte("hash")) + _, _, err = sc.processSCOutputAccounts(&vmcommon.VMOutput{}, vmcommon.DirectCall, outputAccounts, tx, []byte("hash")) require.Nil(t, err) outaddress := []byte("newsmartcontract") @@ -1568,7 +1568,7 @@ func TestScProcessor_processSCOutputAccountsNotInShard(t *testing.T) { return shardCoordinator.SelfId() + 1 } - _, err = sc.processSCOutputAccounts(outputAccounts, tx, []byte("hash")) + _, _, err = sc.processSCOutputAccounts(&vmcommon.VMOutput{}, vmcommon.DirectCall, outputAccounts, tx, []byte("hash")) require.Nil(t, err) } @@ -1613,9 +1613,81 @@ func TestScProcessor_CreateCrossShardTransactions(t *testing.T) { tx.GasLimit = 15 txHash := []byte("txHash") - scTxs, err := sc.processSCOutputAccounts(outputAccounts, tx, txHash) + createdAsyncSCR, scTxs, err := sc.processSCOutputAccounts(&vmcommon.VMOutput{}, vmcommon.DirectCall, outputAccounts, tx, txHash) require.Nil(t, err) require.Equal(t, len(outputAccounts), len(scTxs)) + require.False(t, createdAsyncSCR) +} + +func TestScProcessor_CreateCrossShardTransactionsWithAsyncCalls(t *testing.T) { + t.Parallel() + + testAccounts, _ := state.NewUserAccount([]byte("address")) + accountsDB := &mock.AccountsStub{ + LoadAccountCalled: func(address []byte) (handler state.AccountHandler, err error) { + return testAccounts, nil + }, + SaveAccountCalled: func(accountHandler state.AccountHandler) error { + return nil + }, + } + shardCoordinator := mock.NewMultiShardsCoordinatorMock(5) + arguments := createMockSmartContractProcessorArguments() + arguments.AccountsDB = accountsDB + arguments.Coordinator = shardCoordinator + sc, err := NewSmartContractProcessor(arguments) + require.NotNil(t, sc) + require.Nil(t, err) + + outputAccounts := make([]*vmcommon.OutputAccount, 0) + outaddress := []byte("newsmartcontract") + outacc1 := &vmcommon.OutputAccount{} + outacc1.Address = outaddress + outacc1.Nonce = 0 + outacc1.Balance = big.NewInt(5) + outacc1.BalanceDelta = big.NewInt(15) + outTransfer := vmcommon.OutputTransfer{Value: big.NewInt(5)} + outacc1.OutputTransfers = append(outacc1.OutputTransfers, outTransfer) + outputAccounts = append(outputAccounts, outacc1, outacc1, outacc1) + + tx := &transaction.Transaction{} + tx.Nonce = 1 + tx.SndAddr = []byte("SRC") + tx.RcvAddr = []byte("DST") + + tx.Value = big.NewInt(45) + tx.GasPrice = 10 + tx.GasLimit = 15 + txHash := []byte("txHash") + + createdAsyncSCR, scTxs, err := sc.processSCOutputAccounts(&vmcommon.VMOutput{GasRemaining: 1000}, vmcommon.AsynchronousCall, outputAccounts, tx, txHash) + require.Nil(t, err) + require.Equal(t, len(outputAccounts), len(scTxs)) + require.False(t, createdAsyncSCR) + + outAccBackTransfer := &vmcommon.OutputAccount{ + Address: tx.SndAddr, + Nonce: 0, + Balance: big.NewInt(0), + BalanceDelta: big.NewInt(0), + OutputTransfers: []vmcommon.OutputTransfer{outTransfer}, + GasUsed: 0, + } + outputAccounts = append(outputAccounts, outAccBackTransfer) + createdAsyncSCR, scTxs, err = sc.processSCOutputAccounts(&vmcommon.VMOutput{GasRemaining: 1000}, vmcommon.AsynchronousCall, outputAccounts, tx, txHash) + require.Nil(t, err) + require.Equal(t, len(outputAccounts), len(scTxs)) + require.True(t, createdAsyncSCR) + + lastScTx := scTxs[len(scTxs)-1].(*smartContractResult.SmartContractResult) + require.Equal(t, vmcommon.AsynchronousCallBack, lastScTx.CallType) + + tx.Value = big.NewInt(0) + scTxs, err = sc.processVMOutput(&vmcommon.VMOutput{GasRemaining: 1000}, txHash, tx, vmcommon.AsynchronousCall, 10000) + require.Nil(t, err) + require.Equal(t, 2, len(scTxs)) + lastScTx = scTxs[len(scTxs)-1].(*smartContractResult.SmartContractResult) + require.Equal(t, vmcommon.AsynchronousCallBack, lastScTx.CallType) } func TestScProcessor_ProcessSmartContractResultNilScr(t *testing.T) {