Skip to content

Commit

Permalink
fixing issue with parsing multiple returns
Browse files Browse the repository at this point in the history
  • Loading branch information
dfirebaugh committed Jan 30, 2024
1 parent e5537f7 commit 4990e83
Show file tree
Hide file tree
Showing 7 changed files with 63 additions and 54 deletions.
11 changes: 0 additions & 11 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,6 @@ The goal is to build a simple language that compiles directly to WebAssembly.

Also, I'm just kind of toying around with making a language.

## Goals
* Target WebAssembly Text format (WAT) as an intermediate representation.
* Strict types.
* Enums.
* User-defined types.
* Built-in tools for testing.
* Easy-to-use build tools.
* Module/package management.
* Runtimes for different use cases.
* Functions are private by default and can easily be exported with the `pub` keyword.

The ideal syntax will look similar to below.
```rust
// addTwo is an exported function that adds two ints together and returns the result.
Expand Down
16 changes: 6 additions & 10 deletions internal/parser/function.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,11 +59,10 @@ func (p *Parser) parseFunctionStatement() *ast.FunctionStatement {
return nil
}

body := p.parseBlockStatement()

if p.curTokenIs(token.RBRACE) {
if p.curTokenIs(token.RPAREN) {
p.nextToken()
}
body := p.parseBlockStatement()

stmt := &ast.FunctionStatement{
IsExported: isExported,
Expand Down Expand Up @@ -132,6 +131,7 @@ func (p *Parser) parseFunctionCall(function ast.Expression) ast.Expression {
Token: p.curToken,
Function: function,
}
p.nextToken()
exp.Arguments = p.parseFunctionCallArguments()
p.nextToken()
return exp
Expand All @@ -140,15 +140,15 @@ func (p *Parser) parseFunctionCall(function ast.Expression) ast.Expression {
func (p *Parser) parseFunctionCallArguments() []ast.Expression {
args := []ast.Expression{}

if p.curTokenIs(token.LPAREN) {
p.nextToken()
}
if p.peekTokenIs(token.RPAREN) {
p.nextToken()
return args
}

p.nextToken()
p.nextToken()
args = append(args, p.parseExpression(LOWEST))

for p.peekTokenIs(token.COMMA) {
p.nextToken()
p.nextToken()
Expand All @@ -167,7 +167,6 @@ func (p *Parser) parseReturnStatement() *ast.ReturnStatement {
}

for !p.curTokenIs(token.RPAREN) && !p.curTokenIs(token.EOF) {
println(p.curToken.Literal)
expr := p.parseExpression(LOWEST)
if expr == nil {
p.errors = append(p.errors, "expected expression in return statement")
Expand All @@ -179,9 +178,6 @@ func (p *Parser) parseReturnStatement() *ast.ReturnStatement {
if p.peekTokenIs(token.COMMA) {
p.nextToken()
}
if p.peekTokenIs(token.RBRACE) {
break
}
p.nextToken()
}

Expand Down
2 changes: 1 addition & 1 deletion internal/parser/function_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ import (
)

func TestParseMultipleReturnTypes(t *testing.T) {
input := `pub (i8, string) myFunction(i32 x, bool y) { return (5, "hello") }`
input := `pub (i8, string) myFunction(i32 x, bool y) { return 5, "hello" }`
l := lexer.New(input)
p := New(l)

Expand Down
15 changes: 3 additions & 12 deletions internal/parser/parser_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -244,13 +244,12 @@ func TestParseGroupedExpression(t *testing.T) {
}

func TestParseIfStatement(t *testing.T) {
input := "if (x < y) { x }"
input := "if (x < y) { return x - y }"
lexer := lexer.New(input)
parser := New(lexer)

program := parser.ParseProgram()
checkParserErrors(t, parser)

if len(program.Statements) != 1 {
t.Fatalf("program has wrong number of statements. got=%d", len(program.Statements))
}
Expand All @@ -269,17 +268,9 @@ func TestParseIfStatement(t *testing.T) {
return
}

cons, ok := stmt.Consequence.Statements[0].(*ast.BlockStatement)
_, ok = stmt.Consequence.Statements[0].(*ast.ReturnStatement)
if !ok {
t.Fatalf("Statements[0] is not ast.ExpressionStatement. got=%T", stmt.Consequence.Statements[0])
}

if cons.Statements[0].(*ast.ExpressionStatement).Expression.(*ast.Identifier).Token.Literal != "x" {
t.Fatalf("cons.Statements[0] is not ast.ExpressionStatement. got=%T", cons.Statements[0])
}

if stmt.Alternative != nil {
t.Errorf("exp.Alternative was not nil. got=%+v", stmt.Alternative)
t.Fatalf("Statements[0] is not *ast.Return. got=%T", stmt.Consequence.Statements[0])
}
}

Expand Down
13 changes: 7 additions & 6 deletions internal/parser/parsers.go
Original file line number Diff line number Diff line change
Expand Up @@ -207,23 +207,21 @@ func (p *Parser) parseIfStatement() *ast.IfStatement {
return nil
}

if p.curTokenIs(token.RPAREN) {
p.nextToken()
}
expression.Consequence = p.parseBlockStatement()

if p.peekTokenIs(token.ELSE) {
if p.curTokenIs(token.RBRACE) {
p.nextToken()
}
p.nextToken()

if !p.expectPeek(token.LBRACE) {
return nil
}

expression.Alternative = p.parseBlockStatement()
if p.curTokenIs(token.RBRACE) {
p.nextToken()
}
}

return expression
}

Expand All @@ -240,6 +238,9 @@ func (p *Parser) parseBlockStatement() *ast.BlockStatement {
}
switch stmt.(type) {
case *ast.ReturnStatement:
if p.curTokenIs(token.RPAREN) {
p.nextToken()
}
return block
}
p.nextToken()
Expand Down
25 changes: 11 additions & 14 deletions internal/wat/wat.go
Original file line number Diff line number Diff line change
Expand Up @@ -130,15 +130,14 @@ func generateStringLiteral(str *ast.StringLiteral) string {

func generateFunctionCall(call *ast.FunctionCall) string {
var out strings.Builder
out.WriteString(fmt.Sprintf("(call $%s ", call.FunctionName))
out.WriteString(fmt.Sprintf("\n\t\t(call $%s", call.FunctionName))
for _, arg := range call.Arguments {
out.WriteString(generateExpression(arg))
out.WriteString(" ")
out.WriteString(generateExpression(arg))
}
out.WriteString(")")
return out.String()
}

func generateInfixExpression(infix *ast.InfixExpression) string {
left := generateExpression(infix.Left)
right := generateExpression(infix.Right)
Expand All @@ -162,7 +161,6 @@ func generateInfixExpression(infix *ast.InfixExpression) string {
case token.NOT_EQ:
return fmt.Sprintf("(i32.ne %s %s)", left, right)
default:
println("\t\t operator:", operator)
return fmt.Sprintf(";; unhandled operator: %s\n", operator)
}
}
Expand Down Expand Up @@ -196,16 +194,16 @@ func generateIfStatement(e *ast.IfStatement) string {
}
var out strings.Builder
out.WriteString("\t\t(if ")
out.WriteString(generateExpression(e.Condition)) // Generates the condition expression
out.WriteString(" (then ")
out.WriteString(generateBlockStatement(e.Consequence)) // Generates the consequent block
out.WriteString(" )")
out.WriteString(generateExpression(e.Condition))
out.WriteString("\n\t\t\t(then\n")
out.WriteString(generateBlockStatement(e.Consequence))
out.WriteString("\n\t\t\t)")
if e.Alternative != nil {
out.WriteString(" \t\t(else ")
out.WriteString(generateBlockStatement(e.Alternative)) // Generates the alternative block, if it exists
out.WriteString(" )")
out.WriteString("\n\t\t\t(else\n")
out.WriteString(generateBlockStatement(e.Alternative))
out.WriteString("\n\t\t\t)")
}
out.WriteString(")")
out.WriteString("\n\t\t)")
return out.String()
}

Expand All @@ -223,7 +221,7 @@ func generateFunctionStatement(s *ast.FunctionStatement) string {
}

for _, param := range s.Parameters {
out.WriteString(fmt.Sprintf("(param $%s i32) ", param.Identifier.Value)) // Assuming i32 for simplicity.
out.WriteString(fmt.Sprintf("(param $%s i32) ", param.Identifier.Value))
}
out.WriteString(fmt.Sprintf("(result %s)\n", returnType))
out.WriteString(generateBlockStatement(s.Body))
Expand All @@ -235,7 +233,6 @@ func generateFunctionStatement(s *ast.FunctionStatement) string {
func generateReturnStatement(s *ast.ReturnStatement) string {
if s == nil {
log.Println("Encountered nil *ast.ReturnStatement")
println("returned early from generateReturnStatement ")
return ""
}

Expand Down
35 changes: 35 additions & 0 deletions internal/wat/wat_test.go
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
package wat_test

import (
"strings"
"testing"

"github.com/stretchr/testify/assert"

"github.com/dfirebaugh/punch/internal/ast"
"github.com/dfirebaugh/punch/internal/lexer"
"github.com/dfirebaugh/punch/internal/parser"
"github.com/dfirebaugh/punch/internal/token"
"github.com/dfirebaugh/punch/internal/wat"
)
Expand Down Expand Up @@ -113,3 +116,35 @@ func _TestGenerateWAT(t *testing.T) {
assert.Equal(t, expected, wat.GenerateWAT(ast, false))
})
}

func TestFunctionDeclaration(t *testing.T) {
input := `
bool is_eq(i32 a, i32 b) {
return (a == b)
}
pub i32 add_two(i32 x, i32 y) {
return x + y
}
pub i32 add_four(i32 a, i32 b, i32 c, i32 d) {
if !is_eq(a, c) {
return (a - b - c - d)
}
return a + b + c + d
}
`

l := lexer.New(input)
p := parser.New(l)
program := p.ParseProgram()

expected := "(module\n\t(func $is_eq (param $a i32) (param $b i32) (result i32)\n\t\t\t(return (i32.eq (local.get $a) (local.get $b)))\n\n)\n\t(func $add_two (export \"add_two\") (param $x i32) (param $y i32) (result i32)\n\t\t\t(return (i32.add (local.get $x) (local.get $y)))\n\n)\n\t(func $add_four (export \"add_four\") (param $a i32) (param $b i32) (param $c i32) (param $d i32) (result i32)\n\t\t\t(if (i32.eqz \n\t\t(call $is_eq (local.get $a) (local.get $c)))\n\t\t\t(then\n\t\t\t(return (i32.sub (i32.sub (i32.sub (local.get $a) (local.get $b)) (local.get $c)) (local.get $d)))\n\n\n\t\t\t)\n\t\t)\n\t\t\t(return (i32.add (i32.add (i32.add (local.get $a) (local.get $b)) (local.get $c)) (local.get $d)))\n\n)\n)"
w := wat.GenerateWAT(program, false)
t.Run("generate WAT code for a program with one function declaration", func(t *testing.T) {
assert.Equal(
t,
strings.Trim(strings.Trim(expected, "\n"), "\t"),
strings.Trim(strings.Trim(w, "\n"), "\t"))
})
}

0 comments on commit 4990e83

Please sign in to comment.