Skip to content

Commit

Permalink
s/VMContext/Context/, and change from an interface with two implement…
Browse files Browse the repository at this point in the history
…ations to a struct with a lot of optional fields.
  • Loading branch information
Bob Glickstein committed Mar 23, 2017
1 parent 93fd5e5 commit df6946c
Show file tree
Hide file tree
Showing 11 changed files with 243 additions and 324 deletions.
215 changes: 72 additions & 143 deletions protocol/bc/vmcontext.go
Expand Up @@ -6,155 +6,84 @@ import (
"chain/protocol/vm"
)

// BlockVMContext satisfies the vm.Context interface
type BlockVMContext struct {
prog Program
args [][]byte
block *Block
}

func (b *BlockVMContext) VMVersion() uint64 { return b.prog.VMVersion }
func (b *BlockVMContext) Code() []byte { return b.prog.Code }
func (b *BlockVMContext) Arguments() [][]byte { return b.args }

func (b *BlockVMContext) BlockHash() ([]byte, error) {
h := b.block.Hash()
return h[:], nil
}
func (b *BlockVMContext) BlockTimeMS() (uint64, error) { return b.block.TimestampMS, nil }

func (b *BlockVMContext) NextConsensusProgram() ([]byte, error) {
return b.block.ConsensusProgram, nil
}

func (b *BlockVMContext) TxVersion() (uint64, bool) { return 0, false }
func (b *BlockVMContext) TxSigHash() ([]byte, error) { return nil, vm.ErrContext }
func (b *BlockVMContext) NumResults() (uint64, error) { return 0, vm.ErrContext }
func (b *BlockVMContext) AssetID() ([]byte, error) { return nil, vm.ErrContext }
func (b *BlockVMContext) Amount() (uint64, error) { return 0, vm.ErrContext }
func (b *BlockVMContext) MinTimeMS() (uint64, error) { return 0, vm.ErrContext }
func (b *BlockVMContext) MaxTimeMS() (uint64, error) { return 0, vm.ErrContext }
func (b *BlockVMContext) InputRefDataHash() ([]byte, error) { return nil, vm.ErrContext }
func (b *BlockVMContext) TxRefDataHash() ([]byte, error) { return nil, vm.ErrContext }
func (b *BlockVMContext) InputIndex() (uint64, error) { return 0, vm.ErrContext }
func (b *BlockVMContext) Nonce() ([]byte, error) { return nil, vm.ErrContext }
func (b *BlockVMContext) SpentOutputID() ([]byte, error) { return nil, vm.ErrContext }

func (b *BlockVMContext) CheckOutput(uint64, []byte, uint64, []byte, uint64, []byte) (bool, error) {
return false, vm.ErrContext
}

func NewBlockVMContext(block *Block, prog []byte, args [][]byte) *BlockVMContext {
return &BlockVMContext{
prog: Program{
VMVersion: 1,
Code: prog,
},
args: args,
block: block,
}
}

type TxVMContext struct {
prog Program
args [][]byte
tx *Tx
index uint32
}

