Skip to content

Refactor deref #398

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Aug 10, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions ast/node.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,6 @@ type NilNode struct {
type IdentifierNode struct {
base
Value string
Deref bool
FieldIndex []int
Method bool // true if method, false if field
MethodIndex int // index of method, set only if Method is true
Expand Down Expand Up @@ -106,10 +105,9 @@ type MemberNode struct {
Property Node
Name string // Name of the filed or method. Used for error reporting.
Optional bool
Deref bool
FieldIndex []int

// TODO: Replace with a single MethodIndex field of &int type.
// TODO: Combine Method and MethodIndex into a single MethodIndex field of &int type.
Method bool
MethodIndex int
}
Expand Down
26 changes: 10 additions & 16 deletions checker/checker.go
Original file line number Diff line number Diff line change
Expand Up @@ -135,18 +135,14 @@ func (v *visitor) IdentifierNode(node *ast.IdentifierNode) (reflect.Type, info)
// when the arguments are known in CallNode.
return anyType, info{fn: fn}
}
if v.config.Types == nil {
node.Deref = true
} else if t, ok := v.config.Types[node.Value]; ok {
if t, ok := v.config.Types[node.Value]; ok {
if t.Ambiguous {
return v.error(node, "ambiguous identifier %v", node.Value)
}
d, c := deref(t.Type)
node.Deref = c
node.Method = t.Method
node.MethodIndex = t.MethodIndex
node.FieldIndex = t.FieldIndex
return d, info{method: t.Method}
return t.Type, info{method: t.Method}
}
if v.config.Strict {
return v.error(node, "unknown name %v", node.Value)
Expand Down Expand Up @@ -180,6 +176,8 @@ func (v *visitor) ConstantNode(node *ast.ConstantNode) (reflect.Type, info) {
func (v *visitor) UnaryNode(node *ast.UnaryNode) (reflect.Type, info) {
t, _ := v.visit(node.Node)

t = deref(t)

switch node.Operator {

case "!", "not":
Expand Down Expand Up @@ -209,6 +207,9 @@ func (v *visitor) BinaryNode(node *ast.BinaryNode) (reflect.Type, info) {
l, _ := v.visit(node.Left)
r, _ := v.visit(node.Right)

l = deref(l)
r = deref(r)

// check operator overloading
if fns, ok := v.config.Operators[node.Operator]; ok {
t, _, ok := conf.FindSuitableOperatorOverload(fns, v.config.Types, l, r)
Expand Down Expand Up @@ -427,34 +428,27 @@ func (v *visitor) MemberNode(node *ast.MemberNode) (reflect.Type, info) {

switch base.Kind() {
case reflect.Interface:
node.Deref = true
return anyType, info{}

case reflect.Map:
if prop != nil && !prop.AssignableTo(base.Key()) && !isAny(prop) {
return v.error(node.Property, "cannot use %v to get an element from %v", prop, base)
}
t, c := deref(base.Elem())
node.Deref = c
return t, info{}
return base.Elem(), info{}

case reflect.Array, reflect.Slice:
if !isInteger(prop) && !isAny(prop) {
return v.error(node.Property, "array elements can only be selected using an integer (got %v)", prop)
}
t, c := deref(base.Elem())
node.Deref = c
return t, info{}
return base.Elem(), info{}

case reflect.Struct:
if name, ok := node.Property.(*ast.StringNode); ok {
propertyName := name.Value
if field, ok := fetchField(base, propertyName); ok {
t, c := deref(field.Type)
node.Deref = c
node.FieldIndex = field.Index
node.Name = propertyName
return t, info{}
return field.Type, info{}
}
if len(v.parents) > 1 {
if _, ok := v.parents[len(v.parents)-2].(*ast.CallNode); ok {
Expand Down
12 changes: 5 additions & 7 deletions checker/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -204,25 +204,23 @@ func fetchField(t reflect.Type, name string) (reflect.StructField, bool) {
return reflect.StructField{}, false
}

func deref(t reflect.Type) (reflect.Type, bool) {
func deref(t reflect.Type) reflect.Type {
if t == nil {
return nil, false
return nil
}
if t.Kind() == reflect.Interface {
return t, true
return t
}
found := false
for t != nil && t.Kind() == reflect.Ptr {
e := t.Elem()
switch e.Kind() {
case reflect.Struct, reflect.Map, reflect.Array, reflect.Slice:
return t, false
return t
default:
found = true
t = e
}
}
return t, found
return t
}

func isIntegerOrArithmeticOperation(node ast.Node) bool {
Expand Down
72 changes: 52 additions & 20 deletions compiler/compiler.go
Original file line number Diff line number Diff line change
Expand Up @@ -224,11 +224,6 @@ func (c *compiler) IdentifierNode(node *ast.IdentifierNode) {
} else {
c.emit(OpLoadConst, c.addConstant(node.Value))
}
if node.Deref {
c.emit(OpDeref)
} else if node.Type() == nil {
c.emit(OpDeref)
}
}

func (c *compiler) IntegerNode(node *ast.IntegerNode) {
Expand Down Expand Up @@ -289,6 +284,7 @@ func (c *compiler) ConstantNode(node *ast.ConstantNode) {

func (c *compiler) UnaryNode(node *ast.UnaryNode) {
c.compile(node.Node)
c.derefInNeeded(node.Node)

switch node.Operator {

Expand All @@ -313,7 +309,9 @@ func (c *compiler) BinaryNode(node *ast.BinaryNode) {
switch node.Operator {
case "==":
c.compile(node.Left)
c.derefInNeeded(node.Left)
c.compile(node.Right)
c.derefInNeeded(node.Left)

if l == r && l == reflect.Int {
c.emit(OpEqualInt)
Expand All @@ -325,114 +323,155 @@ func (c *compiler) BinaryNode(node *ast.BinaryNode) {

case "!=":
c.compile(node.Left)
c.derefInNeeded(node.Left)
c.compile(node.Right)
c.derefInNeeded(node.Left)
c.emit(OpEqual)
c.emit(OpNot)

case "or", "||":
c.compile(node.Left)
c.derefInNeeded(node.Left)
end := c.emit(OpJumpIfTrue, placeholder)
c.emit(OpPop)
c.compile(node.Right)
c.derefInNeeded(node.Right)
c.patchJump(end)

case "and", "&&":
c.compile(node.Left)
c.derefInNeeded(node.Left)
end := c.emit(OpJumpIfFalse, placeholder)
c.emit(OpPop)
c.compile(node.Right)
c.derefInNeeded(node.Right)
c.patchJump(end)

case "<":
c.compile(node.Left)
c.derefInNeeded(node.Left)
c.compile(node.Right)
c.derefInNeeded(node.Right)
c.emit(OpLess)

case ">":
c.compile(node.Left)
c.derefInNeeded(node.Left)
c.compile(node.Right)
c.derefInNeeded(node.Right)
c.emit(OpMore)

case "<=":
c.compile(node.Left)
c.derefInNeeded(node.Left)
c.compile(node.Right)
c.derefInNeeded(node.Right)
c.emit(OpLessOrEqual)

case ">=":
c.compile(node.Left)
c.derefInNeeded(node.Left)
c.compile(node.Right)
c.derefInNeeded(node.Right)
c.emit(OpMoreOrEqual)

case "+":
c.compile(node.Left)
c.derefInNeeded(node.Left)
c.compile(node.Right)
c.derefInNeeded(node.Right)
c.emit(OpAdd)

case "-":
c.compile(node.Left)
c.derefInNeeded(node.Left)
c.compile(node.Right)
c.derefInNeeded(node.Right)
c.emit(OpSubtract)

case "*":
c.compile(node.Left)
c.derefInNeeded(node.Left)
c.compile(node.Right)
c.derefInNeeded(node.Right)
c.emit(OpMultiply)

case "/":
c.compile(node.Left)
c.derefInNeeded(node.Left)
c.compile(node.Right)
c.derefInNeeded(node.Right)
c.emit(OpDivide)

case "%":
c.compile(node.Left)
c.derefInNeeded(node.Left)
c.compile(node.Right)
c.derefInNeeded(node.Right)
c.emit(OpModulo)

case "**", "^":
c.compile(node.Left)
c.derefInNeeded(node.Left)
c.compile(node.Right)
c.derefInNeeded(node.Right)
c.emit(OpExponent)

case "in":
c.compile(node.Left)
c.derefInNeeded(node.Left)
c.compile(node.Right)
c.derefInNeeded(node.Right)
c.emit(OpIn)

case "matches":
if node.Regexp != nil {
c.compile(node.Left)
c.derefInNeeded(node.Left)
c.emit(OpMatchesConst, c.addConstant(node.Regexp))
} else {
c.compile(node.Left)
c.derefInNeeded(node.Left)
c.compile(node.Right)
c.derefInNeeded(node.Right)
c.emit(OpMatches)
}

case "contains":
c.compile(node.Left)
c.derefInNeeded(node.Left)
c.compile(node.Right)
c.derefInNeeded(node.Right)
c.emit(OpContains)

case "startsWith":
c.compile(node.Left)
c.derefInNeeded(node.Left)
c.compile(node.Right)
c.derefInNeeded(node.Right)
c.emit(OpStartsWith)

case "endsWith":
c.compile(node.Left)
c.derefInNeeded(node.Left)
c.compile(node.Right)
c.derefInNeeded(node.Right)
c.emit(OpEndsWith)

case "..":
c.compile(node.Left)
c.derefInNeeded(node.Left)
c.compile(node.Right)
c.derefInNeeded(node.Right)
c.emit(OpRange)

case "??":
c.compile(node.Left)
c.derefInNeeded(node.Left)
end := c.emit(OpJumpIfNotNil, placeholder)
c.emit(OpPop)
c.compile(node.Right)
c.derefInNeeded(node.Right)
c.patchJump(end)

default:
Expand Down Expand Up @@ -461,7 +500,6 @@ func (c *compiler) MemberNode(node *ast.MemberNode) {
return
}
op := OpFetch
original := node
index := node.FieldIndex
path := []string{node.Name}
base := node.Node
Expand All @@ -470,21 +508,15 @@ func (c *compiler) MemberNode(node *ast.MemberNode) {
for !node.Optional {
ident, ok := base.(*ast.IdentifierNode)
if ok && len(ident.FieldIndex) > 0 {
if ident.Deref {
panic("IdentifierNode should not be dereferenced")
}
index = append(ident.FieldIndex, index...)
path = append([]string{ident.Value}, path...)
c.emitLocation(ident.Location(), OpLoadField, c.addConstant(
&runtime.Field{Index: index, Path: path},
))
goto deref
return
}
member, ok := base.(*ast.MemberNode)
if ok && len(member.FieldIndex) > 0 {
if member.Deref {
panic("MemberNode should not be dereferenced")
}
index = append(member.FieldIndex, index...)
path = append([]string{member.Name}, path...)
node = member
Expand All @@ -509,13 +541,6 @@ func (c *compiler) MemberNode(node *ast.MemberNode) {
&runtime.Field{Index: index, Path: path},
))
}

deref:
if original.Deref {
c.emit(OpDeref)
} else if original.Type() == nil {
c.emit(OpDeref)
}
}

func (c *compiler) SliceNode(node *ast.SliceNode) {
Expand Down Expand Up @@ -734,6 +759,13 @@ func (c *compiler) PairNode(node *ast.PairNode) {
c.compile(node.Value)
}

func (c *compiler) derefInNeeded(node ast.Node) {
switch kind(node) {
case reflect.Ptr, reflect.Interface:
c.emit(OpDeref)
}
}

func kind(node ast.Node) reflect.Kind {
t := node.Type()
if t == nil {
Expand Down
Loading