Skip to content

Commit

Permalink
Implemented functions.
Browse files Browse the repository at this point in the history
  • Loading branch information
iamsayantan committed Apr 16, 2022
1 parent 8ee64ea commit 22648e4
Show file tree
Hide file tree
Showing 7 changed files with 266 additions and 6 deletions.
11 changes: 11 additions & 0 deletions expr.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ type Visitor interface {
VisitAssignExpr(expr *Assign) (interface{}, error)
VisitLogicalExpr(expr *Logical) (interface{}, error)
VisitBinaryExpr(expr *Binary) (interface{}, error)
VisitCallExpr(expr *Call) (interface{}, error)
VisitGroupingExpr(expr *Grouping) (interface{}, error)
VisitLiteralExpr(expr *Literal) (interface{}, error)
VisitUnaryExpr(expr *Unary) (interface{}, error)
Expand Down Expand Up @@ -43,6 +44,16 @@ func (b *Binary) Accept(visitor Visitor) (interface{}, error) {
return visitor.VisitBinaryExpr(b)
}

type Call struct {
Callee Expr
Paren Token
Arguments []Expr
}

func (c *Call) Accept(visitor Visitor) (interface{}, error) {
return visitor.VisitCallExpr(c)
}

type Grouping struct {
Expression Expr
}
Expand Down
50 changes: 49 additions & 1 deletion interpreter.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,14 @@ import (

type Interpreter struct {
runtime *Runtime
globals *Environment
environment *Environment
}

func NewInterpreter(runtime *Runtime) *Interpreter {
return &Interpreter{runtime: runtime, environment: NewEnvironment(nil)}
global := NewEnvironment(nil)
global.Define("clock", Clock{})
return &Interpreter{runtime: runtime, environment: global, globals: global}
}

type RuntimeError struct {
Expand Down Expand Up @@ -296,6 +299,51 @@ func (i *Interpreter) VisitBinaryExpr(expr *Binary) (interface{}, error) {
return nil, nil
}

// VisitCallExpr interprts function call tree node. First we evaluate the expression for the
// callee, typically this expression is just an identifier that looks up the function by its
// name, but it could be anything. Then we evaluate each of the arguments in order and store
// them in a list. To call a function we cast the callee to the LoxCallable interface and call
// the Call() method on it. The go representation of any lox object that can be called like an
// function will implement this interface.
func (i *Interpreter) VisitCallExpr(expr *Call) (interface{}, error) {
callee, err := i.evaluate(expr.Callee)
if err != nil {
return nil, err
}

arguments := make([]interface{}, 0)
for _, argument := range expr.Arguments {
ag, err := i.evaluate(argument)
if err != nil {
return nil, err
}

arguments = append(arguments, ag)
}

function, ok := callee.(LoxCallable)
if !ok {
return nil, NewRuntimeError(expr.Paren, "Can only call function and classes")
}

if len(arguments) != function.Arity() {
return nil, NewRuntimeError(expr.Paren, fmt.Sprintf("Expected %d arguments but got %d", function.Arity(), len(arguments)))
}

return function.Call(i, arguments)
}

// VisitFunctionStmt interprets a function syntax node. We take FunctionStmt syntax node, which
// is a compile time representation of the function - and convert it to its runtime representation.
// Here that's LoxFunction that wraps the syntax node. Here we also bind the resulting object to
// a new variable. So after creating LoxFunction, we create a new binding in the current environment
// and store a reference to it there.
func (i *Interpreter) VisitFunctionStmt(stmt *FunctionStmt) error {
function := NewLoxFunction(stmt)
i.environment.Define(stmt.Name.Lexeme, function)
return nil
}

// VisitGroupingExpr evaluates the grouping expressions, the node that we get from
// using parenthesis around an expression. The grouping node has reference to the
// inner expression, so to evaluate it we recursively evaluate the inner subexpression.
Expand Down
15 changes: 15 additions & 0 deletions lox_callable.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
package glox

// LoxCallable interface should be implemented by any lox object that can be called like
// a function.
type LoxCallable interface {
// Call is the method that is called to evaluate the function. We pass in the
// interpreter in case the implementing object needs it and the list of arguments.
// The implementing object should return the evaluated value as return parameter.
Call(interpreter *Interpreter, arguments []interface{}) (interface{}, error)

// Arity is the number of arguments a function expects. It's used to check if the
// number of arguments passed to the function matches the number of arguments the
// function expects.
Arity() int
}
36 changes: 36 additions & 0 deletions lox_function.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
package glox

// LoxFunction is the representation of the lox function in terms of the interpreter.
// This struct also implements the LoxCallable interface so the runtime can call this
// function.
type LoxFunction struct {
declaration *FunctionStmt
}

func NewLoxFunction(declaration *FunctionStmt) LoxCallable {
return LoxFunction{declaration: declaration}
}

// Call will execute the function body with the arguments passed to it. The parameters are
// core to a function, a function encapsulates its parameters. No other code outside the
// function should see them. This means each function gets its own environment. And this
// environment is generated at runtime during the function call. Then it walks the parameters
// and argument lists and for each pair it creates a new variable with the parameter's name
// and binds it to the argument's value.
func (lf LoxFunction) Call(interpreter *Interpreter, arguments []interface{}) (interface{}, error) {
env := NewEnvironment(interpreter.globals)
for i, param := range lf.declaration.Params {
env.Define(param.Lexeme, arguments[i])
}

return nil, interpreter.executeBlock(lf.declaration.Body, env)
}


func (lf LoxFunction) Arity() int {
return len(lf.declaration.Params)
}

func (lf LoxFunction) String() string {
return "<fn " + lf.declaration.Name.Lexeme + ">"
}
17 changes: 17 additions & 0 deletions native_fn.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
package glox

import "time"

type Clock struct{}

func (c Clock) Call(interpreter *Interpreter, arguments []interface{}) (interface{}, error) {
return float64(time.Now().Unix()), nil
}

func (c Clock) Arity() int {
return 0
}

func (c Clock) String() string {
return "<native fn>"
}
125 changes: 123 additions & 2 deletions parser.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,18 @@ func (p *Parser) Parse() []Stmt {
// while parsing, the parser tries to recover using synchronize and continue parsing the next
// statements.
// declaration --> varDecl
// | funcDeclaration
// | statement
func (p *Parser) declaration() (Stmt, error) {
if p.match(Fun) {
stmt, err := p.function("function")
if err != nil {
return nil, err
}

return stmt, nil
}

if p.match(Var) {
stmt, err := p.varDeclaration()
if err != nil {
Expand All @@ -64,6 +74,61 @@ func (p *Parser) declaration() (Stmt, error) {
return p.statement()
}

// function parses grammar for function declaration. Since we already matched and consumed
// fun keyword, we don't need to do that again. Next we parse the list of parameters with
// the parentheses wrapped around them. The outer if condition handles the zero parameter
// case and the inner for loop parses parameters as long as we find commas to separate them.
// We consume the { at the beginning of the body before calling block, as block() assumes
// brace token has already been consumed. And this way we cal provide a more precise error
// message if the brace is not provided.
func (p *Parser) function(kind string) (Stmt, error) {
name, err := p.consume(Identifiers, "Expect " + kind + " name")
if err != nil {
return nil, err
}

_, err = p.consume(LeftParen, "Expect '(' after " + kind + " name")
if err != nil {
return nil, err
}

parameters := make([]Token, 0)
if !p.check(RightParen) {
for {
if len(parameters) > 255 {
p.error(p.peek(), "Can't have more than 255 parameters")
}

param, err := p.consume(Identifiers, "Expect parameter name")
if err != nil {
return nil, err
}

parameters = append(parameters, param)
if !p.match(Comma) {
break
}
}
}

_, err = p.consume(RightParen, "Expect ')' after parameters")
if err != nil {
return nil, err
}

_, err = p.consume(LeftBrace, "Expect '{' before " + kind + " body")
if err != nil {
return nil, err
}

body, err := p.block()
if err != nil {
return nil, err
}

return &FunctionStmt{Name: name, Body: body, Params: parameters}, nil
}

// varDeclaration parses variable declaration syntax. When the parser matches a var
// keyword, this method is used to parse that statement.
// varDecl → "var" IDENTIFIER ( "=" expression )? ";" ;
Expand Down Expand Up @@ -507,7 +572,7 @@ func (p *Parser) factor() (Expr, error) {

// unary parses an unary expression and primary expression.
// unary --> ( "!" | "-" ) unary
// | primary
// | call
func (p *Parser) unary() (Expr, error) {
if p.match(Bang, Minus) {
operator := p.previous()
Expand All @@ -519,7 +584,63 @@ func (p *Parser) unary() (Expr, error) {
return &Unary{Operator: operator, Right: right}, nil
}

return p.primary()
return p.call()
}

// call parses a function call grammar. This rule matches a primary expression followed by
// zero or more function calls. If there is no parenthesis this matches a bare primary expression.
// The * in the grammar allows calls like fn(1)(2)(3) function calls.
// call --> primary ( "(" arguments? ")")*;
func (p *Parser) call() (Expr, error) {
expr, err := p.primary()
if err != nil {
return nil, err
}

for {
if p.match(LeftParen) {
expr, err = p.finishCall(expr)
if err != nil {
return nil, err
}
} else {
break
}
}

return expr, nil
}

// finishCall is a helper that parses the function arguments. This is more or less
// the grammar for arguments. Except we also check the zero argument condition. If
// we find the ')' as the next token, we don't parse any expression.
// arguments --> expression ( "," expression )*;
func (p *Parser) finishCall(callee Expr) (Expr, error) {
arguments := make([]Expr, 0)
if !p.check(RightParen) {
for {
expr, err := p.expression()
if err != nil {
return nil, err
}

if len(arguments) >= 255 {
p.error(p.peek(), "Can't have more than 255 arguments.")
}

arguments = append(arguments, expr)
if !p.match(Comma) {
break
}
}
}

paren, err := p.consume(RightParen, "Expect ')' after arguments")
if err != nil {
return nil, err
}

return &Call{Callee: callee, Paren: paren, Arguments: arguments}, nil
}

// primary parses the primary expressions, these are of highest level of precedence.
Expand Down
18 changes: 15 additions & 3 deletions stmt.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ type StmtVisitor interface {
VisitVarStmt(expr *VarStmt) error
VisitIfStmt(stmt *IfStmt) error
VisitWhileStmt(stmt *WhileStmt) error
VisitFunctionStmt(stmt *FunctionStmt) error
}

type Block struct {
Expand All @@ -30,6 +31,20 @@ type Expression struct {
Expression Expr
}

func (e *Expression) Accept(visitor StmtVisitor) error {
return visitor.VisitExpressionExpr(e)
}

type FunctionStmt struct {
Name Token
Params []Token
Body []Stmt
}

func (f *FunctionStmt) Accept(visitor StmtVisitor) error {
return visitor.VisitFunctionStmt(f)
}

type IfStmt struct {
Condition Expr
ThenBranch Stmt
Expand All @@ -40,9 +55,6 @@ func (i *IfStmt) Accept(visitor StmtVisitor) error {
return visitor.VisitIfStmt(i)
}

func (e *Expression) Accept(visitor StmtVisitor) error {
return visitor.VisitExpressionExpr(e)
}

type Print struct {
Expression Expr
Expand Down

0 comments on commit 22648e4

Please sign in to comment.