Skip to content

Commit

Permalink
feat: Constant fold binary expressions
Browse files Browse the repository at this point in the history
  • Loading branch information
Marwes committed Aug 11, 2017
1 parent cc27f0d commit 10b77ac
Show file tree
Hide file tree
Showing 3 changed files with 134 additions and 38 deletions.
3 changes: 2 additions & 1 deletion vm/src/core/grammar.lalrpop
Expand Up @@ -16,7 +16,8 @@ Comma<Rule>: Vec<Rule> =


Identifier: Symbol = {
<r"[A-Za-z_][A-Za-z0-9_]*"> => symbols.symbol(<>)
<r"[A-Za-z_][A-Za-z0-9_]*"> => symbols.symbol(<>),
<r"\(#?[A-Za-z_]*[+\-*/]\)"> => symbols.symbol(&<>[1..<>.len() - 1])
};

Field: (Symbol, Option<Symbol>) = {
Expand Down
150 changes: 120 additions & 30 deletions vm/src/core/interpreter.rs
@@ -1,5 +1,5 @@
use std::ops::{Deref, DerefMut};
use base::ast::{Typed, TypedIdent};
use base::ast::{Literal, Typed, TypedIdent};
use base::fnv::FnvSet;
use base::resolve;
use base::kind::{ArcKind, KindEnv};
Expand Down Expand Up @@ -80,25 +80,39 @@ impl<'e> Visitor<'e, 'e> for FreeVars {
}
}

pub struct Pure(bool);
pub struct Pure<'a, 'l: 'a, 'g: 'a>(bool, &'a Compiler<'g, 'l>, &'a mut FunctionEnvs<'l, 'g>);

impl<'e> Visitor<'e, 'e> for Pure {
type Producer = SameLifetime<'e>;
impl<'a, 'l, 'g, 'expr> Visitor<'l, 'expr> for Pure<'a, 'l, 'g> {
type Producer = DifferentLifetime<'l, 'expr>;

