Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(experimental): comptime globals #4918

Merged
merged 17 commits into from
Apr 29, 2024
Merged
Show file tree
Hide file tree
Changes from 16 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions compiler/noirc_frontend/src/ast/statement.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ pub enum StatementKind {
Break,
Continue,
/// This statement should be executed at compile-time
Comptime(Box<StatementKind>),
Comptime(Box<Statement>),
// This is an expression with a trailing semi-colon
Semi(Expression),
// This statement is the result of a recovered parse error.
Expand Down Expand Up @@ -685,7 +685,7 @@ impl Display for StatementKind {
StatementKind::For(for_loop) => for_loop.fmt(f),
StatementKind::Break => write!(f, "break"),
StatementKind::Continue => write!(f, "continue"),
StatementKind::Comptime(statement) => write!(f, "comptime {statement}"),
StatementKind::Comptime(statement) => write!(f, "comptime {}", statement.kind),
StatementKind::Semi(semi) => write!(f, "{semi};"),
StatementKind::Error => write!(f, "Error"),
}
Expand Down
2 changes: 1 addition & 1 deletion compiler/noirc_frontend/src/hir/comptime/hir_to_ast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ impl StmtId {
HirStatement::Semi(expr) => StatementKind::Semi(expr.to_ast(interner)),
HirStatement::Error => StatementKind::Error,
HirStatement::Comptime(statement) => {
StatementKind::Comptime(Box::new(statement.to_ast(interner).kind))
StatementKind::Comptime(Box::new(statement.to_ast(interner)))
}
};

Expand Down
18 changes: 11 additions & 7 deletions compiler/noirc_frontend/src/hir/comptime/interpreter.rs
jfecher marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
Expand Up @@ -134,8 +134,11 @@ impl<'a> Interpreter<'a> {
/// `exit_function` is called.
pub(super) fn enter_function(&mut self) -> (bool, Vec<HashMap<DefinitionId, Value>>) {
// Drain every scope except the global scope
let scope = self.scopes.drain(1..).collect();
self.push_scope();
let mut scope = Vec::new();
if self.scopes.len() > 1 {
scope = self.scopes.drain(1..).collect();
self.push_scope();
}
(std::mem::take(&mut self.in_loop), scope)
}

Expand All @@ -160,7 +163,7 @@ impl<'a> Interpreter<'a> {
self.scopes.last_mut().unwrap()
}

fn define_pattern(
pub(super) fn define_pattern(
&mut self,
pattern: &HirPattern,
typ: &Type,
Expand Down Expand Up @@ -262,7 +265,7 @@ impl<'a> Interpreter<'a> {
Err(InterpreterError::NonComptimeVarReferenced { name, location })
}

fn lookup(&self, ident: &HirIdent) -> IResult<Value> {
pub(super) fn lookup(&self, ident: &HirIdent) -> IResult<Value> {
self.lookup_id(ident.id, ident.location)
}

Expand Down Expand Up @@ -291,7 +294,7 @@ impl<'a> Interpreter<'a> {
}

/// Evaluate an expression and return the result
fn evaluate(&mut self, id: ExprId) -> IResult<Value> {
pub(super) fn evaluate(&mut self, id: ExprId) -> IResult<Value> {
match self.interner.expression(&id) {
HirExpression::Ident(ident) => self.evaluate_ident(ident, id),
HirExpression::Literal(literal) => self.evaluate_literal(literal, id),
Expand Down Expand Up @@ -322,7 +325,7 @@ impl<'a> Interpreter<'a> {
}
}

fn evaluate_ident(&mut self, ident: HirIdent, id: ExprId) -> IResult<Value> {
pub(super) fn evaluate_ident(&mut self, ident: HirIdent, id: ExprId) -> IResult<Value> {
let definition = self.interner.definition(ident.id);

match &definition.kind {
Expand All @@ -332,6 +335,7 @@ impl<'a> Interpreter<'a> {
}
DefinitionKind::Local(_) => self.lookup(&ident),
DefinitionKind::Global(global_id) => {
// Don't need to check let_.comptime, we can evaluate non-comptime globals too.
let let_ = self.interner.get_global_let_statement(*global_id).unwrap();
self.evaluate_let(let_)?;
self.lookup(&ident)
Expand Down Expand Up @@ -1027,7 +1031,7 @@ impl<'a> Interpreter<'a> {
}
}

fn evaluate_let(&mut self, let_: HirLetStatement) -> IResult<Value> {
pub(super) fn evaluate_let(&mut self, let_: HirLetStatement) -> IResult<Value> {
let rhs = self.evaluate(let_.expression)?;
let location = self.interner.expr_location(&let_.expression);
self.define_pattern(&let_.pattern, &let_.r#type, rhs, location)?;
Expand Down
1 change: 1 addition & 0 deletions compiler/noirc_frontend/src/hir/comptime/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,4 @@ mod value;

pub use errors::InterpreterError;
pub use interpreter::Interpreter;
pub use value::Value;
50 changes: 47 additions & 3 deletions compiler/noirc_frontend/src/hir/comptime/scan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,18 +14,19 @@ use crate::{
hir_def::{
expr::{
HirArrayLiteral, HirBlockExpression, HirCallExpression, HirConstructorExpression,
HirIfExpression, HirIndexExpression, HirInfixExpression, HirLambda,
HirIdent, HirIfExpression, HirIndexExpression, HirInfixExpression, HirLambda,
HirMethodCallExpression,
},
stmt::HirForStatement,
},
macros_api::{HirExpression, HirLiteral, HirStatement},
node_interner::{ExprId, FuncId, StmtId},
node_interner::{DefinitionKind, ExprId, FuncId, GlobalId, StmtId},
};

use super::{
errors::{IResult, InterpreterError},
interpreter::Interpreter,
Value,
};

#[allow(dead_code)]
Expand All @@ -48,9 +49,23 @@ impl<'interner> Interpreter<'interner> {
Ok(())
}

/// Evaluate this global if it is a comptime global.
/// Otherwise, scan through its expression for any comptime blocks to evaluate.
pub fn scan_global(&mut self, global: GlobalId) -> IResult<()> {
if let Some(let_) = self.interner.get_global_let_statement(global) {
if let_.comptime {
self.evaluate_let(let_)?;
} else {
self.scan_expression(let_.expression)?;
}
}

Ok(())
}

fn scan_expression(&mut self, expr: ExprId) -> IResult<()> {
match self.interner.expression(&expr) {
HirExpression::Ident(_) => Ok(()),
HirExpression::Ident(ident) => self.scan_ident(ident, expr),
HirExpression::Literal(literal) => self.scan_literal(literal),
HirExpression::Block(block) => self.scan_block(block),
HirExpression::Prefix(prefix) => self.scan_expression(prefix.rhs),
Expand Down Expand Up @@ -91,6 +106,27 @@ impl<'interner> Interpreter<'interner> {
}
}

// Identifiers have no code to execute but we may need to inline any values
// of comptime variables into runtime code.
fn scan_ident(&mut self, ident: HirIdent, id: ExprId) -> IResult<()> {
let definition = self.interner.definition(ident.id);

match &definition.kind {
DefinitionKind::Function(_) => Ok(()),
jfecher marked this conversation as resolved.
Show resolved Hide resolved
_ => {
// Opportunistically evaluate this identifier to see if it is compile-time known.
// If so, inline its value.
if let Ok(value) = self.evaluate_ident(ident, id) {
// TODO(#4922): Inlining closures is currently unimplemented
if !matches!(value, Value::Closure(..)) {
self.inline_expression(value, id)?;
}
}
Ok(())
}
}
}

fn scan_literal(&mut self, literal: HirLiteral) -> IResult<()> {
match literal {
HirLiteral::Array(elements) | HirLiteral::Slice(elements) => match elements {
Expand Down Expand Up @@ -210,4 +246,12 @@ impl<'interner> Interpreter<'interner> {
self.pop_scope();
Ok(())
}

fn inline_expression(&mut self, value: Value, expr: ExprId) -> IResult<()> {
let location = self.interner.expr_location(&expr);
let new_expr = value.into_expression(self.interner, location)?;
let new_expr = self.interner.expression(&new_expr);
self.interner.replace_expr(&expr, new_expr);
Ok(())
}
}
41 changes: 33 additions & 8 deletions compiler/noirc_frontend/src/hir/def_collector/dc_crate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,13 @@ pub enum CompilationError {
InterpreterError(InterpreterError),
}

impl CompilationError {
fn is_error(&self) -> bool {
let diagnostic = CustomDiagnostic::from(self);
diagnostic.is_error()
}
}

impl<'a> From<&'a CompilationError> for CustomDiagnostic {
fn from(value: &'a CompilationError) -> Self {
match value {
Expand Down Expand Up @@ -404,10 +411,15 @@ impl DefCollector {
);
}

resolved_module.errors.extend(context.def_interner.check_for_dependency_cycles());
let cycle_errors = context.def_interner.check_for_dependency_cycles();
let cycles_present = !cycle_errors.is_empty();
resolved_module.errors.extend(cycle_errors);

resolved_module.type_check(context);
resolved_module.evaluate_comptime(&mut context.def_interner);

if !cycles_present {
resolved_module.evaluate_comptime(&mut context.def_interner);
}

resolved_module.errors
}
Expand Down Expand Up @@ -503,13 +515,21 @@ impl ResolvedModule {

/// Evaluate all `comptime` expressions in this module
fn evaluate_comptime(&mut self, interner: &mut NodeInterner) {
let mut interpreter = Interpreter::new(interner);
if self.count_errors() == 0 {
let mut interpreter = Interpreter::new(interner);

for (_file, function) in &self.functions {
// The file returned by the error may be different than the file the
// function is in so only use the error's file id.
if let Err(error) = interpreter.scan_function(*function) {
self.errors.push(error.into_compilation_error_pair());
for (_file, global) in &self.globals {
if let Err(error) = interpreter.scan_global(*global) {
self.errors.push(error.into_compilation_error_pair());
}
}

for (_file, function) in &self.functions {
// The file returned by the error may be different than the file the
// function is in so only use the error's file id.
if let Err(error) = interpreter.scan_function(*function) {
self.errors.push(error.into_compilation_error_pair());
}
}
}
}
Expand All @@ -524,4 +544,9 @@ impl ResolvedModule {
self.globals.extend(globals.globals);
self.errors.extend(globals.errors);
}

/// Counts the number of errors (minus warnings) this program currently has
fn count_errors(&self) -> usize {
self.errors.iter().filter(|(error, _)| error.is_error()).count()
}
}
8 changes: 6 additions & 2 deletions compiler/noirc_frontend/src/hir/resolution/resolver.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1238,6 +1238,7 @@ impl<'a> Resolver<'a> {
r#type: self.resolve_type(let_stmt.r#type),
expression,
attributes: let_stmt.attributes,
comptime: let_stmt.comptime,
})
}

Expand All @@ -1251,6 +1252,7 @@ impl<'a> Resolver<'a> {
r#type: self.resolve_type(let_stmt.r#type),
expression,
attributes: let_stmt.attributes,
comptime: let_stmt.comptime,
})
}
StatementKind::Constrain(constrain_stmt) => {
Expand Down Expand Up @@ -1322,8 +1324,10 @@ impl<'a> Resolver<'a> {
}
StatementKind::Error => HirStatement::Error,
StatementKind::Comptime(statement) => {
let statement = self.resolve_stmt(*statement, span);
HirStatement::Comptime(self.interner.push_stmt(statement))
let hir_statement = self.resolve_stmt(statement.kind, statement.span);
let statement_id = self.interner.push_stmt(hir_statement);
self.interner.push_statement_location(statement_id, statement.span, self.file);
HirStatement::Comptime(statement_id)
}
}
}
Expand Down
1 change: 1 addition & 0 deletions compiler/noirc_frontend/src/hir/type_check/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -508,6 +508,7 @@ pub mod test {
r#type: Type::FieldElement,
expression: expr_id,
attributes: vec![],
comptime: false,
};
let stmt_id = interner.push_stmt(HirStatement::Let(let_stmt));
let expr_id = interner
Expand Down
1 change: 1 addition & 0 deletions compiler/noirc_frontend/src/hir_def/stmt.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ pub struct HirLetStatement {
pub r#type: Type,
pub expression: ExprId,
pub attributes: Vec<SecondaryAttribute>,
pub comptime: bool,
}

impl HirLetStatement {
Expand Down
19 changes: 10 additions & 9 deletions compiler/noirc_frontend/src/parser/parser.rs
Original file line number Diff line number Diff line change
Expand Up @@ -532,15 +532,16 @@ where
P2: ExprParser + 'a,
S: NoirParser<StatementKind> + 'a,
{
keyword(Keyword::Comptime)
.ignore_then(choice((
declaration(expr),
for_loop(expr_no_constructors, statement.clone()),
block(statement).map_with_span(|block, span| {
StatementKind::Expression(Expression::new(ExpressionKind::Block(block), span))
}),
)))
.map(|statement| StatementKind::Comptime(Box::new(statement)))
let comptime_statement = choice((
declaration(expr),
for_loop(expr_no_constructors, statement.clone()),
block(statement).map_with_span(|block, span| {
StatementKind::Expression(Expression::new(ExpressionKind::Block(block), span))
}),
))
.map_with_span(|kind, span| Box::new(Statement { kind, span }));

keyword(Keyword::Comptime).ignore_then(comptime_statement).map(StatementKind::Comptime)
}

/// Comptime in an expression position only accepts entire blocks
Expand Down
7 changes: 7 additions & 0 deletions test_programs/noir_test_success/comptime_globals/Nargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
[package]
name = "comptime_globals"
type = "bin"
authors = [""]
compiler_version = ">=0.27.0"

[dependencies]
21 changes: 21 additions & 0 deletions test_programs/noir_test_success/comptime_globals/src/main.nr
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
// Normal globals can be evaluated in a comptime context too,
// but comptime globals can only be evaluated in a comptime context.
comptime global FOO: Field = foo();

// Due to this function's mutability and branching, SSA currently fails
// to fold this function into a constant before the assert_constant check
// is evaluated before loop unrolling.
fn foo() -> Field {
let mut three = 3;
if three == 3 { 5 } else { 6 }
}

#[test]
fn foo_global_constant() {
assert_constant(FOO);
}

#[test(should_fail)]
fn foo_function_not_constant() {
assert_constant(foo());
}
2 changes: 1 addition & 1 deletion tooling/nargo_fmt/src/visitor/stmt.rs
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ impl super::FmtVisitor<'_> {
StatementKind::Error => unreachable!(),
StatementKind::Break => self.push_rewrite("break;".into(), span),
StatementKind::Continue => self.push_rewrite("continue;".into(), span),
StatementKind::Comptime(statement) => self.visit_stmt(*statement, span, is_last),
StatementKind::Comptime(statement) => self.visit_stmt(statement.kind, span, is_last),
}
}
}
Loading