diff --git a/neo.UnitTests/Ledger/UT_MemoryPool.cs b/neo.UnitTests/Ledger/UT_MemoryPool.cs index 27c4d37715..6f9022bef1 100644 --- a/neo.UnitTests/Ledger/UT_MemoryPool.cs +++ b/neo.UnitTests/Ledger/UT_MemoryPool.cs @@ -13,6 +13,7 @@ using System.Collections; using System.Collections.Generic; using System.Linq; +using System.Numerics; namespace Neo.UnitTests.Ledger { @@ -73,8 +74,8 @@ private Transaction CreateTransactionWithFee(long fee) var randomBytes = new byte[16]; random.NextBytes(randomBytes); Mock mock = new Mock(); - mock.Setup(p => p.Reverify(It.IsAny(), It.IsAny>())).Returns(true); - mock.Setup(p => p.Verify(It.IsAny(), It.IsAny>())).Returns(true); + mock.Setup(p => p.Reverify(It.IsAny(), It.IsAny())).Returns(true); + mock.Setup(p => p.Verify(It.IsAny(), It.IsAny())).Returns(true); mock.Object.Script = randomBytes; mock.Object.Sender = UInt160.Zero; mock.Object.NetworkFee = fee; diff --git a/neo.UnitTests/Ledger/UT_SendersFeeMonitor.cs b/neo.UnitTests/Ledger/UT_SendersFeeMonitor.cs new file mode 100644 index 0000000000..81fb0d9354 --- /dev/null +++ b/neo.UnitTests/Ledger/UT_SendersFeeMonitor.cs @@ -0,0 +1,56 @@ +using FluentAssertions; +using Microsoft.VisualStudio.TestTools.UnitTesting; +using Moq; +using Neo.Ledger; +using Neo.Network.P2P.Payloads; +using Neo.Persistence; +using System; +using System.Numerics; + +namespace Neo.UnitTests.Ledger +{ + [TestClass] + public class UT_SendersFeeMonitor + { + private Transaction CreateTransactionWithFee(long networkFee, long systemFee) + { + Random random = new Random(); + var randomBytes = new byte[16]; + random.NextBytes(randomBytes); + Mock mock = new Mock(); + mock.Setup(p => p.Reverify(It.IsAny(), It.IsAny())).Returns(true); + mock.Setup(p => p.Verify(It.IsAny(), It.IsAny())).Returns(true); + mock.Object.Script = randomBytes; + mock.Object.Sender = UInt160.Zero; + mock.Object.NetworkFee = networkFee; + mock.Object.SystemFee = systemFee; + mock.Object.Attributes = new TransactionAttribute[0]; + mock.Object.Cosigners = new Cosigner[0]; + mock.Object.Witnesses = new[] + { + new Witness + { + InvocationScript = new byte[0], + VerificationScript = new byte[0] + } + }; + return mock.Object; + } + + [TestMethod] + public void TestMemPoolSenderFee() + { + Transaction transaction = CreateTransactionWithFee(1, 2); + SendersFeeMonitor sendersFeeMonitor = new SendersFeeMonitor(); + sendersFeeMonitor.GetSenderFee(transaction.Sender).Should().Be(0); + sendersFeeMonitor.AddSenderFee(transaction); + sendersFeeMonitor.GetSenderFee(transaction.Sender).Should().Be(3); + sendersFeeMonitor.AddSenderFee(transaction); + sendersFeeMonitor.GetSenderFee(transaction.Sender).Should().Be(6); + sendersFeeMonitor.RemoveSenderFee(transaction); + sendersFeeMonitor.GetSenderFee(transaction.Sender).Should().Be(3); + sendersFeeMonitor.RemoveSenderFee(transaction); + sendersFeeMonitor.GetSenderFee(transaction.Sender).Should().Be(0); + } + } +} diff --git a/neo.UnitTests/Network/P2P/Payloads/UT_Transaction.cs b/neo.UnitTests/Network/P2P/Payloads/UT_Transaction.cs index 217e207f83..570151d941 100644 --- a/neo.UnitTests/Network/P2P/Payloads/UT_Transaction.cs +++ b/neo.UnitTests/Network/P2P/Payloads/UT_Transaction.cs @@ -798,7 +798,7 @@ public void Transaction_Reverify_Hashes_Length_Unequal_To_Witnesses_Length() }; UInt160[] hashes = txSimple.GetScriptHashesForVerifying(snapshot); Assert.AreEqual(2, hashes.Length); - Assert.IsFalse(txSimple.Reverify(snapshot, new Transaction[0])); + Assert.IsFalse(txSimple.Reverify(snapshot, BigInteger.Zero)); } [TestMethod] diff --git a/neo/Consensus/ConsensusContext.cs b/neo/Consensus/ConsensusContext.cs index 463bdc0407..e7c2649782 100644 --- a/neo/Consensus/ConsensusContext.cs +++ b/neo/Consensus/ConsensusContext.cs @@ -37,6 +37,11 @@ internal class ConsensusContext : IDisposable, ISerializable // if this node never heard from validator i, LastSeenMessage[i] will be -1. public int[] LastSeenMessage; + /// + /// Store all verified unsorted transactions' senders' fee currently in the consensus context. + /// + public SendersFeeMonitor SendersFeeMonitor = new SendersFeeMonitor(); + public Snapshot Snapshot { get; private set; } private KeyPair keyPair; private int _witnessSize; @@ -110,6 +115,12 @@ public void Deserialize(BinaryReader reader) if (TransactionHashes.Length == 0 && !RequestSentOrReceived) TransactionHashes = null; Transactions = transactions.Length == 0 && !RequestSentOrReceived ? null : transactions.ToDictionary(p => p.Hash); + SendersFeeMonitor = new SendersFeeMonitor(); + if (Transactions != null) + { + foreach (Transaction tx in Transactions.Values) + SendersFeeMonitor.AddSenderFee(tx); + } } public void Dispose() @@ -245,6 +256,7 @@ internal void EnsureMaxBlockSize(IEnumerable txs) txs = txs.Take((int)maxTransactionsPerBlock); List hashes = new List(); Transactions = new Dictionary(); + SendersFeeMonitor = new SendersFeeMonitor(); // Expected block size var blockSize = GetExpectedBlockSizeWithoutTransactions(txs.Count()); @@ -258,6 +270,7 @@ internal void EnsureMaxBlockSize(IEnumerable txs) hashes.Add(tx.Hash); Transactions.Add(tx.Hash, tx); + SendersFeeMonitor.AddSenderFee(tx); } TransactionHashes = hashes.ToArray(); diff --git a/neo/Consensus/ConsensusService.cs b/neo/Consensus/ConsensusService.cs index f403f7d3cc..574082043a 100644 --- a/neo/Consensus/ConsensusService.cs +++ b/neo/Consensus/ConsensusService.cs @@ -61,7 +61,7 @@ internal ConsensusService(IActorRef localNode, IActorRef taskManager, ConsensusC private bool AddTransaction(Transaction tx, bool verify) { - if (verify && !tx.Verify(context.Snapshot, context.Transactions.Values)) + if (verify && !tx.Verify(context.Snapshot, context.SendersFeeMonitor.GetSenderFee(tx.Sender))) { Log($"Invalid transaction: {tx.Hash}{Environment.NewLine}{tx.ToArray().ToHexString()}", LogLevel.Warning); RequestChangeView(ChangeViewReason.TxInvalid); @@ -74,6 +74,7 @@ private bool AddTransaction(Transaction tx, bool verify) return false; } context.Transactions[tx.Hash] = tx; + context.SendersFeeMonitor.AddSenderFee(tx); return CheckPrepareResponse(); } @@ -423,6 +424,7 @@ private void OnPrepareRequestReceived(ConsensusPayload payload, PrepareRequest m context.Block.ConsensusData.Nonce = message.Nonce; context.TransactionHashes = message.TransactionHashes; context.Transactions = new Dictionary(); + context.SendersFeeMonitor = new SendersFeeMonitor(); for (int i = 0; i < context.PreparationPayloads.Length; i++) if (context.PreparationPayloads[i] != null) if (!context.PreparationPayloads[i].GetDeserializedMessage().PreparationHash.Equals(payload.Hash)) diff --git a/neo/Ledger/Blockchain.cs b/neo/Ledger/Blockchain.cs index 50393d4c8e..dea46ae3e8 100644 --- a/neo/Ledger/Blockchain.cs +++ b/neo/Ledger/Blockchain.cs @@ -244,7 +244,7 @@ private void OnFillMemoryPool(IEnumerable transactions) // First remove the tx if it is unverified in the pool. MemPool.TryRemoveUnVerified(tx.Hash, out _); // Verify the the transaction - if (!tx.Verify(currentSnapshot, MemPool.GetVerifiedTransactions())) + if (!tx.Verify(currentSnapshot, MemPool.SendersFeeMonitor.GetSenderFee(tx.Sender))) continue; // Add to the memory pool MemPool.TryAdd(tx.Hash, tx); @@ -370,7 +370,7 @@ private RelayResultReason OnNewTransaction(Transaction transaction, bool relay) return RelayResultReason.AlreadyExists; if (!MemPool.CanTransactionFitInPool(transaction)) return RelayResultReason.OutOfMemory; - if (!transaction.Verify(currentSnapshot, MemPool.GetVerifiedTransactions())) + if (!transaction.Verify(currentSnapshot, MemPool.SendersFeeMonitor.GetSenderFee(transaction.Sender))) return RelayResultReason.Invalid; if (!NativeContract.Policy.CheckPolicy(transaction, currentSnapshot)) return RelayResultReason.PolicyFail; diff --git a/neo/Ledger/MemoryPool.cs b/neo/Ledger/MemoryPool.cs index f31e07e56c..8f37855a5c 100644 --- a/neo/Ledger/MemoryPool.cs +++ b/neo/Ledger/MemoryPool.cs @@ -69,6 +69,11 @@ public class MemoryPool : IReadOnlyCollection /// public int Capacity { get; } + /// + /// Store all verified unsorted transactions' senders' fee currently in the memory pool. + /// + public SendersFeeMonitor SendersFeeMonitor = new SendersFeeMonitor(); + /// /// Total count of transactions in the pool. /// @@ -268,6 +273,7 @@ internal bool TryAdd(UInt256 hash, Transaction tx) try { _unsortedTransactions.Add(hash, poolItem); + SendersFeeMonitor.AddSenderFee(tx); _sortedTransactions.Add(poolItem); if (Count > Capacity) @@ -310,6 +316,7 @@ private bool TryRemoveVerified(UInt256 hash, out PoolItem item) return false; _unsortedTransactions.Remove(hash); + SendersFeeMonitor.RemoveSenderFee(item.Tx); _sortedTransactions.Remove(item); return true; @@ -337,6 +344,7 @@ internal void InvalidateVerifiedTransactions() // Clear the verified transactions now, since they all must be reverified. _unsortedTransactions.Clear(); + SendersFeeMonitor = new SendersFeeMonitor(); _sortedTransactions.Clear(); } @@ -409,7 +417,7 @@ internal void InvalidateAllTransactions() // Since unverifiedSortedTxPool is ordered in an ascending manner, we take from the end. foreach (PoolItem item in unverifiedSortedTxPool.Reverse().Take(count)) { - if (item.Tx.Reverify(snapshot, _unsortedTransactions.Select(p => p.Value.Tx))) + if (item.Tx.Reverify(snapshot, SendersFeeMonitor.GetSenderFee(item.Tx.Sender))) reverifiedItems.Add(item); else // Transaction no longer valid -- it will be removed from unverifiedTxPool. invalidItems.Add(item); @@ -432,6 +440,7 @@ internal void InvalidateAllTransactions() { if (_unsortedTransactions.TryAdd(item.Tx.Hash, item)) { + SendersFeeMonitor.AddSenderFee(item.Tx); verifiedSortedTxPool.Add(item); if (item.LastBroadcastTimestamp < rebroadcastCutOffTime) diff --git a/neo/Ledger/SendersFeeMonitor.cs b/neo/Ledger/SendersFeeMonitor.cs new file mode 100644 index 0000000000..efe3ac6ebe --- /dev/null +++ b/neo/Ledger/SendersFeeMonitor.cs @@ -0,0 +1,43 @@ +using Neo.Network.P2P.Payloads; +using System.Collections.Generic; +using System.Numerics; +using System.Threading; + +namespace Neo.Ledger +{ + public class SendersFeeMonitor + { + private readonly ReaderWriterLockSlim _senderFeeRwLock = new ReaderWriterLockSlim(LockRecursionPolicy.SupportsRecursion); + + /// + /// Store all verified unsorted transactions' senders' fee currently in the memory pool. + /// + private readonly Dictionary _senderFee = new Dictionary(); + + public BigInteger GetSenderFee(UInt160 sender) + { + _senderFeeRwLock.EnterReadLock(); + if (!_senderFee.TryGetValue(sender, out var value)) + value = BigInteger.Zero; + _senderFeeRwLock.ExitReadLock(); + return value; + } + + public void AddSenderFee(Transaction tx) + { + _senderFeeRwLock.EnterWriteLock(); + if (_senderFee.TryGetValue(tx.Sender, out var value)) + _senderFee[tx.Sender] = value + tx.SystemFee + tx.NetworkFee; + else + _senderFee.Add(tx.Sender, tx.SystemFee + tx.NetworkFee); + _senderFeeRwLock.ExitWriteLock(); + } + + public void RemoveSenderFee(Transaction tx) + { + _senderFeeRwLock.EnterWriteLock(); + if ((_senderFee[tx.Sender] -= tx.SystemFee + tx.NetworkFee) == 0) _senderFee.Remove(tx.Sender); + _senderFeeRwLock.ExitWriteLock(); + } + } +} diff --git a/neo/Network/P2P/Payloads/Transaction.cs b/neo/Network/P2P/Payloads/Transaction.cs index c135185272..de81275877 100644 --- a/neo/Network/P2P/Payloads/Transaction.cs +++ b/neo/Network/P2P/Payloads/Transaction.cs @@ -130,16 +130,14 @@ public UInt160[] GetScriptHashesForVerifying(Snapshot snapshot) return hashes.OrderBy(p => p).ToArray(); } - public virtual bool Reverify(Snapshot snapshot, IEnumerable mempool) + public virtual bool Reverify(Snapshot snapshot, BigInteger totalSenderFeeFromPool) { if (ValidUntilBlock <= snapshot.Height || ValidUntilBlock > snapshot.Height + MaxValidUntilBlockIncrement) return false; if (NativeContract.Policy.GetBlockedAccounts(snapshot).Intersect(GetScriptHashesForVerifying(snapshot)).Count() > 0) return false; BigInteger balance = NativeContract.GAS.BalanceOf(snapshot, Sender); - BigInteger fee = SystemFee + NetworkFee; - if (balance < fee) return false; - fee += mempool.Where(p => p != this && p.Sender.Equals(Sender)).Select(p => (BigInteger)(p.SystemFee + p.NetworkFee)).Sum(); + BigInteger fee = SystemFee + NetworkFee + totalSenderFeeFromPool; if (balance < fee) return false; UInt160[] hashes = GetScriptHashesForVerifying(snapshot); if (hashes.Length != Witnesses.Length) return false; @@ -206,12 +204,12 @@ public static Transaction FromJson(JObject json) bool IInventory.Verify(Snapshot snapshot) { - return Verify(snapshot, Enumerable.Empty()); + return Verify(snapshot, BigInteger.Zero); } - public virtual bool Verify(Snapshot snapshot, IEnumerable mempool) + public virtual bool Verify(Snapshot snapshot, BigInteger totalSenderFeeFromPool) { - if (!Reverify(snapshot, mempool)) return false; + if (!Reverify(snapshot, totalSenderFeeFromPool)) return false; int size = Size; if (size > MaxTransactionSize) return false; long net_fee = NetworkFee - size * NativeContract.Policy.GetFeePerByte(snapshot);