Skip to content

Commit

Permalink
Added support for lambda functions
Browse files Browse the repository at this point in the history
  • Loading branch information
neeilan committed Aug 4, 2018
1 parent 621301f commit 0850ed6
Show file tree
Hide file tree
Showing 18 changed files with 151 additions and 21 deletions.
1 change: 1 addition & 0 deletions include/ast_deleter.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ class AstDeleter : public ExprVisitor<void>,
void visit(const Get* expr);
void visit(const Set* expr);
void visit(const This* expr);
void visit(const Lambda* expr);

void visit(const BlockStmt*);
void visit(const ExprStmt*);
Expand Down
1 change: 1 addition & 0 deletions include/ast_printer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ class AstPrinter : public ExprVisitor<std::string> {
std::string visit(const BoolLiteral* expr);
std::string visit(const Variable* expr);
std::string visit(const Logical* expr);
std::string visit(const Lambda* expr);

private:
std::string parenthesize(std::string, const Expr*); // todo: change these to variadic template
Expand Down
22 changes: 22 additions & 0 deletions include/expr.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -313,5 +313,27 @@ class Set : public Expr {
const Expr& value;
};

class Lambda : public Expr {
public:
explicit Lambda(std::vector<Token> parameters, std::vector<Stmt*> body)
: parameters(parameters),
body(body) {}

virtual void accept(ExprVisitor<void>* visitor) const {
return visitor->visit(this);
}

virtual std::string accept(ExprVisitor<std::string>* visitor) const {
return visitor->visit(this);
}

virtual shared_ptr<InterpreterResult> accept(ExprVisitor<shared_ptr<InterpreterResult> >* visitor) const {
return visitor->visit(this);
}

const std::vector<Token> parameters;
const std::vector<Stmt*> body;
};

#endif

1 change: 1 addition & 0 deletions include/interpreter.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ class Interpreter : public ExprVisitor<shared_ptr<InterpreterResult>>,
shared_ptr<InterpreterResult> visit(const Get* expr);
shared_ptr<InterpreterResult> visit(const Set* expr);
shared_ptr<InterpreterResult> visit(const This* expr);
shared_ptr<InterpreterResult> visit(const Lambda* expr);

void visit(const BlockStmt*);
void visit(const ExprStmt*);
Expand Down
1 change: 1 addition & 0 deletions include/parser.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ class Parser {
Expr* call();
Expr *finish_call(Expr *caller);
Expr* primary();
Expr* lambda_expr();

Stmt* declaration();
Stmt* var_declaration();
Expand Down
3 changes: 2 additions & 1 deletion include/resolver.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
#include "visitable_types.hpp"

enum FunctionType {
NOT_IN_FN, METHOD, FUNCTION, INITIALIZER
NOT_IN_FN, METHOD, FUNCTION, INITIALIZER, LAMBDA_FN
};

enum ClassType {
Expand Down Expand Up @@ -46,6 +46,7 @@ class Resolver : public ExprVisitor<void>, public StmtVisitor{
void visit(const Get* expr);
void visit(const Set* expr);
void visit(const This* expr);
void visit(const Lambda* expr);

void visit(const BlockStmt*);
void visit(const ExprStmt*);
Expand Down
2 changes: 1 addition & 1 deletion include/token.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ enum TokenType {
IDENTIFIER, STRING, NUMBER,

// Keywords.
AND, CLASS, ELSE, FALSE, FUN, FOR, IF, NIL, OR,
AND, CLASS, ELSE, FALSE, FUN, LAMBDA, FOR, IF, NIL, OR,
PRINT, RETURN, SUPER, THIS, TRUE, VAR, WHILE, END_OF_FILE
};

Expand Down
1 change: 1 addition & 0 deletions include/visitable_types.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ class Call;
class Get;
class Set;
class This;
class Lambda;

class Stmt;
class ExprStmt;
Expand Down
1 change: 1 addition & 0 deletions include/visitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ class ExprVisitor {
virtual T visit(const Get*) = 0;
virtual T visit(const Set*) = 0;
virtual T visit(const This*) = 0;
virtual T visit(const Lambda*) = 0;
};

#endif //LOXPP_VISITOR_H
7 changes: 7 additions & 0 deletions src/ast_deleter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -139,4 +139,11 @@ void AstDeleter::visit(const Set *expr) {

void AstDeleter::visit(const This *expr) {
exprs_to_delete.insert(static_cast<const Expr*>(expr));
}

void AstDeleter::visit(const Lambda *expr) {
for (Stmt* stmt : expr->body) {
stmt->accept(this);
}
exprs_to_delete.insert(static_cast<const Expr*>(expr));
}
4 changes: 4 additions & 0 deletions src/ast_printer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,10 @@ std::string AstPrinter::visit(const Logical *expr) {
return parenthesize(expr->op.lexeme, &expr->left, &expr->right);
}

std::string AstPrinter::visit(const Lambda *expr) {
return "<anonymous lambda>";
}

std::string AstPrinter::parenthesize(std::string name, const Expr* expr) {
return "(" + name + " " + expr->accept(this) + ")";
}
Expand Down
12 changes: 11 additions & 1 deletion src/interpreter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -366,7 +366,7 @@ shared_ptr<InterpreterResult> Interpreter::visit(const Get *expr) {
shared_ptr<InterpreterResult> Interpreter::visit(const Set *expr) {
shared_ptr<InterpreterResult> object = evaluate(expr->callee);

if (!(object->kind == InterpreterResult::ResultType::INSTANCE)) {
if (object->kind != InterpreterResult::ResultType::INSTANCE) {
throw RuntimeErr(expr->name, "Only instances have fields.");
}

Expand All @@ -380,6 +380,16 @@ shared_ptr<InterpreterResult> Interpreter::visit(const This *expr) {
return lookup_variable(expr->keyword, expr);
}

shared_ptr<InterpreterResult> Interpreter::visit(const Lambda *expr) {
shared_ptr<InterpreterResult> result = std::make_shared<InterpreterResult>();
result->function = new FuncStmt(Token(LAMBDA, "lambda", "", 0), expr->parameters, expr->body);
result->kind = InterpreterResult::ResultType::FUNCTION;
result->arity = expr->parameters.size();
result->closure = environment;
result->callable = true;
return result;
}

bool Interpreter::is_truthy(const InterpreterResult &expr) {
if (expr.kind == InterpreterResult::ResultType::NIL) {
return false;
Expand Down
31 changes: 31 additions & 0 deletions src/parser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -403,7 +403,38 @@ Expr* Parser::finish_call(Expr* callee) {
return new Call(*callee, paren, args);
}

Expr* Parser::lambda_expr() {
/*
* Lambda syntax:
* lambda (args) { return expr; }
* ex : list_to_double.map(lambda (x) { return x * 2; })
* */

consume(LEFT_PAREN, "Expect '(' after lambda keyword");

std::vector<Token> parameters;
if (!check(RIGHT_PAREN)) {
do {
if (parameters.size() >= 8) {
error(peek(), "Cannot have more than 8 parameters.");
}

parameters.push_back(consume(IDENTIFIER, "Expect parameter name."));
} while (match({COMMA}));
}

consume(RIGHT_PAREN, "Expect ')' after parameters.");

consume(LEFT_BRACE, "Expect '{' before lambda body.");

std::vector<Stmt*> body;
body.push_back(block_statement());

return new Lambda(parameters, body);
}

Expr* Parser::primary() {
if (match({LAMBDA})) return lambda_expr();
if (match({THIS})) return new This(previous());
if (match({FALSE})) return new BoolLiteral(false);
if (match({TRUE})) return new BoolLiteral(true);
Expand Down
10 changes: 10 additions & 0 deletions src/resolver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,16 @@ void Resolver::resolve_fn(FunctionType declaration, const FuncStmt *fn) {
current_function = enclosing_function;
}

void Resolver::visit(const Lambda *expr) {
auto fn = new FuncStmt(
Token(LAMBDA, "lambda", "", 0),
expr->parameters,
expr->body);

resolve_fn(FUNCTION, fn);
delete fn;
}

void Resolver::begin_scope() {
scopes.push_back(new std::map<std::string, bool>());
}
Expand Down
38 changes: 21 additions & 17 deletions src/scanner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,22 +7,23 @@
#include <vector>

const std::map<std::string, TokenType> Scanner::keywords = {
{ "and", AND },
{ "class", CLASS },
{ "else", ELSE },
{ "false", FALSE },
{ "for", FOR },
{ "fun", FUN },
{ "if", IF },
{ "nil", NIL },
{ "or", OR },
{ "print", PRINT },
{ "return",RETURN },
{ "super", SUPER },
{ "this", THIS },
{ "true", TRUE },
{ "var", VAR },
{ "while", WHILE }
{ "and", AND },
{ "class", CLASS },
{ "else", ELSE },
{ "false", FALSE },
{ "for", FOR },
{ "fun", FUN },
{ "lambda", LAMBDA },
{ "if", IF },
{ "nil", NIL },
{ "or", OR },
{ "print", PRINT },
{ "return", RETURN },
{ "super", SUPER },
{ "this", THIS },
{ "true", TRUE },
{ "var", VAR },
{ "while", WHILE }
};


Expand Down Expand Up @@ -68,7 +69,10 @@ void Scanner::scan_token() {
} else if (match('*')) { // A /* multi-line comment
bool in_comment = true;
while (in_comment && !is_at_end()) {
while (!match('*') && !is_at_end()) advance();
while (!match('*') && !is_at_end()) {
if (peek() == '\n') line++;
advance();
}
// Matched a * - comment ends if we match a /
in_comment = !match('/');
}
Expand Down
2 changes: 1 addition & 1 deletion src/token.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ static std::vector<std::string> token_names{

"IDENTIFIER", "STRING", "NUMBER",

"AND", "CLASS", "ELSE", "FALSE", "FUN", "FOR", "IF", "NIL", "OR",
"AND", "CLASS", "ELSE", "FALSE", "FUN", "LAMBDA", "FOR", "IF", "NIL", "OR",
"PRINT", "RETURN", "SUPER", "THIS", "TRUE", "VAR", "WHILE", "EOF"
};

Expand Down
5 changes: 5 additions & 0 deletions test/functional_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,11 @@ def test_multiline_comments(self):
output = run_file(absolute_path('multiline_comments.lox'))
self.assertEqual(expected, output)

def test_simple_lambdas(self):
expected = 'Interpreter output:\n2\n4\n6\n6\n-1\n-2\n-3\n2\n4\n6\n-3\n-6\n-9\n1\n2\n3'
output = run_file(absolute_path('simple_lambdas.lox'))
self.assertEqual(expected, output)


class LoxppOutputErrorsTest(unittest.TestCase):

Expand Down
30 changes: 30 additions & 0 deletions test/simple_lambdas.lox
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
fun from1to3(fn) {
for (var i = 1; i <= 3; i = i + 1) {
fn(i);
}
}
// Simple lambda
from1to3(lambda (a) { print a * 2; });

// Modifying bound state in closure
var sum = 0;
from1to3(lambda (curr) { sum = sum + curr; });
print sum;

// Call a bound function in closure
fun getNegative(x) { return -x; }
from1to3(lambda (x) { print getNegative(x); });

// Return a lambda from a function
fun getDoublingLambda() {
return lambda (i) { print i * 2; };
}
from1to3(getDoublingLambda());

// Declare variable inside a lambda
var scalingFactor = 3;
from1to3(lambda (x) { var scalingFactor = -3; print (x * scalingFactor); });

// Assign a lambda to a variable
var printLambda = lambda (x) { print x; };
from1to3(printLambda);

0 comments on commit 0850ed6

Please sign in to comment.