Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: decompose Instruction::Cast to have an explicit truncation instruction #3946

Merged
merged 16 commits into from
Jan 8, 2024
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 3 additions & 38 deletions compiler/noirc_evaluator/src/ssa/acir_gen/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -472,9 +472,9 @@ impl Context {
self.acir_context.assert_eq_var(lhs, rhs, assert_message.clone())?;
}
}
Instruction::Cast(value_id, typ) => {
let result_acir_var = self.convert_ssa_cast(value_id, typ, dfg)?;
self.define_result_var(dfg, instruction_id, result_acir_var);
Instruction::Cast(value_id, _) => {
let acir_var = self.convert_numeric_value(*value_id, dfg)?;
self.define_result_var(dfg, instruction_id, acir_var);
}
Instruction::Call { func, arguments } => {
let result_ids = dfg.instruction_results(instruction_id);
Expand Down Expand Up @@ -1636,41 +1636,6 @@ impl Context {
}
}

/// Returns an `AcirVar` that is constrained to fit in the target type by truncating the input.
/// If the target cast is to a `NativeField`, no truncation is required so the cast becomes a
/// no-op.
fn convert_ssa_cast(
&mut self,
value_id: &ValueId,
typ: &Type,
dfg: &DataFlowGraph,
) -> Result<AcirVar, RuntimeError> {
let (variable, incoming_type) = match self.convert_value(*value_id, dfg) {
AcirValue::Var(variable, typ) => (variable, typ),
AcirValue::DynamicArray(_) | AcirValue::Array(_) => {
unreachable!("Cast is only applied to numerics")
}
};
let target_numeric = match typ {
Type::Numeric(numeric) => numeric,
_ => unreachable!("Can only cast to a numeric"),
};
match target_numeric {
NumericType::NativeField => {
// Casting into a Field as a no-op
Ok(variable)
}
NumericType::Unsigned { bit_size } | NumericType::Signed { bit_size } => {
let max_bit_size = incoming_type.bit_size();
if max_bit_size <= *bit_size {
// Incoming variable already fits into target bit size - this is a no-op
return Ok(variable);
}
self.acir_context.truncate_var(variable, *bit_size, max_bit_size)
}
}
}

/// Returns an `AcirVar`that is constrained to be result of the truncation.
fn convert_ssa_truncate(
&mut self,
Expand Down
5 changes: 1 addition & 4 deletions compiler/noirc_evaluator/src/ssa/ir/instruction.rs
Original file line number Diff line number Diff line change
Expand Up @@ -236,10 +236,7 @@ impl Instruction {
// In ACIR, a division with a false predicate outputs (0,0), so it cannot replace another instruction unless they have the same predicate
bin.operator != BinaryOp::Div
}
Cast(_, _) | Not(_) | ArrayGet { .. } | ArraySet { .. } => true,

// Unclear why this instruction causes problems.
Truncate { .. } => false,
TomAFrench marked this conversation as resolved.
Show resolved Hide resolved
Cast(_, _) | Truncate { .. } | Not(_) | ArrayGet { .. } | ArraySet { .. } => true,

// These either have side-effects or interact with memory
Constrain(..)
Expand Down
20 changes: 19 additions & 1 deletion compiler/noirc_evaluator/src/ssa/ir/instruction/call.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use std::{collections::VecDeque, rc::Rc};

use acvm::{acir::BlackBoxFunc, BlackBoxResolutionError, FieldElement};
use im::Vector;
use iter_extended::vecmap;
use num_bigint::BigUint;

Expand Down Expand Up @@ -242,7 +243,24 @@
SimplifyResult::SimplifiedToInstruction(instruction)
}
Intrinsic::FromField => {
let instruction = Instruction::Cast(arguments[0], ctrl_typevars.unwrap().remove(0));
let incoming_type = Type::field();
let target_type = ctrl_typevars.unwrap().remove(0);

let truncate = Instruction::Truncate {
value: arguments[0],
bit_size: target_type.bit_size(),
max_bit_size: incoming_type.bit_size(),
};
let truncated_value = dfg
.insert_instruction_and_results(
truncate,
block,
Some(vec![incoming_type]),
Vector::default(), // needs a callstack

Check warning on line 259 in compiler/noirc_evaluator/src/ssa/ir/instruction/call.rs

View workflow job for this annotation

GitHub Actions / Code

Unknown word (callstack)
)
.first();

let instruction = Instruction::Cast(truncated_value, target_type);
SimplifyResult::SimplifiedToInstruction(instruction)
}
}
Expand Down
47 changes: 33 additions & 14 deletions compiler/noirc_evaluator/src/ssa/ir/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,27 @@ pub enum NumericType {
NativeField,
}

