Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

compiler: set default values for complex struct fields (2.x) #952

Merged
merged 3 commits into from
Jun 11, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion pkg/compiler/analysis.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,9 @@ func typeAndValueForField(fld *types.Var) (types.TypeAndValue, error) {
default:
return types.TypeAndValue{}, fmt.Errorf("could not initialize struct field %s to zero, type: %s", fld.Name(), t)
}
default:
return types.TypeAndValue{Type: t}, nil
}
return types.TypeAndValue{}, nil
}

// countGlobals counts the global variables in the program to add
Expand Down
71 changes: 50 additions & 21 deletions pkg/compiler/codegen.go
Original file line number Diff line number Diff line change
Expand Up @@ -470,7 +470,6 @@ func (c *codegen) Visit(node ast.Node) ast.Visitor {
case *ast.SwitchStmt:
ast.Walk(c, n.Tag)

eqOpcode := c.getEqualityOpcode(n.Tag)
switchEnd, label := c.generateLabel(labelEnd)

lastSwitch := c.currentSwitch
Expand All @@ -490,7 +489,7 @@ func (c *codegen) Visit(node ast.Node) ast.Visitor {
for j := range cc.List {
emit.Opcode(c.prog.BinWriter, opcode.DUP)
ast.Walk(c, cc.List[j])
emit.Opcode(c.prog.BinWriter, eqOpcode)
c.emitEquality(n.Tag, token.EQL)
if j == l-1 {
emit.Jmp(c.prog.BinWriter, opcode.JMPIFNOT, lEnd)
} else {
Expand Down Expand Up @@ -533,6 +532,8 @@ func (c *codegen) Visit(node ast.Node) ast.Visitor {
c.emitLoadConst(value)
} else if tv := c.typeInfo.Types[n]; tv.Value != nil {
c.emitLoadConst(tv)
} else if n.Name == "nil" {
c.emitDefault(new(types.Slice))
} else {
c.emitLoadLocal(n.Name)
}
Expand Down Expand Up @@ -615,26 +616,20 @@ func (c *codegen) Visit(node ast.Node) ast.Visitor {
ast.Walk(c, n.X)
ast.Walk(c, n.Y)

switch {
case n.Op == token.ADD:
switch n.Op {
case token.ADD:
// VM has separate opcodes for number and string concatenation
if isStringType(tinfo.Type) {
emit.Opcode(c.prog.BinWriter, opcode.CAT)
} else {
emit.Opcode(c.prog.BinWriter, opcode.ADD)
}
case n.Op == token.EQL:
// VM has separate opcodes for number and string equality
op := c.getEqualityOpcode(n.X)
emit.Opcode(c.prog.BinWriter, op)
case n.Op == token.NEQ:
// VM has separate opcodes for number and string equality
if isStringType(c.typeInfo.Types[n.X].Type) {
emit.Opcode(c.prog.BinWriter, opcode.EQUAL)
emit.Opcode(c.prog.BinWriter, opcode.NOT)
} else {
emit.Opcode(c.prog.BinWriter, opcode.NUMNOTEQUAL)
case token.EQL, token.NEQ:
if isExprNil(n.X) || isExprNil(n.Y) {
c.prog.Err = errors.New("comparison with `nil` is not supported, use `len(..) == 0` instead")
return nil
}
c.emitEquality(n.X, n.Op)
default:
c.convertToken(n.Op)
}
Expand Down Expand Up @@ -980,13 +975,26 @@ func (c *codegen) getLabelOffset(typ labelOffsetType, name string) uint16 {
return c.labels[labelWithType{name: name, typ: typ}]
}

func (c *codegen) getEqualityOpcode(expr ast.Expr) opcode.Opcode {
func (c *codegen) emitEquality(expr ast.Expr, op token.Token) {
fyrchik marked this conversation as resolved.
Show resolved Hide resolved
t, ok := c.typeInfo.Types[expr].Type.Underlying().(*types.Basic)
if ok && t.Info()&types.IsNumeric != 0 {
return opcode.NUMEQUAL
isNum := ok && t.Info()&types.IsNumeric != 0
switch op {
case token.EQL:
if isNum {
emit.Opcode(c.prog.BinWriter, opcode.NUMEQUAL)
} else {
emit.Opcode(c.prog.BinWriter, opcode.EQUAL)
}
case token.NEQ:
if isNum {
emit.Opcode(c.prog.BinWriter, opcode.NUMNOTEQUAL)
} else {
emit.Opcode(c.prog.BinWriter, opcode.EQUAL)
emit.Opcode(c.prog.BinWriter, opcode.NOT)
}
default:
panic("invalid token in emitEqual()")
}

return opcode.EQUAL
}

// getByteArray returns byte array value from constant expr.
Expand Down Expand Up @@ -1230,11 +1238,32 @@ func (c *codegen) convertStruct(lit *ast.CompositeLit) {

emit.Opcode(c.prog.BinWriter, opcode.DUP)
emit.Int(c.prog.BinWriter, int64(i))
c.emitLoadConst(typeAndVal)
c.emitDefault(typeAndVal.Type)
emit.Opcode(c.prog.BinWriter, opcode.SETITEM)
}
}

func (c *codegen) emitDefault(typ types.Type) {
switch t := c.scTypeFromGo(typ); t {
case "Integer":
emit.Int(c.prog.BinWriter, 0)
case "Boolean":
emit.Bool(c.prog.BinWriter, false)
case "String":
emit.String(c.prog.BinWriter, "")
case "Map":
emit.Opcode(c.prog.BinWriter, opcode.NEWMAP)
case "Struct":
emit.Int(c.prog.BinWriter, int64(typ.(*types.Struct).NumFields()))
emit.Opcode(c.prog.BinWriter, opcode.NEWSTRUCT)
case "Array":
emit.Int(c.prog.BinWriter, 0)
emit.Opcode(c.prog.BinWriter, opcode.NEWARRAY)
case "ByteArray":
emit.Bytes(c.prog.BinWriter, []byte{})
}
}

func (c *codegen) convertToken(tok token.Token) {
switch tok {
case token.ADD_ASSIGN:
Expand Down
6 changes: 5 additions & 1 deletion pkg/compiler/debug.go
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,11 @@ func (c *codegen) scReturnTypeFromScope(scope *funcScope) string {
}

func (c *codegen) scTypeFromExpr(typ ast.Expr) string {
switch t := c.typeInfo.Types[typ].Type.(type) {
return c.scTypeFromGo(c.typeInfo.Types[typ].Type)
}

func (c *codegen) scTypeFromGo(typ types.Type) string {
switch t := typ.(type) {
case *types.Basic:
info := t.Info()
switch {
Expand Down
29 changes: 29 additions & 0 deletions pkg/compiler/slice_test.go
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
package compiler_test

import (
"fmt"
"math/big"
"strings"
"testing"

"github.com/nspcc-dev/neo-go/pkg/compiler"
"github.com/nspcc-dev/neo-go/pkg/vm"
"github.com/stretchr/testify/require"
)

var sliceTestCases = []testCase{
Expand Down Expand Up @@ -175,6 +179,31 @@ func TestSliceOperations(t *testing.T) {
runTestCases(t, sliceTestCases)
}

func TestSliceEmpty(t *testing.T) {
srcTmpl := `package foo
func Main() int {
var a []int
%s
if %s {
return 1
}
return 2
}`
t.Run("WithNil", func(t *testing.T) {
src := fmt.Sprintf(srcTmpl, "", "a == nil")
_, err := compiler.Compile(strings.NewReader(src))
require.Error(t, err)
})
t.Run("WithLen", func(t *testing.T) {
src := fmt.Sprintf(srcTmpl, "", "len(a) == 0")
eval(t, src, big.NewInt(1))
})
t.Run("NonEmpty", func(t *testing.T) {
src := fmt.Sprintf(srcTmpl, "a = []int{1}", "len(a) == 0")
eval(t, src, big.NewInt(2))
})
}

func TestJumps(t *testing.T) {
src := `
package foo
Expand Down
46 changes: 46 additions & 0 deletions pkg/compiler/struct_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package compiler_test

import (
"fmt"
"math/big"
"testing"

Expand Down Expand Up @@ -338,8 +339,53 @@ var structTestCases = []testCase{
}`,
big.NewInt(2),
},
{
"uninitialized struct fields",
`package foo
type Foo struct {
i int
m map[string]int
b []byte
a []int
s struct { ii int }
}
func NewFoo() Foo { return Foo{} }
func Main() int {
foo := NewFoo()
if foo.i != 0 { return 1 }
if len(foo.m) != 0 { return 1 }
if len(foo.b) != 0 { return 1 }
if len(foo.a) != 0 { return 1 }
s := foo.s
if s.ii != 0 { return 1 }
return 2
}`,
big.NewInt(2),
},
}

func TestStructs(t *testing.T) {
runTestCases(t, structTestCases)
}

func TestStructCompare(t *testing.T) {
srcTmpl := `package testcase
type T struct { f int }
func Main() int {
a := T{f: %d}
b := T{f: %d}
if a != b {
return 2
}
return 1
}`
t.Run("Equal", func(t *testing.T) {
src := fmt.Sprintf(srcTmpl, 4, 4)
eval(t, src, big.NewInt(1))
})
t.Run("NotEqual", func(t *testing.T) {
src := fmt.Sprintf(srcTmpl, 4, 5)
eval(t, src, big.NewInt(2))
})

}