From e9909f97b70d5ed76f7bfdab8ec8cd9403b0aa76 Mon Sep 17 00:00:00 2001 From: Lieuwe Rooijakkers Date: Wed, 5 Feb 2020 17:25:20 +0100 Subject: [PATCH] pad operator --- src/ast.rs | 2 ++ src/executor/executor.rs | 30 ++++++++++++++++++++++++++++++ src/parser.rs | 6 ++++-- 3 files changed, 36 insertions(+), 2 deletions(-) diff --git a/src/ast.rs b/src/ast.rs index 125d63f..ebab807 100644 --- a/src/ast.rs +++ b/src/ast.rs @@ -94,6 +94,7 @@ pub enum BinOp { In, Max, Min, + Pad, } impl fmt::Display for BinOp { @@ -117,6 +118,7 @@ impl fmt::Display for BinOp { BinOp::In => write!(f, "in"), BinOp::Max => write!(f, "max"), BinOp::Min => write!(f, "min"), + BinOp::Pad => write!(f, "pad"), } } } diff --git a/src/executor/executor.rs b/src/executor/executor.rs index 0bb288f..404987d 100644 --- a/src/executor/executor.rs +++ b/src/executor/executor.rs @@ -282,6 +282,36 @@ fn call_binary(op: BinOp, a: Matrix, b: Matrix) -> Result apply_ok!(|a: &Ratio, b: &Ratio| if b > a { b.clone() } else { a.clone() }), BinOp::Min => apply_ok!(|a: &Ratio, b: &Ratio| if b < a { b.clone() } else { a.clone() }), + + BinOp::Pad => { + let a = expect_vector(ExecutorResult::Value(Value::Matrix(a)))?; + if a.len() != 2 { + return Err(format!("expected 2 arguments on the left, got {}", a.len())); + } + + let mut it = a.into_iter(); + let amount = it.next().unwrap().to_integer(); + let number = it.next().unwrap(); + + let (amount, at_start) = match amount.sign() { + Sign::Minus => (amount.neg(), true), + _ => (amount, false), + }; + let mut values: Vec<_> = b.values; + let mut to_add = vec![number; amount.to_usize().unwrap()]; + if at_start { + to_add.append(&mut values); + values = to_add; + } else { + values.append(&mut to_add); + } + + Ok(Matrix { + values, + shape: b.shape, + } + .into()) + } } } diff --git a/src/parser.rs b/src/parser.rs index c91e738..b563788 100644 --- a/src/parser.rs +++ b/src/parser.rs @@ -51,6 +51,7 @@ fn binary_op<'a>() -> Parser<'a, BinOp> { | symbol_both(operator("in")).map(|_| BinOp::In) | symbol_both(operator("max")).map(|_| BinOp::Max) | symbol_both(operator("min")).map(|_| BinOp::Min) + | symbol_both(operator("pad")).map(|_| BinOp::Pad) | operator("**").map(|_| BinOp::Pow) | operator("*").map(|_| BinOp::Mul) | operator("/").map(|_| BinOp::Div) @@ -62,7 +63,7 @@ fn binary_op<'a>() -> Parser<'a, BinOp> { fn check_reserved(s: String) -> Result { let reserved = vec![ - "skip", "rho", "unpack", "pack", "log", "iota", "abs", "rev", "in", "max", "min", + "skip", "rho", "unpack", "pack", "log", "iota", "abs", "rev", "in", "max", "min", "pad", ]; if reserved.contains(&s.as_str()) { @@ -348,7 +349,8 @@ fn p_expr_9<'a>() -> Parser<'a, Expr> { | symbol_both(operator(",")).map(|_| BinOp::Concat) | symbol_both(operator("in")).map(|_| BinOp::In) | symbol_both(operator("max")).map(|_| BinOp::Max) - | symbol_both(operator("min")).map(|_| BinOp::Min); + | symbol_both(operator("min")).map(|_| BinOp::Min) + | symbol_both(operator("pad")).map(|_| BinOp::Pad); right_recurse(p_expr_8, op_bin, "special binary", |e1, op, e2| { Expr::Binary(Box::new(e1), op, Box::new(e2))