diff --git a/pkg/core/blockchain.go b/pkg/core/blockchain.go index 7ac605bed9..055008e862 100644 --- a/pkg/core/blockchain.go +++ b/pkg/core/blockchain.go @@ -17,6 +17,7 @@ import ( "github.com/nspcc-dev/neo-go/pkg/core/state" "github.com/nspcc-dev/neo-go/pkg/core/storage" "github.com/nspcc-dev/neo-go/pkg/core/transaction" + "github.com/nspcc-dev/neo-go/pkg/crypto/hash" "github.com/nspcc-dev/neo-go/pkg/crypto/keys" "github.com/nspcc-dev/neo-go/pkg/encoding/bigint" "github.com/nspcc-dev/neo-go/pkg/io" @@ -221,6 +222,9 @@ func (bc *Blockchain) init() error { } bc.blockHeight = bHeight bc.persistedHeight = bHeight + if err = bc.dao.InitMPT(bHeight); err != nil { + return errors.Wrapf(err, "can't init MPT at height %d", bHeight) + } hashes, err := bc.dao.GetHeaderHashes() if err != nil { @@ -550,6 +554,11 @@ func (bc *Blockchain) processHeader(h *block.Header, batch storage.Batch, header return nil } +// GetStateRoot returns state root for a given height. +func (bc *Blockchain) GetStateRoot(height uint32) (*state.MPTRootState, error) { + return bc.dao.GetStateRoot(height) +} + // storeBlock performs chain update using the block given, it executes all // transactions with all appropriate side-effects and updates Blockchain state. // This is the only way to change Blockchain state. @@ -633,17 +642,44 @@ func (bc *Blockchain) storeBlock(block *block.Block) error { } } + root := bc.dao.MPT.StateRoot() + var prevHash util.Uint256 + if block.Index > 0 { + prev, err := bc.dao.GetStateRoot(block.Index - 1) + if err != nil { + return errors.WithMessagef(err, "can't get previous state root") + } + prevHash = hash.DoubleSha256(prev.GetSignedPart()) + } + err := bc.AddStateRoot(&state.MPTRoot{ + MPTRootBase: state.MPTRootBase{ + Index: block.Index, + PrevHash: prevHash, + Root: root, + }, + }) + if err != nil { + return err + } + if bc.config.SaveStorageBatch { bc.lastBatch = cache.DAO.GetBatch() } bc.lock.Lock() - _, err := cache.Persist() + _, err = cache.Persist() if err != nil { bc.lock.Unlock() return err } bc.contracts.Policy.OnPersistEnd(bc.dao) + bc.dao.MPT.Flush() + // Every persist cycle we also compact our in-memory MPT. + persistedHeight := atomic.LoadUint32(&bc.persistedHeight) + if persistedHeight == block.Index-1 { + // 10 is good and roughly estimated to fit remaining trie into 1M of memory. + bc.dao.MPT.Collapse(10) + } bc.topBlock.Store(block) atomic.StoreUint32(&bc.blockHeight, block.Index) bc.memPool.RemoveStale(bc.isTxStillRelevant, bc) @@ -1194,6 +1230,82 @@ func (bc *Blockchain) isTxStillRelevant(t *transaction.Transaction) bool { } +// AddStateRoot add new (possibly unverified) state root to the blockchain. +func (bc *Blockchain) AddStateRoot(r *state.MPTRoot) error { + our, err := bc.GetStateRoot(r.Index) + if err == nil { + if our.Flag == state.Verified { + return bc.updateStateHeight(r.Index) + } else if r.Witness == nil && our.Witness != nil { + r.Witness = our.Witness + } + } + if err := bc.verifyStateRoot(r); err != nil { + return errors.WithMessage(err, "invalid state root") + } + if r.Index > bc.BlockHeight() { // just put it into the store for future checks + return bc.dao.PutStateRoot(&state.MPTRootState{ + MPTRoot: *r, + Flag: state.Unverified, + }) + } + + flag := state.Unverified + if r.Witness != nil { + if err := bc.verifyStateRootWitness(r); err != nil { + return errors.WithMessage(err, "can't verify signature") + } + flag = state.Verified + } + err = bc.dao.PutStateRoot(&state.MPTRootState{ + MPTRoot: *r, + Flag: flag, + }) + if err != nil { + return err + } + return bc.updateStateHeight(r.Index) +} + +func (bc *Blockchain) updateStateHeight(newHeight uint32) error { + h, err := bc.dao.GetCurrentStateRootHeight() + if err != nil { + return errors.WithMessage(err, "can't get current state root height") + } else if newHeight == h+1 { + updateStateHeightMetric(newHeight) + return bc.dao.PutCurrentStateRootHeight(h + 1) + } + return nil +} + +// verifyStateRoot checks if state root is valid. +func (bc *Blockchain) verifyStateRoot(r *state.MPTRoot) error { + if r.Index == 0 { + return nil + } + prev, err := bc.GetStateRoot(r.Index - 1) + if err != nil { + return errors.New("can't get previous state root") + } else if !r.PrevHash.Equals(hash.DoubleSha256(prev.GetSignedPart())) { + return errors.New("previous hash mismatch") + } else if prev.Version != r.Version { + return errors.New("version mismatch") + } + return nil +} + +// verifyStateRootWitness verifies that state root signature is correct. +func (bc *Blockchain) verifyStateRootWitness(r *state.MPTRoot) error { + b, err := bc.GetBlock(bc.GetHeaderHash(int(r.Index))) + if err != nil { + return err + } + interopCtx := bc.newInteropContext(trigger.Verification, bc.dao, nil, nil) + interopCtx.Container = r + return bc.verifyHashAgainstScript(b.NextConsensus, r.Witness, interopCtx, true, + bc.contracts.Policy.GetMaxVerificationGas(interopCtx.DAO)) +} + // VerifyTx verifies whether a transaction is bonafide or not. Block parameter // is used for easy interop access and can be omitted for transactions that are // not yet added into any block. diff --git a/pkg/core/blockchainer/blockchainer.go b/pkg/core/blockchainer/blockchainer.go index 9dcac9e337..1086c6beee 100644 --- a/pkg/core/blockchainer/blockchainer.go +++ b/pkg/core/blockchainer/blockchainer.go @@ -20,6 +20,7 @@ type Blockchainer interface { GetConfig() config.ProtocolConfiguration AddHeaders(...*block.Header) error AddBlock(*block.Block) error + AddStateRoot(r *state.MPTRoot) error BlockHeight() uint32 CalculateClaimable(value *big.Int, startHeight, endHeight uint32) *big.Int Close() @@ -42,6 +43,7 @@ type Blockchainer interface { GetValidators() ([]*keys.PublicKey, error) GetStandByValidators() keys.PublicKeys GetScriptHashesForVerifying(*transaction.Transaction) ([]util.Uint160, error) + GetStateRoot(height uint32) (*state.MPTRootState, error) GetStorageItem(id int32, key []byte) *state.StorageItem GetStorageItems(id int32) (map[string]*state.StorageItem, error) GetTestVM(tx *transaction.Transaction) *vm.VM diff --git a/pkg/core/dao/dao.go b/pkg/core/dao/dao.go index f24c7267c2..006b1ecdb3 100644 --- a/pkg/core/dao/dao.go +++ b/pkg/core/dao/dao.go @@ -8,6 +8,7 @@ import ( "github.com/nspcc-dev/neo-go/pkg/config/netmode" "github.com/nspcc-dev/neo-go/pkg/core/block" + "github.com/nspcc-dev/neo-go/pkg/core/mpt" "github.com/nspcc-dev/neo-go/pkg/core/state" "github.com/nspcc-dev/neo-go/pkg/core/storage" "github.com/nspcc-dev/neo-go/pkg/core/transaction" @@ -29,10 +30,13 @@ type DAO interface { GetContractState(hash util.Uint160) (*state.Contract, error) GetCurrentBlockHeight() (uint32, error) GetCurrentHeaderHeight() (i uint32, h util.Uint256, err error) + GetCurrentStateRootHeight() (uint32, error) GetHeaderHashes() ([]util.Uint256, error) GetNEP5Balances(acc util.Uint160) (*state.NEP5Balances, error) GetNEP5TransferLog(acc util.Uint160, index uint32) (*state.NEP5TransferLog, error) GetAndUpdateNextContractID() (int32, error) + GetStateRoot(height uint32) (*state.MPTRootState, error) + PutStateRoot(root *state.MPTRootState) error GetStorageItem(id int32, key []byte) *state.StorageItem GetStorageItems(id int32) (map[string]*state.StorageItem, error) GetStorageItemsWithPrefix(id int32, prefix []byte) (map[string]*state.StorageItem, error) @@ -58,13 +62,15 @@ type DAO interface { // Simple is memCached wrapper around DB, simple DAO implementation. type Simple struct { + MPT *mpt.Trie Store *storage.MemCachedStore network netmode.Magic } // NewSimple creates new simple dao using provided backend store. func NewSimple(backend storage.Store, network netmode.Magic) *Simple { - return &Simple{Store: storage.NewMemCachedStore(backend), network: network} + st := storage.NewMemCachedStore(backend) + return &Simple{Store: st, network: network, MPT: mpt.NewTrie(nil, st)} } // GetBatch returns currently accumulated DB changeset. @@ -75,7 +81,9 @@ func (dao *Simple) GetBatch() *storage.MemBatch { // GetWrapped returns new DAO instance with another layer of wrapped // MemCachedStore around the current DAO Store. func (dao *Simple) GetWrapped() DAO { - return NewSimple(dao.Store, dao.network) + d := NewSimple(dao.Store, dao.network) + d.MPT = dao.MPT + return d } // GetAndDecode performs get operation and decoding with serializable structures. @@ -288,6 +296,63 @@ func (dao *Simple) PutAppExecResult(aer *state.AppExecResult) error { // -- start storage item. +func makeStateRootKey(height uint32) []byte { + key := make([]byte, 5) + key[0] = byte(storage.DataMPT) + binary.LittleEndian.PutUint32(key[1:], height) + return key +} + +// InitMPT initializes MPT at the given height. +func (dao *Simple) InitMPT(height uint32) error { + if height == 0 { + dao.MPT = mpt.NewTrie(nil, dao.Store) + return nil + } + r, err := dao.GetStateRoot(height) + if err != nil { + return err + } + dao.MPT = mpt.NewTrie(mpt.NewHashNode(r.Root), dao.Store) + return nil +} + +// GetCurrentStateRootHeight returns current state root height. +func (dao *Simple) GetCurrentStateRootHeight() (uint32, error) { + key := []byte{byte(storage.DataMPT)} + val, err := dao.Store.Get(key) + if err != nil { + if err == storage.ErrKeyNotFound { + err = nil + } + return 0, err + } + return binary.LittleEndian.Uint32(val), nil +} + +// PutCurrentStateRootHeight updates current state root height. +func (dao *Simple) PutCurrentStateRootHeight(height uint32) error { + key := []byte{byte(storage.DataMPT)} + val := make([]byte, 4) + binary.LittleEndian.PutUint32(val, height) + return dao.Store.Put(key, val) +} + +// GetStateRoot returns state root of a given height. +func (dao *Simple) GetStateRoot(height uint32) (*state.MPTRootState, error) { + r := new(state.MPTRootState) + err := dao.GetAndDecode(r, makeStateRootKey(height)) + if err != nil { + return nil, err + } + return r, nil +} + +// PutStateRoot puts state root of a given height into the store. +func (dao *Simple) PutStateRoot(r *state.MPTRootState) error { + return dao.Put(r, makeStateRootKey(r.Index)) +} + // GetStorageItem returns StorageItem if it exists in the given store. func (dao *Simple) GetStorageItem(id int32, key []byte) *state.StorageItem { b, err := dao.Store.Get(makeStorageItemKey(id, key)) @@ -308,13 +373,27 @@ func (dao *Simple) GetStorageItem(id int32, key []byte) *state.StorageItem { // PutStorageItem puts given StorageItem for given id with given // key into the given store. func (dao *Simple) PutStorageItem(id int32, key []byte, si *state.StorageItem) error { - return dao.Put(si, makeStorageItemKey(id, key)) + stKey := makeStorageItemKey(id, key) + buf := io.NewBufBinWriter() + si.EncodeBinary(buf.BinWriter) + if buf.Err != nil { + return buf.Err + } + v := buf.Bytes() + if err := dao.MPT.Put(stKey[1:], v); err != nil && err != mpt.ErrNotFound { + return err + } + return dao.Store.Put(stKey, v) } // DeleteStorageItem drops storage item for the given id with the // given key from the store. func (dao *Simple) DeleteStorageItem(id int32, key []byte) error { - return dao.Store.Delete(makeStorageItemKey(id, key)) + stKey := makeStorageItemKey(id, key) + if err := dao.MPT.Delete(stKey[1:]); err != nil && err != mpt.ErrNotFound { + return err + } + return dao.Store.Delete(stKey) } // GetStorageItems returns all storage items for a given id. diff --git a/pkg/core/mpt/base.go b/pkg/core/mpt/base.go new file mode 100644 index 0000000000..9f10cc3338 --- /dev/null +++ b/pkg/core/mpt/base.go @@ -0,0 +1,84 @@ +package mpt + +import ( + "github.com/nspcc-dev/neo-go/pkg/crypto/hash" + "github.com/nspcc-dev/neo-go/pkg/io" + "github.com/nspcc-dev/neo-go/pkg/util" +) + +// BaseNode implements basic things every node needs like caching hash and +// serialized representation. It's a basic node building block intended to be +// included into all node types. +type BaseNode struct { + hash util.Uint256 + bytes []byte + hashValid bool + bytesValid bool + + isFlushed bool +} + +// BaseNodeIface abstracts away basic Node functions. +type BaseNodeIface interface { + Hash() util.Uint256 + Type() NodeType + Bytes() []byte + IsFlushed() bool + SetFlushed() +} + +// getHash returns a hash of this BaseNode. +func (b *BaseNode) getHash(n Node) util.Uint256 { + if !b.hashValid { + b.updateHash(n) + } + return b.hash +} + +// getBytes returns a slice of bytes representing this node. +func (b *BaseNode) getBytes(n Node) []byte { + if !b.bytesValid { + b.updateBytes(n) + } + return b.bytes +} + +// updateHash updates hash field for this BaseNode. +func (b *BaseNode) updateHash(n Node) { + if n.Type() == HashT { + panic("can't update hash for hash node") + } + b.hash = hash.DoubleSha256(b.getBytes(n)) + b.hashValid = true +} + +// updateCache updates hash and bytes fields for this BaseNode. +func (b *BaseNode) updateBytes(n Node) { + buf := io.NewBufBinWriter() + encodeNodeWithType(n, buf.BinWriter) + b.bytes = buf.Bytes() + b.bytesValid = true +} + +// invalidateCache sets all cache fields to invalid state. +func (b *BaseNode) invalidateCache() { + b.bytesValid = false + b.hashValid = false + b.isFlushed = false +} + +// IsFlushed checks for node flush status. +func (b *BaseNode) IsFlushed() bool { + return b.isFlushed +} + +// SetFlushed sets 'flushed' flag to true for this node. +func (b *BaseNode) SetFlushed() { + b.isFlushed = true +} + +// encodeNodeWithType encodes node together with it's type. +func encodeNodeWithType(n Node, w *io.BinWriter) { + w.WriteB(byte(n.Type())) + n.EncodeBinary(w) +} diff --git a/pkg/core/mpt/branch.go b/pkg/core/mpt/branch.go new file mode 100644 index 0000000000..fbad5d29ed --- /dev/null +++ b/pkg/core/mpt/branch.go @@ -0,0 +1,91 @@ +package mpt + +import ( + "encoding/json" + "errors" + + "github.com/nspcc-dev/neo-go/pkg/io" + "github.com/nspcc-dev/neo-go/pkg/util" +) + +const ( + // childrenCount represents a number of children of a branch node. + childrenCount = 17 + // lastChild is the index of the last child. + lastChild = childrenCount - 1 +) + +// BranchNode represents MPT's branch node. +type BranchNode struct { + BaseNode + Children [childrenCount]Node +} + +var _ Node = (*BranchNode)(nil) + +// NewBranchNode returns new branch node. +func NewBranchNode() *BranchNode { + b := new(BranchNode) + for i := 0; i < childrenCount; i++ { + b.Children[i] = new(HashNode) + } + return b +} + +// Type implements Node interface. +func (b *BranchNode) Type() NodeType { return BranchT } + +// Hash implements BaseNode interface. +func (b *BranchNode) Hash() util.Uint256 { + return b.getHash(b) +} + +// Bytes implements BaseNode interface. +func (b *BranchNode) Bytes() []byte { + return b.getBytes(b) +} + +// EncodeBinary implements io.Serializable. +func (b *BranchNode) EncodeBinary(w *io.BinWriter) { + for i := 0; i < childrenCount; i++ { + if hn, ok := b.Children[i].(*HashNode); ok { + hn.EncodeBinary(w) + continue + } + n := NewHashNode(b.Children[i].Hash()) + n.EncodeBinary(w) + } +} + +// DecodeBinary implements io.Serializable. +func (b *BranchNode) DecodeBinary(r *io.BinReader) { + for i := 0; i < childrenCount; i++ { + b.Children[i] = new(HashNode) + b.Children[i].DecodeBinary(r) + } +} + +// MarshalJSON implements json.Marshaler. +func (b *BranchNode) MarshalJSON() ([]byte, error) { + return json.Marshal(b.Children) +} + +// UnmarshalJSON implements json.Unmarshaler. +func (b *BranchNode) UnmarshalJSON(data []byte) error { + var obj NodeObject + if err := obj.UnmarshalJSON(data); err != nil { + return err + } else if u, ok := obj.Node.(*BranchNode); ok { + *b = *u + return nil + } + return errors.New("expected branch node") +} + +// splitPath splits path for a branch node. +func splitPath(path []byte) (byte, []byte) { + if len(path) != 0 { + return path[0], path[1:] + } + return lastChild, path +} diff --git a/pkg/core/mpt/doc.go b/pkg/core/mpt/doc.go new file mode 100644 index 0000000000..c307665b36 --- /dev/null +++ b/pkg/core/mpt/doc.go @@ -0,0 +1,45 @@ +/* +Package mpt implements MPT (Merkle-Patricia Tree). + +MPT stores key-value pairs and is a trie over 16-symbol alphabet. https://en.wikipedia.org/wiki/Trie +Trie is a tree where values are stored in leafs and keys are paths from root to the leaf node. +MPT consists of 4 type of nodes: +- Leaf node contains only value. +- Extension node contains both key and value. +- Branch node contains 2 or more children. +- Hash node is a compressed node and contains only actual node's hash. + The actual node must be retrieved from storage or over the network. + +As an example here is a trie containing 3 pairs: +- 0x1201 -> val1 +- 0x1203 -> val2 +- 0x1224 -> val3 +- 0x12 -> val4 + +ExtensionNode(0x0102), Next + _______________________| + | +BranchNode [0, 1, 2, ...], Last -> Leaf(val4) + | | + | ExtensionNode [0x04], Next -> Leaf(val3) + | + BranchNode [0, 1, 2, 3, ...], Last -> HashNode(nil) + | | + | Leaf(val2) + | + Leaf(val1) + +There are 3 invariants that this implementation has: +- Branch node cannot have <= 1 children +- Extension node cannot have zero-length key +- Extension node cannot have another Extension node in it's next field + +Thank to these restrictions, there is a single root hash for every set of key-value pairs +irregardless of the order they were added/removed with. +The actual trie structure can vary because of node -> HashNode compressing. + +There is also one optimization which cost us almost nothing in terms of complexity but is very beneficial: +When we perform get/put/delete on a speficic path, every Hash node which was retreived from storage is +replaced by its uncompressed form, so that subsequent hits of this not don't use storage. +*/ +package mpt diff --git a/pkg/core/mpt/extension.go b/pkg/core/mpt/extension.go new file mode 100644 index 0000000000..8bcc11c248 --- /dev/null +++ b/pkg/core/mpt/extension.go @@ -0,0 +1,87 @@ +package mpt + +import ( + "encoding/hex" + "encoding/json" + "errors" + "fmt" + + "github.com/nspcc-dev/neo-go/pkg/io" + "github.com/nspcc-dev/neo-go/pkg/util" +) + +// MaxKeyLength is the max length of the extension node key. +const MaxKeyLength = 1125 + +// ExtensionNode represents MPT's extension node. +type ExtensionNode struct { + BaseNode + key []byte + next Node +} + +var _ Node = (*ExtensionNode)(nil) + +// NewExtensionNode returns hash node with the specified key and next node. +// Note: because it is a part of Trie, key must be mangled, i.e. must contain only bytes with high half = 0. +func NewExtensionNode(key []byte, next Node) *ExtensionNode { + return &ExtensionNode{ + key: key, + next: next, + } +} + +// Type implements Node interface. +func (e ExtensionNode) Type() NodeType { return ExtensionT } + +// Hash implements BaseNode interface. +func (e *ExtensionNode) Hash() util.Uint256 { + return e.getHash(e) +} + +// Bytes implements BaseNode interface. +func (e *ExtensionNode) Bytes() []byte { + return e.getBytes(e) +} + +// DecodeBinary implements io.Serializable. +func (e *ExtensionNode) DecodeBinary(r *io.BinReader) { + sz := r.ReadVarUint() + if sz > MaxKeyLength { + r.Err = fmt.Errorf("extension node key is too big: %d", sz) + return + } + e.key = make([]byte, sz) + r.ReadBytes(e.key) + e.next = new(HashNode) + e.next.DecodeBinary(r) + e.invalidateCache() +} + +// EncodeBinary implements io.Serializable. +func (e ExtensionNode) EncodeBinary(w *io.BinWriter) { + w.WriteVarBytes(e.key) + n := NewHashNode(e.next.Hash()) + n.EncodeBinary(w) +} + +// MarshalJSON implements json.Marshaler. +func (e *ExtensionNode) MarshalJSON() ([]byte, error) { + m := map[string]interface{}{ + "key": hex.EncodeToString(e.key), + "next": e.next, + } + return json.Marshal(m) +} + +// UnmarshalJSON implements json.Unmarshaler. +func (e *ExtensionNode) UnmarshalJSON(data []byte) error { + var obj NodeObject + if err := obj.UnmarshalJSON(data); err != nil { + return err + } else if u, ok := obj.Node.(*ExtensionNode); ok { + *e = *u + return nil + } + return errors.New("expected extension node") +} diff --git a/pkg/core/mpt/hash.go b/pkg/core/mpt/hash.go new file mode 100644 index 0000000000..42519a1ace --- /dev/null +++ b/pkg/core/mpt/hash.go @@ -0,0 +1,88 @@ +package mpt + +import ( + "errors" + "fmt" + + "github.com/nspcc-dev/neo-go/pkg/io" + "github.com/nspcc-dev/neo-go/pkg/util" +) + +// HashNode represents MPT's hash node. +type HashNode struct { + BaseNode +} + +var _ Node = (*HashNode)(nil) + +// NewHashNode returns hash node with the specified hash. +func NewHashNode(h util.Uint256) *HashNode { + return &HashNode{ + BaseNode: BaseNode{ + hash: h, + hashValid: true, + }, + } +} + +// Type implements Node interface. +func (h *HashNode) Type() NodeType { return HashT } + +// Hash implements Node interface. +func (h *HashNode) Hash() util.Uint256 { + if !h.hashValid { + panic("can't get hash of an empty HashNode") + } + return h.hash +} + +// IsEmpty returns true iff h is an empty node i.e. contains no hash. +func (h *HashNode) IsEmpty() bool { return !h.hashValid } + +// Bytes returns serialized HashNode. +func (h *HashNode) Bytes() []byte { + return h.getBytes(h) +} + +// DecodeBinary implements io.Serializable. +func (h *HashNode) DecodeBinary(r *io.BinReader) { + sz := r.ReadVarUint() + switch sz { + case 0: + h.hashValid = false + case util.Uint256Size: + h.hashValid = true + r.ReadBytes(h.hash[:]) + default: + r.Err = fmt.Errorf("invalid hash node size: %d", sz) + } +} + +// EncodeBinary implements io.Serializable. +func (h HashNode) EncodeBinary(w *io.BinWriter) { + if !h.hashValid { + w.WriteVarUint(0) + return + } + w.WriteVarBytes(h.hash[:]) +} + +// MarshalJSON implements json.Marshaler. +func (h *HashNode) MarshalJSON() ([]byte, error) { + if !h.hashValid { + return []byte(`{}`), nil + } + return []byte(`{"hash":"` + h.hash.StringLE() + `"}`), nil +} + +// UnmarshalJSON implements json.Unmarshaler. +func (h *HashNode) UnmarshalJSON(data []byte) error { + var obj NodeObject + if err := obj.UnmarshalJSON(data); err != nil { + return err + } else if u, ok := obj.Node.(*HashNode); ok { + *h = *u + return nil + } + return errors.New("expected hash node") +} diff --git a/pkg/core/mpt/helpers.go b/pkg/core/mpt/helpers.go new file mode 100644 index 0000000000..1c67c6c59a --- /dev/null +++ b/pkg/core/mpt/helpers.go @@ -0,0 +1,35 @@ +package mpt + +// lcp returns longest common prefix of a and b. +// Note: it does no allocations. +func lcp(a, b []byte) []byte { + if len(a) < len(b) { + return lcp(b, a) + } + + var i int + for i = 0; i < len(b); i++ { + if a[i] != b[i] { + break + } + } + + return a[:i] +} + +// copySlice is a helper for copying slice if needed. +func copySlice(a []byte) []byte { + b := make([]byte, len(a)) + copy(b, a) + return b +} + +// toNibbles mangles path by splitting every byte into 2 containing low- and high- 4-byte part. +func toNibbles(path []byte) []byte { + result := make([]byte, len(path)*2) + for i := range path { + result[i*2] = path[i] >> 4 + result[i*2+1] = path[i] & 0x0F + } + return result +} diff --git a/pkg/core/mpt/leaf.go b/pkg/core/mpt/leaf.go new file mode 100644 index 0000000000..82dd8eef6e --- /dev/null +++ b/pkg/core/mpt/leaf.go @@ -0,0 +1,73 @@ +package mpt + +import ( + "encoding/hex" + "errors" + "fmt" + + "github.com/nspcc-dev/neo-go/pkg/io" + "github.com/nspcc-dev/neo-go/pkg/util" +) + +// MaxValueLength is a max length of a leaf node value. +const MaxValueLength = 1024 * 1024 + +// LeafNode represents MPT's leaf node. +type LeafNode struct { + BaseNode + value []byte +} + +var _ Node = (*LeafNode)(nil) + +// NewLeafNode returns hash node with the specified value. +func NewLeafNode(value []byte) *LeafNode { + return &LeafNode{value: value} +} + +// Type implements Node interface. +func (n LeafNode) Type() NodeType { return LeafT } + +// Hash implements BaseNode interface. +func (n *LeafNode) Hash() util.Uint256 { + return n.getHash(n) +} + +// Bytes implements BaseNode interface. +func (n *LeafNode) Bytes() []byte { + return n.getBytes(n) +} + +// DecodeBinary implements io.Serializable. +func (n *LeafNode) DecodeBinary(r *io.BinReader) { + sz := r.ReadVarUint() + if sz > MaxValueLength { + r.Err = fmt.Errorf("leaf node value is too big: %d", sz) + return + } + n.value = make([]byte, sz) + r.ReadBytes(n.value) + n.invalidateCache() +} + +// EncodeBinary implements io.Serializable. +func (n LeafNode) EncodeBinary(w *io.BinWriter) { + w.WriteVarBytes(n.value) +} + +// MarshalJSON implements json.Marshaler. +func (n *LeafNode) MarshalJSON() ([]byte, error) { + return []byte(`{"value":"` + hex.EncodeToString(n.value) + `"}`), nil +} + +// UnmarshalJSON implements json.Unmarshaler. +func (n *LeafNode) UnmarshalJSON(data []byte) error { + var obj NodeObject + if err := obj.UnmarshalJSON(data); err != nil { + return err + } else if u, ok := obj.Node.(*LeafNode); ok { + *n = *u + return nil + } + return errors.New("expected leaf node") +} diff --git a/pkg/core/mpt/node.go b/pkg/core/mpt/node.go new file mode 100644 index 0000000000..86e675a014 --- /dev/null +++ b/pkg/core/mpt/node.go @@ -0,0 +1,134 @@ +package mpt + +import ( + "encoding/hex" + "encoding/json" + "errors" + "fmt" + + "github.com/nspcc-dev/neo-go/pkg/io" + "github.com/nspcc-dev/neo-go/pkg/util" +) + +// NodeType represents node type.. +type NodeType byte + +// Node types definitions. +const ( + BranchT NodeType = 0x00 + ExtensionT NodeType = 0x01 + HashT NodeType = 0x02 + LeafT NodeType = 0x03 +) + +// NodeObject represents Node together with it's type. +// It is used for serialization/deserialization where type info +// is also expected. +type NodeObject struct { + Node +} + +// Node represents common interface of all MPT nodes. +type Node interface { + io.Serializable + json.Marshaler + json.Unmarshaler + BaseNodeIface +} + +// EncodeBinary implements io.Serializable. +func (n NodeObject) EncodeBinary(w *io.BinWriter) { + encodeNodeWithType(n.Node, w) +} + +// DecodeBinary implements io.Serializable. +func (n *NodeObject) DecodeBinary(r *io.BinReader) { + typ := NodeType(r.ReadB()) + switch typ { + case BranchT: + n.Node = new(BranchNode) + case ExtensionT: + n.Node = new(ExtensionNode) + case HashT: + n.Node = new(HashNode) + case LeafT: + n.Node = new(LeafNode) + default: + r.Err = fmt.Errorf("invalid node type: %x", typ) + return + } + n.Node.DecodeBinary(r) +} + +// UnmarshalJSON implements json.Unmarshaler. +func (n *NodeObject) UnmarshalJSON(data []byte) error { + var m map[string]json.RawMessage + err := json.Unmarshal(data, &m) + if err != nil { // it can be a branch node + var nodes []NodeObject + if err := json.Unmarshal(data, &nodes); err != nil { + return err + } else if len(nodes) != childrenCount { + return errors.New("invalid length of branch node") + } + + b := NewBranchNode() + for i := range b.Children { + b.Children[i] = nodes[i].Node + } + n.Node = b + return nil + } + + switch len(m) { + case 0: + n.Node = new(HashNode) + case 1: + if v, ok := m["hash"]; ok { + var h util.Uint256 + if err := json.Unmarshal(v, &h); err != nil { + return err + } + n.Node = NewHashNode(h) + } else if v, ok = m["value"]; ok { + b, err := unmarshalHex(v) + if err != nil { + return err + } else if len(b) > MaxValueLength { + return errors.New("leaf value is too big") + } + n.Node = NewLeafNode(b) + } else { + return errors.New("invalid field") + } + case 2: + keyRaw, ok1 := m["key"] + nextRaw, ok2 := m["next"] + if !ok1 || !ok2 { + return errors.New("invalid field") + } + key, err := unmarshalHex(keyRaw) + if err != nil { + return err + } else if len(key) > MaxKeyLength { + return errors.New("extension key is too big") + } + + var next NodeObject + if err := json.Unmarshal(nextRaw, &next); err != nil { + return err + } + n.Node = NewExtensionNode(key, next.Node) + default: + return errors.New("0, 1 or 2 fields expected") + } + return nil +} + +func unmarshalHex(data json.RawMessage) ([]byte, error) { + var s string + if err := json.Unmarshal(data, &s); err != nil { + return nil, err + } + return hex.DecodeString(s) +} diff --git a/pkg/core/mpt/node_test.go b/pkg/core/mpt/node_test.go new file mode 100644 index 0000000000..e3aab54d6c --- /dev/null +++ b/pkg/core/mpt/node_test.go @@ -0,0 +1,156 @@ +package mpt + +import ( + "encoding/json" + "testing" + + "github.com/nspcc-dev/neo-go/pkg/internal/random" + "github.com/nspcc-dev/neo-go/pkg/internal/testserdes" + "github.com/nspcc-dev/neo-go/pkg/io" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func getTestFuncEncode(ok bool, expected, actual Node) func(t *testing.T) { + return func(t *testing.T) { + t.Run("IO", func(t *testing.T) { + bs, err := testserdes.EncodeBinary(expected) + require.NoError(t, err) + err = testserdes.DecodeBinary(bs, actual) + if !ok { + require.Error(t, err) + return + } + require.NoError(t, err) + require.Equal(t, expected.Type(), actual.Type()) + require.Equal(t, expected.Hash(), actual.Hash()) + }) + t.Run("JSON", func(t *testing.T) { + bs, err := json.Marshal(expected) + require.NoError(t, err) + err = json.Unmarshal(bs, actual) + if !ok { + require.Error(t, err) + return + } + require.NoError(t, err) + require.Equal(t, expected.Type(), actual.Type()) + require.Equal(t, expected.Hash(), actual.Hash()) + }) + } +} + +func TestNode_Serializable(t *testing.T) { + t.Run("Leaf", func(t *testing.T) { + t.Run("Good", func(t *testing.T) { + l := NewLeafNode(random.Bytes(123)) + t.Run("Raw", getTestFuncEncode(true, l, new(LeafNode))) + t.Run("WithType", getTestFuncEncode(true, &NodeObject{l}, new(NodeObject))) + }) + t.Run("BigValue", getTestFuncEncode(false, + NewLeafNode(random.Bytes(MaxValueLength+1)), new(LeafNode))) + }) + + t.Run("Extension", func(t *testing.T) { + t.Run("Good", func(t *testing.T) { + e := NewExtensionNode(random.Bytes(42), NewLeafNode(random.Bytes(10))) + t.Run("Raw", getTestFuncEncode(true, e, new(ExtensionNode))) + t.Run("WithType", getTestFuncEncode(true, &NodeObject{e}, new(NodeObject))) + }) + t.Run("BigKey", getTestFuncEncode(false, + NewExtensionNode(random.Bytes(MaxKeyLength+1), NewLeafNode(random.Bytes(10))), new(ExtensionNode))) + }) + + t.Run("Branch", func(t *testing.T) { + b := NewBranchNode() + b.Children[0] = NewLeafNode(random.Bytes(10)) + b.Children[lastChild] = NewHashNode(random.Uint256()) + t.Run("Raw", getTestFuncEncode(true, b, new(BranchNode))) + t.Run("WithType", getTestFuncEncode(true, &NodeObject{b}, new(NodeObject))) + }) + + t.Run("Hash", func(t *testing.T) { + t.Run("Good", func(t *testing.T) { + h := NewHashNode(random.Uint256()) + t.Run("Raw", getTestFuncEncode(true, h, new(HashNode))) + t.Run("WithType", getTestFuncEncode(true, &NodeObject{h}, new(NodeObject))) + }) + t.Run("Empty", func(t *testing.T) { // compare nodes, not hashes + testserdes.EncodeDecodeBinary(t, new(HashNode), new(HashNode)) + }) + t.Run("InvalidSize", func(t *testing.T) { + buf := io.NewBufBinWriter() + buf.BinWriter.WriteVarBytes(make([]byte, 13)) + require.Error(t, testserdes.DecodeBinary(buf.Bytes(), new(HashNode))) + }) + }) + + t.Run("Invalid", func(t *testing.T) { + require.Error(t, testserdes.DecodeBinary([]byte{0xFF}, new(NodeObject))) + }) +} + +// https://github.com/neo-project/neo/blob/neox-2.x/neo.UnitTests/UT_MPTTrie.cs#L198 +func TestJSONSharp(t *testing.T) { + tr := NewTrie(nil, newTestStore()) + require.NoError(t, tr.Put([]byte{0xac, 0x11}, []byte{0xac, 0x11})) + require.NoError(t, tr.Put([]byte{0xac, 0x22}, []byte{0xac, 0x22})) + require.NoError(t, tr.Put([]byte{0xac}, []byte{0xac})) + require.NoError(t, tr.Delete([]byte{0xac, 0x11})) + require.NoError(t, tr.Delete([]byte{0xac, 0x22})) + + js, err := tr.root.MarshalJSON() + require.NoError(t, err) + require.JSONEq(t, `{"key":"0a0c", "next":{"value":"ac"}}`, string(js)) +} + +func TestInvalidJSON(t *testing.T) { + t.Run("InvalidChildrenCount", func(t *testing.T) { + var cs [childrenCount + 1]Node + for i := range cs { + cs[i] = new(HashNode) + } + data, err := json.Marshal(cs) + require.NoError(t, err) + + var n NodeObject + require.Error(t, json.Unmarshal(data, &n)) + }) + + testCases := []struct { + name string + data []byte + }{ + {"WrongFieldCount", []byte(`{"key":"0102", "next": {}, "field": {}}`)}, + {"InvalidField1", []byte(`{"next":{}}`)}, + {"InvalidField2", []byte(`{"key":"0102", "hash":{}}`)}, + {"InvalidKey", []byte(`{"key":"xy", "next":{}}`)}, + {"InvalidNext", []byte(`{"key":"01", "next":[]}`)}, + {"InvalidHash", []byte(`{"hash":"01"}`)}, + {"InvalidValue", []byte(`{"value":1}`)}, + {"InvalidBranch", []byte(`[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]`)}, + } + for _, tc := range testCases { + var n NodeObject + assert.Errorf(t, json.Unmarshal(tc.data, &n), "no error in "+tc.name) + } +} + +// C# interoperability test +// https://github.com/neo-project/neo/blob/neox-2.x/neo.UnitTests/UT_MPTTrie.cs#L135 +func TestRootHash(t *testing.T) { + b := NewBranchNode() + r := NewExtensionNode([]byte{0x0A, 0x0C}, b) + + v1 := NewLeafNode([]byte{0xAB, 0xCD}) + l1 := NewExtensionNode([]byte{0x01}, v1) + b.Children[0] = l1 + + v2 := NewLeafNode([]byte{0x22, 0x22}) + l2 := NewExtensionNode([]byte{0x09}, v2) + b.Children[9] = l2 + + r1 := NewExtensionNode([]byte{0x0A, 0x0C, 0x00, 0x01}, v1) + require.Equal(t, "dea3ab46e9461e885ed7091c1e533e0a8030b248d39cbc638962394eaca0fbb3", r1.Hash().StringLE()) + require.Equal(t, "93e8e1ffe2f83dd92fca67330e273bcc811bf64b8f8d9d1b25d5e7366b47d60d", r.Hash().StringLE()) +} diff --git a/pkg/core/mpt/proof.go b/pkg/core/mpt/proof.go new file mode 100644 index 0000000000..5f8fcdc84f --- /dev/null +++ b/pkg/core/mpt/proof.go @@ -0,0 +1,74 @@ +package mpt + +import ( + "bytes" + + "github.com/nspcc-dev/neo-go/pkg/core/storage" + "github.com/nspcc-dev/neo-go/pkg/crypto/hash" + "github.com/nspcc-dev/neo-go/pkg/util" +) + +// GetProof returns a proof that key belongs to t. +// Proof consist of serialized nodes occuring on path from the root to the leaf of key. +func (t *Trie) GetProof(key []byte) ([][]byte, error) { + var proof [][]byte + path := toNibbles(key) + r, err := t.getProof(t.root, path, &proof) + if err != nil { + return proof, err + } + t.root = r + return proof, nil +} + +func (t *Trie) getProof(curr Node, path []byte, proofs *[][]byte) (Node, error) { + switch n := curr.(type) { + case *LeafNode: + if len(path) == 0 { + *proofs = append(*proofs, copySlice(n.Bytes())) + return n, nil + } + case *BranchNode: + *proofs = append(*proofs, copySlice(n.Bytes())) + i, path := splitPath(path) + r, err := t.getProof(n.Children[i], path, proofs) + if err != nil { + return nil, err + } + n.Children[i] = r + return n, nil + case *ExtensionNode: + if bytes.HasPrefix(path, n.key) { + *proofs = append(*proofs, copySlice(n.Bytes())) + r, err := t.getProof(n.next, path[len(n.key):], proofs) + if err != nil { + return nil, err + } + n.next = r + return n, nil + } + case *HashNode: + if !n.IsEmpty() { + r, err := t.getFromStore(n.Hash()) + if err != nil { + return nil, err + } + return t.getProof(r, path, proofs) + } + } + return nil, ErrNotFound +} + +// VerifyProof verifies that path indeed belongs to a MPT with the specified root hash. +// It also returns value for the key. +func VerifyProof(rh util.Uint256, key []byte, proofs [][]byte) ([]byte, bool) { + path := toNibbles(key) + tr := NewTrie(NewHashNode(rh), storage.NewMemCachedStore(storage.NewMemoryStore())) + for i := range proofs { + h := hash.DoubleSha256(proofs[i]) + // no errors in Put to memory store + _ = tr.Store.Put(makeStorageKey(h[:]), proofs[i]) + } + _, bs, err := tr.getWithPath(tr.root, path) + return bs, err == nil +} diff --git a/pkg/core/mpt/proof_test.go b/pkg/core/mpt/proof_test.go new file mode 100644 index 0000000000..17301af15e --- /dev/null +++ b/pkg/core/mpt/proof_test.go @@ -0,0 +1,73 @@ +package mpt + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func newProofTrie(t *testing.T) *Trie { + l := NewLeafNode([]byte("somevalue")) + e := NewExtensionNode([]byte{0x05, 0x06, 0x07}, l) + l2 := NewLeafNode([]byte("invalid")) + e2 := NewExtensionNode([]byte{0x05}, NewHashNode(l2.Hash())) + b := NewBranchNode() + b.Children[4] = NewHashNode(e.Hash()) + b.Children[5] = e2 + + tr := NewTrie(b, newTestStore()) + require.NoError(t, tr.Put([]byte{0x12, 0x31}, []byte("value1"))) + require.NoError(t, tr.Put([]byte{0x12, 0x32}, []byte("value2"))) + tr.putToStore(l) + tr.putToStore(e) + return tr +} + +func TestTrie_GetProof(t *testing.T) { + tr := newProofTrie(t) + + t.Run("MissingKey", func(t *testing.T) { + _, err := tr.GetProof([]byte{0x12}) + require.Error(t, err) + }) + + t.Run("Valid", func(t *testing.T) { + _, err := tr.GetProof([]byte{0x12, 0x31}) + require.NoError(t, err) + }) + + t.Run("MissingHashNode", func(t *testing.T) { + _, err := tr.GetProof([]byte{0x55}) + require.Error(t, err) + }) +} + +func TestVerifyProof(t *testing.T) { + tr := newProofTrie(t) + + t.Run("Simple", func(t *testing.T) { + proof, err := tr.GetProof([]byte{0x12, 0x32}) + require.NoError(t, err) + + t.Run("Good", func(t *testing.T) { + v, ok := VerifyProof(tr.root.Hash(), []byte{0x12, 0x32}, proof) + require.True(t, ok) + require.Equal(t, []byte("value2"), v) + }) + + t.Run("Bad", func(t *testing.T) { + _, ok := VerifyProof(tr.root.Hash(), []byte{0x12, 0x31}, proof) + require.False(t, ok) + }) + }) + + t.Run("InsideHash", func(t *testing.T) { + key := []byte{0x45, 0x67} + proof, err := tr.GetProof(key) + require.NoError(t, err) + + v, ok := VerifyProof(tr.root.Hash(), key, proof) + require.True(t, ok) + require.Equal(t, []byte("somevalue"), v) + }) +} diff --git a/pkg/core/mpt/trie.go b/pkg/core/mpt/trie.go new file mode 100644 index 0000000000..08d128d888 --- /dev/null +++ b/pkg/core/mpt/trie.go @@ -0,0 +1,390 @@ +package mpt + +import ( + "bytes" + "errors" + + "github.com/nspcc-dev/neo-go/pkg/core/storage" + "github.com/nspcc-dev/neo-go/pkg/io" + "github.com/nspcc-dev/neo-go/pkg/util" +) + +// Trie is an MPT trie storing all key-value pairs. +type Trie struct { + Store *storage.MemCachedStore + + root Node +} + +// ErrNotFound is returned when requested trie item is missing. +var ErrNotFound = errors.New("item not found") + +// NewTrie returns new MPT trie. It accepts a MemCachedStore to decouple storage errors from logic errors +// so that all storage errors are processed during `store.Persist()` at the caller. +// This also has the benefit, that every `Put` can be considered an atomic operation. +func NewTrie(root Node, store *storage.MemCachedStore) *Trie { + if root == nil { + root = new(HashNode) + } + + return &Trie{ + Store: store, + root: root, + } +} + +// Get returns value for the provided key in t. +func (t *Trie) Get(key []byte) ([]byte, error) { + path := toNibbles(key) + r, bs, err := t.getWithPath(t.root, path) + if err != nil { + return nil, err + } + t.root = r + return bs, nil +} + +// getWithPath returns value the provided path in a subtrie rooting in curr. +// It also returns a current node with all hash nodes along the path +// replaced to their "unhashed" counterparts. +func (t *Trie) getWithPath(curr Node, path []byte) (Node, []byte, error) { + switch n := curr.(type) { + case *LeafNode: + if len(path) == 0 { + return curr, copySlice(n.value), nil + } + case *BranchNode: + i, path := splitPath(path) + r, bs, err := t.getWithPath(n.Children[i], path) + if err != nil { + return nil, nil, err + } + n.Children[i] = r + return n, bs, nil + case *HashNode: + if !n.IsEmpty() { + if r, err := t.getFromStore(n.hash); err == nil { + return t.getWithPath(r, path) + } + } + case *ExtensionNode: + if bytes.HasPrefix(path, n.key) { + r, bs, err := t.getWithPath(n.next, path[len(n.key):]) + if err != nil { + return nil, nil, err + } + n.next = r + return curr, bs, err + } + default: + panic("invalid MPT node type") + } + return curr, nil, ErrNotFound +} + +// Put puts key-value pair in t. +func (t *Trie) Put(key, value []byte) error { + if len(key) > MaxKeyLength { + return errors.New("key is too big") + } else if len(value) > MaxValueLength { + return errors.New("value is too big") + } + if len(value) == 0 { + return t.Delete(key) + } + path := toNibbles(key) + n := NewLeafNode(value) + r, err := t.putIntoNode(t.root, path, n) + if err != nil { + return err + } + t.root = r + return nil +} + +// putIntoLeaf puts val to trie if current node is a Leaf. +// It returns Node if curr needs to be replaced and error if any. +func (t *Trie) putIntoLeaf(curr *LeafNode, path []byte, val Node) (Node, error) { + v := val.(*LeafNode) + if len(path) == 0 { + return v, nil + } + + b := NewBranchNode() + b.Children[path[0]] = newSubTrie(path[1:], v) + b.Children[lastChild] = curr + return b, nil +} + +// putIntoBranch puts val to trie if current node is a Branch. +// It returns Node if curr needs to be replaced and error if any. +func (t *Trie) putIntoBranch(curr *BranchNode, path []byte, val Node) (Node, error) { + i, path := splitPath(path) + r, err := t.putIntoNode(curr.Children[i], path, val) + if err != nil { + return nil, err + } + curr.Children[i] = r + curr.invalidateCache() + return curr, nil +} + +// putIntoExtension puts val to trie if current node is an Extension. +// It returns Node if curr needs to be replaced and error if any. +func (t *Trie) putIntoExtension(curr *ExtensionNode, path []byte, val Node) (Node, error) { + if bytes.HasPrefix(path, curr.key) { + r, err := t.putIntoNode(curr.next, path[len(curr.key):], val) + if err != nil { + return nil, err + } + curr.next = r + curr.invalidateCache() + return curr, nil + } + + pref := lcp(curr.key, path) + lp := len(pref) + keyTail := curr.key[lp:] + pathTail := path[lp:] + + s1 := newSubTrie(keyTail[1:], curr.next) + b := NewBranchNode() + b.Children[keyTail[0]] = s1 + + i, pathTail := splitPath(pathTail) + s2 := newSubTrie(pathTail, val) + b.Children[i] = s2 + + if lp > 0 { + return NewExtensionNode(copySlice(pref), b), nil + } + return b, nil +} + +// putIntoHash puts val to trie if current node is a HashNode. +// It returns Node if curr needs to be replaced and error if any. +func (t *Trie) putIntoHash(curr *HashNode, path []byte, val Node) (Node, error) { + if curr.IsEmpty() { + return newSubTrie(path, val), nil + } + + result, err := t.getFromStore(curr.hash) + if err != nil { + return nil, err + } + return t.putIntoNode(result, path, val) +} + +// newSubTrie create new trie containing node at provided path. +func newSubTrie(path []byte, val Node) Node { + if len(path) == 0 { + return val + } + return NewExtensionNode(path, val) +} + +func (t *Trie) putIntoNode(curr Node, path []byte, val Node) (Node, error) { + switch n := curr.(type) { + case *LeafNode: + return t.putIntoLeaf(n, path, val) + case *BranchNode: + return t.putIntoBranch(n, path, val) + case *ExtensionNode: + return t.putIntoExtension(n, path, val) + case *HashNode: + return t.putIntoHash(n, path, val) + default: + panic("invalid MPT node type") + } +} + +// Delete removes key from trie. +// It returns no error on missing key. +func (t *Trie) Delete(key []byte) error { + path := toNibbles(key) + r, err := t.deleteFromNode(t.root, path) + if err != nil { + return err + } + t.root = r + return nil +} + +func (t *Trie) deleteFromBranch(b *BranchNode, path []byte) (Node, error) { + i, path := splitPath(path) + r, err := t.deleteFromNode(b.Children[i], path) + if err != nil { + return nil, err + } + b.Children[i] = r + b.invalidateCache() + var count, index int + for i := range b.Children { + h, ok := b.Children[i].(*HashNode) + if !ok || !h.IsEmpty() { + index = i + count++ + } + } + // count is >= 1 because branch node had at least 2 children before deletion. + if count > 1 { + return b, nil + } + c := b.Children[index] + if index == lastChild { + return c, nil + } + if h, ok := c.(*HashNode); ok { + c, err = t.getFromStore(h.Hash()) + if err != nil { + return nil, err + } + } + if e, ok := c.(*ExtensionNode); ok { + e.key = append([]byte{byte(index)}, e.key...) + e.invalidateCache() + return e, nil + } + + return NewExtensionNode([]byte{byte(index)}, c), nil +} + +func (t *Trie) deleteFromExtension(n *ExtensionNode, path []byte) (Node, error) { + if !bytes.HasPrefix(path, n.key) { + return nil, ErrNotFound + } + r, err := t.deleteFromNode(n.next, path[len(n.key):]) + if err != nil { + return nil, err + } + switch nxt := r.(type) { + case *ExtensionNode: + n.key = append(n.key, nxt.key...) + n.next = nxt.next + case *HashNode: + if nxt.IsEmpty() { + return nxt, nil + } + default: + n.next = r + } + n.invalidateCache() + return n, nil +} + +func (t *Trie) deleteFromNode(curr Node, path []byte) (Node, error) { + switch n := curr.(type) { + case *LeafNode: + if len(path) == 0 { + return new(HashNode), nil + } + return nil, ErrNotFound + case *BranchNode: + return t.deleteFromBranch(n, path) + case *ExtensionNode: + return t.deleteFromExtension(n, path) + case *HashNode: + if n.IsEmpty() { + return nil, ErrNotFound + } + newNode, err := t.getFromStore(n.Hash()) + if err != nil { + return nil, err + } + return t.deleteFromNode(newNode, path) + default: + panic("invalid MPT node type") + } +} + +// StateRoot returns root hash of t. +func (t *Trie) StateRoot() util.Uint256 { + if hn, ok := t.root.(*HashNode); ok && hn.IsEmpty() { + return util.Uint256{} + } + return t.root.Hash() +} + +func makeStorageKey(mptKey []byte) []byte { + return append([]byte{byte(storage.DataMPT)}, mptKey...) +} + +// Flush puts every node in the trie except Hash ones to the storage. +// Because we care only about block-level changes, there is no need to put every +// new node to storage. Normally, flush should be called with every StateRoot persist, i.e. +// after every block. +func (t *Trie) Flush() { + t.flush(t.root) +} + +func (t *Trie) flush(node Node) { + if node.IsFlushed() { + return + } + switch n := node.(type) { + case *BranchNode: + for i := range n.Children { + t.flush(n.Children[i]) + } + case *ExtensionNode: + t.flush(n.next) + case *HashNode: + return + } + t.putToStore(node) +} + +func (t *Trie) putToStore(n Node) { + if n.Type() == HashT { + panic("can't put hash node in trie") + } + _ = t.Store.Put(makeStorageKey(n.Hash().BytesBE()), n.Bytes()) // put in MemCached returns no errors + n.SetFlushed() +} + +func (t *Trie) getFromStore(h util.Uint256) (Node, error) { + data, err := t.Store.Get(makeStorageKey(h.BytesBE())) + if err != nil { + return nil, err + } + + var n NodeObject + r := io.NewBinReaderFromBuf(data) + n.DecodeBinary(r) + if r.Err != nil { + return nil, r.Err + } + return n.Node, nil +} + +// Collapse compresses all nodes at depth n to the hash nodes. +// Note: this function does not perform any kind of storage flushing so +// `Flush()` should be called explicitly before invoking function. +func (t *Trie) Collapse(depth int) { + if depth < 0 { + panic("negative depth") + } + t.root = collapse(depth, t.root) +} + +func collapse(depth int, node Node) Node { + if _, ok := node.(*HashNode); ok { + return node + } else if depth == 0 { + return NewHashNode(node.Hash()) + } + + switch n := node.(type) { + case *BranchNode: + for i := range n.Children { + n.Children[i] = collapse(depth-1, n.Children[i]) + } + case *ExtensionNode: + n.next = collapse(depth-1, n.next) + case *LeafNode: + case *HashNode: + default: + panic("invalid MPT node type") + } + return node +} diff --git a/pkg/core/mpt/trie_test.go b/pkg/core/mpt/trie_test.go new file mode 100644 index 0000000000..d06e08168b --- /dev/null +++ b/pkg/core/mpt/trie_test.go @@ -0,0 +1,446 @@ +package mpt + +import ( + "testing" + + "github.com/nspcc-dev/neo-go/pkg/core/storage" + "github.com/nspcc-dev/neo-go/pkg/internal/random" + "github.com/stretchr/testify/require" +) + +func newTestStore() *storage.MemCachedStore { + return storage.NewMemCachedStore(storage.NewMemoryStore()) +} + +func newTestTrie(t *testing.T) *Trie { + b := NewBranchNode() + + l1 := NewLeafNode([]byte{0xAB, 0xCD}) + b.Children[0] = NewExtensionNode([]byte{0x01}, l1) + + l2 := NewLeafNode([]byte{0x22, 0x22}) + b.Children[9] = NewExtensionNode([]byte{0x09}, l2) + + v := NewLeafNode([]byte("hello")) + h := NewHashNode(v.Hash()) + b.Children[10] = NewExtensionNode([]byte{0x0e}, h) + + e := NewExtensionNode(toNibbles([]byte{0xAC}), b) + tr := NewTrie(e, newTestStore()) + + tr.putToStore(e) + tr.putToStore(b) + tr.putToStore(l1) + tr.putToStore(l2) + tr.putToStore(v) + tr.putToStore(b.Children[0]) + tr.putToStore(b.Children[9]) + tr.putToStore(b.Children[10]) + + return tr +} + +func TestTrie_PutIntoBranchNode(t *testing.T) { + b := NewBranchNode() + l := NewLeafNode([]byte{0x8}) + b.Children[0x7] = NewHashNode(l.Hash()) + b.Children[0x8] = NewHashNode(random.Uint256()) + tr := NewTrie(b, newTestStore()) + + // next + require.NoError(t, tr.Put([]byte{}, []byte{0x12, 0x34})) + tr.testHas(t, []byte{}, []byte{0x12, 0x34}) + + // empty hash node child + require.NoError(t, tr.Put([]byte{0x66}, []byte{0x56})) + tr.testHas(t, []byte{0x66}, []byte{0x56}) + require.True(t, isValid(tr.root)) + + // missing hash + require.Error(t, tr.Put([]byte{0x70}, []byte{0x42})) + require.True(t, isValid(tr.root)) + + // hash is in store + tr.putToStore(l) + require.NoError(t, tr.Put([]byte{0x70}, []byte{0x42})) + require.True(t, isValid(tr.root)) +} + +func TestTrie_PutIntoExtensionNode(t *testing.T) { + l := NewLeafNode([]byte{0x11}) + key := []byte{0x12} + e := NewExtensionNode(toNibbles(key), NewHashNode(l.Hash())) + tr := NewTrie(e, newTestStore()) + + // missing hash + require.Error(t, tr.Put(key, []byte{0x42})) + + tr.putToStore(l) + require.NoError(t, tr.Put(key, []byte{0x42})) + tr.testHas(t, key, []byte{0x42}) + require.True(t, isValid(tr.root)) +} + +func TestTrie_PutIntoHashNode(t *testing.T) { + b := NewBranchNode() + l := NewLeafNode(random.Bytes(5)) + e := NewExtensionNode([]byte{0x02}, l) + b.Children[1] = NewHashNode(e.Hash()) + b.Children[9] = NewHashNode(random.Uint256()) + tr := NewTrie(b, newTestStore()) + + tr.putToStore(e) + + t.Run("MissingLeafHash", func(t *testing.T) { + _, err := tr.Get([]byte{0x12}) + require.Error(t, err) + }) + + tr.putToStore(l) + + val := random.Bytes(3) + require.NoError(t, tr.Put([]byte{0x12, 0x34}, val)) + tr.testHas(t, []byte{0x12, 0x34}, val) + tr.testHas(t, []byte{0x12}, l.value) + require.True(t, isValid(tr.root)) +} + +func TestTrie_Put(t *testing.T) { + trExp := newTestTrie(t) + + trAct := NewTrie(nil, newTestStore()) + require.NoError(t, trAct.Put([]byte{0xAC, 0x01}, []byte{0xAB, 0xCD})) + require.NoError(t, trAct.Put([]byte{0xAC, 0x99}, []byte{0x22, 0x22})) + require.NoError(t, trAct.Put([]byte{0xAC, 0xAE}, []byte("hello"))) + + // Note: the exact tries differ because of ("acae":"hello") node is stored as Hash node in test trie. + require.Equal(t, trExp.root.Hash(), trAct.root.Hash()) + require.True(t, isValid(trAct.root)) +} + +func TestTrie_PutInvalid(t *testing.T) { + tr := NewTrie(nil, newTestStore()) + key, value := []byte("key"), []byte("value") + + // big key + require.Error(t, tr.Put(make([]byte, MaxKeyLength+1), value)) + + // big value + require.Error(t, tr.Put(key, make([]byte, MaxValueLength+1))) + + // this is ok though + require.NoError(t, tr.Put(key, value)) + tr.testHas(t, key, value) +} + +func TestTrie_BigPut(t *testing.T) { + tr := NewTrie(nil, newTestStore()) + items := []struct{ k, v string }{ + {"item with long key", "value1"}, + {"item with matching prefix", "value2"}, + {"another prefix", "value3"}, + {"another prefix 2", "value4"}, + {"another ", "value5"}, + } + + for i := range items { + require.NoError(t, tr.Put([]byte(items[i].k), []byte(items[i].v))) + } + + for i := range items { + tr.testHas(t, []byte(items[i].k), []byte(items[i].v)) + } + + t.Run("Rewrite", func(t *testing.T) { + k, v := []byte(items[0].k), []byte{0x01, 0x23} + require.NoError(t, tr.Put(k, v)) + tr.testHas(t, k, v) + }) + + t.Run("Remove", func(t *testing.T) { + k := []byte(items[1].k) + require.NoError(t, tr.Put(k, []byte{})) + tr.testHas(t, k, nil) + }) +} + +func (tr *Trie) testHas(t *testing.T, key, value []byte) { + v, err := tr.Get(key) + if value == nil { + require.Error(t, err) + return + } + require.NoError(t, err) + require.Equal(t, value, v) +} + +// isValid checks for 3 invariants: +// - BranchNode contains > 1 children +// - ExtensionNode do not contain another extension node +// - ExtensionNode do not have nil key +// It is used only during testing to catch possible bugs. +func isValid(curr Node) bool { + switch n := curr.(type) { + case *BranchNode: + var count int + for i := range n.Children { + if !isValid(n.Children[i]) { + return false + } + hn, ok := n.Children[i].(*HashNode) + if !ok || !hn.IsEmpty() { + count++ + } + } + return count > 1 + case *ExtensionNode: + _, ok := n.next.(*ExtensionNode) + return len(n.key) != 0 && !ok + default: + return true + } +} + +func TestTrie_Get(t *testing.T) { + t.Run("HashNode", func(t *testing.T) { + tr := newTestTrie(t) + tr.testHas(t, []byte{0xAC, 0xAE}, []byte("hello")) + }) + t.Run("UnfoldRoot", func(t *testing.T) { + tr := newTestTrie(t) + single := NewTrie(NewHashNode(tr.root.Hash()), tr.Store) + single.testHas(t, []byte{0xAC}, nil) + single.testHas(t, []byte{0xAC, 0x01}, []byte{0xAB, 0xCD}) + single.testHas(t, []byte{0xAC, 0x99}, []byte{0x22, 0x22}) + single.testHas(t, []byte{0xAC, 0xAE}, []byte("hello")) + }) +} + +func TestTrie_Flush(t *testing.T) { + pairs := map[string][]byte{ + "": []byte("value0"), + "key1": []byte("value1"), + "key2": []byte("value2"), + } + + tr := NewTrie(nil, newTestStore()) + for k, v := range pairs { + require.NoError(t, tr.Put([]byte(k), v)) + } + + tr.Flush() + tr = NewTrie(NewHashNode(tr.StateRoot()), tr.Store) + for k, v := range pairs { + actual, err := tr.Get([]byte(k)) + require.NoError(t, err) + require.Equal(t, v, actual) + } +} + +func TestTrie_Delete(t *testing.T) { + t.Run("Hash", func(t *testing.T) { + t.Run("FromStore", func(t *testing.T) { + l := NewLeafNode([]byte{0x12}) + tr := NewTrie(NewHashNode(l.Hash()), newTestStore()) + t.Run("NotInStore", func(t *testing.T) { + require.Error(t, tr.Delete([]byte{})) + }) + + tr.putToStore(l) + tr.testHas(t, []byte{}, []byte{0x12}) + require.NoError(t, tr.Delete([]byte{})) + tr.testHas(t, []byte{}, nil) + }) + + t.Run("Empty", func(t *testing.T) { + tr := NewTrie(nil, newTestStore()) + require.Error(t, tr.Delete([]byte{})) + }) + }) + + t.Run("Leaf", func(t *testing.T) { + l := NewLeafNode([]byte{0x12, 0x34}) + tr := NewTrie(l, newTestStore()) + t.Run("NonExistentKey", func(t *testing.T) { + require.Error(t, tr.Delete([]byte{0x12})) + tr.testHas(t, []byte{}, []byte{0x12, 0x34}) + }) + require.NoError(t, tr.Delete([]byte{})) + tr.testHas(t, []byte{}, nil) + }) + + t.Run("Extension", func(t *testing.T) { + t.Run("SingleKey", func(t *testing.T) { + l := NewLeafNode([]byte{0x12, 0x34}) + e := NewExtensionNode([]byte{0x0A, 0x0B}, l) + tr := NewTrie(e, newTestStore()) + + t.Run("NonExistentKey", func(t *testing.T) { + require.Error(t, tr.Delete([]byte{})) + tr.testHas(t, []byte{0xAB}, []byte{0x12, 0x34}) + }) + + require.NoError(t, tr.Delete([]byte{0xAB})) + require.True(t, tr.root.(*HashNode).IsEmpty()) + }) + + t.Run("MultipleKeys", func(t *testing.T) { + b := NewBranchNode() + b.Children[0] = NewExtensionNode([]byte{0x01}, NewLeafNode([]byte{0x12, 0x34})) + b.Children[6] = NewExtensionNode([]byte{0x07}, NewLeafNode([]byte{0x56, 0x78})) + e := NewExtensionNode([]byte{0x01, 0x02}, b) + tr := NewTrie(e, newTestStore()) + + h := e.Hash() + require.NoError(t, tr.Delete([]byte{0x12, 0x01})) + tr.testHas(t, []byte{0x12, 0x01}, nil) + tr.testHas(t, []byte{0x12, 0x67}, []byte{0x56, 0x78}) + + require.NotEqual(t, h, tr.root.Hash()) + require.Equal(t, toNibbles([]byte{0x12, 0x67}), e.key) + require.IsType(t, (*LeafNode)(nil), e.next) + }) + }) + + t.Run("Branch", func(t *testing.T) { + t.Run("3 Children", func(t *testing.T) { + b := NewBranchNode() + b.Children[lastChild] = NewLeafNode([]byte{0x12}) + b.Children[0] = NewExtensionNode([]byte{0x01}, NewLeafNode([]byte{0x34})) + b.Children[1] = NewExtensionNode([]byte{0x06}, NewLeafNode([]byte{0x56})) + tr := NewTrie(b, newTestStore()) + require.NoError(t, tr.Delete([]byte{0x16})) + tr.testHas(t, []byte{}, []byte{0x12}) + tr.testHas(t, []byte{0x01}, []byte{0x34}) + tr.testHas(t, []byte{0x16}, nil) + }) + t.Run("2 Children", func(t *testing.T) { + newt := func(t *testing.T) *Trie { + b := NewBranchNode() + b.Children[lastChild] = NewLeafNode([]byte{0x12}) + l := NewLeafNode([]byte{0x34}) + e := NewExtensionNode([]byte{0x06}, l) + b.Children[5] = NewHashNode(e.Hash()) + tr := NewTrie(b, newTestStore()) + tr.putToStore(l) + tr.putToStore(e) + return tr + } + + t.Run("DeleteLast", func(t *testing.T) { + t.Run("MergeExtension", func(t *testing.T) { + tr := newt(t) + require.NoError(t, tr.Delete([]byte{})) + tr.testHas(t, []byte{}, nil) + tr.testHas(t, []byte{0x56}, []byte{0x34}) + require.IsType(t, (*ExtensionNode)(nil), tr.root) + }) + + t.Run("LeaveLeaf", func(t *testing.T) { + c := NewBranchNode() + c.Children[5] = NewLeafNode([]byte{0x05}) + c.Children[6] = NewLeafNode([]byte{0x06}) + + b := NewBranchNode() + b.Children[lastChild] = NewLeafNode([]byte{0x12}) + b.Children[5] = c + tr := NewTrie(b, newTestStore()) + + require.NoError(t, tr.Delete([]byte{})) + tr.testHas(t, []byte{}, nil) + tr.testHas(t, []byte{0x55}, []byte{0x05}) + tr.testHas(t, []byte{0x56}, []byte{0x06}) + require.IsType(t, (*ExtensionNode)(nil), tr.root) + }) + }) + + t.Run("DeleteMiddle", func(t *testing.T) { + tr := newt(t) + require.NoError(t, tr.Delete([]byte{0x56})) + tr.testHas(t, []byte{}, []byte{0x12}) + tr.testHas(t, []byte{0x56}, nil) + require.IsType(t, (*LeafNode)(nil), tr.root) + }) + }) + }) +} + +func TestTrie_PanicInvalidRoot(t *testing.T) { + tr := &Trie{Store: newTestStore()} + require.Panics(t, func() { _ = tr.Put([]byte{1}, []byte{2}) }) + require.Panics(t, func() { _, _ = tr.Get([]byte{1}) }) + require.Panics(t, func() { _ = tr.Delete([]byte{1}) }) +} + +func TestTrie_Collapse(t *testing.T) { + t.Run("PanicNegative", func(t *testing.T) { + tr := newTestTrie(t) + require.Panics(t, func() { tr.Collapse(-1) }) + }) + t.Run("Depth=0", func(t *testing.T) { + tr := newTestTrie(t) + h := tr.root.Hash() + + _, ok := tr.root.(*HashNode) + require.False(t, ok) + + tr.Collapse(0) + _, ok = tr.root.(*HashNode) + require.True(t, ok) + require.Equal(t, h, tr.root.Hash()) + }) + t.Run("Branch,Depth=1", func(t *testing.T) { + b := NewBranchNode() + e := NewExtensionNode([]byte{0x01}, NewLeafNode([]byte("value1"))) + he := e.Hash() + b.Children[0] = e + hb := b.Hash() + + tr := NewTrie(b, newTestStore()) + tr.Collapse(1) + + newb, ok := tr.root.(*BranchNode) + require.True(t, ok) + require.Equal(t, hb, newb.Hash()) + require.IsType(t, (*HashNode)(nil), b.Children[0]) + require.Equal(t, he, b.Children[0].Hash()) + }) + t.Run("Extension,Depth=1", func(t *testing.T) { + l := NewLeafNode([]byte("value")) + hl := l.Hash() + e := NewExtensionNode([]byte{0x01}, l) + h := e.Hash() + tr := NewTrie(e, newTestStore()) + tr.Collapse(1) + + newe, ok := tr.root.(*ExtensionNode) + require.True(t, ok) + require.Equal(t, h, newe.Hash()) + require.IsType(t, (*HashNode)(nil), newe.next) + require.Equal(t, hl, newe.next.Hash()) + }) + t.Run("Leaf", func(t *testing.T) { + l := NewLeafNode([]byte("value")) + tr := NewTrie(l, newTestStore()) + tr.Collapse(10) + require.Equal(t, NewLeafNode([]byte("value")), tr.root) + }) + t.Run("Hash", func(t *testing.T) { + t.Run("Empty", func(t *testing.T) { + tr := NewTrie(new(HashNode), newTestStore()) + require.NotPanics(t, func() { tr.Collapse(1) }) + hn, ok := tr.root.(*HashNode) + require.True(t, ok) + require.True(t, hn.IsEmpty()) + }) + + h := random.Uint256() + hn := NewHashNode(h) + tr := NewTrie(hn, newTestStore()) + tr.Collapse(10) + + newRoot, ok := tr.root.(*HashNode) + require.True(t, ok) + require.Equal(t, NewHashNode(h), newRoot) + }) +} diff --git a/pkg/core/prometheus.go b/pkg/core/prometheus.go index b81fb847d6..c849e34591 100644 --- a/pkg/core/prometheus.go +++ b/pkg/core/prometheus.go @@ -30,6 +30,14 @@ var ( Namespace: "neogo", }, ) + //stateHeight prometheus metric. + stateHeight = prometheus.NewGauge( + prometheus.GaugeOpts{ + Help: "Current verified state height", + Name: "current_state_height", + Namespace: "neogo", + }, + ) ) func init() { @@ -51,3 +59,7 @@ func updateHeaderHeightMetric(hHeight int) { func updateBlockHeightMetric(bHeight uint32) { blockHeight.Set(float64(bHeight)) } + +func updateStateHeightMetric(sHeight uint32) { + stateHeight.Set(float64(sHeight)) +} diff --git a/pkg/core/state/mpt_root.go b/pkg/core/state/mpt_root.go new file mode 100644 index 0000000000..dea3f62fac --- /dev/null +++ b/pkg/core/state/mpt_root.go @@ -0,0 +1,146 @@ +package state + +import ( + "encoding/json" + "errors" + + "github.com/nspcc-dev/neo-go/pkg/core/transaction" + "github.com/nspcc-dev/neo-go/pkg/crypto/hash" + "github.com/nspcc-dev/neo-go/pkg/io" + "github.com/nspcc-dev/neo-go/pkg/util" +) + +// MPTRootBase represents storage state root. +type MPTRootBase struct { + Version byte `json:"version"` + Index uint32 `json:"index"` + PrevHash util.Uint256 `json:"prehash"` + Root util.Uint256 `json:"stateroot"` +} + +// MPTRoot represents storage state root together with sign info. +type MPTRoot struct { + MPTRootBase + Witness *transaction.Witness `json:"witness,omitempty"` +} + +// MPTRootStateFlag represents verification state of the state root. +type MPTRootStateFlag byte + +// Possible verification states of MPTRoot. +const ( + Unverified MPTRootStateFlag = 0x00 + Verified MPTRootStateFlag = 0x01 + Invalid MPTRootStateFlag = 0x03 +) + +// MPTRootState represents state root together with its verification state. +type MPTRootState struct { + MPTRoot `json:"stateroot"` + Flag MPTRootStateFlag `json:"flag"` +} + +// EncodeBinary implements io.Serializable. +func (s *MPTRootState) EncodeBinary(w *io.BinWriter) { + w.WriteB(byte(s.Flag)) + s.MPTRoot.EncodeBinary(w) +} + +// DecodeBinary implements io.Serializable. +func (s *MPTRootState) DecodeBinary(r *io.BinReader) { + s.Flag = MPTRootStateFlag(r.ReadB()) + s.MPTRoot.DecodeBinary(r) +} + +// GetSignedPart returns part of MPTRootBase which needs to be signed. +func (s *MPTRootBase) GetSignedPart() []byte { + buf := io.NewBufBinWriter() + s.EncodeBinary(buf.BinWriter) + return buf.Bytes() +} + +// Equals checks if s == other. +func (s *MPTRootBase) Equals(other *MPTRootBase) bool { + return s.Version == other.Version && s.Index == other.Index && + s.PrevHash.Equals(other.PrevHash) && s.Root.Equals(other.Root) +} + +// Hash returns hash of s. +func (s *MPTRootBase) Hash() util.Uint256 { + return hash.DoubleSha256(s.GetSignedPart()) +} + +// DecodeBinary implements io.Serializable. +func (s *MPTRootBase) DecodeBinary(r *io.BinReader) { + s.Version = r.ReadB() + s.Index = r.ReadU32LE() + s.PrevHash.DecodeBinary(r) + s.Root.DecodeBinary(r) +} + +// EncodeBinary implements io.Serializable. +func (s *MPTRootBase) EncodeBinary(w *io.BinWriter) { + w.WriteB(s.Version) + w.WriteU32LE(s.Index) + s.PrevHash.EncodeBinary(w) + s.Root.EncodeBinary(w) +} + +// DecodeBinary implements io.Serializable. +func (s *MPTRoot) DecodeBinary(r *io.BinReader) { + s.MPTRootBase.DecodeBinary(r) + + var ws []transaction.Witness + r.ReadArray(&ws, 1) + if len(ws) == 1 { + s.Witness = &ws[0] + } +} + +// EncodeBinary implements io.Serializable. +func (s *MPTRoot) EncodeBinary(w *io.BinWriter) { + s.MPTRootBase.EncodeBinary(w) + if s.Witness == nil { + w.WriteVarUint(0) + } else { + w.WriteArray([]*transaction.Witness{s.Witness}) + } +} + +// String implements fmt.Stringer. +func (f MPTRootStateFlag) String() string { + switch f { + case Unverified: + return "Unverified" + case Verified: + return "Verified" + case Invalid: + return "Invalid" + default: + return "" + } +} + +// MarshalJSON implements json.Marshaler. +func (f MPTRootStateFlag) MarshalJSON() ([]byte, error) { + return []byte(`"` + f.String() + `"`), nil +} + +// UnmarshalJSON implements json.Unmarshaler. +func (f *MPTRootStateFlag) UnmarshalJSON(data []byte) error { + var s string + if err := json.Unmarshal(data, &s); err != nil { + return err + } + switch s { + case "Unverified": + *f = Unverified + case "Verified": + *f = Verified + case "Invalid": + *f = Invalid + default: + return errors.New("unknown flag") + } + return nil +} diff --git a/pkg/core/state/mpt_root_test.go b/pkg/core/state/mpt_root_test.go new file mode 100644 index 0000000000..f1c0b5c61e --- /dev/null +++ b/pkg/core/state/mpt_root_test.go @@ -0,0 +1,100 @@ +package state + +import ( + "encoding/json" + "math/rand" + "testing" + + "github.com/nspcc-dev/neo-go/pkg/core/transaction" + "github.com/nspcc-dev/neo-go/pkg/internal/random" + "github.com/nspcc-dev/neo-go/pkg/internal/testserdes" + "github.com/nspcc-dev/neo-go/pkg/util" + "github.com/stretchr/testify/require" +) + +func testStateRoot() *MPTRoot { + return &MPTRoot{ + MPTRootBase: MPTRootBase{ + Version: byte(rand.Uint32()), + Index: rand.Uint32(), + PrevHash: random.Uint256(), + Root: random.Uint256(), + }, + } +} + +func TestStateRoot_Serializable(t *testing.T) { + r := testStateRoot() + testserdes.EncodeDecodeBinary(t, r, new(MPTRoot)) + + t.Run("WithWitness", func(t *testing.T) { + r.Witness = &transaction.Witness{ + InvocationScript: random.Bytes(10), + VerificationScript: random.Bytes(11), + } + testserdes.EncodeDecodeBinary(t, r, new(MPTRoot)) + }) +} + +func TestStateRootEquals(t *testing.T) { + r1 := testStateRoot() + r2 := *r1 + require.True(t, r1.Equals(&r2.MPTRootBase)) + + r2.MPTRootBase.Index++ + require.False(t, r1.Equals(&r2.MPTRootBase)) +} + +func TestMPTRootState_Serializable(t *testing.T) { + rs := &MPTRootState{ + MPTRoot: *testStateRoot(), + Flag: 0x04, + } + rs.MPTRoot.Witness = &transaction.Witness{ + InvocationScript: random.Bytes(10), + VerificationScript: random.Bytes(11), + } + testserdes.EncodeDecodeBinary(t, rs, new(MPTRootState)) +} + +func TestMPTRootStateUnverifiedByDefault(t *testing.T) { + var r MPTRootState + require.Equal(t, Unverified, r.Flag) +} + +func TestMPTRoot_MarshalJSON(t *testing.T) { + t.Run("Good", func(t *testing.T) { + r := testStateRoot() + rs := &MPTRootState{ + MPTRoot: *r, + Flag: Verified, + } + testserdes.MarshalUnmarshalJSON(t, rs, new(MPTRootState)) + }) + + t.Run("Compatibility", func(t *testing.T) { + js := []byte(`{ + "flag": "Unverified", + "stateroot": { + "version": 1, + "index": 3000000, + "prehash": "0x4f30f43af8dd2262fc331c45bfcd9066ebbacda204e6e81371cbd884fe7d6c90", + "stateroot": "0xb2fd7e368a848ef70d27cf44940a35237333ed05f1d971c9408f0eb285e0b6f3" + }}`) + + rs := new(MPTRootState) + require.NoError(t, json.Unmarshal(js, &rs)) + + require.EqualValues(t, 1, rs.Version) + require.EqualValues(t, 3000000, rs.Index) + require.Nil(t, rs.Witness) + + u, err := util.Uint256DecodeStringLE("4f30f43af8dd2262fc331c45bfcd9066ebbacda204e6e81371cbd884fe7d6c90") + require.NoError(t, err) + require.Equal(t, u, rs.PrevHash) + + u, err = util.Uint256DecodeStringLE("b2fd7e368a848ef70d27cf44940a35237333ed05f1d971c9408f0eb285e0b6f3") + require.NoError(t, err) + require.Equal(t, u, rs.Root) + }) +} diff --git a/pkg/core/storage/store.go b/pkg/core/storage/store.go index 5e70334ed7..575c42ba37 100644 --- a/pkg/core/storage/store.go +++ b/pkg/core/storage/store.go @@ -9,6 +9,7 @@ import ( const ( DataBlock KeyPrefix = 0x01 DataTransaction KeyPrefix = 0x02 + DataMPT KeyPrefix = 0x03 STAccount KeyPrefix = 0x40 STNotification KeyPrefix = 0x4d STContract KeyPrefix = 0x50 diff --git a/pkg/network/helper_test.go b/pkg/network/helper_test.go index 6004c44a4e..61cf1939ef 100644 --- a/pkg/network/helper_test.go +++ b/pkg/network/helper_test.go @@ -49,6 +49,9 @@ func (chain *testChain) AddBlock(block *block.Block) error { } return nil } +func (chain *testChain) AddStateRoot(r *state.MPTRoot) error { + panic("TODO") +} func (chain *testChain) BlockHeight() uint32 { return atomic.LoadUint32(&chain.blockheight) } @@ -98,6 +101,9 @@ func (chain testChain) GetEnrollments() ([]state.Validator, error) { func (chain testChain) GetScriptHashesForVerifying(*transaction.Transaction) ([]util.Uint160, error) { panic("TODO") } +func (chain testChain) GetStateRoot(height uint32) (*state.MPTRootState, error) { + panic("TODO") +} func (chain testChain) GetStorageItem(id int32, key []byte) *state.StorageItem { panic("TODO") }