diff --git a/crates/postgresql-cst-parser/src/tree_sitter/assert_util.rs b/crates/postgresql-cst-parser/src/tree_sitter/assert_util.rs index dfbec1a..eac5fa0 100644 --- a/crates/postgresql-cst-parser/src/tree_sitter/assert_util.rs +++ b/crates/postgresql-cst-parser/src/tree_sitter/assert_util.rs @@ -45,6 +45,26 @@ pub fn assert_no_direct_nested_kind(root: &ResolvedNode, kind: SyntaxKind) { } } +/// Asserts that there is at least one directly nested node of the specified `SyntaxKind`. +/// In other words, there must be a node of `kind` that has another `kind` node as its immediate child. +pub fn assert_direct_nested_kind(root: &ResolvedNode, kind: SyntaxKind) { + let has_direct_nesting = root + .descendants() + .filter(|node| node.kind() == kind) + .any(|node| { + if let Some(parent) = node.parent() { + node.kind() == kind && parent.kind() == kind + } else { + false + } + }); + + assert!( + has_direct_nesting, + "Expected at least one directly nested {kind:?} node, but none was found." + ); +} + #[cfg(test)] mod tests { use super::*; @@ -118,4 +138,23 @@ mod tests { assert_no_direct_nested_kind(&root, SyntaxKind::target_list); } + + #[test] + fn test_direct_nested_kind_passes() { + let input = "select a,b,c;"; + let root = cst::parse(input).unwrap(); + + assert_direct_nested_kind(&root, SyntaxKind::target_list); + } + + #[test] + #[should_panic( + expected = "Expected at least one directly nested SelectStmt node, but none was found." + )] + fn test_direct_nested_kind_fails() { + let input = "select a;"; + let root = cst::parse(input).unwrap(); + + assert_direct_nested_kind(&root, SyntaxKind::SelectStmt); + } } diff --git a/crates/postgresql-cst-parser/src/tree_sitter/convert.rs b/crates/postgresql-cst-parser/src/tree_sitter/convert.rs index b57c2e6..624b5de 100644 --- a/crates/postgresql-cst-parser/src/tree_sitter/convert.rs +++ b/crates/postgresql-cst-parser/src/tree_sitter/convert.rs @@ -154,7 +154,8 @@ fn walk_and_build( | SyntaxKind::set_target_list | SyntaxKind::insert_column_list | SyntaxKind::index_params - | SyntaxKind::values_clause) => { + | SyntaxKind::values_clause + | SyntaxKind::TableFuncElementList) => { if parent_kind == child_kind { // [Node: Flatten] // @@ -325,7 +326,9 @@ FROM cst, syntax_kind::SyntaxKind, tree_sitter::{ - assert_util::{assert_no_direct_nested_kind, assert_node_count}, + assert_util::{ + assert_direct_nested_kind, assert_no_direct_nested_kind, assert_node_count, + }, convert::get_ts_tree_and_range_map, }, }; @@ -336,6 +339,7 @@ FROM let root = cst::parse(input).unwrap(); assert_node_count(&root, SyntaxKind::target_list, 3); + assert_direct_nested_kind(&root, SyntaxKind::target_list); let (new_root, _) = get_ts_tree_and_range_map(input, &root); assert_node_count(&new_root, SyntaxKind::target_list, 1); @@ -346,8 +350,9 @@ FROM fn no_nested_stmtmulti() { let input = "select a,b,c;\nselect d,e from t;"; let root = cst::parse(input).unwrap(); - let (new_root, _) = get_ts_tree_and_range_map(input, &root); + assert_direct_nested_kind(&root, SyntaxKind::stmtmulti); + let (new_root, _) = get_ts_tree_and_range_map(input, &root); assert_no_direct_nested_kind(&new_root, SyntaxKind::stmtmulti); } @@ -355,8 +360,9 @@ FROM fn no_nested_from_list() { let input = "select * from t1, t2;"; let root = cst::parse(input).unwrap(); - let (new_root, _) = get_ts_tree_and_range_map(input, &root); + assert_direct_nested_kind(&root, SyntaxKind::from_list); + let (new_root, _) = get_ts_tree_and_range_map(input, &root); assert_no_direct_nested_kind(&new_root, SyntaxKind::from_list); } @@ -365,8 +371,9 @@ FROM let input = "select t.a, t.b.c, t1.*, a[1], a[4][5], a[2:5], a[3].b, a[3][4].b, a[3:5].b;"; let root = cst::parse(input).unwrap(); - let (new_root, _) = get_ts_tree_and_range_map(input, &root); + assert_direct_nested_kind(&root, SyntaxKind::indirection); + let (new_root, _) = get_ts_tree_and_range_map(input, &root); assert_no_direct_nested_kind(&new_root, SyntaxKind::indirection); } @@ -374,8 +381,9 @@ FROM fn no_nested_expr_list() { let input = "select a from t where a in (1,2,3);"; let root = cst::parse(input).unwrap(); - let (new_root, _) = get_ts_tree_and_range_map(input, &root); + assert_direct_nested_kind(&root, SyntaxKind::expr_list); + let (new_root, _) = get_ts_tree_and_range_map(input, &root); assert_no_direct_nested_kind(&new_root, SyntaxKind::expr_list); } @@ -383,8 +391,9 @@ FROM fn no_nested_func_arg_list() { let input = "select func(1, 2, func2(3, 4), 5);"; let root = cst::parse(input).unwrap(); - let (new_root, _) = get_ts_tree_and_range_map(input, &root); + assert_direct_nested_kind(&root, SyntaxKind::func_arg_list); + let (new_root, _) = get_ts_tree_and_range_map(input, &root); assert_no_direct_nested_kind(&new_root, SyntaxKind::func_arg_list); } @@ -392,8 +401,9 @@ FROM fn no_nested_when_clause_list() { let input = "select case when a then b when c then d when e then f else g end;"; let root = cst::parse(input).unwrap(); - let (new_root, _) = get_ts_tree_and_range_map(input, &root); + assert_direct_nested_kind(&root, SyntaxKind::when_clause_list); + let (new_root, _) = get_ts_tree_and_range_map(input, &root); assert_no_direct_nested_kind(&new_root, SyntaxKind::when_clause_list); } @@ -401,8 +411,9 @@ FROM fn no_nested_sortby_list() { let input = "select * from t order by a, b, c;"; let root = cst::parse(input).unwrap(); - let (new_root, _) = get_ts_tree_and_range_map(input, &root); + assert_direct_nested_kind(&root, SyntaxKind::sortby_list); + let (new_root, _) = get_ts_tree_and_range_map(input, &root); assert_no_direct_nested_kind(&new_root, SyntaxKind::sortby_list); } @@ -410,8 +421,9 @@ FROM fn no_nested_groupby_list() { let input = "select a, b, c from t group by a, b, c;"; let root = cst::parse(input).unwrap(); - let (new_root, _) = get_ts_tree_and_range_map(input, &root); + assert_direct_nested_kind(&root, SyntaxKind::group_by_list); + let (new_root, _) = get_ts_tree_and_range_map(input, &root); assert_no_direct_nested_kind(&new_root, SyntaxKind::group_by_list); } @@ -419,8 +431,9 @@ FROM fn no_nested_for_locking_items() { let input = "select * from t1, t2 for update of t1 for update of t2;"; let root = cst::parse(input).unwrap(); - let (new_root, _) = get_ts_tree_and_range_map(input, &root); + assert_direct_nested_kind(&root, SyntaxKind::for_locking_items); + let (new_root, _) = get_ts_tree_and_range_map(input, &root); assert_no_direct_nested_kind(&new_root, SyntaxKind::for_locking_items); } @@ -428,8 +441,9 @@ FROM fn no_nested_qualified_name_list() { let input = "select a from t for update of t.a, t.b;"; let root = cst::parse(input).unwrap(); - let (new_root, _) = get_ts_tree_and_range_map(input, &root); + assert_direct_nested_kind(&root, SyntaxKind::qualified_name_list); + let (new_root, _) = get_ts_tree_and_range_map(input, &root); assert_no_direct_nested_kind(&new_root, SyntaxKind::qualified_name_list); } @@ -437,8 +451,9 @@ FROM fn no_nested_cte_list() { let input = "with a as (select 1), b as (select 2) select * from a, b;"; let root = cst::parse(input).unwrap(); - let (new_root, _) = get_ts_tree_and_range_map(input, &root); + assert_direct_nested_kind(&root, SyntaxKind::cte_list); + let (new_root, _) = get_ts_tree_and_range_map(input, &root); assert_no_direct_nested_kind(&new_root, SyntaxKind::cte_list); } @@ -446,8 +461,9 @@ FROM fn no_nested_name_list() { let input = "with t (a, b) as (select 1) select * from t;"; let root = cst::parse(input).unwrap(); - let (new_root, _) = get_ts_tree_and_range_map(input, &root); + assert_direct_nested_kind(&root, SyntaxKind::name_list); + let (new_root, _) = get_ts_tree_and_range_map(input, &root); assert_no_direct_nested_kind(&new_root, SyntaxKind::name_list); } @@ -455,8 +471,9 @@ FROM fn no_nested_set_clause_list() { let input = "update t set a = 1, b = 2, c = 3;"; let root = cst::parse(input).unwrap(); - let (new_root, _) = get_ts_tree_and_range_map(input, &root); + assert_direct_nested_kind(&root, SyntaxKind::set_clause_list); + let (new_root, _) = get_ts_tree_and_range_map(input, &root); assert_no_direct_nested_kind(&new_root, SyntaxKind::set_clause_list); } @@ -464,8 +481,9 @@ FROM fn no_nested_set_target_list() { let input = "update t set (a, b, c) = (1, 2, 3) where id = 1;"; let root = cst::parse(input).unwrap(); - let (new_root, _) = get_ts_tree_and_range_map(input, &root); + assert_direct_nested_kind(&root, SyntaxKind::set_target_list); + let (new_root, _) = get_ts_tree_and_range_map(input, &root); assert_no_direct_nested_kind(&new_root, SyntaxKind::set_target_list); } @@ -473,8 +491,9 @@ FROM fn no_nested_insert_column_list() { let input = "insert into t (a, b, c) values (1, 2, 3);"; let root = cst::parse(input).unwrap(); - let (new_root, _) = get_ts_tree_and_range_map(input, &root); + assert_direct_nested_kind(&root, SyntaxKind::insert_column_list); + let (new_root, _) = get_ts_tree_and_range_map(input, &root); assert_no_direct_nested_kind(&new_root, SyntaxKind::insert_column_list); } @@ -482,8 +501,9 @@ FROM fn no_nested_index_params() { let input = "insert into t (a, b, c) values (1, 2, 3) on conflict (a, b) do nothing;"; let root = cst::parse(input).unwrap(); - let (new_root, _) = get_ts_tree_and_range_map(input, &root); + assert_direct_nested_kind(&root, SyntaxKind::index_params); + let (new_root, _) = get_ts_tree_and_range_map(input, &root); assert_no_direct_nested_kind(&new_root, SyntaxKind::index_params); } @@ -491,9 +511,20 @@ FROM fn no_nested_values_clause() { let input = "values (1,2,3), (4,5,6), (7,8,9);"; let root = cst::parse(input).unwrap(); - let (new_root, _) = get_ts_tree_and_range_map(input, &root); + assert_direct_nested_kind(&root, SyntaxKind::values_clause); + let (new_root, _) = get_ts_tree_and_range_map(input, &root); assert_no_direct_nested_kind(&new_root, SyntaxKind::values_clause); } + + #[test] + fn no_nested_table_func_element_list() { + let input = "select * from unnest(a) as (x int, y text);"; + let root = cst::parse(input).unwrap(); + assert_direct_nested_kind(&root, SyntaxKind::TableFuncElementList); + + let (new_root, _) = get_ts_tree_and_range_map(input, &root); + assert_no_direct_nested_kind(&new_root, SyntaxKind::TableFuncElementList); + } } }