diff --git a/vm/src/core/grammar.lalrpop b/vm/src/core/grammar.lalrpop index 967e422910..820829b403 100644 --- a/vm/src/core/grammar.lalrpop +++ b/vm/src/core/grammar.lalrpop @@ -16,7 +16,8 @@ Comma: Vec = Identifier: Symbol = { - => symbols.symbol(<>) + => symbols.symbol(<>), + => symbols.symbol(&<>[1..<>.len() - 1]) }; Field: (Symbol, Option) = { diff --git a/vm/src/core/interpreter.rs b/vm/src/core/interpreter.rs index 77ced25110..cb8beb1c04 100644 --- a/vm/src/core/interpreter.rs +++ b/vm/src/core/interpreter.rs @@ -1,5 +1,5 @@ use std::ops::{Deref, DerefMut}; -use base::ast::{Typed, TypedIdent}; +use base::ast::{Literal, Typed, TypedIdent}; use base::fnv::FnvSet; use base::resolve; use base::kind::{ArcKind, KindEnv}; @@ -80,25 +80,39 @@ impl<'e> Visitor<'e, 'e> for FreeVars { } } -pub struct Pure(bool); +pub struct Pure<'a, 'l: 'a, 'g: 'a>(bool, &'a Compiler<'g, 'l>, &'a mut FunctionEnvs<'l, 'g>); -impl<'e> Visitor<'e, 'e> for Pure { - type Producer = SameLifetime<'e>; +impl<'a, 'l, 'g, 'expr> Visitor<'l, 'expr> for Pure<'a, 'l, 'g> { + type Producer = DifferentLifetime<'l, 'expr>; - fn visit_expr(&mut self, expr: CExpr<'e>) -> Option> { + fn visit_expr(&mut self, expr: CExpr<'expr>) -> Option> { if !self.0 { return None; } match *expr { - Expr::Call(..) => { - // FIXME Don't treat all function calls as impure - self.0 = false; + Expr::Ident(ref id, ..) => { + match self.1.find(&id.name, self.2) { + Some(variable) => match variable { + Variable::Stack(Some(expr)) => { + self.visit_expr(expr.as_ref()); + } + Variable::Stack(None) => self.0 = false, + Variable::Global(expr) => { + self.visit_expr(expr); + } + Variable::Constructor(..) => (), + }, + // If we can't resolve the identifier to an expression it is a primitive function + // which can be impure + // FIXME Let primitive functions mark themselves as pure somehow + None => self.0 = false, + } None } _ => walk_expr_alloc(self, expr), } } - fn detach_allocator(&self) -> Option<&'e Allocator<'e>> { + fn detach_allocator(&self) -> Option<&'l Allocator<'l>> { None } } @@ -153,6 +167,25 @@ impl<'l, 'g> FunctionEnvs<'l, 'g> { compiler.stack_constructors.exit_scope(); self.envs.pop().expect("FunctionEnv in scope") } + + fn push_stack_var( + &mut self, + compiler: &Compiler<'g, 'l>, + s: Symbol, + expr: ReducedExpr<'l, 'g>, + ) { + let expr = { + let mut p = Pure(true, compiler, self); + p.visit_expr(expr.as_ref()); + // Only allow pure expression to be folded + if p.0 { + Some(expr) + } else { + None + } + }; + self.new_stack_var(s, expr) + } } impl<'l, 'g> FunctionEnv<'l, 'g> { @@ -166,21 +199,7 @@ impl<'l, 'g> FunctionEnv<'l, 'g> { self.new_stack_var(s, None) } - fn push_stack_var(&mut self, s: Symbol, expr: ReducedExpr<'l, 'g>) { - self.new_stack_var(s, Some(expr)) - } - - fn new_stack_var(&mut self, s: Symbol, mut expr: Option>) { - expr = expr.and_then(|expr| { - // Only allow pure expression to be folded - let mut p = Pure(true); - p.visit_expr(expr.as_ref()); - if p.0 { - Some(expr) - } else { - None - } - }); + fn new_stack_var(&mut self, s: Symbol, expr: Option>) { self.stack.insert(s, expr); } @@ -383,6 +402,7 @@ impl<'a, 'e> Compiler<'a, 'e> { core::Named::Expr(ref bind_expr) => { let reduced = self.compile(bind_expr, function)?; function.push_stack_var( + self, let_binding.name.name.clone(), reduced.unwrap_or(ReducedExpr::Local(bind_expr)), ); @@ -407,6 +427,62 @@ impl<'a, 'e> Compiler<'a, 'e> { } return Ok(TailCall::Tail(body)); } + Expr::Call(f, args) if args.len() == 2 => { + let f = self.compile(f, function)?.unwrap_or(ReducedExpr::Local(f)); + match *f.as_ref() { + Expr::Ident(ref id, ..) if id.name.as_ref().starts_with("#") => { + macro_rules! binop { + ($id: expr) => { { + let f: fn (_, _) -> _ = match $id.name.as_ref().chars().last().unwrap() { + '+' => |l, r| l + r, + '-' => |l, r| l - r, + '*' => |l, r| l * r, + '/' => |l, r| l / r, + _ => return Err(format!("Invalid binop `{}`", id.name).into()), + }; + f + } } + } + + let l = self.compile(&args[0], function)?; + let l = match l { + Some(l) => l, + None => return self.walk_expr(expr, function).map(TailCall::Value), + }; + let r = self.compile(&args[1], function)?; + let r = match r { + Some(r) => r, + None => return self.walk_expr(expr, function).map(TailCall::Value), + }; + match (l.as_ref(), r.as_ref()) { + ( + &Expr::Const(Literal::Int(l), ..), + &Expr::Const(Literal::Int(r), ..), + ) => { + let f = binop!(id); + Some(ReducedExpr::Local( + self.allocator + .arena + .alloc(Expr::Const(Literal::Int(f(l, r)), expr.span())), + )) + } + ( + &Expr::Const(Literal::Float(l), ..), + &Expr::Const(Literal::Float(r), ..), + ) => { + let f = binop!(id); + Some(ReducedExpr::Local( + self.allocator + .arena + .alloc(Expr::Const(Literal::Float(f(l, r)), expr.span())), + )) + } + _ => None, + } + } + _ => self.walk_expr(expr, function)?, + } + } Expr::Call(..) => self.walk_expr(expr, function)?, Expr::Match(expr, alts) => { let expr = self.compile(expr, function)? @@ -423,7 +499,7 @@ impl<'a, 'e> Compiler<'a, 'e> { self.compile_let_pattern(&alt.pattern, expr, typ, function)?; } Pattern::Ident(ref id) => { - function.push_stack_var(id.name.clone(), expr); + function.push_stack_var(self, id.name.clone(), expr); } } let new_expr = self.compile(&alt.expr, function)? @@ -462,10 +538,9 @@ impl<'a, 'e> Compiler<'a, 'e> { ) -> Result<()> { match *pattern { Pattern::Ident(ref name) => { - function.push_stack_var(name.name.clone(), pattern_expr); + function.push_stack_var(self, name.name.clone(), pattern_expr); } Pattern::Record(ref fields) => { - let typ = resolve::remove_aliases(self, pattern_type.clone()); match_reduce!{ pattern_expr, wrap; @@ -481,10 +556,10 @@ impl<'a, 'e> Compiler<'a, 'e> { .as_ref() .unwrap_or(&pattern_field.0.name) .clone(); - function.push_stack_var(field_name, wrap(&exprs[field])); + function.push_stack_var(self, field_name, wrap(&exprs[field])); } }, - _ => panic!("Expected record, got {} at {:?}", typ, pattern) + _ => panic!("Expected record, got `{}` at {:?}", pattern_expr.as_ref(), pattern) } } Pattern::Constructor(..) => panic!("constructor pattern in let"), @@ -516,10 +591,15 @@ impl<'a, 'e> Compiler<'a, 'e> { #[cfg(test)] mod tests { + extern crate gluon_parser as parser; + use super::*; - use base::symbol::Symbols; + use base::symbol::{SymbolModule, Symbols}; + + use self::parser::parse_expr; + use thread::RootedThread; use core::*; use core::grammar::parse_Expr as parse_core_expr; @@ -593,4 +673,14 @@ mod tests { x"#; assert_eq_expr!(expr, "1", move |_: &_| -> Option { Some(global) }); } + + #[test] + fn fold_primitive_op() { + let _ = ::env_logger::init(); + + let expr = r#" + (#Int+) 1 2 + "#; + assert_eq_expr!(expr, "3"); + } } diff --git a/vm/src/core/mod.rs b/vm/src/core/mod.rs index b007869afd..875c31131a 100644 --- a/vm/src/core/mod.rs +++ b/vm/src/core/mod.rs @@ -33,6 +33,7 @@ mod grammar; pub mod optimize; pub mod interpreter; +use std::borrow::Cow; use std::collections::HashMap; use std::fmt; use std::iter::once; @@ -782,13 +783,17 @@ fn get_return_type(env: &TypeEnv, alias_type: &ArcType, arg_count: usize) -> Arc } let function_type = remove_aliases_cow(env, alias_type); - let (_, ret) = function_type.as_function().unwrap_or_else(|| { - panic!( - "Call expression with a non function type `{}`", - function_type - ) - }); - get_return_type(env, ret, arg_count - 1) + let ret = function_type + .as_function() + .map(|t| Cow::Borrowed(t.1)) + .unwrap_or_else(|| { + debug!( + "Call expression with a non function type `{}`", + function_type + ); + Cow::Owned(Type::hole()) + }); + get_return_type(env, &ret, arg_count - 1) } pub struct PatternTranslator<'a, 'e: 'a>(&'a Translator<'a, 'e>);