Skip to content

Commit

Permalink
Perform scope discovery on match cases too
Browse files Browse the repository at this point in the history
  • Loading branch information
kontheocharis committed Dec 20, 2022
1 parent 7550e13 commit 889a0b0
Show file tree
Hide file tree
Showing 2 changed files with 120 additions and 13 deletions.
43 changes: 34 additions & 9 deletions compiler/hash-typecheck/src/new/passes/scope_discovery.rs
Original file line number Diff line number Diff line change
Expand Up @@ -495,14 +495,13 @@ impl<'tc> ScopeDiscoveryPass<'tc> {
}
}

/// Add a declaration node `a := b` to the given `stack_id` (which is
/// Add a pattern node to the given `stack_id` (which is
/// "current").
///
/// This adds the declaration as a set of stack members, taking into account
/// all of the pattern bindings. It adds a set of tuples `(AstNodeId,
/// StackMemberData)`, one for each binding, where the `AstNodeId` is
/// the `AstNodeId` of the binding pattern node.
fn add_declaration_node_to_stack(&self, node: AstNodeRef<ast::Declaration>, stack_id: StackId) {
/// This adds the pattern binds as a set of stack members. It adds a set of
/// tuples `(AstNodeId, StackMemberData)`, one for each binding, where
/// the `AstNodeId` is the `AstNodeId` of the binding pattern node.
fn add_pat_node_binds_to_stack(&self, node: AstNodeRef<ast::Pat>, stack_id: StackId) {
self.stack_members.modify_fast(stack_id, |members| {
let members = match members {
Some(members) => members,
Expand All @@ -513,7 +512,7 @@ impl<'tc> ScopeDiscoveryPass<'tc> {

// Add each stack member to the stack_members vector
let mut found_members = smallvec![];
self.add_stack_members_in_pat_to_buf(node.pat.ast_ref(), &mut found_members);
self.add_stack_members_in_pat_to_buf(node, &mut found_members);
for (node_id, stack_member) in found_members {
members.push((node_id, stack_member));
}
Expand All @@ -533,7 +532,8 @@ impl<'tc> ast::AstVisitor for ScopeDiscoveryPass<'tc> {
FnDef,
TyFnDef,
BodyBlock,
MergeDeclaration
MergeDeclaration,
MatchCase
);

type DeclarationRet = ();
Expand Down Expand Up @@ -566,7 +566,7 @@ impl<'tc> ast::AstVisitor for ScopeDiscoveryPass<'tc> {
}
DefId::Stack(stack_id) => {
walk_with_name_hint()?;
self.add_declaration_node_to_stack(node, stack_id)
self.add_pat_node_binds_to_stack(node.pat.ast_ref(), stack_id)
}
DefId::Fn(_) => {
panic_on_span!(
Expand All @@ -580,6 +580,31 @@ impl<'tc> ast::AstVisitor for ScopeDiscoveryPass<'tc> {
Ok(())
}

type MatchCaseRet = ();
fn visit_match_case(
&self,
node: AstNodeRef<ast::MatchCase>,
) -> Result<Self::MatchCaseRet, Self::Error> {
match self.get_current_def() {
DefId::Stack(_) => {
// A match case creates its own stack scope.
let stack_id = self.stack_ops().create_stack();
self.enter_def(node, stack_id, || {
self.add_pat_node_binds_to_stack(node.pat.ast_ref(), stack_id);
walk::walk_match_case(self, node)
})?;
Ok(())
}
_ => {
panic_on_span!(
self.node_location(node),
self.source_map(),
"found match in non-stack scope"
)
}
}
}

type ModuleRet = ();
fn visit_module(
&self,
Expand Down
90 changes: 86 additions & 4 deletions compiler/hash-typecheck/src/new/passes/symbol_resolution.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,10 @@ use hash_ast::{
ast, ast_visitor_default_impl,
visitor::{walk, AstVisitor},
};
use hash_types::new::environment::{context::ScopeKind, env::AccessToEnv};
use hash_types::new::{
environment::{context::ScopeKind, env::AccessToEnv},
scopes::StackMemberId,
};

use super::ast_pass::AstPass;
use crate::{
Expand Down Expand Up @@ -48,6 +51,66 @@ impl<'tc> SymbolResolutionPass<'tc> {
}
}

impl<'tc> SymbolResolutionPass<'tc> {
/// Run a function for each stack member in the given pattern.
///
/// The stack members are found in the `AstInfo` store, specifically the
/// `stack_members` map. They are looked up using the IDs of the pattern
/// binds, as added by the `add_stack_members_in_pat_to_buf` method of the
/// `ScopeDiscoveryPass`.
fn for_each_stack_member_of_pat(
&self,
node: ast::AstNodeRef<ast::Pat>,
f: impl Fn(StackMemberId) + Copy,
) {
match node.body() {
ast::Pat::Binding(_) => {
if let Some(member_id) = self.ast_info().stack_members().get_data_by_node(node.id())
{
f(member_id);
}
}
ast::Pat::Tuple(tuple_pat) => {
for entry in tuple_pat.fields.ast_ref_iter() {
self.for_each_stack_member_of_pat(entry.pat.ast_ref(), f);
}
}
ast::Pat::Constructor(constructor_pat) => {
for field in constructor_pat.fields.ast_ref_iter() {
self.for_each_stack_member_of_pat(field.pat.ast_ref(), f);
}
}
ast::Pat::List(list_pat) => {
for pat in list_pat.fields.ast_ref_iter() {
self.for_each_stack_member_of_pat(pat, f);
}
}
ast::Pat::Or(or_pat) => {
if let Some(pat) = or_pat.variants.get(0) {
self.for_each_stack_member_of_pat(pat.ast_ref(), f)
}
}
ast::Pat::Spread(spread_pat) => {
if let Some(name) = spread_pat.name.as_ref() {
if let Some(member_id) =
self.ast_info().stack_members().get_data_by_node(name.ast_ref().id())
{
f(member_id);
}
}
}
ast::Pat::If(if_pat) => self.for_each_stack_member_of_pat(if_pat.pat.ast_ref(), f),
ast::Pat::Wild(_) => {
if let Some(member_id) = self.ast_info().stack_members().get_data_by_node(node.id())
{
f(member_id);
}
}
ast::Pat::Module(_) | ast::Pat::Access(_) | ast::Pat::Lit(_) | ast::Pat::Range(_) => {}
}
}
}

/// @@Temporary: for now this visitor just walks the AST and enters scopes. The
/// next step is to resolve symbols in these scopes!.
impl ast::AstVisitor for SymbolResolutionPass<'_> {
Expand All @@ -61,6 +124,7 @@ impl ast::AstVisitor for SymbolResolutionPass<'_> {
FnDef,
TyFnDef,
BodyBlock,
MatchCase,
);

type ModuleRet = ();
Expand Down Expand Up @@ -157,12 +221,30 @@ impl ast::AstVisitor for SymbolResolutionPass<'_> {
node: ast::AstNodeRef<ast::Declaration>,
) -> Result<Self::DeclarationRet, Self::Error> {
// If we are in a stack, then we need to add the declaration to the
// stack's scope.
// stack's scope. Otherwise the declaration is handled higher up.
if let ScopeKind::Stack(_) = self.context().get_scope_kind() {
let member = self.ast_info().stack_members().get_data_by_node(node.pat.id()).unwrap();
self.context_ops().add_stack_binding(member);
self.for_each_stack_member_of_pat(node.pat.ast_ref(), |member| {
self.context_ops().add_stack_binding(member);
});
}
walk::walk_declaration(self, node)?;
Ok(())
}

type MatchCaseRet = ();
fn visit_match_case(
&self,
node: ast::AstNodeRef<ast::MatchCase>,
) -> Result<Self::MatchCaseRet, Self::Error> {
let stack_id = self.ast_info().stacks().get_data_by_node(node.id()).unwrap();
// Each match case has its own scope, so we need to enter it, and add all the
// pattern bindings to the context.
self.context_ops().enter_scope(ScopeKind::Stack(stack_id), || {
self.for_each_stack_member_of_pat(node.pat.ast_ref(), |member| {
self.context_ops().add_stack_binding(member);
});
walk::walk_match_case(self, node)?;
Ok(())
})
}
}

0 comments on commit 889a0b0

Please sign in to comment.