func NewTxVMContext(tx *Tx, index uint32, prog Program, args [][]byte) *TxVMContext {
return &TxVMContext{
prog: prog,
args: args,
tx: tx,
index: index,
func NewBlockVMContext(block *Block, prog []byte, args [][]byte) *vm.Context {
blockHash := block.Hash().Bytes()
return &vm.Context{
VMVersion: 1,
Code: prog,
Arguments: args,

BlockHash: &blockHash,
BlockTimeMS: &block.TimestampMS,
NextConsensusProgram: &block.ConsensusProgram,
}
}

func (t *TxVMContext) VMVersion() uint64 { return t.prog.VMVersion }
func (t *TxVMContext) Code() []byte { return t.prog.Code }
func (t *TxVMContext) Arguments() [][]byte { return t.args }

func (t *TxVMContext) BlockHash() ([]byte, error) { return nil, vm.ErrContext }
func (t *TxVMContext) BlockTimeMS() (uint64, error) { return 0, vm.ErrContext }

func (t *TxVMContext) NextConsensusProgram() ([]byte, error) { return nil, vm.ErrContext }

func (t *TxVMContext) TxVersion() (uint64, bool) { return t.tx.Version, true }

func (t *TxVMContext) TxSigHash() ([]byte, error) {
h := t.tx.SigHash(t.index)
return h[:], nil
}

func (t *TxVMContext) NumResults() (uint64, error) { return uint64(len(t.tx.Results)), nil }

func (t *TxVMContext) AssetID() ([]byte, error) {
a := t.tx.Inputs[t.index].AssetID()
return a[:], nil
}

func (t *TxVMContext) Amount() (uint64, error) {
return t.tx.Inputs[t.index].Amount(), nil
}

func (t *TxVMContext) MinTimeMS() (uint64, error) { return t.tx.MinTime, nil }
func (t *TxVMContext) MaxTimeMS() (uint64, error) { return t.tx.MaxTime, nil }

func (t *TxVMContext) InputRefDataHash() ([]byte, error) {
h := hashData(t.tx.Inputs[t.index].ReferenceData)
return h[:], nil
}

func (t *TxVMContext) TxRefDataHash() ([]byte, error) {
h := hashData(t.tx.ReferenceData)
return h[:], nil
}

func (t *TxVMContext) InputIndex() (uint64, error) {
return uint64(t.index), nil
}

func (t *TxVMContext) Nonce() ([]byte, error) {
if inp, ok := t.tx.Inputs[t.index].TypedInput.(*IssuanceInput); ok {
return inp.Nonce, nil
func NewTxVMContext(tx *Tx, index uint32, prog Program, args [][]byte) *vm.Context {
var (
txSigHash = tx.SigHash(index).Bytes()
numResults = uint64(len(tx.Results))
assetID = tx.Inputs[index].AssetID()
assetIDBytes = assetID[:]
amount = tx.Inputs[index].Amount()
inputRefDataHash = hashData(tx.Inputs[index].ReferenceData).Bytes()
txRefDataHash = hashData(tx.ReferenceData).Bytes()
)

checkOutput := func(index uint64, refdatahash []byte, amount uint64, assetID []byte, vmVersion uint64, code []byte) (bool, error) {
if index >= uint64(len(tx.Outputs)) {
return false, vm.ErrBadValue
}
o := tx.Outputs[index]
if o.AssetVersion != 1 {
return false, nil
}
if o.Amount != uint64(amount) {
return false, nil
}
if o.VMVersion != uint64(vmVersion) {
return false, nil
}
if !bytes.Equal(o.ControlProgram, code) {
return false, nil
}
if !bytes.Equal(o.AssetID[:], assetID) {
return false, nil
}
if len(refdatahash) > 0 {
h := hashData(o.ReferenceData)
if !bytes.Equal(h[:], refdatahash) {
return false, nil
}
}
return true, nil
}
return nil, vm.ErrContext
}

func (t *TxVMContext) SpentOutputID() ([]byte, error) {
if _, ok := t.tx.Inputs[t.index].TypedInput.(*SpendInput); ok {
return t.tx.TxHashes.SpentOutputIDs[t.index][:], nil
result := &vm.Context{
VMVersion: prog.VMVersion,
Code: prog.Code,
Arguments: args,

TxVersion: &tx.Version,

TxSigHash: &txSigHash,
NumResults: &numResults,
AssetID: &assetIDBytes,
Amount: &amount,
MinTimeMS: &tx.MinTime,
MaxTimeMS: &tx.MaxTime,
InputRefDataHash: &inputRefDataHash,
TxRefDataHash: &txRefDataHash,
InputIndex: &index,
CheckOutput: checkOutput,
}
return nil, vm.ErrContext
}

func (t *TxVMContext) CheckOutput(index uint64, refdatahash []byte, amount uint64, assetID []byte, vmVersion uint64, code []byte) (bool, error) {
if index >= uint64(len(t.tx.Outputs)) {
return false, vm.ErrBadValue
switch inp := tx.Inputs[index].TypedInput.(type) {
case *IssuanceInput:
result.Nonce = &inp.Nonce
case *SpendInput:
spentOutputID := tx.TxHashes.SpentOutputIDs[index][:]
result.SpentOutputID = &spentOutputID
}

o := t.tx.Outputs[index]
if o.AssetVersion != 1 {
return false, nil
}
if o.Amount != uint64(amount) {
return false, nil
}
if o.VMVersion != uint64(vmVersion) {
return false, nil
}
if !bytes.Equal(o.ControlProgram, code) {
return false, nil
}
if !bytes.Equal(o.AssetID[:], assetID) {
return false, nil
}
if len(refdatahash) > 0 {
h := hashData(o.ReferenceData)
if !bytes.Equal(h[:], refdatahash) {
return false, nil
}
}
return true, nil
return result
}
31 changes: 31 additions & 0 deletions protocol/vm/context.go
@@ -0,0 +1,31 @@
package vm

// Context contains the execution context for the virtual machine.
//
// By convention, variables of this type have the name context, _not_
// ctx (to avoid confusion with context.Context).
type Context struct {
VMVersion uint64
Code []byte
Arguments [][]byte

TxVersion *uint64

BlockHash *[]byte
BlockTimeMS *uint64
NextConsensusProgram *[]byte

TxSigHash *[]byte
NumResults *uint64
AssetID *[]byte
Amount *uint64
MinTimeMS *uint64
MaxTimeMS *uint64
InputRefDataHash *[]byte
TxRefDataHash *[]byte
InputIndex *uint32
Nonce *[]byte
SpentOutputID *[]byte

CheckOutput func(index uint64, data []byte, amount uint64, assetID []byte, vmVersion uint64, code []byte) (bool, error)
}
2 changes: 1 addition & 1 deletion protocol/vm/control.go
Expand Up @@ -61,7 +61,7 @@ func opCheckPredicate(vm *virtualMachine) error {
}

childVM := virtualMachine{
vmContext: vm.vmContext,
context: vm.context,
program: predicate,
runLimit: limit,
depth: vm.depth + 1,
Expand Down
18 changes: 6 additions & 12 deletions protocol/vm/crypto.go
Expand Up @@ -7,7 +7,6 @@ import (
"golang.org/x/crypto/sha3"

"chain/crypto/ed25519"
"chain/errors"
"chain/math/checked"
)

Expand Down Expand Up @@ -132,24 +131,19 @@ func opTxSigHash(vm *virtualMachine) error {
if err != nil {
return err
}

h, err := vm.vmContext.TxSigHash()
if err != nil {
return errors.Wrap(err, "computing txsighash")
if vm.context.TxSigHash == nil {
return ErrContext
}

return vm.push(h, false)
return vm.push(*vm.context.TxSigHash, false)
}

func opBlockHash(vm *virtualMachine) error {
err := vm.applyCost(1)
if err != nil {
return err
}
h, err := vm.vmContext.BlockHash()
if err != nil {
return errors.Wrap(err, "computing blockhash")
if vm.context.BlockHash == nil {
return ErrContext
}

return vm.push(h, false)
return vm.push(*vm.context.BlockHash, false)
}
27 changes: 16 additions & 11 deletions protocol/vm/crypto_test.go
Expand Up @@ -401,12 +401,12 @@ func TestCryptoOps(t *testing.T) {
}, {
op: OP_TXSIGHASH,
startVM: &VirtualMachine{
RunLimit: 50000,
VMContext: bc.NewTxVMContext(tx, 0, bc.Program{VMVersion: 1}, nil),
RunLimit: 50000,
Context: bc.NewTxVMContext(tx, 0, bc.Program{VMVersion: 1}, nil),
},
wantVM: &VirtualMachine{
RunLimit: 49704,
VMContext: bc.NewTxVMContext(tx, 0, bc.Program{VMVersion: 1}, nil),
RunLimit: 49704,
Context: bc.NewTxVMContext(tx, 0, bc.Program{VMVersion: 1}, nil),
DataStack: [][]byte{{
47, 0, 60, 221, 100, 66, 123, 94,
237, 214, 204, 181, 133, 71, 2, 11,
Expand All @@ -417,15 +417,15 @@ func TestCryptoOps(t *testing.T) {
}, {
op: OP_TXSIGHASH,
startVM: &VirtualMachine{
RunLimit: 0,
VMContext: bc.NewTxVMContext(tx, 0, bc.Program{VMVersion: 1}, nil),
RunLimit: 0,
Context: bc.NewTxVMContext(tx, 0, bc.Program{VMVersion: 1}, nil),
},
wantErr: ErrRunLimitExceeded,
}, {
op: OP_BLOCKHASH,
startVM: &VirtualMachine{
RunLimit: 50000,
VMContext: bc.NewBlockVMContext(&bc.Block{}, nil, nil),
RunLimit: 50000,
Context: bc.NewBlockVMContext(&bc.Block{}, nil, nil),
},
wantVM: &VirtualMachine{
RunLimit: 49959,
Expand All @@ -435,13 +435,13 @@ func TestCryptoOps(t *testing.T) {
157, 235, 138, 214, 147, 207, 55, 17,
254, 131, 9, 179, 144, 106, 90, 134,
}},
VMContext: bc.NewBlockVMContext(&bc.Block{}, nil, nil),
Context: bc.NewBlockVMContext(&bc.Block{}, nil, nil),
},
}, {
op: OP_BLOCKHASH,
startVM: &VirtualMachine{
RunLimit: 0,
VMContext: bc.NewBlockVMContext(&bc.Block{}, nil, nil),
RunLimit: 0,
Context: bc.NewBlockVMContext(&bc.Block{}, nil, nil),
},
wantErr: ErrRunLimitExceeded,
}}
Expand Down Expand Up @@ -478,6 +478,11 @@ func TestCryptoOps(t *testing.T) {
continue
}

// Hack: the context objects will otherwise compare unequal
// sometimes (because of the function pointer within?) and we
// don't care
c.wantVM.Context = gotVM.Context

if !testutil.DeepEqual(gotVM, c.wantVM) {
t.Errorf("case %d, op %s: unexpected vm result\n\tgot: %+v\n\twant: %+v\n", i, OpName(c.op), gotVM, c.wantVM)
}
Expand Down

0 comments on commit df6946c

Please sign in to comment.