impl NumericType {
pub(crate) fn bit_size(self: &NumericType) -> u32 {
TomAFrench marked this conversation as resolved.
Show resolved Hide resolved
match self {
NumericType::NativeField => FieldElement::max_num_bits(),
NumericType::Unsigned { bit_size } | NumericType::Signed { bit_size } => *bit_size,
}
}

/// Returns true if the given Field value is within the numeric limits
/// for the current NumericType.
pub(crate) fn value_is_within_limits(self, field: FieldElement) -> bool {
match self {
NumericType::Signed { bit_size } | NumericType::Unsigned { bit_size } => {
let max = 2u128.pow(bit_size) - 1;
field <= max.into()
}
NumericType::NativeField => true,
}
}
}

/// All types representable in the IR.
#[derive(Clone, Debug, PartialEq, Eq, Hash, Ord, PartialOrd)]
pub(crate) enum Type {
Expand Down Expand Up @@ -68,6 +89,18 @@ impl Type {
Type::Numeric(NumericType::NativeField)
}

/// Returns the bit size of the provided numeric type.
///
/// # Panics
///
/// Panics if `self` is not a [`Type::Numeric`]
pub(crate) fn bit_size(&self) -> u32 {
match self {
Type::Numeric(numeric_type) => numeric_type.bit_size(),
other => panic!("bit_size: Expected numeric type, found {other}"),
}
}

/// Returns the size of the element type for this array/slice.
/// The size of a type is defined as representing how many Fields are needed
/// to represent the type. This is 1 for every primitive type, and is the number of fields
Expand Down Expand Up @@ -122,20 +155,6 @@ impl Type {
}
}

impl NumericType {
/// Returns true if the given Field value is within the numeric limits
/// for the current NumericType.
pub(crate) fn value_is_within_limits(self, field: FieldElement) -> bool {
match self {
NumericType::Signed { bit_size } | NumericType::Unsigned { bit_size } => {
let max = 2u128.pow(bit_size) - 1;
field <= max.into()
}
NumericType::NativeField => true,
}
}
}

/// Composite Types are essentially flattened struct or tuple types.
/// Array types may have these as elements where each flattened field is
/// included in the array sequentially.
Expand Down
12 changes: 6 additions & 6 deletions compiler/noirc_evaluator/src/ssa/opt/constant_folding.rs
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ mod test {
instruction::{BinaryOp, Instruction, TerminatorInstruction},
map::Id,
types::Type,
value::{Value, ValueId},
value::Value,
},
};

