Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(experimental): comptime globals #4918

Merged
merged 17 commits into from
Apr 29, 2024
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion compiler/noirc_driver/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,7 @@ pub fn check_crate(
let mut errors = vec![];
let diagnostics = CrateDefMap::collect_defs(crate_id, context, macros);
errors.extend(diagnostics.into_iter().map(|(error, file_id)| {
let diagnostic: CustomDiagnostic = error.into();
let diagnostic = CustomDiagnostic::from(&error);
diagnostic.in_file(file_id)
}));

Expand Down
4 changes: 2 additions & 2 deletions compiler/noirc_frontend/src/ast/statement.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ pub enum StatementKind {
Break,
Continue,
/// This statement should be executed at compile-time
Comptime(Box<StatementKind>),
Comptime(Box<Statement>),
// This is an expression with a trailing semi-colon
Semi(Expression),
// This statement is the result of a recovered parse error.
Expand Down Expand Up @@ -685,7 +685,7 @@ impl Display for StatementKind {
StatementKind::For(for_loop) => for_loop.fmt(f),
StatementKind::Break => write!(f, "break"),
StatementKind::Continue => write!(f, "continue"),
StatementKind::Comptime(statement) => write!(f, "comptime {statement}"),
StatementKind::Comptime(statement) => write!(f, "comptime {}", statement.kind),
StatementKind::Semi(semi) => write!(f, "{semi};"),
StatementKind::Error => write!(f, "Error"),
}
Expand Down
16 changes: 8 additions & 8 deletions compiler/noirc_frontend/src/hir/comptime/errors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use super::value::Value;
/// The possible errors that can halt the interpreter.
#[derive(Debug, Clone)]
pub enum InterpreterError {
ArgumentCountMismatch { expected: usize, actual: usize, call_location: Location },
ArgumentCountMismatch { expected: usize, actual: usize, location: Location },
TypeMismatch { expected: Type, value: Value, location: Location },
NonComptimeVarReferenced { name: String, location: Location },
IntegerOutOfRangeForType { value: FieldElement, typ: Type, location: Location },
Expand Down Expand Up @@ -60,7 +60,7 @@ impl InterpreterError {

pub fn get_location(&self) -> Location {
match self {
InterpreterError::ArgumentCountMismatch { call_location: location, .. }
InterpreterError::ArgumentCountMismatch { location, .. }
| InterpreterError::TypeMismatch { location, .. }
| InterpreterError::NonComptimeVarReferenced { location, .. }
| InterpreterError::IntegerOutOfRangeForType { location, .. }
Expand Down Expand Up @@ -98,20 +98,20 @@ impl InterpreterError {
}
}

impl From<InterpreterError> for CustomDiagnostic {
fn from(error: InterpreterError) -> Self {
impl<'a> From<&'a InterpreterError> for CustomDiagnostic {
fn from(error: &'a InterpreterError) -> Self {
match error {
InterpreterError::ArgumentCountMismatch { expected, actual, call_location } => {
InterpreterError::ArgumentCountMismatch { expected, actual, location } => {
let only = if expected > actual { "only " } else { "" };
let plural = if expected == 1 { "" } else { "s" };
let was_were = if actual == 1 { "was" } else { "were" };
let plural = if *expected == 1 { "" } else { "s" };
let was_were = if *actual == 1 { "was" } else { "were" };
let msg = format!(
"Expected {expected} argument{plural}, but {only}{actual} {was_were} provided"
);

let few_many = if actual < expected { "few" } else { "many" };
let secondary = format!("Too {few_many} arguments");
CustomDiagnostic::simple_error(msg, secondary, call_location.span)
CustomDiagnostic::simple_error(msg, secondary, location.span)
}
InterpreterError::TypeMismatch { expected, value, location } => {
let typ = value.get_type();
Expand Down
2 changes: 1 addition & 1 deletion compiler/noirc_frontend/src/hir/comptime/hir_to_ast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ impl StmtId {
HirStatement::Semi(expr) => StatementKind::Semi(expr.to_ast(interner)),
HirStatement::Error => StatementKind::Error,
HirStatement::Comptime(statement) => {
StatementKind::Comptime(Box::new(statement.to_ast(interner).kind))
StatementKind::Comptime(Box::new(statement.to_ast(interner)))
}
};

Expand Down
27 changes: 16 additions & 11 deletions compiler/noirc_frontend/src/hir/comptime/interpreter.rs
jfecher marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
Expand Up @@ -70,20 +70,21 @@ impl<'a> Interpreter<'a> {
&mut self,
function: FuncId,
arguments: Vec<(Value, Location)>,
call_location: Location,
location: Location,
) -> IResult<Value> {
let previous_state = self.enter_function();

let meta = self.interner.function_meta(&function);
if meta.kind != FunctionKind::Normal {
todo!("Evaluation for {:?} is unimplemented", meta.kind);
let item = "Evaluation for builtin functions";
return Err(InterpreterError::Unimplemented { item, location });
}

if meta.parameters.len() != arguments.len() {
return Err(InterpreterError::ArgumentCountMismatch {
expected: meta.parameters.len(),
actual: arguments.len(),
call_location,
location,
});
}

Expand Down Expand Up @@ -113,7 +114,7 @@ impl<'a> Interpreter<'a> {
return Err(InterpreterError::ArgumentCountMismatch {
expected: closure.parameters.len(),
actual: arguments.len(),
call_location,
location: call_location,
});
}

Expand All @@ -133,8 +134,11 @@ impl<'a> Interpreter<'a> {
/// `exit_function` is called.
pub(super) fn enter_function(&mut self) -> (bool, Vec<HashMap<DefinitionId, Value>>) {
// Drain every scope except the global scope
let scope = self.scopes.drain(1..).collect();
self.push_scope();
let mut scope = Vec::new();
if self.scopes.len() > 1 {
scope = self.scopes.drain(1..).collect();
self.push_scope();
}
(std::mem::take(&mut self.in_loop), scope)
}

Expand All @@ -159,7 +163,7 @@ impl<'a> Interpreter<'a> {
self.scopes.last_mut().unwrap()
}

fn define_pattern(
pub(super) fn define_pattern(
&mut self,
pattern: &HirPattern,
typ: &Type,
Expand Down Expand Up @@ -261,7 +265,7 @@ impl<'a> Interpreter<'a> {
Err(InterpreterError::NonComptimeVarReferenced { name, location })
}

fn lookup(&self, ident: &HirIdent) -> IResult<Value> {
pub(super) fn lookup(&self, ident: &HirIdent) -> IResult<Value> {
self.lookup_id(ident.id, ident.location)
}

Expand Down Expand Up @@ -290,7 +294,7 @@ impl<'a> Interpreter<'a> {
}

/// Evaluate an expression and return the result
fn evaluate(&mut self, id: ExprId) -> IResult<Value> {
pub(super) fn evaluate(&mut self, id: ExprId) -> IResult<Value> {
match self.interner.expression(&id) {
HirExpression::Ident(ident) => self.evaluate_ident(ident, id),
HirExpression::Literal(literal) => self.evaluate_literal(literal, id),
Expand Down Expand Up @@ -321,7 +325,7 @@ impl<'a> Interpreter<'a> {
}
}

fn evaluate_ident(&mut self, ident: HirIdent, id: ExprId) -> IResult<Value> {
pub(super) fn evaluate_ident(&mut self, ident: HirIdent, id: ExprId) -> IResult<Value> {
let definition = self.interner.definition(ident.id);

match &definition.kind {
Expand All @@ -331,6 +335,7 @@ impl<'a> Interpreter<'a> {
}
DefinitionKind::Local(_) => self.lookup(&ident),
DefinitionKind::Global(global_id) => {
// Don't need to check let_.comptime, we can evaluate non-comptime globals too.
let let_ = self.interner.get_global_let_statement(*global_id).unwrap();
self.evaluate_let(let_)?;
self.lookup(&ident)
Expand Down Expand Up @@ -1026,7 +1031,7 @@ impl<'a> Interpreter<'a> {
}
}

fn evaluate_let(&mut self, let_: HirLetStatement) -> IResult<Value> {
pub(super) fn evaluate_let(&mut self, let_: HirLetStatement) -> IResult<Value> {
let rhs = self.evaluate(let_.expression)?;
let location = self.interner.expr_location(&let_.expression);
self.define_pattern(&let_.pattern, &let_.r#type, rhs, location)?;
Expand Down
1 change: 1 addition & 0 deletions compiler/noirc_frontend/src/hir/comptime/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,4 @@ mod value;

pub use errors::InterpreterError;
pub use interpreter::Interpreter;
pub use value::Value;
50 changes: 47 additions & 3 deletions compiler/noirc_frontend/src/hir/comptime/scan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,18 +14,19 @@ use crate::{
hir_def::{
expr::{
HirArrayLiteral, HirBlockExpression, HirCallExpression, HirConstructorExpression,
HirIfExpression, HirIndexExpression, HirInfixExpression, HirLambda,
HirIdent, HirIfExpression, HirIndexExpression, HirInfixExpression, HirLambda,
HirMethodCallExpression,
},
stmt::HirForStatement,
},
macros_api::{HirExpression, HirLiteral, HirStatement},
node_interner::{ExprId, FuncId, StmtId},
node_interner::{DefinitionKind, ExprId, FuncId, GlobalId, StmtId},
};

use super::{
errors::{IResult, InterpreterError},
interpreter::Interpreter,
Value,
};

#[allow(dead_code)]
Expand All @@ -48,9 +49,23 @@ impl<'interner> Interpreter<'interner> {
Ok(())
}

/// Evaluate this global if it is a comptime global.
/// Otherwise, scan through its expression for any comptime blocks to evaluate.
pub fn scan_global(&mut self, global: GlobalId) -> IResult<()> {
if let Some(let_) = self.interner.get_global_let_statement(global) {
if let_.comptime {
self.evaluate_let(let_)?;
} else {
self.scan_expression(let_.expression)?;
}
}

Ok(())
}

fn scan_expression(&mut self, expr: ExprId) -> IResult<()> {
match self.interner.expression(&expr) {
HirExpression::Ident(_) => Ok(()),
HirExpression::Ident(ident) => self.scan_ident(ident, expr),
HirExpression::Literal(literal) => self.scan_literal(literal),
HirExpression::Block(block) => self.scan_block(block),
HirExpression::Prefix(prefix) => self.scan_expression(prefix.rhs),
Expand Down Expand Up @@ -91,6 +106,27 @@ impl<'interner> Interpreter<'interner> {
}
}

// Identifiers have no code to execute but we may need to inline any values
// of comptime variables into runtime code.
fn scan_ident(&mut self, ident: HirIdent, id: ExprId) -> IResult<()> {
let definition = self.interner.definition(ident.id);

match &definition.kind {
DefinitionKind::Function(_) => Ok(()),
jfecher marked this conversation as resolved.
Show resolved Hide resolved
_ => {
// Opportunistically evaluate this identifier to see if it is compile-time known.
// If so, inline its value.
if let Ok(value) = self.evaluate_ident(ident, id) {
// TODO(#4922): Inlining closures is currently unimplemented
if !matches!(value, Value::Closure(..)) {
self.inline_expression(value, id)?;
}
}
Ok(())
}
}
}

fn scan_literal(&mut self, literal: HirLiteral) -> IResult<()> {
match literal {
HirLiteral::Array(elements) | HirLiteral::Slice(elements) => match elements {
Expand Down Expand Up @@ -210,4 +246,12 @@ impl<'interner> Interpreter<'interner> {
self.pop_scope();
Ok(())
}

fn inline_expression(&mut self, value: Value, expr: ExprId) -> IResult<()> {
let location = self.interner.expr_location(&expr);
let new_expr = value.into_expression(self.interner, location)?;
let new_expr = self.interner.expression(&new_expr);
self.interner.replace_expr(&expr, new_expr);
Ok(())
}
}
2 changes: 1 addition & 1 deletion compiler/noirc_frontend/src/hir/comptime/value.rs
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ impl Value {
}
Value::Closure(_lambda, _env, _typ) => {
// TODO: How should a closure's environment be inlined?
let item = "returning closures from a comptime fn";
let item = "Returning closures from a comptime fn";
return Err(InterpreterError::Unimplemented { item, location });
}
Value::Tuple(fields) => {
Expand Down
45 changes: 35 additions & 10 deletions compiler/noirc_frontend/src/hir/def_collector/dc_crate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -158,8 +158,15 @@ pub enum CompilationError {
InterpreterError(InterpreterError),
}

impl From<CompilationError> for CustomDiagnostic {
fn from(value: CompilationError) -> Self {
impl CompilationError {
fn is_error(&self) -> bool {
let diagnostic = CustomDiagnostic::from(self);
diagnostic.is_error()
}
}

impl<'a> From<&'a CompilationError> for CustomDiagnostic {
fn from(value: &'a CompilationError) -> Self {
match value {
CompilationError::ParseError(error) => error.into(),
CompilationError::DefinitionError(error) => error.into(),
Expand Down Expand Up @@ -404,10 +411,15 @@ impl DefCollector {
);
}

resolved_module.errors.extend(context.def_interner.check_for_dependency_cycles());
let cycle_errors = context.def_interner.check_for_dependency_cycles();
let cycles_present = !cycle_errors.is_empty();
resolved_module.errors.extend(cycle_errors);

resolved_module.type_check(context);
resolved_module.evaluate_comptime(&mut context.def_interner);

if !cycles_present {
resolved_module.evaluate_comptime(&mut context.def_interner);
}

resolved_module.errors
}
Expand Down Expand Up @@ -503,13 +515,21 @@ impl ResolvedModule {

/// Evaluate all `comptime` expressions in this module
fn evaluate_comptime(&mut self, interner: &mut NodeInterner) {
let mut interpreter = Interpreter::new(interner);
if self.count_errors() == 0 {
let mut interpreter = Interpreter::new(interner);

for (_file, function) in &self.functions {
// The file returned by the error may be different than the file the
// function is in so only use the error's file id.
if let Err(error) = interpreter.scan_function(*function) {
self.errors.push(error.into_compilation_error_pair());
for (_file, global) in &self.globals {
if let Err(error) = interpreter.scan_global(*global) {
self.errors.push(error.into_compilation_error_pair());
}
}

for (_file, function) in &self.functions {
// The file returned by the error may be different than the file the
// function is in so only use the error's file id.
if let Err(error) = interpreter.scan_function(*function) {
self.errors.push(error.into_compilation_error_pair());
}
}
}
}
Expand All @@ -524,4 +544,9 @@ impl ResolvedModule {
self.globals.extend(globals.globals);
self.errors.extend(globals.errors);
}

/// Counts the number of errors (minus warnings) this program currently has
fn count_errors(&self) -> usize {
self.errors.iter().filter(|(error, _)| error.is_error()).count()
}
}
Loading
Loading