fn visit_expr(&mut self, expr: CExpr<'e>) -> Option<CExpr<'e>> {
fn visit_expr(&mut self, expr: CExpr<'expr>) -> Option<CExpr<'l>> {
if !self.0 {
return None;
}
match *expr {
Expr::Call(..) => {
// FIXME Don't treat all function calls as impure
self.0 = false;
Expr::Ident(ref id, ..) => {
match self.1.find(&id.name, self.2) {
Some(variable) => match variable {
Variable::Stack(Some(expr)) => {
self.visit_expr(expr.as_ref());
}
Variable::Stack(None) => self.0 = false,
Variable::Global(expr) => {
self.visit_expr(expr);
}
Variable::Constructor(..) => (),
},
// If we can't resolve the identifier to an expression it is a primitive function
// which can be impure
// FIXME Let primitive functions mark themselves as pure somehow
None => self.0 = false,
}
None
}
_ => walk_expr_alloc(self, expr),
}
}
fn detach_allocator(&self) -> Option<&'e Allocator<'e>> {
fn detach_allocator(&self) -> Option<&'l Allocator<'l>> {
None
}
}
Expand Down Expand Up @@ -153,6 +167,25 @@ impl<'l, 'g> FunctionEnvs<'l, 'g> {
compiler.stack_constructors.exit_scope();
self.envs.pop().expect("FunctionEnv in scope")
}

fn push_stack_var(
&mut self,
compiler: &Compiler<'g, 'l>,
s: Symbol,
expr: ReducedExpr<'l, 'g>,
) {
let expr = {
let mut p = Pure(true, compiler, self);
p.visit_expr(expr.as_ref());
// Only allow pure expression to be folded
if p.0 {
Some(expr)
} else {
None
}
};
self.new_stack_var(s, expr)
}
}

impl<'l, 'g> FunctionEnv<'l, 'g> {
Expand All @@ -166,21 +199,7 @@ impl<'l, 'g> FunctionEnv<'l, 'g> {
self.new_stack_var(s, None)
}

fn push_stack_var(&mut self, s: Symbol, expr: ReducedExpr<'l, 'g>) {
self.new_stack_var(s, Some(expr))
}

fn new_stack_var(&mut self, s: Symbol, mut expr: Option<ReducedExpr<'l, 'g>>) {
expr = expr.and_then(|expr| {
// Only allow pure expression to be folded
let mut p = Pure(true);
p.visit_expr(expr.as_ref());
if p.0 {
Some(expr)
} else {
None
}
});
fn new_stack_var(&mut self, s: Symbol, expr: Option<ReducedExpr<'l, 'g>>) {
self.stack.insert(s, expr);
}

Expand Down Expand Up @@ -383,6 +402,7 @@ impl<'a, 'e> Compiler<'a, 'e> {
core::Named::Expr(ref bind_expr) => {
let reduced = self.compile(bind_expr, function)?;
function.push_stack_var(
self,
let_binding.name.name.clone(),
reduced.unwrap_or(ReducedExpr::Local(bind_expr)),
);
Expand All @@ -407,6 +427,62 @@ impl<'a, 'e> Compiler<'a, 'e> {
}
return Ok(TailCall::Tail(body));
}
Expr::Call(f, args) if args.len() == 2 => {
let f = self.compile(f, function)?.unwrap_or(ReducedExpr::Local(f));
match *f.as_ref() {
Expr::Ident(ref id, ..) if id.name.as_ref().starts_with("#") => {
macro_rules! binop {
($id: expr) => { {
let f: fn (_, _) -> _ = match $id.name.as_ref().chars().last().unwrap() {
'+' => |l, r| l + r,
'-' => |l, r| l - r,
'*' => |l, r| l * r,
'/' => |l, r| l / r,
_ => return Err(format!("Invalid binop `{}`", id.name).into()),
};
f
} }
}

let l = self.compile(&args[0], function)?;
let l = match l {
Some(l) => l,
None => return self.walk_expr(expr, function).map(TailCall::Value),
};
let r = self.compile(&args[1], function)?;
let r = match r {
Some(r) => r,
None => return self.walk_expr(expr, function).map(TailCall::Value),
};
match (l.as_ref(), r.as_ref()) {
(
&Expr::Const(Literal::Int(l), ..),
&Expr::Const(Literal::Int(r), ..),
) => {
let f = binop!(id);
Some(ReducedExpr::Local(
self.allocator
.arena
.alloc(Expr::Const(Literal::Int(f(l, r)), expr.span())),
))
}
(
&Expr::Const(Literal::Float(l), ..),
&Expr::Const(Literal::Float(r), ..),
) => {
let f = binop!(id);
Some(ReducedExpr::Local(
self.allocator
.arena
.alloc(Expr::Const(Literal::Float(f(l, r)), expr.span())),
))
}
_ => None,
}
}
_ => self.walk_expr(expr, function)?,
}
}
Expr::Call(..) => self.walk_expr(expr, function)?,
Expr::Match(expr, alts) => {
let expr = self.compile(expr, function)?
Expand All @@ -423,7 +499,7 @@ impl<'a, 'e> Compiler<'a, 'e> {
self.compile_let_pattern(&alt.pattern, expr, typ, function)?;
}
Pattern::Ident(ref id) => {
function.push_stack_var(id.name.clone(), expr);
function.push_stack_var(self, id.name.clone(), expr);
}
}
let new_expr = self.compile(&alt.expr, function)?
Expand Down Expand Up @@ -462,10 +538,9 @@ impl<'a, 'e> Compiler<'a, 'e> {
) -> Result<()> {
match *pattern {
Pattern::Ident(ref name) => {
function.push_stack_var(name.name.clone(), pattern_expr);
function.push_stack_var(self, name.name.clone(), pattern_expr);
}
Pattern::Record(ref fields) => {
let typ = resolve::remove_aliases(self, pattern_type.clone());
match_reduce!{
pattern_expr, wrap;

Expand All @@ -481,10 +556,10 @@ impl<'a, 'e> Compiler<'a, 'e> {
.as_ref()
.unwrap_or(&pattern_field.0.name)
.clone();
function.push_stack_var(field_name, wrap(&exprs[field]));
function.push_stack_var(self, field_name, wrap(&exprs[field]));
}
},
_ => panic!("Expected record, got {} at {:?}", typ, pattern)
_ => panic!("Expected record, got `{}` at {:?}", pattern_expr.as_ref(), pattern)
}
}
Pattern::Constructor(..) => panic!("constructor pattern in let"),
Expand Down Expand Up @@ -516,10 +591,15 @@ impl<'a, 'e> Compiler<'a, 'e> {

#[cfg(test)]
mod tests {
extern crate gluon_parser as parser;

use super::*;

use base::symbol::Symbols;
use base::symbol::{SymbolModule, Symbols};

use self::parser::parse_expr;

use thread::RootedThread;
use core::*;
use core::grammar::parse_Expr as parse_core_expr;

Expand Down Expand Up @@ -593,4 +673,14 @@ mod tests {
x"#;
assert_eq_expr!(expr, "1", move |_: &_| -> Option<CExpr> { Some(global) });
}

#[test]
fn fold_primitive_op() {
let _ = ::env_logger::init();

let expr = r#"
(#Int+) 1 2
"#;
assert_eq_expr!(expr, "3");
}
}
19 changes: 12 additions & 7 deletions vm/src/core/mod.rs
Expand Up @@ -33,6 +33,7 @@ mod grammar;
pub mod optimize;
pub mod interpreter;

use std::borrow::Cow;
use std::collections::HashMap;
use std::fmt;
use std::iter::once;
Expand Down Expand Up @@ -782,13 +783,17 @@ fn get_return_type(env: &TypeEnv, alias_type: &ArcType, arg_count: usize) -> Arc
}
let function_type = remove_aliases_cow(env, alias_type);

let (_, ret) = function_type.as_function().unwrap_or_else(|| {
panic!(
"Call expression with a non function type `{}`",
function_type
)
});
get_return_type(env, ret, arg_count - 1)
let ret = function_type
.as_function()
.map(|t| Cow::Borrowed(t.1))
.unwrap_or_else(|| {
debug!(
"Call expression with a non function type `{}`",
function_type
);
Cow::Owned(Type::hole())
});
get_return_type(env, &ret, arg_count - 1)
}

pub struct PatternTranslator<'a, 'e: 'a>(&'a Translator<'a, 'e>);
Expand Down

0 comments on commit 10b77ac

Please sign in to comment.