From 7fc81361c6481b9b1ed06ef0bbb5c7bdf5f444c5 Mon Sep 17 00:00:00 2001 From: Alexey Akhunov Date: Tue, 11 Jun 2019 12:21:34 +0100 Subject: [PATCH 1/5] Introduce trie_pruning --- core/blockchain.go | 12 +- core/blockchain_test.go | 4 +- core/chain_makers.go | 4 +- core/state/database.go | 76 ++++--- trie/resolver.go | 67 ++---- trie/resolver_test.go | 8 - trie/trie.go | 435 ++++++++++++++++++++++++++++---------- trie/trie_pruning.go | 303 +++++++++++++++++++++++++- trie/trie_pruning_test.go | 68 ++++++ 9 files changed, 755 insertions(+), 222 deletions(-) create mode 100644 trie/trie_pruning_test.go diff --git a/core/blockchain.go b/core/blockchain.go index c7355bd97e5..af6035713aa 100644 --- a/core/blockchain.go +++ b/core/blockchain.go @@ -225,12 +225,14 @@ func (bc *BlockChain) GetTrieDbState() *state.TrieDbState { log.Info("Creating StateDB from latest state", "block", currentBlockNr) var err error bc.trieDbState, err = state.NewTrieDbState(bc.CurrentBlock().Header().Root, bc.db, currentBlockNr) + if err != nil { + panic(err) + } bc.trieDbState.SetNoHistory(bc.noHistory) bc.trieDbState.SetResolveReads(bc.resolveReads) - if err != nil { + if err := bc.trieDbState.Rebuild(); err != nil { panic(err) } - bc.trieDbState.Rebuild() } return bc.trieDbState } @@ -1186,12 +1188,14 @@ func (bc *BlockChain) insertChain(chain types.Blocks, verifySeals bool) (int, [] currentBlockNr := bc.CurrentBlock().NumberU64() log.Info("Creating StateDB from latest state", "block", currentBlockNr) bc.trieDbState, err = state.NewTrieDbState(bc.CurrentBlock().Header().Root, bc.db, currentBlockNr) + if err != nil { + return k, events, coalescedLogs, err + } bc.trieDbState.SetNoHistory(bc.noHistory) bc.trieDbState.SetResolveReads(bc.resolveReads) - if err != nil { + if err := bc.trieDbState.Rebuild(); err != nil { return k, events, coalescedLogs, err } - bc.trieDbState.Rebuild() } root = bc.trieDbState.LastRoot() var parentRoot common.Hash diff --git a/core/blockchain_test.go b/core/blockchain_test.go index 6e217e0e421..af78a5d101f 100644 --- a/core/blockchain_test.go +++ b/core/blockchain_test.go @@ -72,7 +72,7 @@ func newCanonical(engine consensus.Engine, n int, full bool) (ethdb.Database, *B // Test fork of length N starting from block i func testFork(t *testing.T, blockchain *BlockChain, i, n int, full bool, comparator func(td1, td2 *big.Int)) { // Copy old chain up to #i into a new db - db, blockchain2, err := newCanonical(ethash.NewFaker(), i, full) + db, blockchain2, err := newCanonical(ethash.NewFaker(), i, true) if err != nil { t.Fatal("could not make new canonical in testFork", err) } @@ -350,7 +350,7 @@ func TestBrokenBlockChain(t *testing.T) { testBrokenChain(t, true) } func testBrokenChain(t *testing.T, full bool) { // Make chain starting from genesis - db, blockchain, err := newCanonical(ethash.NewFaker(), 10, full) + db, blockchain, err := newCanonical(ethash.NewFaker(), 10, true) if err != nil { t.Fatalf("failed to make new canonical chain: %v", err) } diff --git a/core/chain_makers.go b/core/chain_makers.go index 4bee25594ab..0264a45f2b4 100644 --- a/core/chain_makers.go +++ b/core/chain_makers.go @@ -232,7 +232,9 @@ func GenerateChain(config *params.ChainConfig, parent *types.Block, engine conse if err != nil { panic(err) } - tds.Rebuild() + if err := tds.Rebuild(); err != nil { + panic(err) + } for i := 0; i < n; i++ { statedb := state.New(tds) block, receipt := genblock(i, parent, statedb, tds) diff --git a/core/state/database.go b/core/state/database.go index 40c35ee292e..d33359638cc 100644 --- a/core/state/database.go +++ b/core/state/database.go @@ -205,6 +205,7 @@ type TrieDbState struct { noHistory bool resolveReads bool pg *trie.ProofGenerator + tp *trie.TriePruning } func NewTrieDbState(root common.Hash, db ethdb.Database, blockNr uint64) (*TrieDbState, error) { @@ -217,6 +218,10 @@ func NewTrieDbState(root common.Hash, db ethdb.Database, blockNr uint64) (*TrieD return nil, err } t := trie.New(root, false) + tp, err := trie.NewTriePruning(blockNr) + if err != nil { + return nil, err + } tds := TrieDbState{ t: t, db: db, @@ -225,8 +230,11 @@ func NewTrieDbState(root common.Hash, db ethdb.Database, blockNr uint64) (*TrieD codeCache: cc, codeSizeCache: csc, pg: trie.NewProofGenerator(), + tp: tp, } - t.MakeListed(tds.joinGeneration, tds.leftGeneration) + t.SetTouchFunc(func(hex []byte, del bool) { + tp.Touch(nil, hex, del) + }) tds.generationCounts = make(map[uint64]int, 4096) tds.oldestGeneration = blockNr return &tds, nil @@ -550,13 +558,13 @@ func (tds *TrieDbState) clearUpdates() { tds.aggregateBuffer = nil } -func (tds *TrieDbState) Rebuild() { - tr := tds.AccountTrie() - tr.Rebuild(tds.db, tds.blockNr) +func (tds *TrieDbState) Rebuild() error { + return tds.AccountTrie().Rebuild(tds.db, tds.blockNr) } func (tds *TrieDbState) SetBlockNr(blockNr uint64) { tds.blockNr = blockNr + tds.tp.SetBlockNr(blockNr) } func (tds *TrieDbState) UnwindTo(blockNr uint64) error { @@ -799,7 +807,9 @@ func (tds *TrieDbState) getStorageTrie(address common.Address, create bool) (*tr } else { t = trie.New(account.Root, true) } - t.MakeListed(tds.joinGeneration, tds.leftGeneration) + t.SetTouchFunc(func(hex []byte, del bool) { + tds.tp.Touch(common.CopyBytes(address[:]), hex, del) + }) tds.storageTries[address] = t } return t, nil @@ -902,32 +912,40 @@ func (tds *TrieDbState) ReadAccountCodeSize(codeHash common.Hash) (codeSize int, var prevMemStats runtime.MemStats func (tds *TrieDbState) PruneTries(print bool) { - if tds.nodeCount > int(MaxTrieCacheGen) { - toRemove := 0 - excess := tds.nodeCount - int(MaxTrieCacheGen) - gen := tds.oldestGeneration - for excess > 0 { - excess -= tds.generationCounts[gen] - toRemove += tds.generationCounts[gen] - delete(tds.generationCounts, gen) - gen++ + if print { + mainPrunable := tds.t.CountPrunableNodes() + prunableNodes := mainPrunable + for _, storageTrie := range tds.storageTries { + prunableNodes += storageTrie.CountPrunableNodes() } - // Unload all nodes with touch timestamp < gen - for address, storageTrie := range tds.storageTries { - empty := storageTrie.UnloadOlderThan(gen, false) - if empty { - delete(tds.storageTries, address) - } - } - tds.t.UnloadOlderThan(gen, false) - tds.oldestGeneration = gen - tds.nodeCount -= toRemove - var m runtime.MemStats - runtime.ReadMemStats(&m) - log.Info("Memory", "nodes", tds.nodeCount, "alloc", int(m.Alloc/1024), "sys", int(m.Sys/1024), "numGC", int(m.NumGC)) - if print { - fmt.Printf("Pruning done. Nodes: %d, alloc: %d, sys: %d, numGC: %d\n", tds.nodeCount, int(m.Alloc/1024), int(m.Sys/1024), int(m.NumGC)) + fmt.Printf("[Before] Actual prunable nodes: %d (main %d), accounted: %d\n", prunableNodes, mainPrunable, tds.tp.NodeCount()) + } + pruned, emptyAddresses, err := tds.tp.PruneTo(tds.t, int(MaxTrieCacheGen), func(contract common.Address) (*trie.Trie, error) { + return tds.getStorageTrie(contract, false) + }) + if err != nil { + fmt.Printf("Error while pruning: %v\n", err) + } + if !pruned { + //return + } + if print { + mainPrunable := tds.t.CountPrunableNodes() + prunableNodes := mainPrunable + for _, storageTrie := range tds.storageTries { + prunableNodes += storageTrie.CountPrunableNodes() } + fmt.Printf("[After] Actual prunable nodes: %d (main %d), accounted: %d\n", prunableNodes, mainPrunable, tds.tp.NodeCount()) + } + // Storage tries that were completely pruned + for _, address := range emptyAddresses { + delete(tds.storageTries, address) + } + var m runtime.MemStats + runtime.ReadMemStats(&m) + log.Info("Memory", "nodes", tds.tp.NodeCount(), "alloc", int(m.Alloc/1024), "sys", int(m.Sys/1024), "numGC", int(m.NumGC)) + if print { + fmt.Printf("Pruning done. Nodes: %d, alloc: %d, sys: %d, numGC: %d\n", tds.tp.NodeCount(), int(m.Alloc/1024), int(m.Sys/1024), int(m.NumGC)) } } diff --git a/trie/resolver.go b/trie/resolver.go index 52d868c0d1d..3ca086f6283 100644 --- a/trie/resolver.go +++ b/trie/resolver.go @@ -17,26 +17,19 @@ import ( var emptyHash [32]byte -func (t *Trie) Rebuild(db ethdb.Database, blockNr uint64) hashNode { +func (t *Trie) Rebuild(db ethdb.Database, blockNr uint64) error { if t.root == nil { return nil } n, ok := t.root.(hashNode) if !ok { - panic("Expected hashNode") + return fmt.Errorf("Rebuild: Expected hashNode, got %T", t.root) } - root, roothash, err := t.rebuildHashes(db, nil, 0, blockNr, true, n) - if err != nil { - panic(err) + if err := t.rebuildHashes(db, nil, 0, blockNr, true, n); err != nil { + return err } - if bytes.Equal(roothash, n) { - t.root = root - log.Info("Rebuilt hashfile and verified", "root hash", roothash) - } else { - log.Error(fmt.Sprintf("Could not rebuild %s vs %s\n", roothash, n)) - } - t.timestampSubTree(t.root, blockNr) - return roothash + log.Info("Rebuilt hashfile and verified", "root hash", n) + return nil } const Levels = 104 @@ -263,7 +256,7 @@ func (tr *TrieResolver) finishPreviousKey(k []byte) error { tr.nodeStack[level].flags.dirty = true } tr.vertical[level].flags.dirty = true - if onResolvingPath || (tr.hashes && level == 5) { + if onResolvingPath || (tr.hashes && level <= 5) { var c node if tr.fillCount[level+1] == 2 { c = full.duoCopy() @@ -274,6 +267,7 @@ func (tr *TrieResolver) finishPreviousKey(k []byte) error { if tr.fillCount[level] == 0 { tr.nodeStack[level].Val = c } + req.t.touchFunc(hex[2*len(req.contract):level+1], false) } else { tr.vertical[level].Children[keynibble] = hashNode(storeHashTo[:]) if tr.fillCount[level] == 0 { @@ -298,8 +292,10 @@ func (tr *TrieResolver) finishPreviousKey(k []byte) error { if tr.fillCount[req.extResolvePos] == 1 { root = tr.nodeStack[req.extResolvePos].copy() } else if tr.fillCount[req.extResolvePos] == 2 { + req.t.touchFunc(req.resolveHex[:req.resolvePos], false) root = tr.vertical[req.extResolvePos].duoCopy() } else if tr.fillCount[req.extResolvePos] > 2 { + req.t.touchFunc(req.resolveHex[:req.resolvePos], false) root = tr.vertical[req.extResolvePos].copy() } if root == nil { @@ -325,7 +321,6 @@ func (tr *TrieResolver) finishPreviousKey(k []byte) error { req.resolveHash) } } - req.resolved = root for i := 0; i <= Levels; i++ { tr.nodeStack[i].Key = nil tr.nodeStack[i].Val = nil @@ -336,40 +331,7 @@ func (tr *TrieResolver) finishPreviousKey(k []byte) error { tr.vertical[i].flags.dirty = true tr.fillCount[i] = 0 } - req.t.timestampSubTree(root, tr.blockNr) - if req.resolveParent == nil { - if _, ok := req.t.root.(hashNode); ok { - req.t.root = root - } - } else { - switch parent := req.resolveParent.(type) { - case nil: - if _, ok := req.t.root.(hashNode); ok { - req.t.root = root - } - case *shortNode: - if _, ok := parent.Val.(hashNode); ok { - parent.Val = root - } - case *duoNode: - i1, i2 := parent.childrenIdx() - switch req.resolveHex[req.resolvePos-1] { - case i1: - if _, ok := parent.child1.(hashNode); ok { - parent.child1 = root - } - case i2: - if _, ok := parent.child2.(hashNode); ok { - parent.child2 = root - } - } - case *fullNode: - idx := req.resolveHex[req.resolvePos-1] - if _, ok := parent.Children[idx].(hashNode); ok { - parent.Children[idx] = root - } - } - } + req.t.hook(req.resolveHex[:req.resolvePos], root, tr.blockNr) } return nil } @@ -475,12 +437,9 @@ func (tr *TrieResolver) ResolveWithDb(db ethdb.Database, blockNr uint64) error { return err } -func (t *Trie) rebuildHashes(db ethdb.Database, key []byte, pos int, blockNr uint64, accounts bool, expected hashNode) (node, hashNode, error) { +func (t *Trie) rebuildHashes(db ethdb.Database, key []byte, pos int, blockNr uint64, accounts bool, expected hashNode) error { req := t.NewResolveRequest(nil, key, pos, expected) r := NewResolver(true, accounts, blockNr) r.AddRequest(req) - if err := r.ResolveWithDb(db, blockNr); err != nil { - return nil, nil, err - } - return req.resolved, expected, nil + return r.ResolveWithDb(db, blockNr) } diff --git a/trie/resolver_test.go b/trie/resolver_test.go index e86cd576625..a8a686816aa 100644 --- a/trie/resolver_test.go +++ b/trie/resolver_test.go @@ -64,7 +64,6 @@ func TestResolve1Embedded(t *testing.T) { resolveHex: keybytesToHex([]byte("abcdefghijklmnopqrstuvwxyz012345")), resolvePos: 10, // 5 bytes is 10 nibbles resolveHash: nil, - resolved: nil, } r := NewResolver(false, false, 0) r.AddRequest(req) @@ -83,7 +82,6 @@ func TestResolve1(t *testing.T) { resolveHex: keybytesToHex([]byte("aaaaabbbbbaaaaabbbbbaaaaabbbbbaa")), resolvePos: 10, // 5 bytes is 10 nibbles resolveHash: hashNode(common.HexToHash("741326629cbf4ba5d5afebd56dd714ba4a531ddb6b07b829aa85dee4d97d34a4").Bytes()), - resolved: nil, } r := NewResolver(false, false, 0) r.AddRequest(req) @@ -103,7 +101,6 @@ func TestResolve2(t *testing.T) { resolveHex: keybytesToHex([]byte("aaaaabbbbbaaaaabbbbbaaaaabbbbbaa")), resolvePos: 10, // 5 bytes is 10 nibbles resolveHash: hashNode(common.HexToHash("c9f98a7d966d37c7231d11910c72f01a213057111b8171f5f137269bb73e45e4").Bytes()), - resolved: nil, } r := NewResolver(false, false, 0) r.AddRequest(req) @@ -123,7 +120,6 @@ func TestResolve2Keep(t *testing.T) { resolveHex: keybytesToHex([]byte("aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa")), resolvePos: 10, // 5 bytes is 10 nibbles resolveHash: hashNode(common.HexToHash("c9f98a7d966d37c7231d11910c72f01a213057111b8171f5f137269bb73e45e4").Bytes()), - resolved: nil, } r := NewResolver(false, false, 0) r.AddRequest(req) @@ -144,7 +140,6 @@ func TestResolve3Keep(t *testing.T) { resolveHex: keybytesToHex([]byte("aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa")), resolvePos: 10, // 5 bytes is 10 nibbles resolveHash: hashNode(common.HexToHash("03e27bd9cc47c0a03a8480035f765a4ba242c40ae4badfd1628af5a1ca5fd57a").Bytes()), - resolved: nil, } r := NewResolver(false, false, 0) r.AddRequest(req) @@ -170,21 +165,18 @@ func TestTrieResolver(t *testing.T) { resolveHex: keybytesToHex([]byte("aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa")), resolvePos: 10, // 5 bytes is 10 nibbles resolveHash: hashNode(common.HexToHash("c9f98a7d966d37c7231d11910c72f01a213057111b8171f5f137269bb73e45e4").Bytes()), - resolved: nil, } req2 := &ResolveRequest{ t: tr, resolveHex: keybytesToHex([]byte("bbaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa")), resolvePos: 2, // 2 bytes is 4 nibbles resolveHash: hashNode(common.HexToHash("b183c6dd36a92675ab74e32008a41735f485d20df283be0f349a412c769fe6c9").Bytes()), - resolved: nil, } req3 := &ResolveRequest{ t: tr, resolveHex: keybytesToHex([]byte("bbbaaaaaaaaaaaaaaaaaaaaaaaaaaaaa")), resolvePos: 2, // 3 bytes is 6 nibbles resolveHash: hashNode(common.HexToHash("b183c6dd36a92675ab74e32008a41735f485d20df283be0f349a412c769fe6c9").Bytes()), - resolved: nil, } resolver := NewResolver(false, false, 0) resolver.AddRequest(req3) diff --git a/trie/trie.go b/trie/trie.go index 4358c0c8c89..a302a6e757d 100644 --- a/trie/trie.go +++ b/trie/trie.go @@ -50,6 +50,7 @@ type Trie struct { joinGeneration func(gen uint64) leftGeneration func(gen uint64) + touchFunc func(hex []byte, del bool) } // New creates a trie with an existing root node from db. @@ -63,6 +64,7 @@ func New(root common.Hash, encodeToBytes bool) *Trie { encodeToBytes: encodeToBytes, joinGeneration: func(uint64) {}, leftGeneration: func(uint64) {}, + touchFunc: func([]byte, bool) {}, } if (root != common.Hash{}) && root != emptyRoot { trie.root = hashNode(root[:]) @@ -75,6 +77,10 @@ func (t *Trie) MakeListed(joinGeneration, leftGeneration func(gen uint64)) { t.leftGeneration = leftGeneration } +func (t *Trie) SetTouchFunc(touchFunc func(hex []byte, del bool)) { + t.touchFunc = touchFunc +} + // TryGet returns the value for key stored in the trie. func (t *Trie) Get(key []byte, blockNr uint64) (value []byte, gotValue bool) { hex := keybytesToHex(key) @@ -88,6 +94,7 @@ func (t *Trie) get(origNode node, key []byte, pos int, blockNr uint64) (value [] case valueNode: return n, true case *shortNode: + n.updateT(blockNr, t.joinGeneration, t.leftGeneration) var adjust bool nKey := compactToHex(n.Key) if len(key)-pos < len(nKey) || !bytes.Equal(nKey, key[pos:pos+len(nKey)]) { @@ -106,6 +113,7 @@ func (t *Trie) get(origNode node, key []byte, pos int, blockNr uint64) (value [] } return case *duoNode: + t.touchFunc(key[:pos], false) n.updateT(blockNr, t.joinGeneration, t.leftGeneration) var adjust bool i1, i2 := n.childrenIdx() @@ -125,6 +133,7 @@ func (t *Trie) get(origNode node, key []byte, pos int, blockNr uint64) (value [] } return case *fullNode: + t.touchFunc(key[:pos], false) n.updateT(blockNr, t.joinGeneration, t.leftGeneration) child := n.Children[key[pos]] adjust := child != nil && n.tod(blockNr) == child.tod(blockNr) @@ -167,8 +176,6 @@ type ResolveRequest struct { resolvePos int // Position in the key for which resolution is requested extResolvePos int resolveHash hashNode // Expected hash of the resolved node (for correctness checking) - resolved node // Node that has been resolved via Database access - resolveParent node // Parent node of the one needs to be resolved. nil if the root needs to be resolved } func (t *Trie) NewResolveRequest(contract []byte, hex []byte, pos int, resolveHash []byte) *ResolveRequest { @@ -184,7 +191,6 @@ func (rr *ResolveRequest) String() string { // In the case of "Yes", also returns a corresponding ResolveRequest func (t *Trie) NeedResolution(contract []byte, key []byte) (bool, *ResolveRequest) { var nd node = t.root - var parent node = nil hex := keybytesToHex(key) pos := 0 for { @@ -195,7 +201,6 @@ func (t *Trie) NeedResolution(contract []byte, key []byte) (bool, *ResolveReques nKey := compactToHex(n.Key) matchlen := prefixLen(hex[pos:], nKey) if matchlen == len(nKey) { - parent = nd nd = n.Val pos += matchlen } else { @@ -205,11 +210,9 @@ func (t *Trie) NeedResolution(contract []byte, key []byte) (bool, *ResolveReques i1, i2 := n.childrenIdx() switch hex[pos] { case i1: - parent = nd nd = n.child1 pos++ case i2: - parent = nd nd = n.child2 pos++ default: @@ -220,16 +223,13 @@ func (t *Trie) NeedResolution(contract []byte, key []byte) (bool, *ResolveReques if child == nil { return false, nil } else { - parent = nd nd = child pos++ } case valueNode: return false, nil case hashNode: - c := t.NewResolveRequest(contract, hex, pos, common.CopyBytes(n)) - c.resolveParent = parent - return true, c + return true, t.NewResolveRequest(contract, hex, pos, common.CopyBytes(n)) default: panic(fmt.Sprintf("Unknown node: %T", n)) } @@ -346,6 +346,8 @@ func (t *Trie) insert(origNode node, key []byte, pos int, value node, blockNr ui } switch n := origNode.(type) { case *shortNode: + t.leftGeneration(n.flags.t) + n.flags.t = blockNr nKey := compactToHex(n.Key) matchlen := prefixLen(key[pos:], nKey) // If the whole key matches, keep this short node as is @@ -357,6 +359,7 @@ func (t *Trie) insert(origNode node, key []byte, pos int, value node, blockNr ui n.flags.dirty = true } newNode = n + t.joinGeneration(blockNr) n.adjustTod(blockNr) } else { // Otherwise branch out at the index where they differ. @@ -397,15 +400,18 @@ func (t *Trie) insert(origNode node, key []byte, pos int, value node, blockNr ui // Replace this shortNode with the branch if it occurs at index 0. if matchlen == 0 { + t.touchFunc(key[:pos], false) newNode = branch // current node leaves the generation, but new node branch joins it } else { // Otherwise, replace it with a short node leading up to the branch. + t.touchFunc(key[:pos+matchlen], false) n.Key = hexToCompact(key[pos : pos+matchlen]) n.Val = branch t.joinGeneration(blockNr) // new branch node joins the generation n.flags.dirty = true n.flags.t = blockNr newNode = n + t.joinGeneration(blockNr) // n joins the generation too n.adjustTod(blockNr) } updated = true @@ -413,54 +419,25 @@ func (t *Trie) insert(origNode node, key []byte, pos int, value node, blockNr ui return case *duoNode: + t.touchFunc(key[:pos], false) n.updateT(blockNr, t.joinGeneration, t.leftGeneration) var adjust bool i1, i2 := n.childrenIdx() switch key[pos] { case i1: - adjust = n.child1 != nil && n.tod(blockNr) == n.child1.tod(blockNr) - if n.child1 == nil { - if len(key) == pos+1 { - n.child1 = value - } else { - short := &shortNode{Key: hexToCompact(key[pos+1:]), Val: value} - short.flags.dirty = true - short.flags.t = blockNr - short.adjustTod(blockNr) - t.joinGeneration(blockNr) - n.child1 = short - } - updated = true + adjust = n.tod(blockNr) == n.child1.tod(blockNr) + updated, nn = t.insert(n.child1, key, pos+1, value, blockNr) + if updated { + n.child1 = nn n.flags.dirty = true - } else { - updated, nn = t.insert(n.child1, key, pos+1, value, blockNr) - if updated { - n.child1 = nn - n.flags.dirty = true - } } newNode = n case i2: - adjust = n.child2 != nil && n.tod(blockNr) == n.child2.tod(blockNr) - if n.child2 == nil { - if len(key) == pos+1 { - n.child2 = value - } else { - short := &shortNode{Key: hexToCompact(key[pos+1:]), Val: value} - short.flags.dirty = true - short.flags.t = blockNr - short.adjustTod(blockNr) - t.joinGeneration(blockNr) - n.child2 = short - } - updated = true + adjust = n.tod(blockNr) == n.child2.tod(blockNr) + updated, nn = t.insert(n.child2, key, pos+1, value, blockNr) + if updated { + n.child2 = nn n.flags.dirty = true - } else { - updated, nn = t.insert(n.child2, key, pos+1, value, blockNr) - if updated { - n.child2 = nn - n.flags.dirty = true - } } newNode = n default: @@ -493,6 +470,7 @@ func (t *Trie) insert(origNode node, key []byte, pos int, value node, blockNr ui return case *fullNode: + t.touchFunc(key[:pos], false) n.updateT(blockNr, t.joinGeneration, t.leftGeneration) child := n.Children[key[pos]] adjust := child != nil && n.tod(blockNr) == child.tod(blockNr) @@ -522,11 +500,87 @@ func (t *Trie) insert(origNode node, key []byte, pos int, value node, blockNr ui } return default: - fmt.Printf("Key: %x, Prefix: %x\n", key[pos:], key[:pos]) + fmt.Printf("Key: %x, Pos: %d\n", key, pos) panic(fmt.Sprintf("%T: invalid node: %v", n, n)) } } +func (t *Trie) hook(hex []byte, n node, blockNr uint64) { + var nd node = t.root + var parent node + pos := 0 + for pos < len(hex) { + switch n := nd.(type) { + case nil: + return + case *shortNode: + n.flags.t = blockNr + nKey := compactToHex(n.Key) + matchlen := prefixLen(hex[pos:], nKey) + if matchlen == len(nKey) { + parent = n + nd = n.Val + pos += matchlen + } else { + return + } + case *duoNode: + t.touchFunc(hex[:pos], false) + n.flags.t = blockNr + i1, i2 := n.childrenIdx() + switch hex[pos] { + case i1: + parent = n + nd = n.child1 + pos++ + case i2: + parent = n + nd = n.child2 + pos++ + default: + return + } + case *fullNode: + t.touchFunc(hex[:pos], false) + n.flags.t = blockNr + child := n.Children[hex[pos]] + if child == nil { + return + } else { + parent = n + nd = child + pos++ + } + case valueNode: + return + case hashNode: + return + default: + panic(fmt.Sprintf("Unknown node: %T", n)) + } + } + if _, ok := nd.(hashNode); !ok { + return + } + switch p := parent.(type) { + case nil: + t.root = n + case *shortNode: + p.Val = n + case *duoNode: + i1, i2 := p.childrenIdx() + switch hex[len(hex)-1] { + case i1: + p.child1 = n + case i2: + p.child2 = n + } + case *fullNode: + idx := hex[len(hex)-1] + p.Children[idx] = n + } +} + // Delete removes any existing value for key from the trie. func (t *Trie) Delete(key []byte, blockNr uint64) { hex := keybytesToHex(key) @@ -574,13 +628,15 @@ func (t *Trie) delete(origNode node, key []byte, keyStart int, blockNr uint64) ( var nn node switch n := origNode.(type) { case *shortNode: + t.leftGeneration(n.flags.t) + n.flags.t = blockNr nKey := compactToHex(n.Key) matchlen := prefixLen(key[keyStart:], nKey) if matchlen < len(nKey) { updated = false + t.joinGeneration(blockNr) newNode = n // don't replace n on mismatch } else if matchlen == len(key)-keyStart { - t.leftGeneration(n.flags.t) updated = true newNode = nil // remove n entirely for whole matches } else { @@ -590,10 +646,10 @@ func (t *Trie) delete(origNode node, key []byte, keyStart int, blockNr uint64) ( // longer than n.Key. updated, nn = t.delete(n.Val, key, keyStart+len(nKey), blockNr) if !updated { + t.joinGeneration(blockNr) newNode = n } else { if nn == nil { - t.leftGeneration(n.flags.t) newNode = nil } else { if shortChild, ok := nn.(*shortNode); ok { @@ -609,6 +665,7 @@ func (t *Trie) delete(origNode node, key []byte, keyStart int, blockNr uint64) ( newnode.flags.dirty = true newnode.flags.t = blockNr newnode.adjustTod(blockNr) + t.joinGeneration(blockNr) // We do not increase generation count here, because one short node comes, but another one t.leftGeneration(shortChild.flags.t) // But shortChild goes away newNode = newnode @@ -616,6 +673,7 @@ func (t *Trie) delete(origNode node, key []byte, keyStart int, blockNr uint64) ( n.Val = nn n.flags.dirty = true n.adjustTod(blockNr) + t.joinGeneration(blockNr) newNode = n } } @@ -624,7 +682,7 @@ func (t *Trie) delete(origNode node, key []byte, keyStart int, blockNr uint64) ( return case *duoNode: - n.updateT(blockNr, t.joinGeneration, t.leftGeneration) + t.leftGeneration(n.flags.t) var adjust bool i1, i2 := n.childrenIdx() switch key[keyStart] { @@ -632,18 +690,19 @@ func (t *Trie) delete(origNode node, key []byte, keyStart int, blockNr uint64) ( adjust = n.child1 != nil && n.tod(blockNr) == n.child1.tod(blockNr) updated, nn = t.delete(n.child1, key, keyStart+1, blockNr) if !updated { + t.touchFunc(key[:keyStart], false) + t.joinGeneration(blockNr) + n.flags.t = blockNr newNode = n } else { - n.child1 = nn if nn == nil { - if n.child2 == nil { - adjust = false - t.leftGeneration(n.flags.t) - newNode = nil - } else { - newNode = t.convertToShortNode(key, keyStart, n.child2, uint(i2), blockNr) - } + t.touchFunc(key[:keyStart], true) + newNode = t.convertToShortNode(key, keyStart, n.child2, uint(i2), blockNr) } else { + t.touchFunc(key[:keyStart], false) + t.joinGeneration(blockNr) + n.flags.t = blockNr + n.child1 = nn n.flags.dirty = true newNode = n } @@ -652,23 +711,27 @@ func (t *Trie) delete(origNode node, key []byte, keyStart int, blockNr uint64) ( adjust = n.child2 != nil && n.tod(blockNr) == n.child2.tod(blockNr) updated, nn = t.delete(n.child2, key, keyStart+1, blockNr) if !updated { + t.touchFunc(key[:keyStart], false) + t.joinGeneration(blockNr) + n.flags.t = blockNr newNode = n } else { - n.child2 = nn if nn == nil { - if n.child1 == nil { - adjust = false - t.leftGeneration(n.flags.t) - newNode = nil - } else { - newNode = t.convertToShortNode(key, keyStart, n.child1, uint(i1), blockNr) - } + t.touchFunc(key[:keyStart], true) + newNode = t.convertToShortNode(key, keyStart, n.child1, uint(i1), blockNr) } else { + t.touchFunc(key[:keyStart], false) + t.joinGeneration(blockNr) + n.flags.t = blockNr + n.child2 = nn n.flags.dirty = true newNode = n } } default: + t.touchFunc(key[:keyStart], false) + t.joinGeneration(blockNr) + n.flags.t = blockNr adjust = false updated = false newNode = n @@ -679,11 +742,14 @@ func (t *Trie) delete(origNode node, key []byte, keyStart int, blockNr uint64) ( return case *fullNode: - n.updateT(blockNr, t.joinGeneration, t.leftGeneration) + t.leftGeneration(n.flags.t) child := n.Children[key[keyStart]] adjust := child != nil && n.tod(blockNr) == child.tod(blockNr) updated, nn = t.delete(child, key, keyStart+1, blockNr) if !updated { + t.touchFunc(key[:keyStart], false) + t.joinGeneration(blockNr) + n.flags.t = blockNr newNode = n } else { n.Children[key[keyStart]] = nn @@ -712,12 +778,11 @@ func (t *Trie) delete(origNode node, key []byte, keyStart int, blockNr uint64) ( } } } - if count == 0 { - t.leftGeneration(n.flags.t) - newNode = nil - } else if count == 1 { + if count == 1 { + t.touchFunc(key[:keyStart], true) newNode = t.convertToShortNode(key, keyStart, n.Children[pos1], uint(pos1), blockNr) } else if count == 2 { + t.touchFunc(key[:keyStart], false) duo := &duoNode{} if pos1 == int(key[keyStart]) { duo.child1 = nn @@ -732,10 +797,14 @@ func (t *Trie) delete(origNode node, key []byte, keyStart int, blockNr uint64) ( duo.flags.dirty = true duo.mask = (1 << uint(pos1)) | (uint32(1) << uint(pos2)) duo.flags.t = blockNr + t.joinGeneration(blockNr) duo.adjustTod(blockNr) adjust = false newNode = duo - } else { + } else if count > 2 { + t.touchFunc(key[:keyStart], false) + t.joinGeneration(blockNr) + n.flags.t = blockNr // n still contains at least three values and cannot be reduced. n.flags.dirty = true newNode = n @@ -762,55 +831,42 @@ func (t *Trie) delete(origNode node, key []byte, keyStart int, blockNr uint64) ( } func (t *Trie) PrepareToRemove() { - t.prepareToRemove(t.root) + t.prepareToRemove(t.root, []byte{}) } -func (t *Trie) prepareToRemove(n node) { +func (t *Trie) prepareToRemove(n node, hex []byte) { switch n := n.(type) { case *shortNode: + var hexVal []byte + if _, ok := n.Val.(valueNode); !ok { // Don't need to compute prefix for a leaf + nKey := compactToHex(n.Key) + hexVal = make([]byte, len(hex)+len(nKey)) + copy(hexVal, hex) + copy(hexVal[len(hex):], nKey) + } t.leftGeneration(n.flags.t) - t.prepareToRemove(n.Val) + t.prepareToRemove(n.Val, hexVal) case *duoNode: + t.touchFunc(hex, true) t.leftGeneration(n.flags.t) - t.prepareToRemove(n.child1) - t.prepareToRemove(n.child2) + i1, i2 := n.childrenIdx() + hex1 := make([]byte, len(hex)+1) + copy(hex1, hex) + hex1[len(hex)] = byte(i1) + hex2 := make([]byte, len(hex)+1) + copy(hex2, hex) + hex2[len(hex)] = byte(i2) + t.prepareToRemove(n.child1, hex1) + t.prepareToRemove(n.child2, hex2) case *fullNode: + t.touchFunc(hex, true) t.leftGeneration(n.flags.t) - for _, child := range n.Children { + for i, child := range n.Children { if child != nil { - t.prepareToRemove(child) - } - } - } -} - -// Timestamp given node and all descendants -func (t *Trie) timestampSubTree(n node, blockNr uint64) { - switch n := n.(type) { - case *shortNode: - if n.flags.t == 0 { - n.flags.t = blockNr - n.flags.tod = blockNr - t.joinGeneration(blockNr) - t.timestampSubTree(n.Val, blockNr) - } - case *duoNode: - if n.flags.t == 0 { - n.flags.t = blockNr - n.flags.tod = blockNr - t.joinGeneration(blockNr) - t.timestampSubTree(n.child1, blockNr) - t.timestampSubTree(n.child2, blockNr) - } - case *fullNode: - if n.flags.t == 0 { - n.flags.t = blockNr - n.flags.tod = blockNr - t.joinGeneration(blockNr) - for _, child := range n.Children { - if child != nil { - t.timestampSubTree(child, blockNr) - } + hexChild := make([]byte, len(hex)+1) + copy(hexChild, hex) + hexChild[len(hex)] = byte(i) + t.prepareToRemove(child, hexChild) } } } @@ -888,6 +944,159 @@ func unloadOlderThan(n node, gen uint64) (hashNode, bool) { return nil, false } +func (t *Trie) unload(hex []byte, h *hasher) { + var nd node = t.root + var parent node + pos := 0 + for pos < len(hex) { + switch n := nd.(type) { + case nil: + return + case *shortNode: + nKey := compactToHex(n.Key) + matchlen := prefixLen(hex[pos:], nKey) + if matchlen == len(nKey) { + parent = n + nd = n.Val + pos += matchlen + } else { + return + } + case *duoNode: + i1, i2 := n.childrenIdx() + switch hex[pos] { + case i1: + parent = n + nd = n.child1 + pos++ + case i2: + parent = n + nd = n.child2 + pos++ + default: + return + } + case *fullNode: + child := n.Children[hex[pos]] + if child == nil { + return + } else { + parent = n + nd = child + pos++ + } + case valueNode: + return + case hashNode: + return + default: + panic(fmt.Sprintf("Unknown node: %T", n)) + } + } + if _, ok := nd.(hashNode); ok { + return + } + var hn common.Hash + h.hash(nd, len(hex) == 0, hn[:]) + hnode := hashNode(hn[:]) + switch p := parent.(type) { + case nil: + t.root = hnode + case *shortNode: + p.Val = hnode + case *duoNode: + i1, i2 := p.childrenIdx() + switch hex[len(hex)-1] { + case i1: + p.child1 = hnode + case i2: + p.child2 = hnode + } + case *fullNode: + idx := hex[len(hex)-1] + p.Children[idx] = hnode + } +} + +func (t *Trie) CountPrunableNodes() int { + return t.countPrunableNodes(t.root, []byte{}, false) +} + +func (t *Trie) countPrunableNodes(nd node, hex []byte, print bool) int { + switch n := nd.(type) { + case nil: + return 0 + case valueNode: + return 0 + case hashNode: + return 0 + case *shortNode: + var hexVal []byte + if _, ok := n.Val.(valueNode); !ok { // Don't need to compute prefix for a leaf + nKey := compactToHex(n.Key) + hexVal = make([]byte, len(hex)+len(nKey)) + copy(hexVal, hex) + copy(hexVal[len(hex):], nKey) + } + return t.countPrunableNodes(n.Val, hexVal, print) + case *duoNode: + i1, i2 := n.childrenIdx() + hex1 := make([]byte, len(hex)+1) + copy(hex1, hex) + hex1[len(hex)] = byte(i1) + hex2 := make([]byte, len(hex)+1) + copy(hex2, hex) + hex2[len(hex)] = byte(i2) + if print { + fmt.Printf("%T node: %x, t: %d\n", n, hex, n.flags.t) + } + return 1 + t.countPrunableNodes(n.child1, hex1, print) + t.countPrunableNodes(n.child2, hex2, print) + case *fullNode: + if print { + fmt.Printf("%T node: %x, t: %d\n", n, hex, n.flags.t) + } + count := 0 + for i, child := range n.Children { + if child != nil { + hexChild := make([]byte, len(hex)+1) + copy(hexChild, hex) + hexChild[len(hex)] = byte(i) + count += t.countPrunableNodes(child, hexChild, print) + } + } + return 1 + count + default: + panic("") + } +} + +func (t *Trie) CountGenerations(m map[uint64]int) { + t.countGenerations(t.root, m) +} + +func (t *Trie) countGenerations(nd node, m map[uint64]int) { + switch n := nd.(type) { + case nil: + case valueNode: + case hashNode: + case *shortNode: + t.countGenerations(n.Val, m) + case *duoNode: + m[n.flags.t]++ + t.countGenerations(n.child1, m) + t.countGenerations(n.child2, m) + case *fullNode: + m[n.flags.t]++ + for _, child := range n.Children { + if child != nil { + t.countGenerations(child, m) + } + } + default: + panic("") + } +} + func (t *Trie) hashRoot() (node, error) { if t.root == nil { return hashNode(emptyRoot.Bytes()), nil diff --git a/trie/trie_pruning.go b/trie/trie_pruning.go index 15a634c5007..95fed96098e 100644 --- a/trie/trie_pruning.go +++ b/trie/trie_pruning.go @@ -19,6 +19,12 @@ package trie import ( + "bytes" + "encoding/binary" + "fmt" + "sort" + "strings" + "github.com/ledgerwatch/bolt" "github.com/ledgerwatch/turbo-geth/common" ) @@ -28,11 +34,11 @@ type TriePruning struct { // It maps prefixes to their corresponding timestamps (uint64) timestamps *bolt.DB - // Maps timestamp (uint64) to address of the contract to prefix of node (string) to parent node - storage map[uint64]map[common.Address]map[string]node + // Maps timestamp (uint64) to address of the contract to set of prefixes of nodes (string) + storage map[uint64]map[common.Address]map[string]struct{} - // Maps timestamp (uint64) to prefix of node (string) to parent node - accounts map[uint64]map[string]node + // Maps timestamp (uint64) to set of prefixes of nodees (string) + accounts map[uint64]map[string]struct{} // For each timestamp, keeps number of branch nodes belonging to it generationCounts map[uint64]int @@ -47,22 +53,297 @@ type TriePruning struct { blockNr uint64 } -func NewTriePruning() (*TriePruning, error) { +func NewTriePruning(oldestGeneration uint64) (*TriePruning, error) { db, err := bolt.Open("in-memory", 0600, &bolt.Options{MemOnly: true}) if err != nil { return nil, err } + // Pre-create the bucket so we can assume it is there + if err := db.Update(func(tx *bolt.Tx) error { + if _, err := tx.CreateBucket(abucket, false); err != nil { + return err + } + if _, err := tx.CreateBucket(sbucket, false); err != nil { + return err + } + return nil + }); err != nil { + db.Close() + return nil, err + } return &TriePruning{ - timestamps: db, - storage: make(map[uint64]map[common.Address]map[string]node), - accounts: make(map[uint64]map[string]node), + oldestGeneration: oldestGeneration, + blockNr: oldestGeneration, + timestamps: db, + storage: make(map[uint64]map[common.Address]map[string]struct{}), + accounts: make(map[uint64]map[string]struct{}), generationCounts: make(map[uint64]int), }, nil } +func (tp *TriePruning) SetBlockNr(blockNr uint64) { + tp.blockNr = blockNr +} + +var abucket = []byte("a") +var sbucket = []byte("s") + // Updates a node to the current timestamp // contract is effectively address of the smart contract -// hex is the prefix of the key -func (tp *TriePruning) touch(contract []byte, hex []byte, parent, n node) { +// hex is the prefix of the key +// parent is the node that needs to be modified to unload the touched node +// exists is true when the node existed before, and false if it is a new one +// prevTimestamp is the timestamp the node current has +func (tp *TriePruning) TouchFrom(contract []byte, hex []byte, exists bool, prevTimestamp uint64, del bool, newTimestamp uint64) { + //fmt.Printf("TouchFrom %x, exists: %t, prevTimestamp %d, del %t, newTimestamp %d\n", hex, exists, prevTimestamp, del, newTimestamp) + if exists && !del && prevTimestamp == newTimestamp { + return + } + if !del { + hexS := string(common.CopyBytes(hex)) + var newMap map[string]struct{} + if contract == nil { + if m, ok := tp.accounts[newTimestamp]; ok { + newMap = m + } else { + newMap = make(map[string]struct{}) + tp.accounts[newTimestamp] = newMap + } + } else { + contractAddress := common.BytesToAddress(contract) + if m, ok := tp.storage[newTimestamp]; ok { + if m1, ok1 := m[contractAddress]; ok1 { + newMap = m1 + } else { + newMap = make(map[string]struct{}) + m[contractAddress] = newMap + } + } else { + m = make(map[common.Address]map[string]struct{}) + newMap = make(map[string]struct{}) + m[contractAddress] = newMap + tp.storage[newTimestamp] = m + } + } + newMap[hexS] = struct{}{} + } + if exists { + if contract == nil { + if m, ok := tp.accounts[prevTimestamp]; ok { + delete(m, string(hex)) + if len(m) == 0 { + delete(tp.accounts, prevTimestamp) + } + } + } else { + contractAddress := common.BytesToAddress(contract) + if m, ok := tp.storage[prevTimestamp]; ok { + if m1, ok1 := m[contractAddress]; ok1 { + delete(m1, string(hex)) + if len(m1) == 0 { + delete(m, contractAddress) + if len(m) == 0 { + delete(tp.storage, prevTimestamp) + } + } + } + } + } + } + // Update generation count + if !del { + tp.generationCounts[newTimestamp]++ + tp.nodeCount++ + } + if exists { + tp.generationCounts[prevTimestamp]-- + if tp.generationCounts[prevTimestamp] == 0 { + delete(tp.generationCounts, prevTimestamp) + } + tp.nodeCount-- + } +} -} \ No newline at end of file +// Updates a node to the current timestamp +// contract is effectively address of the smart contract +// hex is the prefix of the key +// parent is the node that needs to be modified to unload the touched node +func (tp *TriePruning) Touch(contract []byte, hex []byte, del bool) error { + var exists = false + var timestampInput [8]byte + var timestampOutput [8]byte + // Now it is the current timestamp, but after the transaction, it will be replaced + // by the previously existing (if it existed) + binary.BigEndian.PutUint64(timestampInput[:], tp.blockNr) + var cKey []byte + var bucket []byte + if contract == nil { + cKey = make([]byte, len(hex)+1) + cKey[0] = 0xff + copy(cKey[1:], hex) + bucket = abucket + } else { + cKey = make([]byte, len(contract)+len(hex)) + copy(cKey, contract) + copy(cKey[len(contract):], hex) + bucket = sbucket + } + if err := tp.timestamps.Update(func(tx *bolt.Tx) error { + b := tx.Bucket(bucket) + if b == nil { + return fmt.Errorf("timestamp bucket %s did not exist", bucket) + } + if v, _ := b.Get(cKey); v != nil { + if del { + if err := b.Delete(cKey); err != nil { + return err + } + } else if !bytes.Equal(v, timestampInput[:]) { + if err := b.Put(cKey, timestampInput[:]); err != nil { + return err + } + } + copy(timestampOutput[:], v) + exists = true + } else { + if !del { + if err := b.Put(cKey, timestampInput[:]); err != nil { + return err + } + } + } + return nil + }); err != nil { + return err + } + var prevTimestamp uint64 + if exists { + prevTimestamp = binary.BigEndian.Uint64(timestampOutput[:]) + } + tp.TouchFrom(contract, hex, exists, prevTimestamp, del, tp.blockNr) + return nil +} + +func pruneMap(t *Trie, m map[string]struct{}, h *hasher) bool { + hexes := make([]string, len(m)) + i := 0 + for hexS := range m { + hexes[i] = hexS + i++ + } + var empty = false + sort.Strings(hexes) + for i, hex := range hexes { + if i == 0 || len(hex) == 0 || !strings.HasPrefix(hex, hexes[i-1]) { // If the parent nodes are pruned, there is no need to prune descendants + t.unload([]byte(hex), h) + if len(hex) == 0 { + empty = true + } + } + } + return empty +} + +func (tp *TriePruning) PruneTo( + t *Trie, + targetNodeCount int, + storageTrieFunc func(contract common.Address) (*Trie, error), +) (bool, []common.Address, error) { + if tp.nodeCount <= targetNodeCount { + return false, nil, nil + } + excess := tp.nodeCount - targetNodeCount + prunable := 0 + pruneGeneration := tp.oldestGeneration + for prunable < excess { + prunable += tp.generationCounts[pruneGeneration] + delete(tp.generationCounts, pruneGeneration) + pruneGeneration++ + } + //fmt.Printf("Will prune to generation %d, nodes to prune: %d, excess %d\n", pruneGeneration, prunable, excess) + // Remove (unload) nodes from storage tries and account trie + aggregateStorage := make(map[common.Address]map[string]struct{}) + aggregateAccounts := make(map[string]struct{}) + for gen := tp.oldestGeneration; gen < pruneGeneration; gen++ { + if m, ok := tp.storage[gen]; ok { + for address, m1 := range m { + var aggregateM map[string]struct{} + if m2, ok2 := aggregateStorage[address]; ok2 { + aggregateM = m2 + } else { + aggregateM = make(map[string]struct{}) + aggregateStorage[address] = aggregateM + } + for hexS := range m1 { + aggregateM[hexS] = struct{}{} + } + } + } + delete(tp.storage, gen) + if m, ok := tp.accounts[gen]; ok { + for hexS := range m { + aggregateAccounts[hexS] = struct{}{} + } + } + delete(tp.accounts, gen) + } + var emptyAddresses []common.Address + h := newHasher(true) // Create hasher appropriate for storage tries first + defer returnHasherToPool(h) + for address, m := range aggregateStorage { + storageTrie, err := storageTrieFunc(address) + if err != nil { + return false, nil, err + } + empty := pruneMap(storageTrie, m, h) + if empty { + emptyAddresses = append(emptyAddresses, address) + } + } + // Change hasher to be appropriate for the main trie + h.encodeToBytes = false + pruneMap(t, aggregateAccounts, h) + // Remove fom the timestamp structure + if err := tp.timestamps.Update(func(tx *bolt.Tx) error { + ab := tx.Bucket(abucket) + if ab == nil { + return fmt.Errorf("timestamp bucket %s did not exist", abucket) + } + for hexS := range aggregateAccounts { + cKey := make([]byte, 1+len(hexS)) + cKey[0] = 0xff + copy(cKey[1:], []byte(hexS)) + if err := ab.Delete(cKey); err != nil { + return err + } + } + sb := tx.Bucket(sbucket) + if sb == nil { + return fmt.Errorf("timestamp bucket %s did not exist", sbucket) + } + for address, m := range aggregateStorage { + for hexS := range m { + cKey := make([]byte, len(address)+len(hexS)) + copy(cKey, address[:]) + copy(cKey[len(address):], []byte(hexS)) + if err := sb.Delete(cKey); err != nil { + return err + } + } + } + return nil + }); err != nil { + return false, nil, err + } + tp.oldestGeneration = pruneGeneration + tp.nodeCount -= prunable + return true, emptyAddresses, nil +} + +func (tp *TriePruning) NodeCount() int { + return tp.nodeCount +} + +func (tp *TriePruning) GenCounts() map[uint64]int { + return tp.generationCounts +} diff --git a/trie/trie_pruning_test.go b/trie/trie_pruning_test.go new file mode 100644 index 00000000000..e8e3c016ff1 --- /dev/null +++ b/trie/trie_pruning_test.go @@ -0,0 +1,68 @@ +// Copyright 2019 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty off +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +// Pruning of the Merkle Patricia trees + +package trie + +import ( + "encoding/binary" + "fmt" + "testing" + + "github.com/ledgerwatch/turbo-geth/common" +) + +func TestOnePerTimestamp(t *testing.T) { + tp, err := NewTriePruning(0) + if err != nil { + t.Errorf("Error creating trie pruning: %v", err) + } + tr := New(common.Hash{}, false) + tr.SetTouchFunc(func(hex []byte, del bool) { + tp.Touch(nil, hex, del) + }) + var key [4]byte + value := []byte("V") + var timestamp uint64 = 0 + for n := uint32(0); n < uint32(100); n++ { + tp.SetBlockNr(timestamp) + binary.BigEndian.PutUint32(key[:], n) + tr.Update(key[:], value, timestamp) // Each key is added within a new generation + timestamp++ + } + for n := uint32(50); n < uint32(60); n++ { + tp.SetBlockNr(timestamp) + binary.BigEndian.PutUint32(key[:], n) + tr.Delete(key[:], timestamp) // Each key is added within a new generation + timestamp++ + } + for n := uint32(30); n < uint32(59); n++ { + tp.SetBlockNr(timestamp) + binary.BigEndian.PutUint32(key[:], n) + tr.Get(key[:], timestamp) // Each key is added within a new generation + timestamp++ + } + prunableNodes := tr.CountPrunableNodes() + fmt.Printf("Actual prunable nodes: %d, accounted: %d\n", prunableNodes, tp.NodeCount()) + if _, _, err := tp.PruneTo(tr, 4, func(contract common.Address) (*Trie, error) { + return nil, nil + }); err != nil { + t.Errorf("Error while pruning: %v", err) + } + prunableNodes = tr.CountPrunableNodes() + fmt.Printf("Actual prunable nodes: %d, accounted: %d\n", prunableNodes, tp.NodeCount()) +} From e3a57cd7776aa03f21822061964d2437223f3e27 Mon Sep 17 00:00:00 2001 From: Alexey Akhunov Date: Tue, 11 Jun 2019 20:25:44 +0100 Subject: [PATCH 2/5] Integrate with stateless prototype, use maps instead of bolt for timestamps --- core/right_6.txt | 1 - core/root_6.txt | 1 - core/state/database.go | 4 +- core/state/stateless.go | 64 ++++++-- trie/proof_generator.go | 114 +++++++------- trie/trie.go | 20 +-- trie/trie_pruning.go | 318 +++++++++++++++++++------------------- trie/trie_pruning_test.go | 2 +- 8 files changed, 276 insertions(+), 248 deletions(-) delete mode 100644 core/right_6.txt delete mode 100644 core/root_6.txt diff --git a/core/right_6.txt b/core/right_6.txt deleted file mode 100644 index 5431ed096d1..00000000000 --- a/core/right_6.txt +++ /dev/null @@ -1 +0,0 @@ -f(1:s(0f000e000204090d0d0e000f0d0c070c040e04010c01020e06070d0e0b070d020e04020402000b0d020e07090b0e0c0f09030e0c0704000702050e0c03080d10:v(f84c80881bc16d674ec80000a056e81f171bcc55a6ff8345e692c0f86e5b48e01b996cadc001622fb5e363b421a0c5d2460186f7233c927e7db2dcc703c0e500b653ca82273b7bfad8045d85a470))5:s(0e05030b06020d0303030e070b040b06060605000b0707020903020f0c0d080d0a000d0106090b0b000b0601070402010d0b0d020e07080805030b0308010d10:v(f84c80881bc16d674ec80000a056e81f171bcc55a6ff8345e692c0f86e5b48e01b996cadc001622fb5e363b421a0c5d2460186f7233c927e7db2dcc703c0e500b653ca82273b7bfad8045d85a470))11:s(0f08070d060f0403090c080a05040c090d020a030202000b03010b050c0e0b0e0c04010a0c050d070d090c0f01080d05070700070c040a0a0502090c07010d10:v(f84c80881bc16d674ec80000a056e81f171bcc55a6ff8345e692c0f86e5b48e01b996cadc001622fb5e363b421a0c5d2460186f7233c927e7db2dcc703c0e500b653ca82273b7bfad8045d85a470))13:s(0f0c0f0b0d020905000203030804090c000901050c070c0a08010508050907000407050d07010b0a060c0b0e06070307020e0e0b0604030c020207090c040710:v(f84c80881bc16d674ec80000a056e81f171bcc55a6ff8345e692c0f86e5b48e01b996cadc001622fb5e363b421a0c5d2460186f7233c927e7db2dcc703c0e500b653ca82273b7bfad8045d85a470))14:s(0703090b080c020d04050208090f090708050b030e040a01070b010f000f000e090f010f0c0e00020c0601040f0109000a000503020f06020f0e0e08060d0610:v(f84c80881bc16d674ec80000a056e81f171bcc55a6ff8345e692c0f86e5b48e01b996cadc001622fb5e363b421a0c5d2460186f7233c927e7db2dcc703c0e500b653ca82273b7bfad8045d85a470))15:s(02020b0b04060e0d0f03010a0f0805050903080b0e0f0a0a0807000e0d030d08060a040a0d09030a090e0c0f070c06030c0e0a0a08000d0a0e0a090a0c040d10:v(f84c80881bc16d674ec80000a056e81f171bcc55a6ff8345e692c0f86e5b48e01b996cadc001622fb5e363b421a0c5d2460186f7233c927e7db2dcc703c0e500b653ca82273b7bfad8045d85a470))) diff --git a/core/root_6.txt b/core/root_6.txt deleted file mode 100644 index cdd3fa962b8..00000000000 --- a/core/root_6.txt +++ /dev/null @@ -1 +0,0 @@ -f(1:s(0f000e000204090d0d0e000f0d0c070c040e04010c01020e06070d0e0b070d020e04020402000b0d020e07090b0e0c0f09030e0c0704000702050e0c03080d10:v(f84c80883782dace9d900000a056e81f171bcc55a6ff8345e692c0f86e5b48e01b996cadc001622fb5e363b421a0c5d2460186f7233c927e7db2dcc703c0e500b653ca82273b7bfad8045d85a470))5:s(0e05030b06020d0303030e070b040b06060605000b0707020903020f0c0d080d0a000d0106090b0b000b0601070402010d0b0d020e07080805030b0308010d10:v(f84c80881bc16d674ec80000a056e81f171bcc55a6ff8345e692c0f86e5b48e01b996cadc001622fb5e363b421a0c5d2460186f7233c927e7db2dcc703c0e500b653ca82273b7bfad8045d85a470))11:s(0f08070d060f0403090c080a05040c090d020a030202000b03010b050c0e0b0e0c04010a0c050d070d090c0f01080d05070700070c040a0a0502090c07010d10:v(f84c80881bc16d674ec80000a056e81f171bcc55a6ff8345e692c0f86e5b48e01b996cadc001622fb5e363b421a0c5d2460186f7233c927e7db2dcc703c0e500b653ca82273b7bfad8045d85a470))13:s(0f0c0f0b0d020905000203030804090c000901050c070c0a08010508050907000407050d07010b0a060c0b0e06070307020e0e0b0604030c020207090c040710:v(f84c80881bc16d674ec80000a056e81f171bcc55a6ff8345e692c0f86e5b48e01b996cadc001622fb5e363b421a0c5d2460186f7233c927e7db2dcc703c0e500b653ca82273b7bfad8045d85a470))14:s(0703090b080c020d04050208090f090708050b030e040a01070b010f000f000e090f010f0c0e00020c0601040f0109000a000503020f06020f0e0e08060d0610:v(f84c80881bc16d674ec80000a056e81f171bcc55a6ff8345e692c0f86e5b48e01b996cadc001622fb5e363b421a0c5d2460186f7233c927e7db2dcc703c0e500b653ca82273b7bfad8045d85a470))15:s(02020b0b04060e0d0f03010a0f0805050903080b0e0f0a0a0807000e0d030d08060a040a0d09030a090e0c0f070c06030c0e0a0a08000d0a0e0a090a0c040d10:v(f84c80881bc16d674ec80000a056e81f171bcc55a6ff8345e692c0f86e5b48e01b996cadc001622fb5e363b421a0c5d2460186f7233c927e7db2dcc703c0e500b653ca82273b7bfad8045d85a470))) diff --git a/core/state/database.go b/core/state/database.go index d33359638cc..7c806d27e8c 100644 --- a/core/state/database.go +++ b/core/state/database.go @@ -233,7 +233,7 @@ func NewTrieDbState(root common.Hash, db ethdb.Database, blockNr uint64) (*TrieD tp: tp, } t.SetTouchFunc(func(hex []byte, del bool) { - tp.Touch(nil, hex, del) + tp.Touch(hex, del) }) tds.generationCounts = make(map[uint64]int, 4096) tds.oldestGeneration = blockNr @@ -808,7 +808,7 @@ func (tds *TrieDbState) getStorageTrie(address common.Address, create bool) (*tr t = trie.New(account.Root, true) } t.SetTouchFunc(func(hex []byte, del bool) { - tds.tp.Touch(common.CopyBytes(address[:]), hex, del) + tds.tp.TouchContract(address, hex, del) }) tds.storageTries[address] = t } diff --git a/core/state/stateless.go b/core/state/stateless.go index 3b6f403fa2c..3d46294fe29 100644 --- a/core/state/stateless.go +++ b/core/state/stateless.go @@ -43,6 +43,7 @@ type Stateless struct { storageUpdates map[common.Address]map[common.Hash][]byte accountUpdates map[common.Hash]*Account deleted map[common.Hash]struct{} + tp *trie.TriePruning } func NewStateless(stateRoot common.Hash, @@ -52,10 +53,18 @@ func NewStateless(stateRoot common.Hash, ) (*Stateless, error) { h := newHasher() defer returnHasherToPool(h) + tp, err := trie.NewTriePruning(blockNr) + if err != nil { + return nil, err + } if trace { fmt.Printf("ACCOUNT TRIE ==============================================\n") } - t, _, _, _, _ := trie.NewFromProofs(blockNr, false, blockProof.Masks, blockProof.ShortKeys, blockProof.Values, blockProof.Hashes, trace) + touchFunc := func(hex []byte, del bool) { + tp.Touch(hex, del) + } + t, _, _, _, _ := trie.NewFromProofs(touchFunc, blockNr, false, blockProof.Masks, blockProof.ShortKeys, blockProof.Values, blockProof.Hashes, trace) + t.SetTouchFunc(touchFunc) if stateRoot != t.Hash() { filename := fmt.Sprintf("root_%d.txt", blockNr) f, err := os.Create(filename) @@ -71,8 +80,13 @@ func NewStateless(stateRoot common.Hash, if trace { fmt.Printf("TRIE %x ==============================================\n", contract) } - st, mIdx, hIdx, sIdx, vIdx := trie.NewFromProofs(blockNr, true, + contractCopy := contract + touchFunc := func(hex []byte, del bool) { + tp.TouchContract(contractCopy, hex, del) + } + st, mIdx, hIdx, sIdx, vIdx := trie.NewFromProofs(touchFunc, blockNr, true, blockProof.CMasks[maskIdx:], blockProof.CShortKeys[shortIdx:], blockProof.CValues[valueIdx:], blockProof.CHashes[hashIdx:], trace) + st.SetTouchFunc(touchFunc) h.sha.Reset() h.sha.Write(contract[:]) var addrHash common.Hash @@ -133,10 +147,14 @@ func NewStateless(stateRoot common.Hash, storageUpdates: make(map[common.Address]map[common.Hash][]byte), accountUpdates: make(map[common.Hash]*Account), deleted: make(map[common.Hash]struct{}), + tp: tp, }, nil } func (s *Stateless) ThinProof(blockProof trie.BlockProof, blockNr uint64, cuttime uint64, trace bool) trie.BlockProof { + if blockNr != s.tp.BlockNr() { + panic(fmt.Sprintf("blockNr %d != s.tp.BlockNr() %d", blockNr, s.tp.BlockNr())) + } h := newHasher() defer returnHasherToPool(h) if trace { @@ -146,7 +164,10 @@ func (s *Stateless) ThinProof(blockProof trie.BlockProof, blockNr uint64, cuttim var aShortKeys, acShortKeys [][]byte var aValues, acValues [][]byte var aHashes, acHashes []common.Hash - _, _, _, _, aMasks, aShortKeys, aValues, aHashes = s.t.AmmendProofs(cuttime, blockProof.Masks, blockProof.ShortKeys, blockProof.Values, blockProof.Hashes, + timeFunc := func(hex []byte) uint64 { + return s.tp.Timestamp(hex) + } + _, _, _, _, aMasks, aShortKeys, aValues, aHashes = s.t.AmmendProofs(timeFunc, cuttime, blockProof.Masks, blockProof.ShortKeys, blockProof.Values, blockProof.Hashes, aMasks, aShortKeys, aValues, aHashes, trace) var maskIdx, hashIdx, shortIdx, valueIdx int aContracts := []common.Address{} @@ -162,7 +183,8 @@ func (s *Stateless) ThinProof(blockProof trie.BlockProof, blockNr uint64, cuttim var ok bool var mIdx, hIdx, sIdx, vIdx int if st, ok = s.storageTries[contract]; !ok { - _, mIdx, hIdx, sIdx, vIdx = trie.NewFromProofs(blockNr, true, + touchFunc := func(hex []byte, del bool) {} + _, mIdx, hIdx, sIdx, vIdx = trie.NewFromProofs(touchFunc, blockNr, true, blockProof.CMasks[maskIdx:], blockProof.CShortKeys[shortIdx:], blockProof.CValues[valueIdx:], blockProof.CHashes[hashIdx:], trace) if mIdx > 0 { acMasks = append(acMasks, blockProof.CMasks[maskIdx:maskIdx+mIdx]...) @@ -178,7 +200,11 @@ func (s *Stateless) ThinProof(blockProof trie.BlockProof, blockNr uint64, cuttim } aContracts = append(aContracts, contract) } else { - mIdx, hIdx, sIdx, vIdx, acMasks, acShortKeys, acValues, acHashes = st.AmmendProofs(cuttime, + contractCopy := contract + timeFunc := func(hex []byte) uint64 { + return s.tp.TimestampContract(contractCopy, hex) + } + mIdx, hIdx, sIdx, vIdx, acMasks, acShortKeys, acValues, acHashes = st.AmmendProofs(timeFunc, cuttime, blockProof.CMasks[maskIdx:], blockProof.CShortKeys[shortIdx:], blockProof.CValues[valueIdx:], blockProof.CHashes[hashIdx:], acMasks, acShortKeys, acValues, acHashes, trace) @@ -234,6 +260,9 @@ func (s *Stateless) ApplyProof(stateRoot common.Hash, blockProof trie.BlockProof blockNr uint64, trace bool, ) error { + if blockNr != s.tp.BlockNr() { + panic(fmt.Sprintf("blockNr %d != s.tp.BlockNr() %d", blockNr, s.tp.BlockNr())) + } h := newHasher() defer returnHasherToPool(h) if trace { @@ -264,8 +293,13 @@ func (s *Stateless) ApplyProof(stateRoot common.Hash, blockProof trie.BlockProof var ok bool var mIdx, hIdx, sIdx, vIdx int if st, ok = s.storageTries[contract]; !ok { - st, mIdx, hIdx, sIdx, vIdx = trie.NewFromProofs(blockNr, true, + contractCopy := contract + touchFunc := func(hex []byte, del bool) { + s.tp.TouchContract(contractCopy, hex, del) + } + st, mIdx, hIdx, sIdx, vIdx = trie.NewFromProofs(touchFunc, blockNr, true, blockProof.CMasks[maskIdx:], blockProof.CShortKeys[shortIdx:], blockProof.CValues[valueIdx:], blockProof.CHashes[hashIdx:], trace) + st.SetTouchFunc(touchFunc) s.storageTries[contract] = st } else { mIdx, hIdx, sIdx, vIdx = st.ApplyProof(blockNr, blockProof.CMasks[maskIdx:], blockProof.CShortKeys[shortIdx:], blockProof.CValues[valueIdx:], blockProof.CHashes[hashIdx:], trace) @@ -304,6 +338,7 @@ func (s *Stateless) ApplyProof(stateRoot common.Hash, blockProof trie.BlockProof func (s *Stateless) SetBlockNr(blockNr uint64) { s.blockNr = blockNr + s.tp.SetBlockNr(blockNr) } func (s *Stateless) ReadAccountData(address common.Address) (*Account, error) { @@ -324,6 +359,9 @@ func (s *Stateless) getStorageTrie(address common.Address, create bool) (*trie.T t, ok := s.storageTries[address] if !ok && create { t = trie.New(common.Hash{}, true) + t.SetTouchFunc(func(hex []byte, del bool) { + s.tp.TouchContract(address, hex, del) + }) s.storageTries[address] = t } return t, nil @@ -525,12 +563,14 @@ func (s *Stateless) WriteAccountStorage(address common.Address, key, original, v } func (s *Stateless) Prune(oldest uint64, trace bool) { - s.t.UnloadOlderThan(oldest, trace) - for addrHash, st := range s.storageTries { - empty := st.UnloadOlderThan(oldest, trace) - if empty { - delete(s.storageTries, addrHash) - } + emptyAddresses, err := s.tp.PruneToTimestamp(s.t, oldest, func(contract common.Address) (*trie.Trie, error) { + return s.getStorageTrie(contract, false) + }) + if err != nil { + fmt.Printf("Error while pruning: %v\n", err) + } + for _, address := range emptyAddresses { + delete(s.storageTries, address) } if m, ok := s.timeToCodeHash[oldest-1]; ok { for codeHash, _ := range m { diff --git a/trie/proof_generator.go b/trie/proof_generator.go index 409f938a6cd..48e804cd261 100644 --- a/trie/proof_generator.go +++ b/trie/proof_generator.go @@ -42,22 +42,22 @@ type BlockProof struct { } type ProofGenerator struct { - proofMasks map[string]uint32 - sMasks map[string]map[string]uint32 - proofHashes map[string][16]common.Hash - sHashes map[string]map[string][16]common.Hash - soleHashes map[string]common.Hash - sSoleHashes map[string]map[string]common.Hash - createdProofs map[string]struct{} - sCreatedProofs map[string]map[string]struct{} - proofShorts map[string][]byte - sShorts map[string]map[string][]byte - createdShorts map[string]struct{} - sCreatedShorts map[string]map[string]struct{} - proofValues map[string][]byte - sValues map[string]map[string][]byte - proofCodes map[common.Hash][]byte - createdCodes map[common.Hash][]byte + proofMasks map[string]uint32 + sMasks map[string]map[string]uint32 + proofHashes map[string][16]common.Hash + sHashes map[string]map[string][16]common.Hash + soleHashes map[string]common.Hash + sSoleHashes map[string]map[string]common.Hash + createdProofs map[string]struct{} + sCreatedProofs map[string]map[string]struct{} + proofShorts map[string][]byte + sShorts map[string]map[string][]byte + createdShorts map[string]struct{} + sCreatedShorts map[string]map[string]struct{} + proofValues map[string][]byte + sValues map[string]map[string][]byte + proofCodes map[common.Hash][]byte + createdCodes map[common.Hash][]byte } func NewProofGenerator() *ProofGenerator { @@ -495,7 +495,8 @@ func (pg *ProofGenerator) CreateCode(codeHash common.Hash, code []byte) { } } -func constructFullNode(ctime uint64, pos int, +func constructFullNode(touchFunc func(hex []byte, del bool), ctime uint64, + hex []byte, masks []uint16, shortKeys [][]byte, values [][]byte, @@ -503,6 +504,7 @@ func constructFullNode(ctime uint64, pos int, maskIdx, shortIdx, valueIdx, hashIdx *int, trace bool, ) *fullNode { + pos := len(hex) hashmask := masks[*maskIdx] (*maskIdx)++ fullnodemask := masks[*maskIdx] @@ -516,6 +518,7 @@ func constructFullNode(ctime uint64, pos int, // Make a full node f := &fullNode{} f.flags.dirty = true + touchFunc(hex, false) f.flags.t = ctime for nibble := byte(0); nibble < 16; nibble++ { if (hashmask & (uint16(1) << nibble)) != 0 { @@ -540,19 +543,20 @@ func constructFullNode(ctime uint64, pos int, if trace { fmt.Printf("%sIn the loop at pos: %d, hashes: %16b, fullnodes: %16b, shortnodes: %16b, nibble %x\n", strings.Repeat(" ", pos), pos, hashmask, fullnodemask, shortnodemask, nibble) } - f.Children[nibble] = constructFullNode(ctime, pos+1, masks, shortKeys, values, hashes, maskIdx, shortIdx, valueIdx, hashIdx, trace) + f.Children[nibble] = constructFullNode(touchFunc, ctime, concat(hex, nibble), masks, shortKeys, values, hashes, maskIdx, shortIdx, valueIdx, hashIdx, trace) } else if (shortnodemask & (uint16(1) << nibble)) != 0 { if trace { fmt.Printf("%sIn the loop at pos: %d, hashes: %16b, fullnodes: %16b, shortnodes: %16b, nibble %x\n", strings.Repeat(" ", pos), pos, hashmask, fullnodemask, shortnodemask, nibble) } - f.Children[nibble] = constructShortNode(ctime, pos+1, masks, shortKeys, values, hashes, maskIdx, shortIdx, valueIdx, hashIdx, trace) + f.Children[nibble] = constructShortNode(touchFunc, ctime, concat(hex, nibble), masks, shortKeys, values, hashes, maskIdx, shortIdx, valueIdx, hashIdx, trace) } } f.adjustTod(ctime) return f } -func constructShortNode(ctime uint64, pos int, +func constructShortNode(touchFunc func(hex []byte, del bool), ctime uint64, + hex []byte, masks []uint16, shortKeys [][]byte, values [][]byte, @@ -560,6 +564,7 @@ func constructShortNode(ctime uint64, pos int, maskIdx, shortIdx, valueIdx, hashIdx *int, trace bool, ) *shortNode { + pos := len(hex) downmask := masks[*maskIdx] (*maskIdx)++ if trace { @@ -589,7 +594,7 @@ func constructShortNode(ctime uint64, pos int, s.Val = hashNode(hash[:]) (*hashIdx)++ } else if downmask == 1 || downmask == 6 { - s.Val = constructFullNode(ctime, pos+len(nKey), masks, shortKeys, values, hashes, maskIdx, shortIdx, valueIdx, hashIdx, trace) + s.Val = constructFullNode(touchFunc, ctime, concat(hex, nKey...), masks, shortKeys, values, hashes, maskIdx, shortIdx, valueIdx, hashIdx, trace) } } if s.Val == nil { @@ -599,7 +604,7 @@ func constructShortNode(ctime uint64, pos int, return s } -func NewFromProofs(ctime uint64, +func NewFromProofs(touchFunc func(hex []byte, del bool), ctime uint64, encodeToBytes bool, masks []uint16, shortKeys [][]byte, @@ -622,15 +627,15 @@ func NewFromProofs(ctime uint64, firstMask := masks[0] maskIdx = 1 if firstMask == 0 { - t.root = constructFullNode(ctime, 0, masks, shortKeys, values, hashes, &maskIdx, &shortIdx, &valueIdx, &hashIdx, trace) + t.root = constructFullNode(touchFunc, ctime, []byte{}, masks, shortKeys, values, hashes, &maskIdx, &shortIdx, &valueIdx, &hashIdx, trace) } else { - t.root = constructShortNode(ctime, 0, masks, shortKeys, values, hashes, &maskIdx, &shortIdx, &valueIdx, &hashIdx, trace) + t.root = constructShortNode(touchFunc, ctime, []byte{}, masks, shortKeys, values, hashes, &maskIdx, &shortIdx, &valueIdx, &hashIdx, trace) } return t, maskIdx, hashIdx, shortIdx, valueIdx } -func ammendFullNode(cuttime uint64, n node, - pos int, +func ammendFullNode(timeFunc func(hex []byte) uint64, cuttime uint64, n node, + hex []byte, masks []uint16, shortKeys [][]byte, values [][]byte, @@ -642,6 +647,7 @@ func ammendFullNode(cuttime uint64, n node, aHashes []common.Hash, trace bool, ) ([]uint16, [][]byte, [][]byte, []common.Hash) { + pos := len(hex) hashmask := masks[*maskIdx] (*maskIdx)++ fullnodemask := masks[*maskIdx] @@ -669,9 +675,9 @@ func ammendFullNode(cuttime uint64, n node, } } if ok && trace { - fmt.Printf("%sf.flags.t %d, cuttime %d\n", strings.Repeat(" ", pos), f.flags.t, cuttime) + fmt.Printf("%sf.flags.t %d, cuttime %d\n", strings.Repeat(" ", pos), timeFunc(hex), cuttime) } - if ok && f.flags.t < cuttime { + if ok && timeFunc(hex) < cuttime { f = nil ok = false } @@ -705,7 +711,7 @@ func ammendFullNode(cuttime uint64, n node, fmt.Printf("%sIn the loop at pos: %d, hashes: %16b, fullnodes: %16b, shortnodes: %16b, nibble %x, fchild %T\n", strings.Repeat(" ", pos), pos, hashmask, fullnodemask, shortnodemask, nibble, child) } - aMasks, aShortKeys, aValues, aHashes = ammendFullNode(cuttime, child, pos+1, masks, shortKeys, values, hashes, + aMasks, aShortKeys, aValues, aHashes = ammendFullNode(timeFunc, cuttime, child, concat(hex, nibble), masks, shortKeys, values, hashes, maskIdx, shortIdx, valueIdx, hashIdx, aMasks, aShortKeys, aValues, aHashes, trace) aFullnodemask |= (uint16(1) << nibble) @@ -714,7 +720,7 @@ func ammendFullNode(cuttime uint64, n node, fmt.Printf("%sIn the loop at pos: %d, hashes: %16b, fullnodes: %16b, shortnodes: %16b, nibble %x, schild %T\n", strings.Repeat(" ", pos), pos, hashmask, fullnodemask, shortnodemask, nibble, child) } - aMasks, aShortKeys, aValues, aHashes = ammendShortNode(cuttime, child, pos+1, masks, shortKeys, values, hashes, + aMasks, aShortKeys, aValues, aHashes = ammendShortNode(timeFunc, cuttime, child, concat(hex, nibble), masks, shortKeys, values, hashes, maskIdx, shortIdx, valueIdx, hashIdx, aMasks, aShortKeys, aValues, aHashes, trace) aShortnodemask |= (uint16(1) << nibble) @@ -727,8 +733,8 @@ func ammendFullNode(cuttime uint64, n node, return aMasks, aShortKeys, aValues, aHashes } -func ammendShortNode(cuttime uint64, n node, - pos int, +func ammendShortNode(timeFunc func(hex []byte) uint64, cuttime uint64, n node, + hex []byte, masks []uint16, shortKeys [][]byte, values [][]byte, @@ -740,6 +746,7 @@ func ammendShortNode(cuttime uint64, n node, aHashes []common.Hash, trace bool, ) ([]uint16, [][]byte, [][]byte, []common.Hash) { + pos := len(hex) downmask := masks[*maskIdx] (*maskIdx)++ // short node (leaf or extension) @@ -749,13 +756,6 @@ func ammendShortNode(cuttime uint64, n node, fmt.Printf("%spos: %d, down: %16b, nKey %x", strings.Repeat(" ", pos), pos, downmask, nKey) } s, ok := n.(*shortNode) - if ok && trace { - fmt.Printf("%ss.flags.t %d, cuttime %d\n", strings.Repeat(" ", pos), s.flags.t, cuttime) - } - if ok && s.flags.t < cuttime { - s = nil - ok = false - } if trace { fmt.Printf("\n") } @@ -795,8 +795,8 @@ func ammendShortNode(cuttime uint64, n node, val = s.Val aMasks = append(aMasks, 7) } - aMasks, aShortKeys, aValues, aHashes = ammendFullNode(cuttime, - val, pos+len(nKey), masks, shortKeys, values, hashes, + aMasks, aShortKeys, aValues, aHashes = ammendFullNode(timeFunc, cuttime, + val, concat(hex, nKey...), masks, shortKeys, values, hashes, maskIdx, shortIdx, valueIdx, hashIdx, aMasks, aShortKeys, aValues, aHashes, trace) @@ -806,6 +806,7 @@ func ammendShortNode(cuttime uint64, n node, } func (t *Trie) AmmendProofs( + timeFunc func(hex []byte) uint64, cuttime uint64, masks []uint16, shortKeys [][]byte, @@ -825,19 +826,19 @@ func (t *Trie) AmmendProofs( maskIdx = 1 aMasks = append(aMasks, firstMask) if firstMask == 0 { - aMasks_, aShortKeys_, aValues_, aHashes_ = ammendFullNode(cuttime, t.root, 0, masks, shortKeys, values, hashes, + aMasks_, aShortKeys_, aValues_, aHashes_ = ammendFullNode(timeFunc, cuttime, t.root, []byte{}, masks, shortKeys, values, hashes, &maskIdx, &shortIdx, &valueIdx, &hashIdx, aMasks, aShortKeys, aValues, aHashes, trace) } else { - aMasks_, aShortKeys_, aValues_, aHashes_ = ammendShortNode(cuttime, t.root, 0, masks, shortKeys, values, hashes, + aMasks_, aShortKeys_, aValues_, aHashes_ = ammendShortNode(timeFunc, cuttime, t.root, []byte{}, masks, shortKeys, values, hashes, &maskIdx, &shortIdx, &valueIdx, &hashIdx, aMasks, aShortKeys, aValues, aHashes, trace) } return maskIdx, hashIdx, shortIdx, valueIdx, aMasks_, aShortKeys_, aValues_, aHashes_ } -func applyFullNode(h *hasher, ctime uint64, n node, - pos int, +func applyFullNode(h *hasher, touchFunc func(hex []byte, del bool), ctime uint64, n node, + hex []byte, masks []uint16, shortKeys [][]byte, values [][]byte, @@ -845,6 +846,7 @@ func applyFullNode(h *hasher, ctime uint64, n node, maskIdx, shortIdx, valueIdx, hashIdx *int, trace bool, ) *fullNode { + pos := len(hex) hashmask := masks[*maskIdx] (*maskIdx)++ fullnodemask := masks[*maskIdx] @@ -867,6 +869,7 @@ func applyFullNode(h *hasher, ctime uint64, n node, f.flags.dirty = true } } + touchFunc(hex, false) f.flags.t = ctime for nibble := byte(0); nibble < 16; nibble++ { if (hashmask & (uint16(1) << nibble)) != 0 { @@ -900,7 +903,7 @@ func applyFullNode(h *hasher, ctime uint64, n node, fmt.Printf("%sIn the loop at pos: %d, hashes: %16b, fullnodes: %16b, shortnodes: %16b, nibble %x, child %T\n", strings.Repeat(" ", pos), pos, hashmask, fullnodemask, shortnodemask, nibble, child) } - fn := applyFullNode(h, ctime, child, pos+1, masks, shortKeys, values, hashes, + fn := applyFullNode(h, touchFunc, ctime, child, concat(hex, nibble), masks, shortKeys, values, hashes, maskIdx, shortIdx, valueIdx, hashIdx, trace) f.Children[nibble] = fn } else if (shortnodemask & (uint16(1) << nibble)) != 0 { @@ -908,7 +911,7 @@ func applyFullNode(h *hasher, ctime uint64, n node, fmt.Printf("%sIn the loop at pos: %d, hashes: %16b, fullnodes: %16b, shortnodes: %16b, nibble %x, child %T\n", strings.Repeat(" ", pos), pos, hashmask, fullnodemask, shortnodemask, nibble, child) } - sn := applyShortNode(h, ctime, child, pos+1, masks, shortKeys, values, hashes, + sn := applyShortNode(h, touchFunc, ctime, child, concat(hex, nibble), masks, shortKeys, values, hashes, maskIdx, shortIdx, valueIdx, hashIdx, trace) f.Children[nibble] = sn } @@ -921,8 +924,8 @@ func applyFullNode(h *hasher, ctime uint64, n node, return f } -func applyShortNode(h *hasher, ctime uint64, n node, - pos int, +func applyShortNode(h *hasher, touchFunc func(hex []byte, del bool), ctime uint64, n node, + hex []byte, masks []uint16, shortKeys [][]byte, values [][]byte, @@ -930,6 +933,7 @@ func applyShortNode(h *hasher, ctime uint64, n node, maskIdx, shortIdx, valueIdx, hashIdx *int, trace bool, ) *shortNode { + pos := len(hex) downmask := masks[*maskIdx] (*maskIdx)++ // short node (leaf or extension) @@ -946,7 +950,6 @@ func applyShortNode(h *hasher, ctime uint64, n node, s = &shortNode{Key: hexToCompact(nKey)} s.flags.dirty = true } - s.flags.t = ctime if trace { fmt.Printf("%spos: %d, down: %16b, nKey: %x", strings.Repeat(" ", pos), pos, downmask, nKey) } @@ -956,6 +959,7 @@ func applyShortNode(h *hasher, ctime uint64, n node, fmt.Printf("%skeep existing short node %x\n", strings.Repeat(" ", pos), compactToHex(s.Key)) } } + s.flags.t = ctime switch downmask { case 0: if pos+len(nKey) == 65 { @@ -973,7 +977,7 @@ func applyShortNode(h *hasher, ctime uint64, n node, (*valueIdx)++ s.Val = valueNode(value) } else { - s.Val = applyFullNode(h, ctime, s.Val, pos+len(nKey), masks, shortKeys, values, hashes, + s.Val = applyFullNode(h, touchFunc, ctime, s.Val, concat(hex, nKey...), masks, shortKeys, values, hashes, maskIdx, shortIdx, valueIdx, hashIdx, trace) } case 2: @@ -993,10 +997,10 @@ func applyShortNode(h *hasher, ctime uint64, n node, fmt.Printf("%spos = %d, len(nKey) = %d, nKey = %x\n", strings.Repeat(" ", pos), pos, len(nKey), nKey) } case 6: - s.Val = applyFullNode(h, ctime, nil, pos+len(nKey), masks, shortKeys, values, hashes, + s.Val = applyFullNode(h, touchFunc, ctime, nil, concat(hex, nKey...), masks, shortKeys, values, hashes, maskIdx, shortIdx, valueIdx, hashIdx, trace) case 7: - s.Val = applyFullNode(h, ctime, s.Val, pos+len(nKey), masks, shortKeys, values, hashes, + s.Val = applyFullNode(h, touchFunc, ctime, s.Val, concat(hex, nKey...), masks, shortKeys, values, hashes, maskIdx, shortIdx, valueIdx, hashIdx, trace) } s.adjustTod(ctime) @@ -1027,10 +1031,10 @@ func (t *Trie) ApplyProof( h := newHasher(t.encodeToBytes) defer returnHasherToPool(h) if firstMask == 0 { - t.root = applyFullNode(h, ctime, t.root, 0, masks, shortKeys, values, hashes, + t.root = applyFullNode(h, t.touchFunc, ctime, t.root, []byte{}, masks, shortKeys, values, hashes, &maskIdx, &shortIdx, &valueIdx, &hashIdx, trace) } else { - t.root = applyShortNode(h, ctime, t.root, 0, masks, shortKeys, values, hashes, + t.root = applyShortNode(h, t.touchFunc, ctime, t.root, []byte{}, masks, shortKeys, values, hashes, &maskIdx, &shortIdx, &valueIdx, &hashIdx, trace) } return maskIdx, hashIdx, shortIdx, valueIdx diff --git a/trie/trie.go b/trie/trie.go index a302a6e757d..aa39a75bb58 100644 --- a/trie/trie.go +++ b/trie/trie.go @@ -839,10 +839,7 @@ func (t *Trie) prepareToRemove(n node, hex []byte) { case *shortNode: var hexVal []byte if _, ok := n.Val.(valueNode); !ok { // Don't need to compute prefix for a leaf - nKey := compactToHex(n.Key) - hexVal = make([]byte, len(hex)+len(nKey)) - copy(hexVal, hex) - copy(hexVal[len(hex):], nKey) + hexVal = concat(hex, compactToHex(n.Key)...) } t.leftGeneration(n.flags.t) t.prepareToRemove(n.Val, hexVal) @@ -863,10 +860,7 @@ func (t *Trie) prepareToRemove(n node, hex []byte) { t.leftGeneration(n.flags.t) for i, child := range n.Children { if child != nil { - hexChild := make([]byte, len(hex)+1) - copy(hexChild, hex) - hexChild[len(hex)] = byte(i) - t.prepareToRemove(child, hexChild) + t.prepareToRemove(child, concat(hex, byte(i))) } } } @@ -1033,10 +1027,7 @@ func (t *Trie) countPrunableNodes(nd node, hex []byte, print bool) int { case *shortNode: var hexVal []byte if _, ok := n.Val.(valueNode); !ok { // Don't need to compute prefix for a leaf - nKey := compactToHex(n.Key) - hexVal = make([]byte, len(hex)+len(nKey)) - copy(hexVal, hex) - copy(hexVal[len(hex):], nKey) + hexVal = concat(hex, compactToHex(n.Key)...) } return t.countPrunableNodes(n.Val, hexVal, print) case *duoNode: @@ -1058,10 +1049,7 @@ func (t *Trie) countPrunableNodes(nd node, hex []byte, print bool) int { count := 0 for i, child := range n.Children { if child != nil { - hexChild := make([]byte, len(hex)+1) - copy(hexChild, hex) - hexChild[len(hex)] = byte(i) - count += t.countPrunableNodes(child, hexChild, print) + count += t.countPrunableNodes(child, concat(hex, byte(i)), print) } } return 1 + count diff --git a/trie/trie_pruning.go b/trie/trie_pruning.go index 95fed96098e..96eeb5f0dc0 100644 --- a/trie/trie_pruning.go +++ b/trie/trie_pruning.go @@ -19,20 +19,15 @@ package trie import ( - "bytes" - "encoding/binary" - "fmt" "sort" "strings" - "github.com/ledgerwatch/bolt" "github.com/ledgerwatch/turbo-geth/common" ) type TriePruning struct { - // boltDB is used here for its B+tree implementation with prefix-compression. - // It maps prefixes to their corresponding timestamps (uint64) - timestamps *bolt.DB + storageTimestamps map[common.Address]map[string]uint64 + accountTimestamps map[string]uint64 // Maps timestamp (uint64) to address of the contract to set of prefixes of nodes (string) storage map[uint64]map[common.Address]map[string]struct{} @@ -54,30 +49,14 @@ type TriePruning struct { } func NewTriePruning(oldestGeneration uint64) (*TriePruning, error) { - db, err := bolt.Open("in-memory", 0600, &bolt.Options{MemOnly: true}) - if err != nil { - return nil, err - } - // Pre-create the bucket so we can assume it is there - if err := db.Update(func(tx *bolt.Tx) error { - if _, err := tx.CreateBucket(abucket, false); err != nil { - return err - } - if _, err := tx.CreateBucket(sbucket, false); err != nil { - return err - } - return nil - }); err != nil { - db.Close() - return nil, err - } return &TriePruning{ - oldestGeneration: oldestGeneration, - blockNr: oldestGeneration, - timestamps: db, - storage: make(map[uint64]map[common.Address]map[string]struct{}), - accounts: make(map[uint64]map[string]struct{}), - generationCounts: make(map[uint64]int), + oldestGeneration: oldestGeneration, + blockNr: oldestGeneration, + storageTimestamps: make(map[common.Address]map[string]uint64), + accountTimestamps: make(map[string]uint64), + storage: make(map[uint64]map[common.Address]map[string]struct{}), + accounts: make(map[uint64]map[string]struct{}), + generationCounts: make(map[uint64]int), }, nil } @@ -85,8 +64,9 @@ func (tp *TriePruning) SetBlockNr(blockNr uint64) { tp.blockNr = blockNr } -var abucket = []byte("a") -var sbucket = []byte("s") +func (tp *TriePruning) BlockNr() uint64 { + return tp.blockNr +} // Updates a node to the current timestamp // contract is effectively address of the smart contract @@ -94,57 +74,35 @@ var sbucket = []byte("s") // parent is the node that needs to be modified to unload the touched node // exists is true when the node existed before, and false if it is a new one // prevTimestamp is the timestamp the node current has -func (tp *TriePruning) TouchFrom(contract []byte, hex []byte, exists bool, prevTimestamp uint64, del bool, newTimestamp uint64) { - //fmt.Printf("TouchFrom %x, exists: %t, prevTimestamp %d, del %t, newTimestamp %d\n", hex, exists, prevTimestamp, del, newTimestamp) +func (tp *TriePruning) touchContract(contract common.Address, hexS string, exists bool, prevTimestamp uint64, del bool, newTimestamp uint64) { if exists && !del && prevTimestamp == newTimestamp { return } if !del { - hexS := string(common.CopyBytes(hex)) var newMap map[string]struct{} - if contract == nil { - if m, ok := tp.accounts[newTimestamp]; ok { - newMap = m + if m, ok := tp.storage[newTimestamp]; ok { + if m1, ok1 := m[contract]; ok1 { + newMap = m1 } else { newMap = make(map[string]struct{}) - tp.accounts[newTimestamp] = newMap + m[contract] = newMap } } else { - contractAddress := common.BytesToAddress(contract) - if m, ok := tp.storage[newTimestamp]; ok { - if m1, ok1 := m[contractAddress]; ok1 { - newMap = m1 - } else { - newMap = make(map[string]struct{}) - m[contractAddress] = newMap - } - } else { - m = make(map[common.Address]map[string]struct{}) - newMap = make(map[string]struct{}) - m[contractAddress] = newMap - tp.storage[newTimestamp] = m - } + m = make(map[common.Address]map[string]struct{}) + newMap = make(map[string]struct{}) + m[contract] = newMap + tp.storage[newTimestamp] = m } newMap[hexS] = struct{}{} } if exists { - if contract == nil { - if m, ok := tp.accounts[prevTimestamp]; ok { - delete(m, string(hex)) - if len(m) == 0 { - delete(tp.accounts, prevTimestamp) - } - } - } else { - contractAddress := common.BytesToAddress(contract) - if m, ok := tp.storage[prevTimestamp]; ok { - if m1, ok1 := m[contractAddress]; ok1 { - delete(m1, string(hex)) - if len(m1) == 0 { - delete(m, contractAddress) - if len(m) == 0 { - delete(tp.storage, prevTimestamp) - } + if m, ok := tp.storage[prevTimestamp]; ok { + if m1, ok1 := m[contract]; ok1 { + delete(m1, hexS) + if len(m1) == 0 { + delete(m, contract) + if len(m) == 0 { + delete(tp.storage, prevTimestamp) } } } @@ -168,59 +126,102 @@ func (tp *TriePruning) TouchFrom(contract []byte, hex []byte, exists bool, prevT // contract is effectively address of the smart contract // hex is the prefix of the key // parent is the node that needs to be modified to unload the touched node -func (tp *TriePruning) Touch(contract []byte, hex []byte, del bool) error { - var exists = false - var timestampInput [8]byte - var timestampOutput [8]byte - // Now it is the current timestamp, but after the transaction, it will be replaced - // by the previously existing (if it existed) - binary.BigEndian.PutUint64(timestampInput[:], tp.blockNr) - var cKey []byte - var bucket []byte - if contract == nil { - cKey = make([]byte, len(hex)+1) - cKey[0] = 0xff - copy(cKey[1:], hex) - bucket = abucket - } else { - cKey = make([]byte, len(contract)+len(hex)) - copy(cKey, contract) - copy(cKey[len(contract):], hex) - bucket = sbucket +// exists is true when the node existed before, and false if it is a new one +// prevTimestamp is the timestamp the node current has +func (tp *TriePruning) touch(hexS string, exists bool, prevTimestamp uint64, del bool, newTimestamp uint64) { + //fmt.Printf("TouchFrom %x, exists: %t, prevTimestamp %d, del %t, newTimestamp %d\n", hex, exists, prevTimestamp, del, newTimestamp) + if exists && !del && prevTimestamp == newTimestamp { + return } - if err := tp.timestamps.Update(func(tx *bolt.Tx) error { - b := tx.Bucket(bucket) - if b == nil { - return fmt.Errorf("timestamp bucket %s did not exist", bucket) + if !del { + var newMap map[string]struct{} + if m, ok := tp.accounts[newTimestamp]; ok { + newMap = m + } else { + newMap = make(map[string]struct{}) + tp.accounts[newTimestamp] = newMap } - if v, _ := b.Get(cKey); v != nil { - if del { - if err := b.Delete(cKey); err != nil { - return err - } - } else if !bytes.Equal(v, timestampInput[:]) { - if err := b.Put(cKey, timestampInput[:]); err != nil { - return err - } + newMap[hexS] = struct{}{} + } + if exists { + if m, ok := tp.accounts[prevTimestamp]; ok { + delete(m, hexS) + if len(m) == 0 { + delete(tp.accounts, prevTimestamp) } - copy(timestampOutput[:], v) + } + } + // Update generation count + if !del { + tp.generationCounts[newTimestamp]++ + tp.nodeCount++ + } + if exists { + tp.generationCounts[prevTimestamp]-- + if tp.generationCounts[prevTimestamp] == 0 { + delete(tp.generationCounts, prevTimestamp) + } + tp.nodeCount-- + } +} + +func (tp *TriePruning) Timestamp(hex []byte) uint64 { + return tp.accountTimestamps[string(hex)] +} + +// Returns timestamp for the given prunable node +func (tp *TriePruning) TimestampContract(contract common.Address, hex []byte) uint64 { + if m, ok := tp.storageTimestamps[contract]; ok { + return m[string(hex)] + } + return 0 +} + +func (tp *TriePruning) TouchContract(contract common.Address, hex []byte, del bool) { + var exists = false + var prevTimestamp uint64 + hexS := string(common.CopyBytes(hex)) + if m, ok := tp.storageTimestamps[contract]; ok { + if m1, ok1 := m[hexS]; ok1 { + prevTimestamp = m1 exists = true - } else { - if !del { - if err := b.Put(cKey, timestampInput[:]); err != nil { - return err + if del { + delete(m, hexS) + if len(m) == 0 { + delete(tp.storageTimestamps, contract) } } } - return nil - }); err != nil { - return err + if !del { + m[hexS] = tp.blockNr + } + } else if !del { + m = make(map[string]uint64) + tp.storageTimestamps[contract] = m + m[hexS] = tp.blockNr } + tp.touchContract(contract, hexS, exists, prevTimestamp, del, tp.blockNr) +} + +// Updates a node to the current timestamp +// contract is effectively address of the smart contract +// hex is the prefix of the key +// parent is the node that needs to be modified to unload the touched node +func (tp *TriePruning) Touch(hex []byte, del bool) error { + var exists = false var prevTimestamp uint64 - if exists { - prevTimestamp = binary.BigEndian.Uint64(timestampOutput[:]) + hexS := string(common.CopyBytes(hex)) + if m, ok := tp.accountTimestamps[hexS]; ok { + prevTimestamp = m + exists = true + if del { + delete(tp.accountTimestamps, hexS) + } + } + if !del { + tp.accountTimestamps[hexS] = tp.blockNr } - tp.TouchFrom(contract, hex, exists, prevTimestamp, del, tp.blockNr) + tp.touch(hexS, exists, prevTimestamp, del, tp.blockNr) return nil } @@ -244,27 +245,18 @@ func pruneMap(t *Trie, m map[string]struct{}, h *hasher) bool { return empty } -func (tp *TriePruning) PruneTo( - t *Trie, - targetNodeCount int, +// Prunes all nodes that are older than given timestamp +func (tp *TriePruning) PruneToTimestamp( + mainTrie *Trie, + targetTimestamp uint64, storageTrieFunc func(contract common.Address) (*Trie, error), -) (bool, []common.Address, error) { - if tp.nodeCount <= targetNodeCount { - return false, nil, nil - } - excess := tp.nodeCount - targetNodeCount - prunable := 0 - pruneGeneration := tp.oldestGeneration - for prunable < excess { - prunable += tp.generationCounts[pruneGeneration] - delete(tp.generationCounts, pruneGeneration) - pruneGeneration++ - } - //fmt.Printf("Will prune to generation %d, nodes to prune: %d, excess %d\n", pruneGeneration, prunable, excess) +) ([]common.Address, error) { // Remove (unload) nodes from storage tries and account trie aggregateStorage := make(map[common.Address]map[string]struct{}) aggregateAccounts := make(map[string]struct{}) - for gen := tp.oldestGeneration; gen < pruneGeneration; gen++ { + for gen := tp.oldestGeneration; gen < targetTimestamp; gen++ { + tp.nodeCount -= tp.generationCounts[gen] + delete(tp.generationCounts, gen) if m, ok := tp.storage[gen]; ok { for address, m1 := range m { var aggregateM map[string]struct{} @@ -293,7 +285,7 @@ func (tp *TriePruning) PruneTo( for address, m := range aggregateStorage { storageTrie, err := storageTrieFunc(address) if err != nil { - return false, nil, err + return nil, err } empty := pruneMap(storageTrie, m, h) if empty { @@ -302,41 +294,47 @@ func (tp *TriePruning) PruneTo( } // Change hasher to be appropriate for the main trie h.encodeToBytes = false - pruneMap(t, aggregateAccounts, h) + pruneMap(mainTrie, aggregateAccounts, h) // Remove fom the timestamp structure - if err := tp.timestamps.Update(func(tx *bolt.Tx) error { - ab := tx.Bucket(abucket) - if ab == nil { - return fmt.Errorf("timestamp bucket %s did not exist", abucket) - } - for hexS := range aggregateAccounts { - cKey := make([]byte, 1+len(hexS)) - cKey[0] = 0xff - copy(cKey[1:], []byte(hexS)) - if err := ab.Delete(cKey); err != nil { - return err - } - } - sb := tx.Bucket(sbucket) - if sb == nil { - return fmt.Errorf("timestamp bucket %s did not exist", sbucket) - } - for address, m := range aggregateStorage { + for hexS := range aggregateAccounts { + delete(tp.accountTimestamps, hexS) + } + for address, m := range aggregateStorage { + if m1, ok := tp.storageTimestamps[address]; ok { for hexS := range m { - cKey := make([]byte, len(address)+len(hexS)) - copy(cKey, address[:]) - copy(cKey[len(address):], []byte(hexS)) - if err := sb.Delete(cKey); err != nil { - return err - } + delete(m1, hexS) + } + if len(m1) == 0 { + delete(tp.storageTimestamps, address) } } - return nil - }); err != nil { + } + tp.oldestGeneration = targetTimestamp + return emptyAddresses, nil +} + +// Prunes mininum number of generations necessary so that the total +// number of prunable nodes is at most `targetNodeCount` +func (tp *TriePruning) PruneTo( + mainTrie *Trie, + targetNodeCount int, + storageTrieFunc func(contract common.Address) (*Trie, error), +) (bool, []common.Address, error) { + if tp.nodeCount <= targetNodeCount { + return false, nil, nil + } + excess := tp.nodeCount - targetNodeCount + prunable := 0 + pruneGeneration := tp.oldestGeneration + for prunable < excess { + prunable += tp.generationCounts[pruneGeneration] + pruneGeneration++ + } + //fmt.Printf("Will prune to generation %d, nodes to prune: %d, excess %d\n", pruneGeneration, prunable, excess) + emptyAddresses, err := tp.PruneToTimestamp(mainTrie, pruneGeneration, storageTrieFunc) + if err != nil { return false, nil, err } - tp.oldestGeneration = pruneGeneration - tp.nodeCount -= prunable return true, emptyAddresses, nil } diff --git a/trie/trie_pruning_test.go b/trie/trie_pruning_test.go index e8e3c016ff1..8a5a61f8d1e 100644 --- a/trie/trie_pruning_test.go +++ b/trie/trie_pruning_test.go @@ -33,7 +33,7 @@ func TestOnePerTimestamp(t *testing.T) { } tr := New(common.Hash{}, false) tr.SetTouchFunc(func(hex []byte, del bool) { - tp.Touch(nil, hex, del) + tp.Touch(hex, del) }) var key [4]byte value := []byte("V") From c436d854336682e394d4471b30e66cea45670923 Mon Sep 17 00:00:00 2001 From: Alexey Akhunov Date: Tue, 11 Jun 2019 20:49:37 +0100 Subject: [PATCH 3/5] Clean up code that is now freed --- core/state/database.go | 44 +++------ core/state/repair.go | 33 +------ trie/node.go | 77 --------------- trie/proof_generator.go | 12 +-- trie/trie.go | 213 +--------------------------------------- 5 files changed, 21 insertions(+), 358 deletions(-) diff --git a/core/state/database.go b/core/state/database.go index 7c806d27e8c..208ba55904a 100644 --- a/core/state/database.go +++ b/core/state/database.go @@ -189,23 +189,20 @@ func (b *Buffer) merge(other *Buffer) { // Implements StateReader by wrapping a trie and a database, where trie acts as a cache for the database type TrieDbState struct { - t *trie.Trie - db ethdb.Database - blockNr uint64 - storageTries map[common.Address]*trie.Trie - buffers []*Buffer - aggregateBuffer *Buffer // Merge of all buffers - currentBuffer *Buffer - codeCache *lru.Cache - codeSizeCache *lru.Cache - historical bool - generationCounts map[uint64]int - nodeCount int - oldestGeneration uint64 - noHistory bool - resolveReads bool - pg *trie.ProofGenerator - tp *trie.TriePruning + t *trie.Trie + db ethdb.Database + blockNr uint64 + storageTries map[common.Address]*trie.Trie + buffers []*Buffer + aggregateBuffer *Buffer // Merge of all buffers + currentBuffer *Buffer + codeCache *lru.Cache + codeSizeCache *lru.Cache + historical bool + noHistory bool + resolveReads bool + pg *trie.ProofGenerator + tp *trie.TriePruning } func NewTrieDbState(root common.Hash, db ethdb.Database, blockNr uint64) (*TrieDbState, error) { @@ -235,8 +232,6 @@ func NewTrieDbState(root common.Hash, db ethdb.Database, blockNr uint64) (*TrieD t.SetTouchFunc(func(hex []byte, del bool) { tp.Touch(hex, del) }) - tds.generationCounts = make(map[uint64]int, 4096) - tds.oldestGeneration = blockNr return &tds, nil } @@ -721,17 +716,6 @@ func encodingToAccount(enc []byte) (*Account, error) { return &data, nil } -func (tds *TrieDbState) joinGeneration(gen uint64) { - tds.nodeCount++ - tds.generationCounts[gen]++ - -} - -func (tds *TrieDbState) leftGeneration(gen uint64) { - tds.nodeCount-- - tds.generationCounts[gen]-- -} - func (tds *TrieDbState) ReadAccountData(address common.Address) (*Account, error) { h := newHasher() defer returnHasherToPool(h) diff --git a/core/state/repair.go b/core/state/repair.go index d8e717e6bbb..ff302340741 100644 --- a/core/state/repair.go +++ b/core/state/repair.go @@ -235,7 +235,6 @@ func (rds *RepairDbState) getStorageTrie(address common.Address, create bool) (* } else { t = trie.New(account.Root, true) } - t.MakeListed(rds.joinGeneration, rds.leftGeneration) rds.storageTries[address] = t } return t, nil @@ -427,38 +426,8 @@ func (rds *RepairDbState) WriteAccountStorage(address common.Address, key, origi return nil } -func (rds *RepairDbState) joinGeneration(gen uint64) { - rds.nodeCount++ - rds.generationCounts[gen]++ - -} - -func (rds *RepairDbState) leftGeneration(gen uint64) { - rds.nodeCount-- - rds.generationCounts[gen]-- -} - func (rds *RepairDbState) PruneTries() { - if rds.nodeCount > int(MaxTrieCacheGen) { - toRemove := 0 - excess := rds.nodeCount - int(MaxTrieCacheGen) - gen := rds.oldestGeneration - for excess > 0 { - excess -= rds.generationCounts[gen] - toRemove += rds.generationCounts[gen] - delete(rds.generationCounts, gen) - gen++ - } - // Unload all nodes with touch timestamp < gen - for address, storageTrie := range rds.storageTries { - empty := storageTrie.UnloadOlderThan(gen, false) - if empty { - delete(rds.storageTries, address) - } - } - rds.oldestGeneration = gen - rds.nodeCount -= toRemove - } + // TODO Reintroduce pruning if necessary var m runtime.MemStats runtime.ReadMemStats(&m) fmt.Printf("Memory: nodes=%d, alloc=%d, sys=%d\n", rds.nodeCount, int(m.Alloc/1024), int(m.Sys/1024)) diff --git a/trie/node.go b/trie/node.go index 0a2c5f2ad38..875621bf4c7 100644 --- a/trie/node.go +++ b/trie/node.go @@ -31,7 +31,6 @@ type node interface { dirty() bool hash() []byte makedirty() - tod(def uint64) uint64 // Read Touch time of the Oldest Decendant } type ( @@ -161,8 +160,6 @@ func (n *fullNode) duoCopy() *duoNode { if !n.flags.dirty { copy(c.flags.hash[:], n.flags.hash[:]) } - c.flags.t = n.flags.t - c.flags.tod = n.flags.tod c.flags.dirty = n.flags.dirty return &c } @@ -175,8 +172,6 @@ func (n *duoNode) fullCopy() *fullNode { if !n.flags.dirty { copy(c.flags.hash[:], n.flags.hash[:]) } - c.flags.t = n.flags.t - c.flags.tod = n.flags.tod c.flags.dirty = n.flags.dirty return &c } @@ -254,8 +249,6 @@ func (n *shortNode) copy() *shortNode { // nodeFlag contains caching-related metadata about a node. type nodeFlag struct { - t uint64 // Touch time of the node - tod uint64 // Touch time of the Oldest Decendent hash common.Hash // cached hash of the node dirty bool // whether the hash field represent the true hash } @@ -298,73 +291,3 @@ func (n duoNode) String() string { return n.fstring("") } func (n shortNode) String() string { return n.fstring("") } func (n hashNode) String() string { return n.fstring("") } func (n valueNode) String() string { return n.fstring("") } - -func (n hashNode) tod(def uint64) uint64 { return def } -func (n valueNode) tod(def uint64) uint64 { return def } -func (n *fullNode) tod(def uint64) uint64 { return n.flags.tod } -func (n *duoNode) tod(def uint64) uint64 { return n.flags.tod } -func (n *shortNode) tod(def uint64) uint64 { return n.flags.tod } - -func (n *fullNode) updateT(t uint64, joinGeneration, leftGeneration func(uint64)) { - if n.flags.t != t { - leftGeneration(n.flags.t) - joinGeneration(t) - n.flags.t = t - } -} - -func (n *fullNode) adjustTod(def uint64) { - tod := def - for _, node := range &n.Children { - if node != nil { - nodeTod := node.tod(def) - if nodeTod < tod { - tod = nodeTod - } - } - } - n.flags.tod = tod -} - -func (n *duoNode) updateT(t uint64, joinGeneration, leftGeneration func(uint64)) { - if n.flags.t != t { - leftGeneration(n.flags.t) - joinGeneration(t) - n.flags.t = t - } -} - -func (n *duoNode) adjustTod(def uint64) { - tod := def - if n.child1 != nil { - nodeTod := n.child1.tod(def) - if nodeTod < tod { - tod = nodeTod - } - } - if n.child2 != nil { - nodeTod := n.child2.tod(def) - if nodeTod < tod { - tod = nodeTod - } - } - n.flags.tod = tod -} - -func (n *shortNode) updateT(t uint64, joinGeneration, leftGeneration func(uint64)) { - if n.flags.t != t { - leftGeneration(n.flags.t) - joinGeneration(t) - n.flags.t = t - } -} - -func (n *shortNode) adjustTod(def uint64) { - tod := def - nodeTod := n.Val.tod(def) - if nodeTod < tod { - tod = nodeTod - } - n.flags.tod = tod -} - diff --git a/trie/proof_generator.go b/trie/proof_generator.go index 48e804cd261..ffc94af0b0e 100644 --- a/trie/proof_generator.go +++ b/trie/proof_generator.go @@ -519,7 +519,6 @@ func constructFullNode(touchFunc func(hex []byte, del bool), ctime uint64, f := &fullNode{} f.flags.dirty = true touchFunc(hex, false) - f.flags.t = ctime for nibble := byte(0); nibble < 16; nibble++ { if (hashmask & (uint16(1) << nibble)) != 0 { hash := hashes[*hashIdx] @@ -551,7 +550,6 @@ func constructFullNode(touchFunc func(hex []byte, del bool), ctime uint64, f.Children[nibble] = constructShortNode(touchFunc, ctime, concat(hex, nibble), masks, shortKeys, values, hashes, maskIdx, shortIdx, valueIdx, hashIdx, trace) } } - f.adjustTod(ctime) return f } @@ -575,7 +573,6 @@ func constructShortNode(touchFunc func(hex []byte, del bool), ctime uint64, (*shortIdx)++ s := &shortNode{Key: hexToCompact(nKey)} s.flags.dirty = true - s.flags.t = ctime if trace { fmt.Printf("\n") } @@ -600,7 +597,6 @@ func constructShortNode(touchFunc func(hex []byte, del bool), ctime uint64, if s.Val == nil { fmt.Printf("s.Val is nil, pos %d, nKey %x, downmask %d\n", pos, nKey, downmask) } - s.adjustTod(ctime) return s } @@ -613,9 +609,7 @@ func NewFromProofs(touchFunc func(hex []byte, del bool), ctime uint64, trace bool, ) (t *Trie, mIdx, hIdx, sIdx, vIdx int) { t = &Trie{ - encodeToBytes: encodeToBytes, - joinGeneration: func(uint64) {}, - leftGeneration: func(uint64) {}, + encodeToBytes: encodeToBytes, } var maskIdx int var hashIdx int // index in the hashes @@ -870,7 +864,6 @@ func applyFullNode(h *hasher, touchFunc func(hex []byte, del bool), ctime uint64 } } touchFunc(hex, false) - f.flags.t = ctime for nibble := byte(0); nibble < 16; nibble++ { if (hashmask & (uint16(1) << nibble)) != 0 { hash := hashes[*hashIdx] @@ -916,7 +909,6 @@ func applyFullNode(h *hasher, touchFunc func(hex []byte, del bool), ctime uint64 f.Children[nibble] = sn } } - f.adjustTod(ctime) if f.flags.dirty { var hn common.Hash h.hash(f, pos == 0, hn[:]) @@ -959,7 +951,6 @@ func applyShortNode(h *hasher, touchFunc func(hex []byte, del bool), ctime uint6 fmt.Printf("%skeep existing short node %x\n", strings.Repeat(" ", pos), compactToHex(s.Key)) } } - s.flags.t = ctime switch downmask { case 0: if pos+len(nKey) == 65 { @@ -1003,7 +994,6 @@ func applyShortNode(h *hasher, touchFunc func(hex []byte, del bool), ctime uint6 s.Val = applyFullNode(h, touchFunc, ctime, s.Val, concat(hex, nKey...), masks, shortKeys, values, hashes, maskIdx, shortIdx, valueIdx, hashIdx, trace) } - s.adjustTod(ctime) if s.flags.dirty { var hn common.Hash h.hash(s, pos == 0, hn[:]) diff --git a/trie/trie.go b/trie/trie.go index aa39a75bb58..b193bb78ac5 100644 --- a/trie/trie.go +++ b/trie/trie.go @@ -48,9 +48,7 @@ type Trie struct { encodeToBytes bool - joinGeneration func(gen uint64) - leftGeneration func(gen uint64) - touchFunc func(hex []byte, del bool) + touchFunc func(hex []byte, del bool) } // New creates a trie with an existing root node from db. @@ -61,10 +59,8 @@ type Trie struct { // not exist in the database. Accessing the trie loads nodes from db on demand. func New(root common.Hash, encodeToBytes bool) *Trie { trie := &Trie{ - encodeToBytes: encodeToBytes, - joinGeneration: func(uint64) {}, - leftGeneration: func(uint64) {}, - touchFunc: func([]byte, bool) {}, + encodeToBytes: encodeToBytes, + touchFunc: func([]byte, bool) {}, } if (root != common.Hash{}) && root != emptyRoot { trie.root = hashNode(root[:]) @@ -72,11 +68,6 @@ func New(root common.Hash, encodeToBytes bool) *Trie { return trie } -func (t *Trie) MakeListed(joinGeneration, leftGeneration func(gen uint64)) { - t.joinGeneration = joinGeneration - t.leftGeneration = leftGeneration -} - func (t *Trie) SetTouchFunc(touchFunc func(hex []byte, del bool)) { t.touchFunc = touchFunc } @@ -94,53 +85,33 @@ func (t *Trie) get(origNode node, key []byte, pos int, blockNr uint64) (value [] case valueNode: return n, true case *shortNode: - n.updateT(blockNr, t.joinGeneration, t.leftGeneration) - var adjust bool nKey := compactToHex(n.Key) if len(key)-pos < len(nKey) || !bytes.Equal(nKey, key[pos:pos+len(nKey)]) { - adjust = false value, gotValue = nil, true } else { - adjust = true if v, ok := n.Val.(valueNode); ok { value, gotValue = v, true } else { value, gotValue = t.get(n.Val, key, pos+len(nKey), blockNr) } } - if adjust { - n.adjustTod(blockNr) - } return case *duoNode: t.touchFunc(key[:pos], false) - n.updateT(blockNr, t.joinGeneration, t.leftGeneration) - var adjust bool i1, i2 := n.childrenIdx() switch key[pos] { case i1: - adjust = n.tod(blockNr) == n.child1.tod(blockNr) value, gotValue = t.get(n.child1, key, pos+1, blockNr) case i2: - adjust = n.tod(blockNr) == n.child2.tod(blockNr) value, gotValue = t.get(n.child2, key, pos+1, blockNr) default: - adjust = false value, gotValue = nil, true } - if adjust { - n.adjustTod(blockNr) - } return case *fullNode: t.touchFunc(key[:pos], false) - n.updateT(blockNr, t.joinGeneration, t.leftGeneration) child := n.Children[key[pos]] - adjust := child != nil && n.tod(blockNr) == child.tod(blockNr) value, gotValue = t.get(child, key, pos+1, blockNr) - if adjust { - n.adjustTod(blockNr) - } return case hashNode: return nil, false @@ -160,9 +131,6 @@ func (t *Trie) Update(key, value []byte, blockNr uint64) { if t.root == nil { newnode := &shortNode{Key: hexToCompact(hex), Val: valueNode(value)} newnode.flags.dirty = true - newnode.flags.t = blockNr - newnode.adjustTod(blockNr) - t.joinGeneration(blockNr) t.root = newnode } else { _, t.root = t.insert(t.root, hex, 0, valueNode(value), blockNr) @@ -346,8 +314,6 @@ func (t *Trie) insert(origNode node, key []byte, pos int, value node, blockNr ui } switch n := origNode.(type) { case *shortNode: - t.leftGeneration(n.flags.t) - n.flags.t = blockNr nKey := compactToHex(n.Key) matchlen := prefixLen(key[pos:], nKey) // If the whole key matches, keep this short node as is @@ -359,8 +325,6 @@ func (t *Trie) insert(origNode node, key []byte, pos int, value node, blockNr ui n.flags.dirty = true } newNode = n - t.joinGeneration(blockNr) - n.adjustTod(blockNr) } else { // Otherwise branch out at the index where they differ. var c1 node @@ -369,10 +333,7 @@ func (t *Trie) insert(origNode node, key []byte, pos int, value node, blockNr ui } else { s1 := &shortNode{Key: hexToCompact(nKey[matchlen+1:]), Val: n.Val} s1.flags.dirty = true - s1.flags.t = blockNr - s1.adjustTod(blockNr) c1 = s1 - t.joinGeneration(blockNr) } var c2 node if len(key) == pos+matchlen+1 { @@ -380,10 +341,7 @@ func (t *Trie) insert(origNode node, key []byte, pos int, value node, blockNr ui } else { s2 := &shortNode{Key: hexToCompact(key[pos+matchlen+1:]), Val: value} s2.flags.dirty = true - s2.flags.t = blockNr - s2.adjustTod(blockNr) c2 = s2 - t.joinGeneration(blockNr) } branch := &duoNode{} if nKey[matchlen] < key[pos+matchlen] { @@ -395,8 +353,6 @@ func (t *Trie) insert(origNode node, key []byte, pos int, value node, blockNr ui } branch.mask = (1 << (nKey[matchlen])) | (1 << (key[pos+matchlen])) branch.flags.dirty = true - branch.flags.t = blockNr - branch.adjustTod(blockNr) // Replace this shortNode with the branch if it occurs at index 0. if matchlen == 0 { @@ -407,12 +363,8 @@ func (t *Trie) insert(origNode node, key []byte, pos int, value node, blockNr ui t.touchFunc(key[:pos+matchlen], false) n.Key = hexToCompact(key[pos : pos+matchlen]) n.Val = branch - t.joinGeneration(blockNr) // new branch node joins the generation n.flags.dirty = true - n.flags.t = blockNr newNode = n - t.joinGeneration(blockNr) // n joins the generation too - n.adjustTod(blockNr) } updated = true } @@ -420,12 +372,9 @@ func (t *Trie) insert(origNode node, key []byte, pos int, value node, blockNr ui case *duoNode: t.touchFunc(key[:pos], false) - n.updateT(blockNr, t.joinGeneration, t.leftGeneration) - var adjust bool i1, i2 := n.childrenIdx() switch key[pos] { case i1: - adjust = n.tod(blockNr) == n.child1.tod(blockNr) updated, nn = t.insert(n.child1, key, pos+1, value, blockNr) if updated { n.child1 = nn @@ -433,7 +382,6 @@ func (t *Trie) insert(origNode node, key []byte, pos int, value node, blockNr ui } newNode = n case i2: - adjust = n.tod(blockNr) == n.child2.tod(blockNr) updated, nn = t.insert(n.child2, key, pos+1, value, blockNr) if updated { n.child2 = nn @@ -447,42 +395,28 @@ func (t *Trie) insert(origNode node, key []byte, pos int, value node, blockNr ui } else { short := &shortNode{Key: hexToCompact(key[pos+1:]), Val: value} short.flags.dirty = true - short.flags.t = blockNr - short.adjustTod(blockNr) - t.joinGeneration(blockNr) child = short } newnode := &fullNode{} newnode.Children[i1] = n.child1 newnode.Children[i2] = n.child2 newnode.flags.dirty = true - newnode.flags.t = blockNr - newnode.adjustTod(blockNr) - adjust = false newnode.Children[key[pos]] = child updated = true // current node leaves the generation but newnode joins it newNode = newnode } - if adjust { - n.adjustTod(blockNr) - } return case *fullNode: t.touchFunc(key[:pos], false) - n.updateT(blockNr, t.joinGeneration, t.leftGeneration) child := n.Children[key[pos]] - adjust := child != nil && n.tod(blockNr) == child.tod(blockNr) if child == nil { if len(key) == pos+1 { n.Children[key[pos]] = value } else { short := &shortNode{Key: hexToCompact(key[pos+1:]), Val: value} short.flags.dirty = true - short.flags.t = blockNr - short.adjustTod(blockNr) - t.joinGeneration(blockNr) n.Children[key[pos]] = short } updated = true @@ -495,9 +429,6 @@ func (t *Trie) insert(origNode node, key []byte, pos int, value node, blockNr ui } } newNode = n - if adjust { - n.adjustTod(blockNr) - } return default: fmt.Printf("Key: %x, Pos: %d\n", key, pos) @@ -514,7 +445,6 @@ func (t *Trie) hook(hex []byte, n node, blockNr uint64) { case nil: return case *shortNode: - n.flags.t = blockNr nKey := compactToHex(n.Key) matchlen := prefixLen(hex[pos:], nKey) if matchlen == len(nKey) { @@ -526,7 +456,6 @@ func (t *Trie) hook(hex []byte, n node, blockNr uint64) { } case *duoNode: t.touchFunc(hex[:pos], false) - n.flags.t = blockNr i1, i2 := n.childrenIdx() switch hex[pos] { case i1: @@ -542,7 +471,6 @@ func (t *Trie) hook(hex []byte, n node, blockNr uint64) { } case *fullNode: t.touchFunc(hex[:pos], false) - n.flags.t = blockNr child := n.Children[hex[pos]] if child == nil { return @@ -602,11 +530,8 @@ func (t *Trie) convertToShortNode(key []byte, keyStart int, child node, pos uint k[0] = byte(pos) copy(k[1:], cnodeKey) newshort := &shortNode{Key: hexToCompact(k)} - t.leftGeneration(short.flags.t) newshort.Val = short.Val newshort.flags.dirty = true - newshort.flags.t = blockNr - newshort.adjustTod(blockNr) // cnode gets removed, but newshort gets added return newshort } @@ -616,8 +541,6 @@ func (t *Trie) convertToShortNode(key []byte, keyStart int, child node, pos uint newshort := &shortNode{Key: hexToCompact([]byte{byte(pos)})} newshort.Val = cnode newshort.flags.dirty = true - newshort.flags.t = blockNr - newshort.adjustTod(blockNr) return newshort } @@ -628,13 +551,10 @@ func (t *Trie) delete(origNode node, key []byte, keyStart int, blockNr uint64) ( var nn node switch n := origNode.(type) { case *shortNode: - t.leftGeneration(n.flags.t) - n.flags.t = blockNr nKey := compactToHex(n.Key) matchlen := prefixLen(key[keyStart:], nKey) if matchlen < len(nKey) { updated = false - t.joinGeneration(blockNr) newNode = n // don't replace n on mismatch } else if matchlen == len(key)-keyStart { updated = true @@ -646,7 +566,6 @@ func (t *Trie) delete(origNode node, key []byte, keyStart int, blockNr uint64) ( // longer than n.Key. updated, nn = t.delete(n.Val, key, keyStart+len(nKey), blockNr) if !updated { - t.joinGeneration(blockNr) newNode = n } else { if nn == nil { @@ -663,17 +582,10 @@ func (t *Trie) delete(origNode node, key []byte, keyStart int, blockNr uint64) ( newnode := &shortNode{Key: hexToCompact(concat(nKey, childKey...))} newnode.Val = shortChild.Val newnode.flags.dirty = true - newnode.flags.t = blockNr - newnode.adjustTod(blockNr) - t.joinGeneration(blockNr) - // We do not increase generation count here, because one short node comes, but another one - t.leftGeneration(shortChild.flags.t) // But shortChild goes away newNode = newnode } else { n.Val = nn n.flags.dirty = true - n.adjustTod(blockNr) - t.joinGeneration(blockNr) newNode = n } } @@ -682,17 +594,12 @@ func (t *Trie) delete(origNode node, key []byte, keyStart int, blockNr uint64) ( return case *duoNode: - t.leftGeneration(n.flags.t) - var adjust bool i1, i2 := n.childrenIdx() switch key[keyStart] { case i1: - adjust = n.child1 != nil && n.tod(blockNr) == n.child1.tod(blockNr) updated, nn = t.delete(n.child1, key, keyStart+1, blockNr) if !updated { t.touchFunc(key[:keyStart], false) - t.joinGeneration(blockNr) - n.flags.t = blockNr newNode = n } else { if nn == nil { @@ -700,20 +607,15 @@ func (t *Trie) delete(origNode node, key []byte, keyStart int, blockNr uint64) ( newNode = t.convertToShortNode(key, keyStart, n.child2, uint(i2), blockNr) } else { t.touchFunc(key[:keyStart], false) - t.joinGeneration(blockNr) - n.flags.t = blockNr n.child1 = nn n.flags.dirty = true newNode = n } } case i2: - adjust = n.child2 != nil && n.tod(blockNr) == n.child2.tod(blockNr) updated, nn = t.delete(n.child2, key, keyStart+1, blockNr) if !updated { t.touchFunc(key[:keyStart], false) - t.joinGeneration(blockNr) - n.flags.t = blockNr newNode = n } else { if nn == nil { @@ -721,8 +623,6 @@ func (t *Trie) delete(origNode node, key []byte, keyStart int, blockNr uint64) ( newNode = t.convertToShortNode(key, keyStart, n.child1, uint(i1), blockNr) } else { t.touchFunc(key[:keyStart], false) - t.joinGeneration(blockNr) - n.flags.t = blockNr n.child2 = nn n.flags.dirty = true newNode = n @@ -730,26 +630,16 @@ func (t *Trie) delete(origNode node, key []byte, keyStart int, blockNr uint64) ( } default: t.touchFunc(key[:keyStart], false) - t.joinGeneration(blockNr) - n.flags.t = blockNr - adjust = false updated = false newNode = n } - if adjust { - n.adjustTod(blockNr) - } return case *fullNode: - t.leftGeneration(n.flags.t) child := n.Children[key[keyStart]] - adjust := child != nil && n.tod(blockNr) == child.tod(blockNr) updated, nn = t.delete(child, key, keyStart+1, blockNr) if !updated { t.touchFunc(key[:keyStart], false) - t.joinGeneration(blockNr) - n.flags.t = blockNr newNode = n } else { n.Children[key[keyStart]] = nn @@ -796,23 +686,14 @@ func (t *Trie) delete(origNode node, key []byte, keyStart int, blockNr uint64) ( } duo.flags.dirty = true duo.mask = (1 << uint(pos1)) | (uint32(1) << uint(pos2)) - duo.flags.t = blockNr - t.joinGeneration(blockNr) - duo.adjustTod(blockNr) - adjust = false newNode = duo } else if count > 2 { t.touchFunc(key[:keyStart], false) - t.joinGeneration(blockNr) - n.flags.t = blockNr // n still contains at least three values and cannot be reduced. n.flags.dirty = true newNode = n } } - if adjust { - n.adjustTod(blockNr) - } return case valueNode: @@ -841,11 +722,9 @@ func (t *Trie) prepareToRemove(n node, hex []byte) { if _, ok := n.Val.(valueNode); !ok { // Don't need to compute prefix for a leaf hexVal = concat(hex, compactToHex(n.Key)...) } - t.leftGeneration(n.flags.t) t.prepareToRemove(n.Val, hexVal) case *duoNode: t.touchFunc(hex, true) - t.leftGeneration(n.flags.t) i1, i2 := n.childrenIdx() hex1 := make([]byte, len(hex)+1) copy(hex1, hex) @@ -857,7 +736,6 @@ func (t *Trie) prepareToRemove(n node, hex []byte) { t.prepareToRemove(n.child2, hex2) case *fullNode: t.touchFunc(hex, true) - t.leftGeneration(n.flags.t) for i, child := range n.Children { if child != nil { t.prepareToRemove(child, concat(hex, byte(i))) @@ -884,60 +762,6 @@ func (t *Trie) Hash() common.Hash { return common.BytesToHash(hash.(hashNode)) } -func (t *Trie) UnloadOlderThan(gen uint64, trace bool) bool { - if hn, unloaded := unloadOlderThan(t.root, gen); unloaded { - t.root = hn - return true - } - return false -} - -func unloadOlderThan(n node, gen uint64) (hashNode, bool) { - if n == nil { - return nil, false - } - switch n := (n).(type) { - case *shortNode: - if n.flags.tod < gen { - if hn, unloaded := unloadOlderThan(n.Val, gen); unloaded { - n.Val = hn - } - } - case *duoNode: - if n.flags.t < gen { - if n.flags.dirty { - panic(fmt.Sprintf("duoNode dirty: %s", n)) - } - return hashNode(common.CopyBytes(n.hash())), true - } - if n.flags.tod < gen { - if hn, unloaded := unloadOlderThan(n.child1, gen); unloaded { - n.child1 = hn - } - if hn, unloaded := unloadOlderThan(n.child2, gen); unloaded { - n.child2 = hn - } - } - case *fullNode: - if n.flags.t < gen { - if n.flags.dirty { - panic(fmt.Sprintf("fullNode dirty: %s", n)) - } - return hashNode(common.CopyBytes(n.hash())), true - } - if n.flags.tod < gen { - for i, child := range &n.Children { - if child != nil { - if hn, unloaded := unloadOlderThan(child, gen); unloaded { - n.Children[i] = hn - } - } - } - } - } - return nil, false -} - func (t *Trie) unload(hex []byte, h *hasher) { var nd node = t.root var parent node @@ -1039,12 +863,12 @@ func (t *Trie) countPrunableNodes(nd node, hex []byte, print bool) int { copy(hex2, hex) hex2[len(hex)] = byte(i2) if print { - fmt.Printf("%T node: %x, t: %d\n", n, hex, n.flags.t) + fmt.Printf("%T node: %x\n", n, hex) } return 1 + t.countPrunableNodes(n.child1, hex1, print) + t.countPrunableNodes(n.child2, hex2, print) case *fullNode: if print { - fmt.Printf("%T node: %x, t: %d\n", n, hex, n.flags.t) + fmt.Printf("%T node: %x\n", n, hex) } count := 0 for i, child := range n.Children { @@ -1058,33 +882,6 @@ func (t *Trie) countPrunableNodes(nd node, hex []byte, print bool) int { } } -func (t *Trie) CountGenerations(m map[uint64]int) { - t.countGenerations(t.root, m) -} - -func (t *Trie) countGenerations(nd node, m map[uint64]int) { - switch n := nd.(type) { - case nil: - case valueNode: - case hashNode: - case *shortNode: - t.countGenerations(n.Val, m) - case *duoNode: - m[n.flags.t]++ - t.countGenerations(n.child1, m) - t.countGenerations(n.child2, m) - case *fullNode: - m[n.flags.t]++ - for _, child := range n.Children { - if child != nil { - t.countGenerations(child, m) - } - } - default: - panic("") - } -} - func (t *Trie) hashRoot() (node, error) { if t.root == nil { return hashNode(emptyRoot.Bytes()), nil From 08fcb349f716fff14c53a9fe59a34702eeaf00b0 Mon Sep 17 00:00:00 2001 From: Alexey Akhunov Date: Wed, 12 Jun 2019 11:59:50 +0100 Subject: [PATCH 4/5] Fix stateless pruning, reduce inital rebuild --- cmd/state/stateless.go | 6 ++++-- core/state/stateless.go | 16 ++++++++++++++++ trie/proof_generator.go | 14 ++++++++++++-- trie/resolver.go | 2 +- trie/trie.go | 7 +------ 5 files changed, 34 insertions(+), 11 deletions(-) diff --git a/cmd/state/stateless.go b/cmd/state/stateless.go index 5f8017bb7bb..4ea2904e127 100644 --- a/cmd/state/stateless.go +++ b/cmd/state/stateless.go @@ -174,7 +174,7 @@ func stateless(genLag, consLag int) { var proofGen *state.Stateless // Generator of proofs var proofCons *state.Stateless // Consumer of proofs for !interrupt { - trace := false //blockNum == 318335 + trace := false // blockNum == 1807 if trace { filename := fmt.Sprintf("right_%d.txt", blockNum-1) f, err1 := os.Create(filename) @@ -279,7 +279,6 @@ func stateless(genLag, consLag int) { return } writeStats(wf, blockNum, pBlockProof) - proofCons.Prune(blockNum-uint64(consLag), false) } else { if err := proofCons.ApplyProof(preRoot, blockProof, block.NumberU64()-1, false); err != nil { fmt.Printf("Error applying proof to consumer: %v\n", err) @@ -289,6 +288,9 @@ func stateless(genLag, consLag int) { if err := runBlock(tds, proofCons, chainConfig, bcb, header, block, trace, false); err != nil { fmt.Printf("Error running block %d through proof consumer: %v\n", blockNum, err) } + if blockNum > uint64(consLag) { + proofCons.Prune(blockNum-uint64(consLag), false) + } } if proofGen != nil { if err := proofGen.ApplyProof(preRoot, blockProof, block.NumberU64()-1, false); err != nil { diff --git a/core/state/stateless.go b/core/state/stateless.go index 3d46294fe29..b14cba177e2 100644 --- a/core/state/stateless.go +++ b/core/state/stateless.go @@ -563,12 +563,28 @@ func (s *Stateless) WriteAccountStorage(address common.Address, key, original, v } func (s *Stateless) Prune(oldest uint64, trace bool) { + if trace { + mainPrunable := s.t.CountPrunableNodes() + prunableNodes := mainPrunable + for _, storageTrie := range s.storageTries { + prunableNodes += storageTrie.CountPrunableNodes() + } + fmt.Printf("[Before pruning to %d] Actual prunable nodes: %d (main %d), accounted: %d\n", oldest, prunableNodes, mainPrunable, s.tp.NodeCount()) + } emptyAddresses, err := s.tp.PruneToTimestamp(s.t, oldest, func(contract common.Address) (*trie.Trie, error) { return s.getStorageTrie(contract, false) }) if err != nil { fmt.Printf("Error while pruning: %v\n", err) } + if trace { + mainPrunable := s.t.CountPrunableNodes() + prunableNodes := mainPrunable + for _, storageTrie := range s.storageTries { + prunableNodes += storageTrie.CountPrunableNodes() + } + fmt.Printf("[After pruning to %d Actual prunable nodes: %d (main %d), accounted: %d\n", oldest, prunableNodes, mainPrunable, s.tp.NodeCount()) + } for _, address := range emptyAddresses { delete(s.storageTries, address) } diff --git a/trie/proof_generator.go b/trie/proof_generator.go index ffc94af0b0e..d5f968b39ca 100644 --- a/trie/proof_generator.go +++ b/trie/proof_generator.go @@ -929,7 +929,17 @@ func applyShortNode(h *hasher, touchFunc func(hex []byte, del bool), ctime uint6 downmask := masks[*maskIdx] (*maskIdx)++ // short node (leaf or extension) - s, ok := n.(*shortNode) + var s *shortNode + var ok bool + switch nt := n.(type) { + case *shortNode: + s = nt + ok = true + case *duoNode: + touchFunc(hex, true) // duoNode turned into shortNode - delete from prunable set + case *fullNode: + touchFunc(hex, true) // fullNode turned into shortNode - delete from prunable set + } var nKey []byte if (downmask <= 1) || downmask == 2 || downmask == 4 || downmask == 6 { nKey = shortKeys[*shortIdx] @@ -991,7 +1001,7 @@ func applyShortNode(h *hasher, touchFunc func(hex []byte, del bool), ctime uint6 s.Val = applyFullNode(h, touchFunc, ctime, nil, concat(hex, nKey...), masks, shortKeys, values, hashes, maskIdx, shortIdx, valueIdx, hashIdx, trace) case 7: - s.Val = applyFullNode(h, touchFunc, ctime, s.Val, concat(hex, nKey...), masks, shortKeys, values, hashes, + s.Val = applyFullNode(h, touchFunc, ctime, s.Val, concat(hex, compactToHex(s.Key)...), masks, shortKeys, values, hashes, maskIdx, shortIdx, valueIdx, hashIdx, trace) } if s.flags.dirty { diff --git a/trie/resolver.go b/trie/resolver.go index 3ca086f6283..50acde5a1ec 100644 --- a/trie/resolver.go +++ b/trie/resolver.go @@ -256,7 +256,7 @@ func (tr *TrieResolver) finishPreviousKey(k []byte) error { tr.nodeStack[level].flags.dirty = true } tr.vertical[level].flags.dirty = true - if onResolvingPath || (tr.hashes && level <= 5) { + if onResolvingPath || (tr.hashes && level < 5) { var c node if tr.fillCount[level+1] == 2 { c = full.duoCopy() diff --git a/trie/trie.go b/trie/trie.go index b193bb78ac5..1fbae77563c 100644 --- a/trie/trie.go +++ b/trie/trie.go @@ -33,11 +33,6 @@ var ( emptyState = crypto.Keccak256Hash(nil) ) -// LeafCallback is a callback type invoked when a trie operation reaches a leaf -// node. It's used by state sync and commit to allow handling external references -// between account and storage tries. -type LeafCallback func(leaf []byte, parent common.Hash) error - // Trie is a Merkle Patricia Trie. // The zero value is an empty trie with no database. // Use New to create a trie that sits on top of a database. @@ -654,7 +649,7 @@ func (t *Trie) delete(origNode node, key []byte, keyStart int, blockNr uint64) ( // values. var pos1, pos2 int count := 0 - for i, cld := range &n.Children { + for i, cld := range n.Children { if cld != nil { if count == 0 { pos1 = i From 19f80d88a9fcf2e870d4838175560738ab263ac0 Mon Sep 17 00:00:00 2001 From: Alexey Akhunov Date: Wed, 12 Jun 2019 23:29:05 +0100 Subject: [PATCH 5/5] Reduce default node count --- core/state/database.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/state/database.go b/core/state/database.go index 208ba55904a..5b6d543771a 100644 --- a/core/state/database.go +++ b/core/state/database.go @@ -35,7 +35,7 @@ import ( ) // Trie cache generation limit after which to evict trie nodes from memory. -var MaxTrieCacheGen = uint32(4 * 1024 * 1024) +var MaxTrieCacheGen = uint32(1024 * 1024) var AccountsBucket = []byte("AT") var AccountsHistoryBucket = []byte("hAT")