diff --git a/jit/src/instructions.rs b/jit/src/instructions.rs index cd75b5e4fe..7b3ec63ab1 100644 --- a/jit/src/instructions.rs +++ b/jit/src/instructions.rs @@ -1,7 +1,7 @@ use cranelift::prelude::*; use num_traits::cast::ToPrimitive; use rustpython_bytecode::bytecode::{ - BinaryOperator, CodeObject, ComparisonOperator, Constant, Instruction, Label, NameScope, + UnaryOperator, BinaryOperator, CodeObject, ComparisonOperator, Constant, Instruction, Label, NameScope, }; use std::collections::HashMap; @@ -233,6 +233,36 @@ impl<'a, 'b> FunctionCompiler<'a, 'b> { _ => Err(JitCompileError::NotSupported), } } + Instruction::UnaryOperation { op, .. } => { + let a = self.stack.pop().ok_or(JitCompileError::BadBytecode)?; + + match a.ty { + JitType::Int => match op { + UnaryOperator::Minus => { + // Compile minus as 0 - a. + let zero = self.builder.ins().iconst(types::I64, 0); + let (out, carry) = self.builder.ins().isub_ifbout(zero, a.val); + self.builder.ins().trapif( + IntCC::Overflow, + carry, + TrapCode::IntegerOverflow, + ); + self.stack.push(JitValue { + val: out, + ty: JitType::Int, + }); + Ok(()) + } + UnaryOperator::Plus => { + // Nothing to do + self.stack.push(a); + Ok(()) + } + _ => Err(JitCompileError::NotSupported), + }, + _ => Err(JitCompileError::NotSupported), + } + } Instruction::BinaryOperation { op, .. } => { // the rhs is popped off first let b = self.stack.pop().ok_or(JitCompileError::BadBytecode)?; diff --git a/jit/tests/int_tests.rs b/jit/tests/int_tests.rs index 0725ed064a..3b94b7a78b 100644 --- a/jit/tests/int_tests.rs +++ b/jit/tests/int_tests.rs @@ -54,3 +54,31 @@ fn test_gt() { assert_eq!(gt(-1, -10), Ok(1)); assert_eq!(gt(1, -1), Ok(1)); } + +#[test] +fn test_minus() { + let minus = jit_function! { minus(a:i64) -> i64 => r##" + def minus(a: int): + return -a + "## }; + + assert_eq!(minus(5), Ok(-5)); + assert_eq!(minus(12), Ok(-12)); + assert_eq!(minus(-7), Ok(7)); + assert_eq!(minus(-3), Ok(3)); + assert_eq!(minus(0), Ok(0)); +} + +#[test] +fn test_plus() { + let plus = jit_function! { plus(a:i64) -> i64 => r##" + def plus(a: int): + return +a + "## }; + + assert_eq!(plus(5), Ok(5)); + assert_eq!(plus(12), Ok(12)); + assert_eq!(plus(-7), Ok(-7)); + assert_eq!(plus(-3), Ok(-3)); + assert_eq!(plus(0), Ok(0)); +}