Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix: Perform scope discovery on match cases too #639

Merged
merged 1 commit into from
Dec 22, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 34 additions & 9 deletions compiler/hash-typecheck/src/new/passes/scope_discovery.rs
Original file line number Diff line number Diff line change
Expand Up @@ -510,14 +510,13 @@ impl<'tc> ScopeDiscoveryPass<'tc> {
}
}

/// Add a declaration node `a := b` to the given `stack_id` (which is
/// Add a pattern node to the given `stack_id` (which is
/// "current").
///
/// This adds the declaration as a set of stack members, taking into account
/// all of the pattern bindings. It adds a set of tuples `(AstNodeId,
/// StackMemberData)`, one for each binding, where the `AstNodeId` is
/// the `AstNodeId` of the binding pattern node.
fn add_declaration_node_to_stack(&self, node: AstNodeRef<ast::Declaration>, stack_id: StackId) {
/// This adds the pattern binds as a set of stack members. It adds a set of
/// tuples `(AstNodeId, StackMemberData)`, one for each binding, where
/// the `AstNodeId` is the `AstNodeId` of the binding pattern node.
fn add_pat_node_binds_to_stack(&self, node: AstNodeRef<ast::Pat>, stack_id: StackId) {
self.stack_members.modify_fast(stack_id, |members| {
let members = match members {
Some(members) => members,
Expand All @@ -528,7 +527,7 @@ impl<'tc> ScopeDiscoveryPass<'tc> {

// Add each stack member to the stack_members vector
let mut found_members = smallvec![];
self.add_stack_members_in_pat_to_buf(node.pat.ast_ref(), &mut found_members);
self.add_stack_members_in_pat_to_buf(node, &mut found_members);
for (node_id, stack_member) in found_members {
members.push((node_id, stack_member));
}
Expand All @@ -548,7 +547,8 @@ impl<'tc> ast::AstVisitor for ScopeDiscoveryPass<'tc> {
FnDef,
TyFnDef,
BodyBlock,
MergeDeclaration
MergeDeclaration,
MatchCase
);

type DeclarationRet = ();
Expand Down Expand Up @@ -581,7 +581,7 @@ impl<'tc> ast::AstVisitor for ScopeDiscoveryPass<'tc> {
}
DefId::Stack(stack_id) => {
walk_with_name_hint()?;
self.add_declaration_node_to_stack(node, stack_id)
self.add_pat_node_binds_to_stack(node.pat.ast_ref(), stack_id)
}
DefId::Fn(_) => {
panic_on_span!(
Expand All @@ -595,6 +595,31 @@ impl<'tc> ast::AstVisitor for ScopeDiscoveryPass<'tc> {
Ok(())
}

type MatchCaseRet = ();
fn visit_match_case(
&self,
node: AstNodeRef<ast::MatchCase>,
) -> Result<Self::MatchCaseRet, Self::Error> {
match self.get_current_def() {
DefId::Stack(_) => {
// A match case creates its own stack scope.
let stack_id = self.stack_ops().create_stack();
self.enter_def(node, stack_id, || {
self.add_pat_node_binds_to_stack(node.pat.ast_ref(), stack_id);
walk::walk_match_case(self, node)
})?;
Ok(())
}
_ => {
panic_on_span!(
self.node_location(node),
self.source_map(),
"found match in non-stack scope"
)
}
}
}

type ModuleRet = ();
fn visit_module(
&self,
Expand Down
99 changes: 95 additions & 4 deletions compiler/hash-typecheck/src/new/passes/symbol_resolution.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,10 @@ use hash_ast::{
ast, ast_visitor_default_impl,
visitor::{walk, AstVisitor},
};
use hash_types::new::environment::{context::ScopeKind, env::AccessToEnv};
use hash_types::new::{
environment::{context::ScopeKind, env::AccessToEnv},
scopes::StackMemberId,
};

use super::ast_pass::AstPass;
use crate::{
Expand Down Expand Up @@ -48,6 +51,75 @@ impl<'tc> SymbolResolutionPass<'tc> {
}
}

impl<'tc> SymbolResolutionPass<'tc> {
/// Run a function for each stack member in the given pattern.
///
/// The stack members are found in the `AstInfo` store, specifically the
/// `stack_members` map. They are looked up using the IDs of the pattern
/// binds, as added by the `add_stack_members_in_pat_to_buf` method of the
/// `ScopeDiscoveryPass`.
fn for_each_stack_member_of_pat(
&self,
node: ast::AstNodeRef<ast::Pat>,
f: impl Fn(StackMemberId) + Copy,
) {
let for_spread_pat = |spread: &ast::AstNode<ast::SpreadPat>| {
if let Some(name) = &spread.name {
if let Some(member_id) =
self.ast_info().stack_members().get_data_by_node(name.ast_ref().id())
{
Comment on lines +67 to +70
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could use a let_chain here:

Suggested change
if let Some(name) = &spread.name {
if let Some(member_id) =
self.ast_info().stack_members().get_data_by_node(name.ast_ref().id())
{
if let Some(name) = &spread.name && let Some(member_id) =
self.ast_info().stack_members().get_data_by_node(name.ast_ref().id())
{

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't like let chains cause they are not yet supported by rustfmt rust-lang/rustfmt#5203. So I'll just keep it like this for now..

f(member_id);
}
}
};
match node.body() {
ast::Pat::Binding(_) => {
if let Some(member_id) = self.ast_info().stack_members().get_data_by_node(node.id())
{
f(member_id);
}
}
ast::Pat::Tuple(tuple_pat) => {
for (index, entry) in tuple_pat.fields.ast_ref_iter().enumerate() {
if let Some(spread_node) = &tuple_pat.spread && spread_node.position == index {
for_spread_pat(spread_node);
}
self.for_each_stack_member_of_pat(entry.pat.ast_ref(), f);
}
}
ast::Pat::Constructor(constructor_pat) => {
for (index, field) in constructor_pat.fields.ast_ref_iter().enumerate() {
if let Some(spread_node) = &constructor_pat.spread && spread_node.position == index {
for_spread_pat(spread_node);
}
self.for_each_stack_member_of_pat(field.pat.ast_ref(), f);
}
}
ast::Pat::List(list_pat) => {
for (index, pat) in list_pat.fields.ast_ref_iter().enumerate() {
if let Some(spread_node) = &list_pat.spread && spread_node.position == index {
for_spread_pat(spread_node);
}
self.for_each_stack_member_of_pat(pat, f);
}
}
ast::Pat::Or(or_pat) => {
if let Some(pat) = or_pat.variants.get(0) {
self.for_each_stack_member_of_pat(pat.ast_ref(), f)
}
}
ast::Pat::If(if_pat) => self.for_each_stack_member_of_pat(if_pat.pat.ast_ref(), f),
ast::Pat::Wild(_) => {
if let Some(member_id) = self.ast_info().stack_members().get_data_by_node(node.id())
{
f(member_id);
}
}
ast::Pat::Module(_) | ast::Pat::Access(_) | ast::Pat::Lit(_) | ast::Pat::Range(_) => {}
}
}
}

/// @@Temporary: for now this visitor just walks the AST and enters scopes. The
/// next step is to resolve symbols in these scopes!.
impl ast::AstVisitor for SymbolResolutionPass<'_> {
Expand All @@ -61,6 +133,7 @@ impl ast::AstVisitor for SymbolResolutionPass<'_> {
FnDef,
TyFnDef,
BodyBlock,
MatchCase,
);

type ModuleRet = ();
Expand Down Expand Up @@ -157,12 +230,30 @@ impl ast::AstVisitor for SymbolResolutionPass<'_> {
node: ast::AstNodeRef<ast::Declaration>,
) -> Result<Self::DeclarationRet, Self::Error> {
// If we are in a stack, then we need to add the declaration to the
// stack's scope.
// stack's scope. Otherwise the declaration is handled higher up.
if let ScopeKind::Stack(_) = self.context().get_scope_kind() {
let member = self.ast_info().stack_members().get_data_by_node(node.pat.id()).unwrap();
self.context_ops().add_stack_binding(member);
self.for_each_stack_member_of_pat(node.pat.ast_ref(), |member| {
self.context_ops().add_stack_binding(member);
});
}
walk::walk_declaration(self, node)?;
Ok(())
}

type MatchCaseRet = ();
fn visit_match_case(
&self,
node: ast::AstNodeRef<ast::MatchCase>,
) -> Result<Self::MatchCaseRet, Self::Error> {
let stack_id = self.ast_info().stacks().get_data_by_node(node.id()).unwrap();
// Each match case has its own scope, so we need to enter it, and add all the
// pattern bindings to the context.
self.context_ops().enter_scope(ScopeKind::Stack(stack_id), || {
self.for_each_stack_member_of_pat(node.pat.ast_ref(), |member| {
self.context_ops().add_stack_binding(member);
});
walk::walk_match_case(self, node)?;
Ok(())
})
}
}