/
batch.go
288 lines (260 loc) · 7.37 KB
/
batch.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
package mpt
import (
"bytes"
"sort"
)
// Batch is batch of storage changes.
// It stores key-value pairs in a sorted state.
type Batch struct {
kv []keyValue
}
type keyValue struct {
key []byte
value []byte
}
// MapToMPTBatch makes a Batch from unordered set of storage changes.
func MapToMPTBatch(m map[string][]byte) Batch {
var b Batch
b.kv = make([]keyValue, 0, len(m))
for k, v := range m {
b.kv = append(b.kv, keyValue{strToNibbles(k), v}) // Strip storage prefix.
}
sort.Slice(b.kv, func(i, j int) bool {
return bytes.Compare(b.kv[i].key, b.kv[j].key) < 0
})
return b
}
// PutBatch puts batch to trie.
// It is not atomic (and probably cannot be without substantial slow-down)
// and returns number of elements processed.
// If an error is returned, the trie may be in the inconsistent state in case of storage failures.
// This is due to the fact that we can remove multiple children from the branch node simultaneously
// and won't strip the resulting branch node.
// However it is used mostly after the block processing to update MPT and error is not expected.
func (t *Trie) PutBatch(b Batch) (int, error) {
if len(b.kv) == 0 {
return 0, nil
}
r, n, err := t.putBatch(b.kv)
t.root = r
return n, err
}
func (t *Trie) putBatch(kv []keyValue) (Node, int, error) {
return t.putBatchIntoNode(t.root, kv)
}
func (t *Trie) putBatchIntoNode(curr Node, kv []keyValue) (Node, int, error) {
switch n := curr.(type) {
case *LeafNode:
return t.putBatchIntoLeaf(n, kv)
case *BranchNode:
return t.putBatchIntoBranch(n, kv)
case *ExtensionNode:
return t.putBatchIntoExtension(n, kv)
case *HashNode:
return t.putBatchIntoHash(n, kv)
case EmptyNode:
return t.putBatchIntoEmpty(kv)
default:
panic("invalid MPT node type")
}
}
func (t *Trie) putBatchIntoLeaf(curr *LeafNode, kv []keyValue) (Node, int, error) {
t.removeRef(curr.Hash(), curr.Bytes())
return t.newSubTrieMany(nil, kv, curr.value)
}
func (t *Trie) putBatchIntoBranch(curr *BranchNode, kv []keyValue) (Node, int, error) {
return t.addToBranch(curr, kv, true)
}
func (t *Trie) mergeExtension(prefix []byte, sub Node) (Node, error) {
switch sn := sub.(type) {
case *ExtensionNode:
t.removeRef(sn.Hash(), sn.bytes)
sn.key = append(prefix, sn.key...)
sn.invalidateCache()
t.addRef(sn.Hash(), sn.bytes)
return sn, nil
case EmptyNode:
return sn, nil
case *HashNode:
n, err := t.getFromStore(sn.Hash())
if err != nil {
return sn, err
}
return t.mergeExtension(prefix, n)
default:
if len(prefix) != 0 {
e := NewExtensionNode(prefix, sub)
t.addRef(e.Hash(), e.bytes)
return e, nil
}
return sub, nil
}
}
func (t *Trie) putBatchIntoExtension(curr *ExtensionNode, kv []keyValue) (Node, int, error) {
t.removeRef(curr.Hash(), curr.bytes)
common := lcpMany(kv)
pref := lcp(common, curr.key)
if len(pref) == len(curr.key) {
// Extension must be split into new nodes.
stripPrefix(len(curr.key), kv)
sub, n, err := t.putBatchIntoNode(curr.next, kv)
if err == nil {
sub, err = t.mergeExtension(pref, sub)
}
return sub, n, err
}
if len(pref) != 0 {
stripPrefix(len(pref), kv)
sub, n, err := t.putBatchIntoExtensionNoPrefix(curr.key[len(pref):], curr.next, kv)
if err == nil {
sub, err = t.mergeExtension(pref, sub)
}
return sub, n, err
}
return t.putBatchIntoExtensionNoPrefix(curr.key, curr.next, kv)
}
func (t *Trie) putBatchIntoExtensionNoPrefix(key []byte, next Node, kv []keyValue) (Node, int, error) {
b := NewBranchNode()
if len(key) > 1 {
b.Children[key[0]] = t.newSubTrie(key[1:], next, false)
} else {
b.Children[key[0]] = next
}
return t.addToBranch(b, kv, false)
}
func isEmpty(n Node) bool {
_, ok := n.(EmptyNode)
return ok
}
// addToBranch puts items into the branch node assuming b is not yet in trie.
func (t *Trie) addToBranch(b *BranchNode, kv []keyValue, inTrie bool) (Node, int, error) {
if inTrie {
t.removeRef(b.Hash(), b.bytes)
}
// Error during iterate means some storage failure (i.e. some hash node cannot be
// retrieved from storage). This can leave trie in inconsistent state, because
// it can be impossible to strip branch node after it has been changed.
// Consider a branch with 10 children, first 9 of which are deleted and the remaining one
// is a leaf node replaced by a hash node missing from storage.
// This can't be fixed easily because we need to _revert_ changes in reference counts
// for children which were updated successfully. But storage access errors means we are
// in a bad state anyway.
n, err := t.iterateBatch(kv, func(c byte, kv []keyValue) (int, error) {
child, n, err := t.putBatchIntoNode(b.Children[c], kv)
b.Children[c] = child
return n, err
})
if inTrie && n != 0 {
b.invalidateCache()
}
// Even if some of the children can't be put, we need to try to strip branch
// and possibly update refcounts.
nd, bErr := t.stripBranch(b)
if err == nil {
err = bErr
}
return nd, n, err
}
// stripsBranch strips branch node after incomplete batch put.
// It assumes there is no reference to b in trie.
func (t *Trie) stripBranch(b *BranchNode) (Node, error) {
var n int
var lastIndex byte
for i := range b.Children {
if !isEmpty(b.Children[i]) {
n++
lastIndex = byte(i)
}
}
switch {
case n == 0:
return EmptyNode{}, nil
case n == 1:
if lastIndex != lastChild {
return t.mergeExtension([]byte{lastIndex}, b.Children[lastIndex])
}
return b.Children[lastIndex], nil
default:
t.addRef(b.Hash(), b.bytes)
return b, nil
}
}
func (t *Trie) iterateBatch(kv []keyValue, f func(c byte, kv []keyValue) (int, error)) (int, error) {
var n int
for len(kv) != 0 {
c, i := getLastIndex(kv)
if c != lastChild {
stripPrefix(1, kv[:i])
}
sub, err := f(c, kv[:i])
n += sub
if err != nil {
return n, err
}
kv = kv[i:]
}
return n, nil
}
func (t *Trie) putBatchIntoEmpty(kv []keyValue) (Node, int, error) {
common := lcpMany(kv)
stripPrefix(len(common), kv)
return t.newSubTrieMany(common, kv, nil)
}
func (t *Trie) putBatchIntoHash(curr *HashNode, kv []keyValue) (Node, int, error) {
result, err := t.getFromStore(curr.hash)
if err != nil {
return curr, 0, err
}
return t.putBatchIntoNode(result, kv)
}
// Creates new subtrie from provided key-value pairs.
// Items in kv must have no common prefix.
// If there are any deletions in kv, return error.
// kv is not empty.
// kv is sorted by key.
// value is current value stored by prefix.
func (t *Trie) newSubTrieMany(prefix []byte, kv []keyValue, value []byte) (Node, int, error) {
if len(kv[0].key) == 0 {
if kv[0].value == nil {
if len(kv) == 1 {
return EmptyNode{}, 1, nil
}
node, n, err := t.newSubTrieMany(prefix, kv[1:], nil)
return node, n + 1, err
}
if len(kv) == 1 {
return t.newSubTrie(prefix, NewLeafNode(kv[0].value), true), 1, nil
}
value = kv[0].value
}
// Prefix is empty and we have at least 2 children.
b := NewBranchNode()
if value != nil {
// Empty key is always first.
leaf := NewLeafNode(value)
t.addRef(leaf.Hash(), leaf.bytes)
b.Children[lastChild] = leaf
}
nd, n, err := t.addToBranch(b, kv, false)
if err == nil {
nd, err = t.mergeExtension(prefix, nd)
}
return nd, n, err
}
func stripPrefix(n int, kv []keyValue) {
for i := range kv {
kv[i].key = kv[i].key[n:]
}
}
func getLastIndex(kv []keyValue) (byte, int) {
if len(kv[0].key) == 0 {
return lastChild, 1
}
c := kv[0].key[0]
for i := range kv[1:] {
if kv[i+1].key[0] != c {
return c, i + 1
}
}
return c, len(kv)
}