Expand Down Expand Up @@ -293,7 +293,7 @@ mod test {
#[test]
fn instruction_deduplication() {
// fn main f0 {
// b0(v0: Field):
// b0(v0: u16):
// v1 = cast v0 as u32
// v2 = cast v0 as u32
// constrain v1 v2
Expand All @@ -308,7 +308,7 @@ mod test {

// Compiling main
let mut builder = FunctionBuilder::new("main".into(), main_id, RuntimeType::Acir);
let v0 = builder.add_parameter(Type::field());
let v0 = builder.add_parameter(Type::unsigned(16));

let v1 = builder.insert_cast(v0, Type::unsigned(32));
let v2 = builder.insert_cast(v0, Type::unsigned(32));
Expand All @@ -322,8 +322,8 @@ mod test {
// Expected output:
//
// fn main f0 {
// b0(v0: Field):
// v1 = cast v0 as u32
// b0(v0: u16):
// v1 = cast v1 as u32
TomAFrench marked this conversation as resolved.
Show resolved Hide resolved
// }
let ssa = ssa.fold_constants();
let main = ssa.main();
Expand All @@ -332,6 +332,6 @@ mod test {
assert_eq!(instructions.len(), 1);
let instruction = &main.dfg[instructions[0]];

assert_eq!(instruction, &Instruction::Cast(ValueId::test_new(0), Type::unsigned(32)));
assert_eq!(instruction, &Instruction::Cast(v0, Type::unsigned(32)));
}
}
25 changes: 25 additions & 0 deletions compiler/noirc_evaluator/src/ssa/ssa_gen/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -665,6 +665,31 @@ impl<'a> FunctionContext<'a> {
reshaped_return_values
}

/// Inserts a cast instruction at the end of the current block and returns the results
/// of the cast.
///
/// Compared to `self.builder.insert_cast`, this version will automatically truncate `value` to be a valid `typ`.
pub(super) fn insert_safe_cast(
&mut self,
value: ValueId,
jfecher marked this conversation as resolved.
Show resolved Hide resolved
typ: Type,
location: Location,
) -> Result<Values, RuntimeError> {
self.builder.set_location(location);

// To ensure that `value` is a valid `typ`, we insert an `Instruction::Truncate` instruction beforehand if
// we're narrowing the type size.
let incoming_type_size = self.builder.type_of_value(value).bit_size();
let target_type_size = typ.bit_size();
let truncated_value = if target_type_size < incoming_type_size {
self.builder.insert_truncate(value, target_type_size, incoming_type_size)
} else {
value
};

Ok(self.builder.insert_cast(truncated_value, typ).into())
}

/// Create a const offset of an address for an array load or store
pub(super) fn make_offset(&mut self, mut address: ValueId, offset: u128) -> ValueId {
if offset != 0 {
Expand Down
4 changes: 2 additions & 2 deletions compiler/noirc_evaluator/src/ssa/ssa_gen/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -434,8 +434,8 @@
fn codegen_cast(&mut self, cast: &ast::Cast) -> Result<Values, RuntimeError> {
let lhs = self.codegen_non_tuple_expression(&cast.lhs)?;
let typ = Self::convert_non_tuple_type(&cast.r#type);
self.builder.set_location(cast.location);
Ok(self.builder.insert_cast(lhs, typ).into())

self.insert_safe_cast(lhs, typ, cast.location)
}

/// Codegens a for loop, creating three new blocks in the process.
Expand All @@ -448,7 +448,7 @@
/// br loop_entry(v0)
/// loop_entry(i: Field):
/// v2 = lt i v1
/// brif v2, then: loop_body, else: loop_end

Check warning on line 451 in compiler/noirc_evaluator/src/ssa/ssa_gen/mod.rs

View workflow job for this annotation

GitHub Actions / Code

Unknown word (brif)
/// loop_body():
/// v3 = ... codegen body ...
/// v4 = add 1, i
Expand Down Expand Up @@ -502,7 +502,7 @@
/// For example, the expression `if cond { a } else { b }` is codegen'd as:
///
/// v0 = ... codegen cond ...
/// brif v0, then: then_block, else: else_block

Check warning on line 505 in compiler/noirc_evaluator/src/ssa/ssa_gen/mod.rs

View workflow job for this annotation

GitHub Actions / Code

Unknown word (brif)
/// then_block():
/// v1 = ... codegen a ...
/// br end_if(v1)
Expand All @@ -515,7 +515,7 @@
/// As another example, the expression `if cond { a }` is codegen'd as:
///
/// v0 = ... codegen cond ...
/// brif v0, then: then_block, else: end_block

Check warning on line 518 in compiler/noirc_evaluator/src/ssa/ssa_gen/mod.rs

View workflow job for this annotation

GitHub Actions / Code

Unknown word (brif)
/// then_block:
/// v1 = ... codegen a ...
/// br end_if()
Expand Down
Loading