diff --git a/ast/ast.go b/ast/ast.go index ce9bc28..6d40680 100644 --- a/ast/ast.go +++ b/ast/ast.go @@ -110,3 +110,22 @@ type IntegerLiteral struct { func (il *IntegerLiteral) expressionNode() {} func (il *IntegerLiteral) TokenLiteral() string { return il.Token.Literal } func (il *IntegerLiteral) String() string { return il.Token.Literal } + +type PrefixExpression struct { + Token token.Token // The prefix token, e.g. ! + Operator string + Right Expression +} + +func (pe *PrefixExpression) expressionNode() {} +func (pe *PrefixExpression) TokenLiteral() string { return pe.Token.Literal } +func (pe *PrefixExpression) String() string { + var out bytes.Buffer + + out.WriteString("(") + out.WriteString(pe.Operator) + out.WriteString(pe.Right.String()) + out.WriteString(")") + + return out.String() +} diff --git a/parser/parser.go b/parser/parser.go index be358c2..291825f 100644 --- a/parser/parser.go +++ b/parser/parser.go @@ -49,6 +49,8 @@ func New(l *lexer.Lexer) *Parser { p.prefixParseFns = make(map[token.TokenType]prefixParseFn) p.registerPrefix(token.IDENT, p.parseIdentifier) p.registerPrefix(token.INT, p.parseIntegerLiteral) + p.registerPrefix(token.BANG, p.parsePrefixExpression) + p.registerPrefix(token.MINUS, p.parsePrefixExpression) return p } @@ -166,9 +168,15 @@ func (p *Parser) parseExpressionStatement() *ast.ExpressionStatement { return stmt } +func (p *Parser) noPrefixParseFnError(t token.TokenType) { + msg := fmt.Sprintf("no prefix parse function for %s found", t) + p.errors = append(p.errors, msg) +} + func (p *Parser) parseExpression(precedence int) ast.Expression { prefix := p.prefixParseFns[p.curToken.Type] if prefix == nil { + p.noPrefixParseFnError(p.curToken.Type) return nil } leftExp := prefix() @@ -197,3 +205,16 @@ func (p *Parser) parseIntegerLiteral() ast.Expression { return lit } + +func (p *Parser) parsePrefixExpression() ast.Expression { + expression := &ast.PrefixExpression{ + Token: p.curToken, + Operator: p.curToken.Literal, + } + + p.nextToken() + + expression.Right = p.parseExpression(PREFIX) + + return expression +} diff --git a/parser/parser_test.go b/parser/parser_test.go index ef3d0a8..d7584a8 100644 --- a/parser/parser_test.go +++ b/parser/parser_test.go @@ -1,6 +1,7 @@ package parser import ( + "fmt" "monkey_interpreter/ast" "monkey_interpreter/lexer" "testing" @@ -162,3 +163,65 @@ func TestIntegerLiteralExpression(t *testing.T) { literal.TokenLiteral()) } } + +func TestParsingPrefixExpressions(t *testing.T) { + prefixTests := []struct { + input string + operator string + integerValue int64 + }{ + {"!5;", "!", 5}, + {"-15;", "-", 15}, + } + + for _, tt := range prefixTests { + l := lexer.New(tt.input) + p := New(l) + program := p.ParseProgram() + checkParserErrors(t, p) + + if len(program.Statements) != 1 { + t.Fatalf("program.Statements does not contain %d statements. got=%d\n", + 1, len(program.Statements)) + } + + stmt, ok := program.Statements[0].(*ast.ExpressionStatement) + if !ok { + t.Fatalf("program.Statements[0] is not ast.ExpressionStatement. got=%T", + program.Statements[0]) + } + + exp, ok := stmt.Expression.(*ast.PrefixExpression) + if !ok { + t.Fatalf("stmt is not ast.PrefixExpression. got=%T", stmt.Expression) + } + if exp.Operator != tt.operator { + t.Fatalf("exp.Operator is not '%s'. got=%s", + tt.operator, exp.Operator) + } + if !testIntegerLiteral(t, exp.Right, tt.integerValue) { + return + } + } +} + +func testIntegerLiteral(t *testing.T, il ast.Expression, value int64) bool { + integ, ok := il.(*ast.IntegerLiteral) + if !ok { + t.Errorf("il not *ast.IntegerLiteral. got=%T", il) + return false + } + + if integ.Value != value { + t.Errorf("integ.Value not %d. got=%d", value, integ.Value) + return false + } + + if integ.TokenLiteral() != fmt.Sprintf("%d", value) { + t.Errorf("integ.TokenLiteral not %d. got=%s", value, + integ.TokenLiteral()) + return false + } + + return true +}