Skip to content

Commit

Permalink
feat: decompose Instruction::Cast to have an explicit truncation in…
Browse files Browse the repository at this point in the history
…struction (#3946)

# Description

## Problem\*

Resolves #3749

## Summary\*

This PR removes the implicit `Instruction::Truncate` which lives inside
`Instruction::Cast` so that it's purely responsible for casting the type
of the value in SSA rather than massaging the input to fit in the
desired type.

## Additional Context



## Documentation\*

Check one:
- [ ] No documentation needed.
- [ ] Documentation included in this PR.
- [ ] **[Exceptional Case]** Documentation to be submitted in a separate
PR.

# PR Checklist\*

- [ ] I have tested the changes locally.
- [ ] I have formatted the changes with [Prettier](https://prettier.io/)
and/or `cargo fmt` on default settings.

---------

Co-authored-by: jfecher <jake@aztecprotocol.com>
  • Loading branch information
TomAFrench and jfecher committed Jan 8, 2024
1 parent c2acdf1 commit 35f18ef
Show file tree
Hide file tree
Showing 8 changed files with 105 additions and 76 deletions.
41 changes: 3 additions & 38 deletions compiler/noirc_evaluator/src/ssa/acir_gen/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -472,9 +472,9 @@ impl Context {
self.acir_context.assert_eq_var(lhs, rhs, assert_message.clone())?;
}
}
Instruction::Cast(value_id, typ) => {
let result_acir_var = self.convert_ssa_cast(value_id, typ, dfg)?;
self.define_result_var(dfg, instruction_id, result_acir_var);
Instruction::Cast(value_id, _) => {
let acir_var = self.convert_numeric_value(*value_id, dfg)?;
self.define_result_var(dfg, instruction_id, acir_var);
}
Instruction::Call { func, arguments } => {
let result_ids = dfg.instruction_results(instruction_id);
Expand Down Expand Up @@ -1636,41 +1636,6 @@ impl Context {
}
}

/// Returns an `AcirVar` that is constrained to fit in the target type by truncating the input.
/// If the target cast is to a `NativeField`, no truncation is required so the cast becomes a
/// no-op.
fn convert_ssa_cast(
&mut self,
value_id: &ValueId,
typ: &Type,
dfg: &DataFlowGraph,
) -> Result<AcirVar, RuntimeError> {
let (variable, incoming_type) = match self.convert_value(*value_id, dfg) {
AcirValue::Var(variable, typ) => (variable, typ),
AcirValue::DynamicArray(_) | AcirValue::Array(_) => {
unreachable!("Cast is only applied to numerics")
}
};
let target_numeric = match typ {
Type::Numeric(numeric) => numeric,
_ => unreachable!("Can only cast to a numeric"),
};
match target_numeric {
NumericType::NativeField => {
// Casting into a Field as a no-op
Ok(variable)
}
NumericType::Unsigned { bit_size } | NumericType::Signed { bit_size } => {
let max_bit_size = incoming_type.bit_size();
if max_bit_size <= *bit_size {
// Incoming variable already fits into target bit size - this is a no-op
return Ok(variable);
}
self.acir_context.truncate_var(variable, *bit_size, max_bit_size)
}
}
}

/// Returns an `AcirVar`that is constrained to be result of the truncation.
fn convert_ssa_truncate(
&mut self,
Expand Down
2 changes: 1 addition & 1 deletion compiler/noirc_evaluator/src/ssa/ir/dfg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ impl DataFlowGraph {
call_stack: CallStack,
) -> InsertInstructionResult {
use InsertInstructionResult::*;
match instruction.simplify(self, block, ctrl_typevars.clone()) {
match instruction.simplify(self, block, ctrl_typevars.clone(), &call_stack) {
SimplifyResult::SimplifiedTo(simplification) => SimplifiedTo(simplification),
SimplifyResult::SimplifiedToMultiple(simplification) => {
SimplifiedToMultiple(simplification)
Expand Down
8 changes: 3 additions & 5 deletions compiler/noirc_evaluator/src/ssa/ir/instruction.rs
Original file line number Diff line number Diff line change
Expand Up @@ -236,10 +236,7 @@ impl Instruction {
// In ACIR, a division with a false predicate outputs (0,0), so it cannot replace another instruction unless they have the same predicate
bin.operator != BinaryOp::Div
}
Cast(_, _) | Not(_) | ArrayGet { .. } | ArraySet { .. } => true,

// Unclear why this instruction causes problems.
Truncate { .. } => false,
Cast(_, _) | Truncate { .. } | Not(_) | ArrayGet { .. } | ArraySet { .. } => true,

// These either have side-effects or interact with memory
Constrain(..)
Expand Down Expand Up @@ -408,6 +405,7 @@ impl Instruction {
dfg: &mut DataFlowGraph,
block: BasicBlockId,
ctrl_typevars: Option<Vec<Type>>,
call_stack: &CallStack,
) -> SimplifyResult {
use SimplifyResult::*;
match self {
Expand Down Expand Up @@ -551,7 +549,7 @@ impl Instruction {
}
}
Instruction::Call { func, arguments } => {
simplify_call(*func, arguments, dfg, block, ctrl_typevars)
simplify_call(*func, arguments, dfg, block, ctrl_typevars, call_stack)
}
Instruction::EnableSideEffects { condition } => {
if let Some(last) = dfg[block].instructions().last().copied() {
Expand Down
20 changes: 19 additions & 1 deletion compiler/noirc_evaluator/src/ssa/ir/instruction/call.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ pub(super) fn simplify_call(
dfg: &mut DataFlowGraph,
block: BasicBlockId,
ctrl_typevars: Option<Vec<Type>>,
call_stack: &CallStack,
) -> SimplifyResult {
let intrinsic = match &dfg[func] {
Value::Intrinsic(intrinsic) => *intrinsic,
Expand Down Expand Up @@ -242,7 +243,24 @@ pub(super) fn simplify_call(
SimplifyResult::SimplifiedToInstruction(instruction)
}
Intrinsic::FromField => {
let instruction = Instruction::Cast(arguments[0], ctrl_typevars.unwrap().remove(0));
let incoming_type = Type::field();
let target_type = ctrl_typevars.unwrap().remove(0);

let truncate = Instruction::Truncate {
value: arguments[0],
bit_size: target_type.bit_size(),
max_bit_size: incoming_type.bit_size(),
};
let truncated_value = dfg
.insert_instruction_and_results(
truncate,
block,
Some(vec![incoming_type]),
call_stack.clone(),
)
.first();

let instruction = Instruction::Cast(truncated_value, target_type);
SimplifyResult::SimplifiedToInstruction(instruction)
}
}
Expand Down
48 changes: 34 additions & 14 deletions compiler/noirc_evaluator/src/ssa/ir/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,28 @@ pub enum NumericType {
NativeField,
}

impl NumericType {
/// Returns the bit size of the provided numeric type.
pub(crate) fn bit_size(self: &NumericType) -> u32 {
match self {
NumericType::NativeField => FieldElement::max_num_bits(),
NumericType::Unsigned { bit_size } | NumericType::Signed { bit_size } => *bit_size,
}
}

/// Returns true if the given Field value is within the numeric limits
/// for the current NumericType.
pub(crate) fn value_is_within_limits(self, field: FieldElement) -> bool {
match self {
NumericType::Signed { bit_size } | NumericType::Unsigned { bit_size } => {
let max = 2u128.pow(bit_size) - 1;
field <= max.into()
}
NumericType::NativeField => true,
}
}
}

/// All types representable in the IR.
#[derive(Clone, Debug, PartialEq, Eq, Hash, Ord, PartialOrd)]
pub(crate) enum Type {
Expand Down Expand Up @@ -68,6 +90,18 @@ impl Type {
Type::Numeric(NumericType::NativeField)
}

/// Returns the bit size of the provided numeric type.
///
/// # Panics
///
/// Panics if `self` is not a [`Type::Numeric`]
pub(crate) fn bit_size(&self) -> u32 {
match self {
Type::Numeric(numeric_type) => numeric_type.bit_size(),
other => panic!("bit_size: Expected numeric type, found {other}"),
}
}

/// Returns the size of the element type for this array/slice.
/// The size of a type is defined as representing how many Fields are needed
/// to represent the type. This is 1 for every primitive type, and is the number of fields
Expand Down Expand Up @@ -122,20 +156,6 @@ impl Type {
}
}

impl NumericType {
/// Returns true if the given Field value is within the numeric limits
/// for the current NumericType.
pub(crate) fn value_is_within_limits(self, field: FieldElement) -> bool {
match self {
NumericType::Signed { bit_size } | NumericType::Unsigned { bit_size } => {
let max = 2u128.pow(bit_size) - 1;
field <= max.into()
}
NumericType::NativeField => true,
}
}
}

/// Composite Types are essentially flattened struct or tuple types.
/// Array types may have these as elements where each flattened field is
/// included in the array sequentially.
Expand Down
10 changes: 5 additions & 5 deletions compiler/noirc_evaluator/src/ssa/opt/constant_folding.rs
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ mod test {
instruction::{BinaryOp, Instruction, TerminatorInstruction},
map::Id,
types::Type,
value::{Value, ValueId},
value::Value,
},
};

Expand Down Expand Up @@ -293,7 +293,7 @@ mod test {
#[test]
fn instruction_deduplication() {
// fn main f0 {
// b0(v0: Field):
// b0(v0: u16):
// v1 = cast v0 as u32
// v2 = cast v0 as u32
// constrain v1 v2
Expand All @@ -308,7 +308,7 @@ mod test {

// Compiling main
let mut builder = FunctionBuilder::new("main".into(), main_id, RuntimeType::Acir);
let v0 = builder.add_parameter(Type::field());
let v0 = builder.add_parameter(Type::unsigned(16));

let v1 = builder.insert_cast(v0, Type::unsigned(32));
let v2 = builder.insert_cast(v0, Type::unsigned(32));
Expand All @@ -322,7 +322,7 @@ mod test {
// Expected output:
//
// fn main f0 {
// b0(v0: Field):
// b0(v0: u16):
// v1 = cast v0 as u32
// }
let ssa = ssa.fold_constants();
Expand All @@ -332,6 +332,6 @@ mod test {
assert_eq!(instructions.len(), 1);
let instruction = &main.dfg[instructions[0]];

assert_eq!(instruction, &Instruction::Cast(ValueId::test_new(0), Type::unsigned(32)));
assert_eq!(instruction, &Instruction::Cast(v0, Type::unsigned(32)));
}
}
48 changes: 38 additions & 10 deletions compiler/noirc_evaluator/src/ssa/ssa_gen/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -275,6 +275,8 @@ impl<'a> FunctionContext<'a> {
let bit_width =
self.builder.numeric_constant(FieldElement::from(2_i128.pow(bit_size)), Type::field());
let sign_not = self.builder.insert_binary(one, BinaryOp::Sub, sign);

// We use unsafe casts here, this is fine as we're casting to a `field` type.
let as_field = self.builder.insert_cast(input, Type::field());
let sign_field = self.builder.insert_cast(sign, Type::field());
let positive_predicate = self.builder.insert_binary(sign_field, BinaryOp::Mul, as_field);
Expand Down Expand Up @@ -310,12 +312,12 @@ impl<'a> FunctionContext<'a> {
match operator {
BinaryOpKind::Add | BinaryOpKind::Subtract => {
// Result is computed modulo the bit size
let mut result =
self.builder.insert_truncate(result, bit_size, bit_size + 1);
result = self.builder.insert_cast(result, Type::unsigned(bit_size));
let result = self.builder.insert_truncate(result, bit_size, bit_size + 1);
let result =
self.insert_safe_cast(result, Type::unsigned(bit_size), location);

self.check_signed_overflow(result, lhs, rhs, operator, bit_size, location);
self.builder.insert_cast(result, result_type)
self.insert_safe_cast(result, result_type, location)
}
BinaryOpKind::Multiply => {
// Result is computed modulo the bit size
Expand All @@ -324,7 +326,7 @@ impl<'a> FunctionContext<'a> {
result = self.builder.insert_truncate(result, bit_size, 2 * bit_size);

self.check_signed_overflow(result, lhs, rhs, operator, bit_size, location);
self.builder.insert_cast(result, result_type)
self.insert_safe_cast(result, result_type, location)
}
BinaryOpKind::ShiftLeft | BinaryOpKind::ShiftRight => {
self.check_shift_overflow(result, rhs, bit_size, location, true)
Expand Down Expand Up @@ -374,8 +376,11 @@ impl<'a> FunctionContext<'a> {
is_signed: bool,
) -> ValueId {
let one = self.builder.numeric_constant(FieldElement::one(), Type::bool());
let rhs =
if is_signed { self.builder.insert_cast(rhs, Type::unsigned(bit_size)) } else { rhs };
let rhs = if is_signed {
self.insert_safe_cast(rhs, Type::unsigned(bit_size), location)
} else {
rhs
};
// Bit-shift with a negative number is an overflow
if is_signed {
// We compute the sign of rhs.
Expand Down Expand Up @@ -431,8 +436,8 @@ impl<'a> FunctionContext<'a> {
Type::unsigned(bit_size),
);
// We compute the sign of the operands. The overflow checks for signed integers depends on these signs
let lhs_as_unsigned = self.builder.insert_cast(lhs, Type::unsigned(bit_size));
let rhs_as_unsigned = self.builder.insert_cast(rhs, Type::unsigned(bit_size));
let lhs_as_unsigned = self.insert_safe_cast(lhs, Type::unsigned(bit_size), location);
let rhs_as_unsigned = self.insert_safe_cast(rhs, Type::unsigned(bit_size), location);
let lhs_sign = self.builder.insert_binary(lhs_as_unsigned, BinaryOp::Lt, half_width);
let mut rhs_sign = self.builder.insert_binary(rhs_as_unsigned, BinaryOp::Lt, half_width);
let message = if is_sub {
Expand Down Expand Up @@ -473,7 +478,7 @@ impl<'a> FunctionContext<'a> {
// Then we check the signed product fits in a signed integer of bit_size-bits
let not_same = self.builder.insert_binary(one, BinaryOp::Sub, same_sign);
let not_same_sign_field =
self.builder.insert_cast(not_same, Type::unsigned(bit_size));
self.insert_safe_cast(not_same, Type::unsigned(bit_size), location);
let positive_maximum_with_offset =
self.builder.insert_binary(half_width, BinaryOp::Add, not_same_sign_field);
let product_overflow_check =
Expand Down Expand Up @@ -663,6 +668,29 @@ impl<'a> FunctionContext<'a> {
reshaped_return_values
}

/// Inserts a cast instruction at the end of the current block and returns the results
/// of the cast.
///
/// Compared to `self.builder.insert_cast`, this version will automatically truncate `value` to be a valid `typ`.
pub(super) fn insert_safe_cast(
&mut self,
mut value: ValueId,
typ: Type,
location: Location,
) -> ValueId {
self.builder.set_location(location);

// To ensure that `value` is a valid `typ`, we insert an `Instruction::Truncate` instruction beforehand if
// we're narrowing the type size.
let incoming_type_size = self.builder.type_of_value(value).bit_size();
let target_type_size = typ.bit_size();
if target_type_size < incoming_type_size {
value = self.builder.insert_truncate(value, target_type_size, incoming_type_size);
}

self.builder.insert_cast(value, typ)
}

/// Create a const offset of an address for an array load or store
pub(super) fn make_offset(&mut self, mut address: ValueId, offset: u128) -> ValueId {
if offset != 0 {
Expand Down
4 changes: 2 additions & 2 deletions compiler/noirc_evaluator/src/ssa/ssa_gen/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -434,8 +434,8 @@ impl<'a> FunctionContext<'a> {
fn codegen_cast(&mut self, cast: &ast::Cast) -> Result<Values, RuntimeError> {
let lhs = self.codegen_non_tuple_expression(&cast.lhs)?;
let typ = Self::convert_non_tuple_type(&cast.r#type);
self.builder.set_location(cast.location);
Ok(self.builder.insert_cast(lhs, typ).into())

Ok(self.insert_safe_cast(lhs, typ, cast.location).into())
}

/// Codegens a for loop, creating three new blocks in the process.
Expand Down

0 comments on commit 35f18ef

Please sign in to comment.