Skip to content

Commit

Permalink
Give function implementations access to return type
Browse files Browse the repository at this point in the history
Now that we support computed return types for functions, it's convenient
to retain the computed return type and pass it in to the function's
implementation so that it doesn't need to repeat the logic to determine
the result type when working with collections generically.

To support this we add a new CallbackTyped attribute to Function which
can be used instead of Callback when knowing the return type is desired.
A new AST node CallTyped is used to augment a Call node with its
computed result type from the type check phase, which can then be used
in the eval phase.
  • Loading branch information
apparentlymart committed Feb 5, 2017
1 parent 296e9e1 commit 7cd522a
Show file tree
Hide file tree
Showing 5 changed files with 107 additions and 30 deletions.
33 changes: 33 additions & 0 deletions ast/call.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@ import (
)

// Call represents a function call.
//
// The type checker replaces Call nodes with CallTyped nodes in order to retain
// the type information for use in the evaluation phase.
type Call struct {
Func string
Args []Node
Expand Down Expand Up @@ -45,3 +48,33 @@ func (n *Call) Type(s Scope) (Type, error) {
func (n *Call) GoString() string {
return fmt.Sprintf("*%#v", *n)
}

// CallTyped represents a function call *after* type checking.
//
// The type check phase replaces any Call node with a CallTyped node in order to
// capture the type information that was determined so that it can be used during
// a subsequent evaluation.
type CallTyped struct {
// CallTyped embeds the Call it was created from.
Call

// ReturnType is the return type determined for the function during type checking.
// A well-behaved function implementation is bound by the interface contract to return
// a value that conforms to this type.
ReturnType Type
}

func (n *CallTyped) Accept(v Visitor) Node {
// Accept must be re-implemented on CallTyped to make sure we pass the full CallTyped
// value, rather than the embedded Call value that would result were we to inherit
// the implementation from Call.
for i, a := range n.Args {
n.Args[i] = a.Accept(v)
}

return v(n)
}

func (n *CallTyped) Type(s Scope) (Type, error) {
return n.ReturnType, nil
}
18 changes: 15 additions & 3 deletions ast/scope.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,11 +59,23 @@ type Function struct {
Variadic bool
VariadicType Type

// Callback is the function called for a function. The argument
// types are guaranteed to match the spec above by the type checker.
// Either Callback or CallbackTyped are called as the implementation of
// the function. Both recieve a slice interface values of an appropriate
// dynamic type for the call arguments, while CallbackTyped additionally
// recieves the required result type, for easier implementation of
// type-generic functions without duplicating the logic in ReturnTypeFunc.
//
// The argument types are guaranteed by the type checker to match what is
// described by ArgTypes, ReturnTypeFunc and VariadicType.
// The length of the args is strictly == len(ArgTypes) unless Varidiac
// is true, in which case its >= len(ArgTypes).
Callback func([]interface{}) (interface{}, error)
//
// The value returned MUST confirm to the function's return type, whether
// determined by ReturnType or ReturnTypeFunc.
//
// Setting both Callback and CallbackTyped is invalid usage.
Callback func([]interface{}) (interface{}, error)
CallbackTyped func(args []interface{}, returnType Type) (interface{}, error)
}

// ReturnTypeFunc is a function type used to decide the return type of a
Expand Down
65 changes: 44 additions & 21 deletions check_types.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,12 @@ func (v *TypeCheck) visit(raw ast.Node) ast.Node {
case *ast.Call:
tc := &typeCheckCall{n}
result, err = tc.TypeCheck(v)
case *ast.CallTyped:
// We only enter this branch if the result of type checking is
// passed into a second pass of type checking. In that case
// we just re-check the original Call embedded inside.
tc := &typeCheckCall{&n.Call}
result, err = tc.TypeCheck(v)
case *ast.Conditional:
tc := &typeCheckConditional{n}
result, err = tc.TypeCheck(v)
Expand Down Expand Up @@ -178,10 +184,13 @@ func (tc *typeCheckArithmetic) checkNumeric(v *TypeCheck, exprs []ast.Type) (ast
Posx: tc.n.Pos(),
}
copy(args[1:], tc.n.Exprs)
return &ast.Call{
Func: mathFunc,
Args: args,
Posx: tc.n.Pos(),
return &ast.CallTyped{
Call: ast.Call{
Func: mathFunc,
Args: args,
Posx: tc.n.Pos(),
},
ReturnType: mathType,
}, nil
}

Expand Down Expand Up @@ -271,10 +280,13 @@ func (tc *typeCheckArithmetic) checkComparison(v *TypeCheck, exprs []ast.Type) (
Posx: tc.n.Pos(),
}
copy(args[1:], tc.n.Exprs)
return &ast.Call{
Func: compareFunc,
Args: args,
Posx: tc.n.Pos(),
return &ast.CallTyped{
Call: ast.Call{
Func: compareFunc,
Args: args,
Posx: tc.n.Pos(),
},
ReturnType: ast.TypeBool,
}, nil
}

Expand Down Expand Up @@ -303,10 +315,13 @@ func (tc *typeCheckArithmetic) checkLogical(v *TypeCheck, exprs []ast.Type) (ast
Posx: tc.n.Pos(),
}
copy(args[1:], tc.n.Exprs)
return &ast.Call{
Func: "__builtin_Logical",
Args: args,
Posx: tc.n.Pos(),
return &ast.CallTyped{
Call: ast.Call{
Func: "__builtin_Logical",
Args: args,
Posx: tc.n.Pos(),
},
ReturnType: ast.TypeBool,
}, nil
}

Expand Down Expand Up @@ -338,8 +353,8 @@ func (tc *typeCheckCall) TypeCheck(v *TypeCheck) (ast.Node, error) {

// If we're variadic, then verify the types there
if function.Variadic {
args = args[len(function.ArgTypes):]
for i, t := range args {
varArgs := args[len(function.ArgTypes):]
for i, t := range varArgs {
realI := i + len(function.ArgTypes)
cn, err := tc.compatibleArg(v, tc.n.Func, realI+1, tc.n.Args[realI], function.VariadicType, t)
if err != nil {
Expand All @@ -350,17 +365,22 @@ func (tc *typeCheckCall) TypeCheck(v *TypeCheck) (ast.Node, error) {
}

// Return type
var returnType ast.Type
if function.ReturnTypeFunc != nil {
rt, err := function.ReturnTypeFunc(args)
if err != nil {
return nil, err
}
v.StackPush(rt)
returnType = rt
} else {
v.StackPush(function.ReturnType)
returnType = function.ReturnType
}
v.StackPush(returnType)

return tc.n, nil
return &ast.CallTyped{
Call: *tc.n,
ReturnType: returnType,
}, nil
}

// compatibleTypes implements the type matching and conversion rules for
Expand Down Expand Up @@ -618,10 +638,13 @@ func (v *TypeCheck) ImplicitConversion(
return nil
}

return &ast.Call{
Func: toFunc,
Args: []ast.Node{n},
Posx: n.Pos(),
return &ast.CallTyped{
Call: ast.Call{
Func: toFunc,
Args: []ast.Node{n},
Posx: n.Pos(),
},
ReturnType: expected,
}
}

Expand Down
2 changes: 1 addition & 1 deletion check_types_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -667,7 +667,7 @@ func TestTypeCheck_implicit(t *testing.T) {
visitor = &TypeCheck{Scope: tc.Scope}
err = visitor.Visit(node)
if err != nil {
t.Fatalf("Error: %s\n\nInput: %s", err, tc.Input)
t.Fatalf("Error on second pass: %s\n\nInput: %s", err, tc.Input)
}
})
}
Expand Down
19 changes: 14 additions & 5 deletions eval.go
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,9 @@ func evalNode(raw ast.Node) (EvalNode, error) {
switch n := raw.(type) {
case *ast.Index:
return &evalIndex{n}, nil
case *ast.Call:
case *ast.CallTyped:
// Type checker replaces ast.Call with ast.CallTyped to allow
// us to access the computed function return type.
return &evalCall{n}, nil
case *ast.Conditional:
return &evalConditional{n}, nil
Expand All @@ -253,7 +255,7 @@ func evalNode(raw ast.Node) (EvalNode, error) {
}
}

type evalCall struct{ *ast.Call }
type evalCall struct{ *ast.CallTyped }

func (v *evalCall) Eval(s ast.Scope, stack *ast.Stack) (interface{}, ast.Type, error) {
// Look up the function in the map
Expand All @@ -265,18 +267,25 @@ func (v *evalCall) Eval(s ast.Scope, stack *ast.Stack) (interface{}, ast.Type, e

// The arguments are on the stack in reverse order, so pop them off.
args := make([]interface{}, len(v.Args))
for i, _ := range v.Args {
for i := range v.Args {
node := stack.Pop().(*ast.LiteralNode)
args[len(v.Args)-1-i] = node.Value
}

// Call the function
result, err := function.Callback(args)
var result interface{}
var err error
if function.CallbackTyped != nil {
result, err = function.CallbackTyped(args, v.ReturnType)
} else {
result, err = function.Callback(args)
}

if err != nil {
return nil, ast.TypeInvalid, fmt.Errorf("%s: %s", v.Func, err)
}

return result, function.ReturnType, nil
return result, v.ReturnType, nil
}

type evalConditional struct{ *ast.Conditional }
Expand Down

0 comments on commit 7cd522a

Please sign in to comment.