From 0850ed63127297f3cff3532df2be1afc23ccb6d2 Mon Sep 17 00:00:00 2001 From: Neeilan Selvalingam Date: Sat, 4 Aug 2018 15:55:55 -0700 Subject: [PATCH] Added support for lambda functions --- include/ast_deleter.hpp | 1 + include/ast_printer.hpp | 1 + include/expr.hpp | 22 +++++++++++++++++++++ include/interpreter.h | 1 + include/parser.h | 1 + include/resolver.hpp | 3 ++- include/token.hpp | 2 +- include/visitable_types.hpp | 1 + include/visitor.h | 1 + src/ast_deleter.cpp | 7 +++++++ src/ast_printer.cpp | 4 ++++ src/interpreter.cpp | 12 +++++++++++- src/parser.cpp | 31 ++++++++++++++++++++++++++++++ src/resolver.cpp | 10 ++++++++++ src/scanner.cpp | 38 ++++++++++++++++++++----------------- src/token.cpp | 2 +- test/functional_tests.py | 5 +++++ test/simple_lambdas.lox | 30 +++++++++++++++++++++++++++++ 18 files changed, 151 insertions(+), 21 deletions(-) create mode 100644 test/simple_lambdas.lox diff --git a/include/ast_deleter.hpp b/include/ast_deleter.hpp index adeeb26..0431ae8 100644 --- a/include/ast_deleter.hpp +++ b/include/ast_deleter.hpp @@ -29,6 +29,7 @@ class AstDeleter : public ExprVisitor, void visit(const Get* expr); void visit(const Set* expr); void visit(const This* expr); + void visit(const Lambda* expr); void visit(const BlockStmt*); void visit(const ExprStmt*); diff --git a/include/ast_printer.hpp b/include/ast_printer.hpp index 454ef20..e50987f 100644 --- a/include/ast_printer.hpp +++ b/include/ast_printer.hpp @@ -20,6 +20,7 @@ class AstPrinter : public ExprVisitor { std::string visit(const BoolLiteral* expr); std::string visit(const Variable* expr); std::string visit(const Logical* expr); + std::string visit(const Lambda* expr); private: std::string parenthesize(std::string, const Expr*); // todo: change these to variadic template diff --git a/include/expr.hpp b/include/expr.hpp index a7111c2..cadeef1 100644 --- a/include/expr.hpp +++ b/include/expr.hpp @@ -313,5 +313,27 @@ class Set : public Expr { const Expr& value; }; +class Lambda : public Expr { +public: + explicit Lambda(std::vector parameters, std::vector body) + : parameters(parameters), + body(body) {} + + virtual void accept(ExprVisitor* visitor) const { + return visitor->visit(this); + } + + virtual std::string accept(ExprVisitor* visitor) const { + return visitor->visit(this); + } + + virtual shared_ptr accept(ExprVisitor >* visitor) const { + return visitor->visit(this); + } + + const std::vector parameters; + const std::vector body; +}; + #endif diff --git a/include/interpreter.h b/include/interpreter.h index 55d854f..f85effd 100644 --- a/include/interpreter.h +++ b/include/interpreter.h @@ -38,6 +38,7 @@ class Interpreter : public ExprVisitor>, shared_ptr visit(const Get* expr); shared_ptr visit(const Set* expr); shared_ptr visit(const This* expr); + shared_ptr visit(const Lambda* expr); void visit(const BlockStmt*); void visit(const ExprStmt*); diff --git a/include/parser.h b/include/parser.h index e38b6b3..38c0883 100644 --- a/include/parser.h +++ b/include/parser.h @@ -45,6 +45,7 @@ class Parser { Expr* call(); Expr *finish_call(Expr *caller); Expr* primary(); + Expr* lambda_expr(); Stmt* declaration(); Stmt* var_declaration(); diff --git a/include/resolver.hpp b/include/resolver.hpp index 591e939..0a70316 100644 --- a/include/resolver.hpp +++ b/include/resolver.hpp @@ -12,7 +12,7 @@ #include "visitable_types.hpp" enum FunctionType { - NOT_IN_FN, METHOD, FUNCTION, INITIALIZER + NOT_IN_FN, METHOD, FUNCTION, INITIALIZER, LAMBDA_FN }; enum ClassType { @@ -46,6 +46,7 @@ class Resolver : public ExprVisitor, public StmtVisitor{ void visit(const Get* expr); void visit(const Set* expr); void visit(const This* expr); + void visit(const Lambda* expr); void visit(const BlockStmt*); void visit(const ExprStmt*); diff --git a/include/token.hpp b/include/token.hpp index 7dc9adf..c7fde44 100644 --- a/include/token.hpp +++ b/include/token.hpp @@ -18,7 +18,7 @@ enum TokenType { IDENTIFIER, STRING, NUMBER, // Keywords. - AND, CLASS, ELSE, FALSE, FUN, FOR, IF, NIL, OR, + AND, CLASS, ELSE, FALSE, FUN, LAMBDA, FOR, IF, NIL, OR, PRINT, RETURN, SUPER, THIS, TRUE, VAR, WHILE, END_OF_FILE }; diff --git a/include/visitable_types.hpp b/include/visitable_types.hpp index f93add9..f7693a3 100644 --- a/include/visitable_types.hpp +++ b/include/visitable_types.hpp @@ -15,6 +15,7 @@ class Call; class Get; class Set; class This; +class Lambda; class Stmt; class ExprStmt; diff --git a/include/visitor.h b/include/visitor.h index 7f97da3..7aac999 100644 --- a/include/visitor.h +++ b/include/visitor.h @@ -38,6 +38,7 @@ class ExprVisitor { virtual T visit(const Get*) = 0; virtual T visit(const Set*) = 0; virtual T visit(const This*) = 0; + virtual T visit(const Lambda*) = 0; }; #endif //LOXPP_VISITOR_H diff --git a/src/ast_deleter.cpp b/src/ast_deleter.cpp index 82303b5..abb9a5a 100644 --- a/src/ast_deleter.cpp +++ b/src/ast_deleter.cpp @@ -139,4 +139,11 @@ void AstDeleter::visit(const Set *expr) { void AstDeleter::visit(const This *expr) { exprs_to_delete.insert(static_cast(expr)); +} + +void AstDeleter::visit(const Lambda *expr) { + for (Stmt* stmt : expr->body) { + stmt->accept(this); + } + exprs_to_delete.insert(static_cast(expr)); } \ No newline at end of file diff --git a/src/ast_printer.cpp b/src/ast_printer.cpp index 4438430..a0a10f7 100644 --- a/src/ast_printer.cpp +++ b/src/ast_printer.cpp @@ -46,6 +46,10 @@ std::string AstPrinter::visit(const Logical *expr) { return parenthesize(expr->op.lexeme, &expr->left, &expr->right); } +std::string AstPrinter::visit(const Lambda *expr) { + return ""; +} + std::string AstPrinter::parenthesize(std::string name, const Expr* expr) { return "(" + name + " " + expr->accept(this) + ")"; } diff --git a/src/interpreter.cpp b/src/interpreter.cpp index b77a256..7e399e9 100644 --- a/src/interpreter.cpp +++ b/src/interpreter.cpp @@ -366,7 +366,7 @@ shared_ptr Interpreter::visit(const Get *expr) { shared_ptr Interpreter::visit(const Set *expr) { shared_ptr object = evaluate(expr->callee); - if (!(object->kind == InterpreterResult::ResultType::INSTANCE)) { + if (object->kind != InterpreterResult::ResultType::INSTANCE) { throw RuntimeErr(expr->name, "Only instances have fields."); } @@ -380,6 +380,16 @@ shared_ptr Interpreter::visit(const This *expr) { return lookup_variable(expr->keyword, expr); } +shared_ptr Interpreter::visit(const Lambda *expr) { + shared_ptr result = std::make_shared(); + result->function = new FuncStmt(Token(LAMBDA, "lambda", "", 0), expr->parameters, expr->body); + result->kind = InterpreterResult::ResultType::FUNCTION; + result->arity = expr->parameters.size(); + result->closure = environment; + result->callable = true; + return result; +} + bool Interpreter::is_truthy(const InterpreterResult &expr) { if (expr.kind == InterpreterResult::ResultType::NIL) { return false; diff --git a/src/parser.cpp b/src/parser.cpp index 1dbf50a..56f8b20 100644 --- a/src/parser.cpp +++ b/src/parser.cpp @@ -403,7 +403,38 @@ Expr* Parser::finish_call(Expr* callee) { return new Call(*callee, paren, args); } +Expr* Parser::lambda_expr() { + /* + * Lambda syntax: + * lambda (args) { return expr; } + * ex : list_to_double.map(lambda (x) { return x * 2; }) + * */ + + consume(LEFT_PAREN, "Expect '(' after lambda keyword"); + + std::vector parameters; + if (!check(RIGHT_PAREN)) { + do { + if (parameters.size() >= 8) { + error(peek(), "Cannot have more than 8 parameters."); + } + + parameters.push_back(consume(IDENTIFIER, "Expect parameter name.")); + } while (match({COMMA})); + } + + consume(RIGHT_PAREN, "Expect ')' after parameters."); + + consume(LEFT_BRACE, "Expect '{' before lambda body."); + + std::vector body; + body.push_back(block_statement()); + + return new Lambda(parameters, body); +} + Expr* Parser::primary() { + if (match({LAMBDA})) return lambda_expr(); if (match({THIS})) return new This(previous()); if (match({FALSE})) return new BoolLiteral(false); if (match({TRUE})) return new BoolLiteral(true); diff --git a/src/resolver.cpp b/src/resolver.cpp index 30543b9..80bcda2 100644 --- a/src/resolver.cpp +++ b/src/resolver.cpp @@ -132,6 +132,16 @@ void Resolver::resolve_fn(FunctionType declaration, const FuncStmt *fn) { current_function = enclosing_function; } +void Resolver::visit(const Lambda *expr) { + auto fn = new FuncStmt( + Token(LAMBDA, "lambda", "", 0), + expr->parameters, + expr->body); + + resolve_fn(FUNCTION, fn); + delete fn; +} + void Resolver::begin_scope() { scopes.push_back(new std::map()); } diff --git a/src/scanner.cpp b/src/scanner.cpp index 6b955fa..0388bee 100644 --- a/src/scanner.cpp +++ b/src/scanner.cpp @@ -7,22 +7,23 @@ #include const std::map Scanner::keywords = { - { "and", AND }, - { "class", CLASS }, - { "else", ELSE }, - { "false", FALSE }, - { "for", FOR }, - { "fun", FUN }, - { "if", IF }, - { "nil", NIL }, - { "or", OR }, - { "print", PRINT }, - { "return",RETURN }, - { "super", SUPER }, - { "this", THIS }, - { "true", TRUE }, - { "var", VAR }, - { "while", WHILE } + { "and", AND }, + { "class", CLASS }, + { "else", ELSE }, + { "false", FALSE }, + { "for", FOR }, + { "fun", FUN }, + { "lambda", LAMBDA }, + { "if", IF }, + { "nil", NIL }, + { "or", OR }, + { "print", PRINT }, + { "return", RETURN }, + { "super", SUPER }, + { "this", THIS }, + { "true", TRUE }, + { "var", VAR }, + { "while", WHILE } }; @@ -68,7 +69,10 @@ void Scanner::scan_token() { } else if (match('*')) { // A /* multi-line comment bool in_comment = true; while (in_comment && !is_at_end()) { - while (!match('*') && !is_at_end()) advance(); + while (!match('*') && !is_at_end()) { + if (peek() == '\n') line++; + advance(); + } // Matched a * - comment ends if we match a / in_comment = !match('/'); } diff --git a/src/token.cpp b/src/token.cpp index c1977d8..1f14aae 100644 --- a/src/token.cpp +++ b/src/token.cpp @@ -16,7 +16,7 @@ static std::vector token_names{ "IDENTIFIER", "STRING", "NUMBER", - "AND", "CLASS", "ELSE", "FALSE", "FUN", "FOR", "IF", "NIL", "OR", + "AND", "CLASS", "ELSE", "FALSE", "FUN", "LAMBDA", "FOR", "IF", "NIL", "OR", "PRINT", "RETURN", "SUPER", "THIS", "TRUE", "VAR", "WHILE", "EOF" }; diff --git a/test/functional_tests.py b/test/functional_tests.py index 0d7f450..6f98441 100644 --- a/test/functional_tests.py +++ b/test/functional_tests.py @@ -82,6 +82,11 @@ def test_multiline_comments(self): output = run_file(absolute_path('multiline_comments.lox')) self.assertEqual(expected, output) + def test_simple_lambdas(self): + expected = 'Interpreter output:\n2\n4\n6\n6\n-1\n-2\n-3\n2\n4\n6\n-3\n-6\n-9\n1\n2\n3' + output = run_file(absolute_path('simple_lambdas.lox')) + self.assertEqual(expected, output) + class LoxppOutputErrorsTest(unittest.TestCase): diff --git a/test/simple_lambdas.lox b/test/simple_lambdas.lox new file mode 100644 index 0000000..a943041 --- /dev/null +++ b/test/simple_lambdas.lox @@ -0,0 +1,30 @@ +fun from1to3(fn) { + for (var i = 1; i <= 3; i = i + 1) { + fn(i); + } +} +// Simple lambda +from1to3(lambda (a) { print a * 2; }); + +// Modifying bound state in closure +var sum = 0; +from1to3(lambda (curr) { sum = sum + curr; }); +print sum; + +// Call a bound function in closure +fun getNegative(x) { return -x; } +from1to3(lambda (x) { print getNegative(x); }); + +// Return a lambda from a function +fun getDoublingLambda() { + return lambda (i) { print i * 2; }; +} +from1to3(getDoublingLambda()); + +// Declare variable inside a lambda +var scalingFactor = 3; +from1to3(lambda (x) { var scalingFactor = -3; print (x * scalingFactor); }); + +// Assign a lambda to a variable +var printLambda = lambda (x) { print x; }; +from1to3(printLambda);